From e0c090f227e9b64e595b47d4d1f96f8a2fff5bf7 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 2 Jan 2018 09:19:18 +0800 Subject: [PATCH 001/774] [SPARK-22932][SQL] Refactor AnalysisContext ## What changes were proposed in this pull request? Add a `reset` function to ensure the state in `AnalysisContext ` is per-query. ## How was this patch tested? The existing test cases Author: gatorsmile Closes #20127 from gatorsmile/refactorAnalysisContext. --- .../sql/catalyst/analysis/Analyzer.scala | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6d294d48c0ee7..35b35110e491f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -52,6 +52,7 @@ object SimpleAnalyzer extends Analyzer( /** * Provides a way to keep state during the analysis, this enables us to decouple the concerns * of analysis environment from the catalog. + * The state that is kept here is per-query. * * Note this is thread local. * @@ -70,6 +71,8 @@ object AnalysisContext { } def get: AnalysisContext = value.get() + def reset(): Unit = value.remove() + private def set(context: AnalysisContext): Unit = value.set(context) def withAnalysisContext[A](database: Option[String])(f: => A): A = { @@ -95,6 +98,17 @@ class Analyzer( this(catalog, conf, conf.optimizerMaxIterations) } + override def execute(plan: LogicalPlan): LogicalPlan = { + AnalysisContext.reset() + try { + executeSameContext(plan) + } finally { + AnalysisContext.reset() + } + } + + private def executeSameContext(plan: LogicalPlan): LogicalPlan = super.execute(plan) + def resolver: Resolver = conf.resolver protected val fixedPoint = FixedPoint(maxIterations) @@ -176,7 +190,7 @@ class Analyzer( case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => - resolved :+ name -> execute(substituteCTE(relation, resolved)) + resolved :+ name -> executeSameContext(substituteCTE(relation, resolved)) }) case other => other } @@ -600,7 +614,7 @@ class Analyzer( "avoid errors. Increase the value of spark.sql.view.maxNestedViewDepth to work " + "aroud this.") } - execute(child) + executeSameContext(child) } view.copy(child = newChild) case p @ SubqueryAlias(_, view: View) => @@ -1269,7 +1283,7 @@ class Analyzer( do { // Try to resolve the subquery plan using the regular analyzer. previous = current - current = execute(current) + current = executeSameContext(current) // Use the outer references to resolve the subquery plan if it isn't resolved yet. val i = plans.iterator @@ -1392,7 +1406,7 @@ class Analyzer( grouping, Alias(cond, "havingCondition")() :: Nil, child) - val resolvedOperator = execute(aggregatedCondition) + val resolvedOperator = executeSameContext(aggregatedCondition) def resolvedAggregateFilter = resolvedOperator .asInstanceOf[Aggregate] @@ -1450,7 +1464,8 @@ class Analyzer( val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) - val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] + val resolvedAggregate: Aggregate = + executeSameContext(aggregatedOrdering).asInstanceOf[Aggregate] val resolvedAliasedOrdering: Seq[Alias] = resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]] From a6fc300e91273230e7134ac6db95ccb4436c6f8f Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Tue, 2 Jan 2018 23:30:38 +0800 Subject: [PATCH 002/774] [SPARK-22897][CORE] Expose stageAttemptId in TaskContext ## What changes were proposed in this pull request? stageAttemptId added in TaskContext and corresponding construction modification ## How was this patch tested? Added a new test in TaskContextSuite, two cases are tested: 1. Normal case without failure 2. Exception case with resubmitted stages Link to [SPARK-22897](https://issues.apache.org/jira/browse/SPARK-22897) Author: Xianjin YE Closes #20082 from advancedxy/SPARK-22897. --- .../scala/org/apache/spark/TaskContext.scala | 9 +++++- .../org/apache/spark/TaskContextImpl.scala | 5 ++-- .../org/apache/spark/scheduler/Task.scala | 1 + .../spark/JavaTaskContextCompileCheck.java | 2 ++ .../scala/org/apache/spark/ShuffleSuite.scala | 6 ++-- .../spark/memory/MemoryTestingUtils.scala | 1 + .../spark/scheduler/TaskContextSuite.scala | 29 +++++++++++++++++-- .../spark/storage/BlockInfoManagerSuite.scala | 2 +- project/MimaExcludes.scala | 3 ++ .../UnsafeFixedWidthAggregationMapSuite.scala | 1 + .../UnsafeKVExternalSorterSuite.scala | 1 + .../execution/UnsafeRowSerializerSuite.scala | 2 +- .../SortBasedAggregationStoreSuite.scala | 3 +- 13 files changed, 54 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 0b87cd503d4fa..69739745aa6cf 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -66,7 +66,7 @@ object TaskContext { * An empty task context that does not represent an actual task. This is only used in tests. */ private[spark] def empty(): TaskContextImpl = { - new TaskContextImpl(0, 0, 0, 0, null, new Properties, null) + new TaskContextImpl(0, 0, 0, 0, 0, null, new Properties, null) } } @@ -150,6 +150,13 @@ abstract class TaskContext extends Serializable { */ def stageId(): Int + /** + * How many times the stage that this task belongs to has been attempted. The first stage attempt + * will be assigned stageAttemptNumber = 0, and subsequent attempts will have increasing attempt + * numbers. + */ + def stageAttemptNumber(): Int + /** * The ID of the RDD partition that is computed by this task. */ diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 01d8973e1bb06..cccd3ea457ba4 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -41,8 +41,9 @@ import org.apache.spark.util._ * `TaskMetrics` & `MetricsSystem` objects are not thread safe. */ private[spark] class TaskContextImpl( - val stageId: Int, - val partitionId: Int, + override val stageId: Int, + override val stageAttemptNumber: Int, + override val partitionId: Int, override val taskAttemptId: Long, override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 7767ef1803a06..f536fc2a5f0a1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -79,6 +79,7 @@ private[spark] abstract class Task[T]( SparkEnv.get.blockManager.registerTask(taskAttemptId) context = new TaskContextImpl( stageId, + stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal partitionId, taskAttemptId, attemptNumber, diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java index 94f5805853e1e..f8e233a05a447 100644 --- a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java +++ b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java @@ -38,6 +38,7 @@ public static void test() { tc.attemptNumber(); tc.partitionId(); tc.stageId(); + tc.stageAttemptNumber(); tc.taskAttemptId(); } @@ -51,6 +52,7 @@ public void onTaskCompletion(TaskContext context) { context.isCompleted(); context.isInterrupted(); context.stageId(); + context.stageAttemptNumber(); context.partitionId(); context.addTaskCompletionListener(this); } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 3931d53b4ae0a..ced5a06516f75 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -363,14 +363,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // first attempt -- its successful val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)) val data1 = (1 to 10).map { x => x -> x} // second attempt -- also successful. We'll write out different data, // just to simulate the fact that the records may get written differently // depending on what gets spilled, what gets combined, etc. val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)) val data2 = (11 to 20).map { x => x -> x} // interleave writes of both attempts -- we want to test that both attempts can occur @@ -398,7 +398,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC } val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, - new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)) val readData = reader.read().toIndexedSeq assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala index 362cd861cc248..dcf89e4f75acf 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala @@ -29,6 +29,7 @@ object MemoryTestingUtils { val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0) new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = 0, attemptNumber = 0, diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index a1d9085fa085d..aa9c36c0aaacb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.JvmSource import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { @@ -158,6 +159,30 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(attemptIdsWithFailedTask.toSet === Set(0, 1)) } + test("TaskContext.stageAttemptNumber getter") { + sc = new SparkContext("local[1,2]", "test") + + // Check stageAttemptNumbers are 0 for initial stage + val stageAttemptNumbers = sc.parallelize(Seq(1, 2), 2).mapPartitions { _ => + Seq(TaskContext.get().stageAttemptNumber()).iterator + }.collect() + assert(stageAttemptNumbers.toSet === Set(0)) + + // Check stageAttemptNumbers that are resubmitted when tasks have FetchFailedException + val stageAttemptNumbersWithFailedStage = + sc.parallelize(Seq(1, 2, 3, 4), 4).repartition(1).mapPartitions { _ => + val stageAttemptNumber = TaskContext.get().stageAttemptNumber() + if (stageAttemptNumber < 2) { + // Throw FetchFailedException to explicitly trigger stage resubmission. A normal exception + // will only trigger task resubmission in the same stage. + throw new FetchFailedException(null, 0, 0, 0, "Fake") + } + Seq(stageAttemptNumber).iterator + }.collect() + + assert(stageAttemptNumbersWithFailedStage.toSet === Set(2)) + } + test("accumulators are updated on exception failures") { // This means use 1 core and 4 max task failures sc = new SparkContext("local[1,4]", "test") @@ -190,7 +215,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark // accumulator updates from it. val taskMetrics = TaskMetrics.empty val task = new Task[Int](0, 0, 0) { - context = new TaskContextImpl(0, 0, 0L, 0, + context = new TaskContextImpl(0, 0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), new Properties, SparkEnv.get.metricsSystem, @@ -213,7 +238,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark // accumulator updates from it. val taskMetrics = TaskMetrics.registered val task = new Task[Int](0, 0, 0) { - context = new TaskContextImpl(0, 0, 0L, 0, + context = new TaskContextImpl(0, 0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), new Properties, SparkEnv.get.metricsSystem, diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala index 917db766f7f11..9c0699bc981f8 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala @@ -62,7 +62,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { private def withTaskId[T](taskAttemptId: Long)(block: => T): T = { try { TaskContext.setTaskContext( - new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null)) + new TaskContextImpl(0, 0, 0, taskAttemptId, 0, null, new Properties, null)) block } finally { TaskContext.unset() diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 81584af6813ea..3b452f35c5ec1 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,9 @@ object MimaExcludes { // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( + // [SPARK-22897] Expose stageAttemptId in TaskContext + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.stageAttemptNumber"), + // SPARK-22789: Map-only continuous processing execution ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$8"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$6"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 232c1beae7998..3e31d22e15c0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -70,6 +70,7 @@ class UnsafeFixedWidthAggregationMapSuite TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = Random.nextInt(10000), attemptNumber = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 604502f2a57d0..6af9f8b77f8d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -116,6 +116,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { val taskMemMgr = new TaskMemoryManager(memoryManager, 0) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = 98456, attemptNumber = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index dff88ce7f1b9a..a3ae93810aa3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -114,7 +114,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { (i, converter(Row(i))) } val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0) - val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new Properties, null) + val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null) val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( taskContext, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala index 10f1ee279bedf..3fad7dfddadcc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala @@ -35,7 +35,8 @@ class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkConte val conf = new SparkConf() sc = new SparkContext("local[2, 4]", "test", conf) val taskManager = new TaskMemoryManager(new TestMemoryManager(conf), 0) - TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, taskManager, new Properties, null)) + TaskContext.setTaskContext( + new TaskContextImpl(0, 0, 0, 0, 0, taskManager, new Properties, null)) } override def afterAll(): Unit = TaskContext.unset() From 247a08939d58405aef39b2a4e7773aa45474ad12 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Wed, 3 Jan 2018 21:40:51 +0800 Subject: [PATCH 003/774] [SPARK-22938] Assert that SQLConf.get is accessed only on the driver. ## What changes were proposed in this pull request? Assert if code tries to access SQLConf.get on executor. This can lead to hard to detect bugs, where the executor will read fallbackConf, falling back to default config values, ignoring potentially changed non-default configs. If a config is to be passed to executor code, it needs to be read on the driver, and passed explicitly. ## How was this patch tested? Check in existing tests. Author: Juliusz Sompolski Closes #20136 from juliuszsompolski/SPARK-22938. --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4f77c54a7af57..80cdc61484c0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,11 +27,13 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator +import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -70,7 +72,7 @@ object SQLConf { * Default config. Only used when there is no active SparkSession for the thread. * See [[get]] for more information. */ - private val fallbackConf = new ThreadLocal[SQLConf] { + private lazy val fallbackConf = new ThreadLocal[SQLConf] { override def initialValue: SQLConf = new SQLConf } @@ -1087,6 +1089,12 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ + if (Utils.isTesting && SparkEnv.get != null) { + // assert that we're only accessing it on the driver. + assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, + "SQLConf should only be created and accessed on the driver.") + } + /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) From 1a87a1609c4d2c9027a2cf669ea3337b89f61fb6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 3 Jan 2018 22:09:30 +0800 Subject: [PATCH 004/774] [SPARK-22934][SQL] Make optional clauses order insensitive for CREATE TABLE SQL statement ## What changes were proposed in this pull request? Currently, our CREATE TABLE syntax require the EXACT order of clauses. It is pretty hard to remember the exact order. Thus, this PR is to make optional clauses order insensitive for `CREATE TABLE` SQL statement. ``` CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name [(col_name1 col_type1 [COMMENT col_comment1], ...)] USING datasource [OPTIONS (key1=val1, key2=val2, ...)] [PARTITIONED BY (col_name1, col_name2, ...)] [CLUSTERED BY (col_name3, col_name4, ...) INTO num_buckets BUCKETS] [LOCATION path] [COMMENT table_comment] [TBLPROPERTIES (key1=val1, key2=val2, ...)] [AS select_statement] ``` The proposal is to make the following clauses order insensitive. ``` [OPTIONS (key1=val1, key2=val2, ...)] [PARTITIONED BY (col_name1, col_name2, ...)] [CLUSTERED BY (col_name3, col_name4, ...) INTO num_buckets BUCKETS] [LOCATION path] [COMMENT table_comment] [TBLPROPERTIES (key1=val1, key2=val2, ...)] ``` The same idea is also applicable to Create Hive Table. ``` CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name [(col_name1[:] col_type1 [COMMENT col_comment1], ...)] [COMMENT table_comment] [PARTITIONED BY (col_name2[:] col_type2 [COMMENT col_comment2], ...)] [ROW FORMAT row_format] [STORED AS file_format] [LOCATION path] [TBLPROPERTIES (key1=val1, key2=val2, ...)] [AS select_statement] ``` The proposal is to make the following clauses order insensitive. ``` [COMMENT table_comment] [PARTITIONED BY (col_name2[:] col_type2 [COMMENT col_comment2], ...)] [ROW FORMAT row_format] [STORED AS file_format] [LOCATION path] [TBLPROPERTIES (key1=val1, key2=val2, ...)] ``` ## How was this patch tested? Added test cases Author: gatorsmile Closes #20133 from gatorsmile/createDataSourceTableDDL. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 24 +- .../sql/catalyst/parser/ParserUtils.scala | 9 + .../spark/sql/execution/SparkSqlParser.scala | 81 +++++-- .../execution/command/DDLParserSuite.scala | 220 ++++++++++++++---- .../sql/execution/command/DDLSuite.scala | 2 +- .../sql/hive/execution/HiveDDLSuite.scala | 13 +- .../sql/hive/execution/SQLQuerySuite.scala | 124 +++++----- 7 files changed, 335 insertions(+), 138 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 6fe995f650d55..6daf01d98426c 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -73,18 +73,22 @@ statement | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase | createTableHeader ('(' colTypeList ')')? tableProvider - (OPTIONS options=tablePropertyList)? - (PARTITIONED BY partitionColumnNames=identifierList)? - bucketSpec? locationSpec? - (COMMENT comment=STRING)? - (TBLPROPERTIES tableProps=tablePropertyList)? + ((OPTIONS options=tablePropertyList) | + (PARTITIONED BY partitionColumnNames=identifierList) | + bucketSpec | + locationSpec | + (COMMENT comment=STRING) | + (TBLPROPERTIES tableProps=tablePropertyList))* (AS? query)? #createTable | createTableHeader ('(' columns=colTypeList ')')? - (COMMENT comment=STRING)? - (PARTITIONED BY '(' partitionColumns=colTypeList ')')? - bucketSpec? skewSpec? - rowFormat? createFileFormat? locationSpec? - (TBLPROPERTIES tablePropertyList)? + ((COMMENT comment=STRING) | + (PARTITIONED BY '(' partitionColumns=colTypeList ')') | + bucketSpec | + skewSpec | + rowFormat | + createFileFormat | + locationSpec | + (TBLPROPERTIES tableProps=tablePropertyList))* (AS? query)? #createHiveTable | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier LIKE source=tableIdentifier locationSpec? #createTableLike diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 9b127f91648e6..89347f4b1f7bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.catalyst.parser +import java.util + import scala.collection.mutable.StringBuilder import org.antlr.v4.runtime.{ParserRuleContext, Token} @@ -39,6 +41,13 @@ object ParserUtils { throw new ParseException(s"Operation not allowed: $message", ctx) } + def checkDuplicateClauses[T]( + nodes: util.List[T], clauseName: String, ctx: ParserRuleContext): Unit = { + if (nodes.size() > 1) { + throw new ParseException(s"Found duplicate clauses: $clauseName", ctx) + } + } + /** Check if duplicate keys exist in a set of key-value pairs. */ def checkDuplicateKeys[T](keyPairs: Seq[(String, T)], ctx: ParserRuleContext): Unit = { keyPairs.groupBy(_._1).filter(_._2.size > 1).foreach { case (key, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 29b584b55972c..d3cfd2a1ffbf2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -383,16 +383,19 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * {{{ * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name * USING table_provider - * [OPTIONS table_property_list] - * [PARTITIONED BY (col_name, col_name, ...)] - * [CLUSTERED BY (col_name, col_name, ...) - * [SORTED BY (col_name [ASC|DESC], ...)] - * INTO num_buckets BUCKETS - * ] - * [LOCATION path] - * [COMMENT table_comment] - * [TBLPROPERTIES (property_name=property_value, ...)] + * create_table_clauses * [[AS] select_statement]; + * + * create_table_clauses (order insensitive): + * [OPTIONS table_property_list] + * [PARTITIONED BY (col_name, col_name, ...)] + * [CLUSTERED BY (col_name, col_name, ...) + * [SORTED BY (col_name [ASC|DESC], ...)] + * INTO num_buckets BUCKETS + * ] + * [LOCATION path] + * [COMMENT table_comment] + * [TBLPROPERTIES (property_name=property_value, ...)] * }}} */ override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { @@ -400,6 +403,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { if (external) { operationNotAllowed("CREATE EXTERNAL TABLE ... USING", ctx) } + + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) val provider = ctx.tableProvider.qualifiedName.getText val schema = Option(ctx.colTypeList()).map(createSchema) @@ -408,9 +419,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { .map(visitIdentifierList(_).toArray) .getOrElse(Array.empty[String]) val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) - val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec) + val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) - val location = Option(ctx.locationSpec).map(visitLocationSpec) + val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec) val storage = DataSource.buildStorageFormatFromOptions(options) if (location.isDefined && storage.locationUri.isDefined) { @@ -1087,13 +1098,16 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * {{{ * CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name * [(col1[:] data_type [COMMENT col_comment], ...)] - * [COMMENT table_comment] - * [PARTITIONED BY (col2[:] data_type [COMMENT col_comment], ...)] - * [ROW FORMAT row_format] - * [STORED AS file_format] - * [LOCATION path] - * [TBLPROPERTIES (property_name=property_value, ...)] + * create_table_clauses * [AS select_statement]; + * + * create_table_clauses (order insensitive): + * [COMMENT table_comment] + * [PARTITIONED BY (col2[:] data_type [COMMENT col_comment], ...)] + * [ROW FORMAT row_format] + * [STORED AS file_format] + * [LOCATION path] + * [TBLPROPERTIES (property_name=property_value, ...)] * }}} */ override def visitCreateHiveTable(ctx: CreateHiveTableContext): LogicalPlan = withOrigin(ctx) { @@ -1104,15 +1118,23 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { "CREATE TEMPORARY TABLE is not supported yet. " + "Please use CREATE TEMPORARY VIEW as an alternative.", ctx) } - if (ctx.skewSpec != null) { + if (ctx.skewSpec.size > 0) { operationNotAllowed("CREATE TABLE ... SKEWED BY", ctx) } + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.createFileFormat, "STORED AS/BY", ctx) + checkDuplicateClauses(ctx.rowFormat, "ROW FORMAT", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + val dataCols = Option(ctx.columns).map(visitColTypeList).getOrElse(Nil) val partitionCols = Option(ctx.partitionColumns).map(visitColTypeList).getOrElse(Nil) - val properties = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty) + val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) val selectQuery = Option(ctx.query).map(plan) - val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec) + val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) // Note: Hive requires partition columns to be distinct from the schema, so we need // to include the partition columns here explicitly @@ -1120,12 +1142,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { // Storage format val defaultStorage = HiveSerDe.getDefaultStorage(conf) - validateRowFormatFileFormat(ctx.rowFormat, ctx.createFileFormat, ctx) - val fileStorage = Option(ctx.createFileFormat).map(visitCreateFileFormat) + validateRowFormatFileFormat(ctx.rowFormat.asScala, ctx.createFileFormat.asScala, ctx) + val fileStorage = ctx.createFileFormat.asScala.headOption.map(visitCreateFileFormat) .getOrElse(CatalogStorageFormat.empty) - val rowStorage = Option(ctx.rowFormat).map(visitRowFormat) + val rowStorage = ctx.rowFormat.asScala.headOption.map(visitRowFormat) .getOrElse(CatalogStorageFormat.empty) - val location = Option(ctx.locationSpec).map(visitLocationSpec) + val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec) // If we are creating an EXTERNAL table, then the LOCATION field is required if (external && location.isEmpty) { operationNotAllowed("CREATE EXTERNAL TABLE must be accompanied by LOCATION", ctx) @@ -1180,7 +1202,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { ctx) } - val hasStorageProperties = (ctx.createFileFormat != null) || (ctx.rowFormat != null) + val hasStorageProperties = (ctx.createFileFormat.size != 0) || (ctx.rowFormat.size != 0) if (conf.convertCTAS && !hasStorageProperties) { // At here, both rowStorage.serdeProperties and fileStorage.serdeProperties // are empty Maps. @@ -1366,6 +1388,15 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } } + private def validateRowFormatFileFormat( + rowFormatCtx: Seq[RowFormatContext], + createFileFormatCtx: Seq[CreateFileFormatContext], + parentCtx: ParserRuleContext): Unit = { + if (rowFormatCtx.size == 1 && createFileFormatCtx.size == 1) { + validateRowFormatFileFormat(rowFormatCtx.head, createFileFormatCtx.head, parentCtx) + } + } + /** * Create or replace a view. This creates a [[CreateViewCommand]] command. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index eb7c33590b602..2b1aea08b1223 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -54,6 +54,13 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + private def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parser.parsePlan(sqlCommand)).getMessage + messages.foreach { message => + assert(e.contains(message)) + } + } + private def parseAs[T: ClassTag](query: String): T = { parser.parsePlan(query) match { case t: T => t @@ -494,6 +501,37 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + test("Duplicate clauses - create table") { + def createTableHeader(duplicateClause: String, isNative: Boolean): String = { + val fileFormat = if (isNative) "USING parquet" else "STORED AS parquet" + s"CREATE TABLE my_tab(a INT, b STRING) $fileFormat $duplicateClause $duplicateClause" + } + + Seq(true, false).foreach { isNative => + intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')", isNative), + "Found duplicate clauses: TBLPROPERTIES") + intercept(createTableHeader("LOCATION '/tmp/file'", isNative), + "Found duplicate clauses: LOCATION") + intercept(createTableHeader("COMMENT 'a table'", isNative), + "Found duplicate clauses: COMMENT") + intercept(createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS", isNative), + "Found duplicate clauses: CLUSTERED BY") + } + + // Only for native data source tables + intercept(createTableHeader("PARTITIONED BY (b)", isNative = true), + "Found duplicate clauses: PARTITIONED BY") + + // Only for Hive serde tables + intercept(createTableHeader("PARTITIONED BY (k int)", isNative = false), + "Found duplicate clauses: PARTITIONED BY") + intercept(createTableHeader("STORED AS parquet", isNative = false), + "Found duplicate clauses: STORED AS/BY") + intercept( + createTableHeader("ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'", isNative = false), + "Found duplicate clauses: ROW FORMAT") + } + test("create table - with location") { val v1 = "CREATE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'" @@ -1153,38 +1191,119 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + test("Test CTAS against data source tables") { + val s1 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s2 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |LOCATION '/user/external/page_view' + |COMMENT 'This is the staging page view table' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s3 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + checkParsing(s3) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.provider == Some("parquet")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } + } + test("Test CTAS #1") { val s1 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view |COMMENT 'This is the staging page view table' |STORED AS RCFILE |LOCATION '/user/external/page_view' |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin + |AS SELECT * FROM src + """.stripMargin - val (desc, exists) = extractTableDesc(s1) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - assert(desc.comment == Some("This is the staging page view table")) - // TODO will be SQLText - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) - assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - assert(desc.storage.serde == - Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + val s2 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |STORED AS RCFILE + |COMMENT 'This is the staging page view table' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |LOCATION '/user/external/page_view' + |AS SELECT * FROM src + """.stripMargin + + val s3 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |LOCATION '/user/external/page_view' + |STORED AS RCFILE + |COMMENT 'This is the staging page view table' + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + checkParsing(s3) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + assert(desc.comment == Some("This is the staging page view table")) + // TODO will be SQLText + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(desc.storage.serde == + Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } } test("Test CTAS #2") { - val s2 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + val s1 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view |COMMENT 'This is the staging page view table' |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' | STORED AS @@ -1192,26 +1311,45 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' |LOCATION '/user/external/page_view' |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin + |AS SELECT * FROM src + """.stripMargin - val (desc, exists) = extractTableDesc(s2) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - // TODO will be SQLText - assert(desc.comment == Some("This is the staging page view table")) - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.storage.properties == Map()) - assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) - assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) - assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + val s2 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' + | STORED AS + | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat' + | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' + |COMMENT 'This is the staging page view table' + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + // TODO will be SQLText + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.storage.properties == Map()) + assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) + assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) + assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } } test("Test CTAS #3") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index fdb9b2f51f9cb..591510c1d8283 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1971,8 +1971,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { s""" |CREATE TABLE t(a int, b int, c int, d int) |USING parquet - |PARTITIONED BY(a, b) |LOCATION "${dir.toURI}" + |PARTITIONED BY(a, b) """.stripMargin) spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4") checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index f2e0c695ca38b..65be244418670 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -875,12 +875,13 @@ class HiveDDLSuite test("desc table for Hive table - bucketed + sorted table") { withTable("tbl") { - sql(s""" - CREATE TABLE tbl (id int, name string) - PARTITIONED BY (ds string) - CLUSTERED BY(id) - SORTED BY(id, name) INTO 1024 BUCKETS - """) + sql( + s""" + |CREATE TABLE tbl (id int, name string) + |CLUSTERED BY(id) + |SORTED BY(id, name) INTO 1024 BUCKETS + |PARTITIONED BY (ds string) + """.stripMargin) val x = sql("DESC FORMATTED tbl").collect() assert(x.containsSlice( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 07ae3ae945848..47adc77a52d51 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -461,51 +461,55 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("CTAS without serde without location") { - val originalConf = sessionState.conf.convertCTAS - - setConf(SQLConf.CONVERT_CTAS, true) - - val defaultDataSource = sessionState.conf.defaultDataSourceName - try { - sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - val message = intercept[AnalysisException] { + withSQLConf(SQLConf.CONVERT_CTAS.key -> "true") { + val defaultDataSource = sessionState.conf.defaultDataSourceName + withTable("ctas1") { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - }.getMessage - assert(message.contains("already exists")) - checkRelation("ctas1", true, defaultDataSource) - sql("DROP TABLE ctas1") + sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + val message = intercept[AnalysisException] { + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + }.getMessage + assert(message.contains("already exists")) + checkRelation("ctas1", isDataSourceTable = true, defaultDataSource) + } // Specifying database name for query can be converted to data source write path // is not allowed right now. - sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = true, defaultDataSource) + } - sql("CREATE TABLE ctas1 stored as textfile" + + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as textfile" + " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "text") - sql("DROP TABLE ctas1") + checkRelation("ctas1", isDataSourceTable = false, "text") + } - sql("CREATE TABLE ctas1 stored as sequencefile" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "sequence") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as sequencefile" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "sequence") + } - sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "rcfile") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "rcfile") + } - sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "orc") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "orc") + } - sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "parquet") - sql("DROP TABLE ctas1") - } finally { - setConf(SQLConf.CONVERT_CTAS, originalConf) - sql("DROP TABLE IF EXISTS ctas1") + withTable("ctas1") { + sql( + """ + |CREATE TABLE ctas1 stored as parquet + |AS SELECT key k, value FROM src ORDER BY k, value + """.stripMargin) + checkRelation("ctas1", isDataSourceTable = false, "parquet") + } } } @@ -539,30 +543,40 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { val defaultDataSource = sessionState.conf.defaultDataSourceName val tempLocation = dir.toURI.getPath.stripSuffix("/") - sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c1'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource, Some(s"file:$tempLocation/c1")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c1'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = true, defaultDataSource, Some(s"file:$tempLocation/c1")) + } - sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c2'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource, Some(s"file:$tempLocation/c2")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c2'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = true, defaultDataSource, Some(s"file:$tempLocation/c2")) + } - sql(s"CREATE TABLE ctas1 stored as textfile LOCATION 'file:$tempLocation/c3'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "text", Some(s"file:$tempLocation/c3")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as textfile LOCATION 'file:$tempLocation/c3'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "text", Some(s"file:$tempLocation/c3")) + } - sql(s"CREATE TABLE ctas1 stored as sequenceFile LOCATION 'file:$tempLocation/c4'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "sequence", Some(s"file:$tempLocation/c4")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as sequenceFile LOCATION 'file:$tempLocation/c4'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "sequence", Some(s"file:$tempLocation/c4")) + } - sql(s"CREATE TABLE ctas1 stored as rcfile LOCATION 'file:$tempLocation/c5'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "rcfile", Some(s"file:$tempLocation/c5")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as rcfile LOCATION 'file:$tempLocation/c5'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "rcfile", Some(s"file:$tempLocation/c5")) + } } } } From a66fe36cee9363b01ee70e469f1c968f633c5713 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 3 Jan 2018 22:18:13 +0800 Subject: [PATCH 005/774] [SPARK-20236][SQL] dynamic partition overwrite ## What changes were proposed in this pull request? When overwriting a partitioned table with dynamic partition columns, the behavior is different between data source and hive tables. data source table: delete all partition directories that match the static partition values provided in the insert statement. hive table: only delete partition directories which have data written into it This PR adds a new config to make users be able to choose hive's behavior. ## How was this patch tested? new tests Author: Wenchen Fan Closes #18714 from cloud-fan/overwrite-partition. --- .../internal/io/FileCommitProtocol.scala | 25 ++++-- .../io/HadoopMapReduceCommitProtocol.scala | 75 ++++++++++++++---- .../apache/spark/sql/internal/SQLConf.scala | 21 +++++ .../InsertIntoHadoopFsRelationCommand.scala | 20 ++++- .../SQLHadoopMapReduceCommitProtocol.scala | 10 ++- .../spark/sql/sources/InsertSuite.scala | 78 +++++++++++++++++++ 6 files changed, 200 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index 50f51e1af4530..6d0059b6a0272 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -28,8 +28,9 @@ import org.apache.spark.util.Utils * * 1. Implementations must be serializable, as the committer instance instantiated on the driver * will be used for tasks on executors. - * 2. Implementations should have a constructor with 2 arguments: - * (jobId: String, path: String) + * 2. Implementations should have a constructor with 2 or 3 arguments: + * (jobId: String, path: String) or + * (jobId: String, path: String, dynamicPartitionOverwrite: Boolean) * 3. A committer should not be reused across multiple Spark jobs. * * The proper call sequence is: @@ -139,10 +140,22 @@ object FileCommitProtocol { /** * Instantiates a FileCommitProtocol using the given className. */ - def instantiate(className: String, jobId: String, outputPath: String) - : FileCommitProtocol = { + def instantiate( + className: String, + jobId: String, + outputPath: String, + dynamicPartitionOverwrite: Boolean = false): FileCommitProtocol = { val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]] - val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) - ctor.newInstance(jobId, outputPath) + // First try the constructor with arguments (jobId: String, outputPath: String, + // dynamicPartitionOverwrite: Boolean). + // If that doesn't exist, try the one with (jobId: string, outputPath: String). + try { + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean]) + ctor.newInstance(jobId, outputPath, dynamicPartitionOverwrite.asInstanceOf[java.lang.Boolean]) + } catch { + case _: NoSuchMethodException => + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) + ctor.newInstance(jobId, outputPath) + } } } diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 95c99d29c3a9c..6d20ef1f98a3c 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -39,8 +39,19 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil * * @param jobId the job's or stage's id * @param path the job's output path, or null if committer acts as a noop + * @param dynamicPartitionOverwrite If true, Spark will overwrite partition directories at runtime + * dynamically, i.e., we first write files under a staging + * directory with partition path, e.g. + * /path/to/staging/a=1/b=1/xxx.parquet. When committing the job, + * we first clean up the corresponding partition directories at + * destination path, e.g. /path/to/destination/a=1/b=1, and move + * files from staging directory to the corresponding partition + * directories under destination path. */ -class HadoopMapReduceCommitProtocol(jobId: String, path: String) +class HadoopMapReduceCommitProtocol( + jobId: String, + path: String, + dynamicPartitionOverwrite: Boolean = false) extends FileCommitProtocol with Serializable with Logging { import FileCommitProtocol._ @@ -67,9 +78,17 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) @transient private var addedAbsPathFiles: mutable.Map[String, String] = null /** - * The staging directory for all files committed with absolute output paths. + * Tracks partitions with default path that have new files written into them by this task, + * e.g. a=1/b=2. Files under these partitions will be saved into staging directory and moved to + * destination directory at the end, if `dynamicPartitionOverwrite` is true. */ - private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId) + @transient private var partitionPaths: mutable.Set[String] = null + + /** + * The staging directory of this write job. Spark uses it to deal with files with absolute output + * path, or writing data into partitioned directory with dynamicPartitionOverwrite=true. + */ + private def stagingDir = new Path(path, ".spark-staging-" + jobId) protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { val format = context.getOutputFormatClass.newInstance() @@ -85,11 +104,16 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { val filename = getFilename(taskContext, ext) - val stagingDir: String = committer match { + val stagingDir: Path = committer match { + case _ if dynamicPartitionOverwrite => + assert(dir.isDefined, + "The dataset to be written must be partitioned when dynamicPartitionOverwrite is true.") + partitionPaths += dir.get + this.stagingDir // For FileOutputCommitter it has its own staging path called "work path". case f: FileOutputCommitter => - Option(f.getWorkPath).map(_.toString).getOrElse(path) - case _ => path + new Path(Option(f.getWorkPath).map(_.toString).getOrElse(path)) + case _ => new Path(path) } dir.map { d => @@ -106,8 +130,7 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) // Include a UUID here to prevent file collisions for one task writing to different dirs. // In principle we could include hash(absoluteDir) instead but this is simpler. - val tmpOutputPath = new Path( - absPathStagingDir, UUID.randomUUID().toString() + "-" + filename).toString + val tmpOutputPath = new Path(stagingDir, UUID.randomUUID().toString() + "-" + filename).toString addedAbsPathFiles(tmpOutputPath) = absOutputPath tmpOutputPath @@ -141,23 +164,42 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { committer.commitJob(jobContext) - val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]]) - .foldLeft(Map[String, String]())(_ ++ _) - logDebug(s"Committing files staged for absolute locations $filesToMove") + if (hasValidPath) { - val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + val (allAbsPathFiles, allPartitionPaths) = + taskCommits.map(_.obj.asInstanceOf[(Map[String, String], Set[String])]).unzip + val fs = stagingDir.getFileSystem(jobContext.getConfiguration) + + val filesToMove = allAbsPathFiles.foldLeft(Map[String, String]())(_ ++ _) + logDebug(s"Committing files staged for absolute locations $filesToMove") + if (dynamicPartitionOverwrite) { + val absPartitionPaths = filesToMove.values.map(new Path(_).getParent).toSet + logDebug(s"Clean up absolute partition directories for overwriting: $absPartitionPaths") + absPartitionPaths.foreach(fs.delete(_, true)) + } for ((src, dst) <- filesToMove) { fs.rename(new Path(src), new Path(dst)) } - fs.delete(absPathStagingDir, true) + + if (dynamicPartitionOverwrite) { + val partitionPaths = allPartitionPaths.foldLeft(Set[String]())(_ ++ _) + logDebug(s"Clean up default partition directories for overwriting: $partitionPaths") + for (part <- partitionPaths) { + val finalPartPath = new Path(path, part) + fs.delete(finalPartPath, true) + fs.rename(new Path(stagingDir, part), finalPartPath) + } + } + + fs.delete(stagingDir, true) } } override def abortJob(jobContext: JobContext): Unit = { committer.abortJob(jobContext, JobStatus.State.FAILED) if (hasValidPath) { - val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) - fs.delete(absPathStagingDir, true) + val fs = stagingDir.getFileSystem(jobContext.getConfiguration) + fs.delete(stagingDir, true) } } @@ -165,13 +207,14 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) committer = setupCommitter(taskContext) committer.setupTask(taskContext) addedAbsPathFiles = mutable.Map[String, String]() + partitionPaths = mutable.Set[String]() } override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { val attemptId = taskContext.getTaskAttemptID SparkHadoopMapRedUtil.commitTask( committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId) - new TaskCommitMessage(addedAbsPathFiles.toMap) + new TaskCommitMessage(addedAbsPathFiles.toMap -> partitionPaths.toSet) } override def abortTask(taskContext: TaskAttemptContext): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 80cdc61484c0f..5d6edf6b8abec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1068,6 +1068,24 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(100) + object PartitionOverwriteMode extends Enumeration { + val STATIC, DYNAMIC = Value + } + + val PARTITION_OVERWRITE_MODE = + buildConf("spark.sql.sources.partitionOverwriteMode") + .doc("When INSERT OVERWRITE a partitioned data source table, we currently support 2 modes: " + + "static and dynamic. In static mode, Spark deletes all the partitions that match the " + + "partition specification(e.g. PARTITION(a=1,b)) in the INSERT statement, before " + + "overwriting. In dynamic mode, Spark doesn't delete partitions ahead, and only overwrite " + + "those partitions that have data written into it at runtime. By default we use static " + + "mode to keep the same behavior of Spark prior to 2.3. Note that this config doesn't " + + "affect Hive serde tables, as they are always overwritten with dynamic mode.") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(PartitionOverwriteMode.values.map(_.toString)) + .createWithDefault(PartitionOverwriteMode.STATIC.toString) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1394,6 +1412,9 @@ class SQLConf extends Serializable with Logging { def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) + def partitionOverwriteMode: PartitionOverwriteMode.Value = + PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index ad24e280d942a..dd7ef0d15c140 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.util.SchemaUtils /** @@ -89,13 +90,19 @@ case class InsertIntoHadoopFsRelationCommand( } val pathExists = fs.exists(qualifiedOutputPath) - // If we are appending data to an existing dir. - val isAppend = pathExists && (mode == SaveMode.Append) + + val enableDynamicOverwrite = + sparkSession.sessionState.conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC + // This config only makes sense when we are overwriting a partitioned dataset with dynamic + // partition columns. + val dynamicPartitionOverwrite = enableDynamicOverwrite && mode == SaveMode.Overwrite && + staticPartitions.size < partitionColumns.length val committer = FileCommitProtocol.instantiate( sparkSession.sessionState.conf.fileCommitProtocolClass, jobId = java.util.UUID.randomUUID().toString, - outputPath = outputPath.toString) + outputPath = outputPath.toString, + dynamicPartitionOverwrite = dynamicPartitionOverwrite) val doInsertion = (mode, pathExists) match { case (SaveMode.ErrorIfExists, true) => @@ -103,6 +110,9 @@ case class InsertIntoHadoopFsRelationCommand( case (SaveMode.Overwrite, true) => if (ifPartitionNotExists && matchingPartitions.nonEmpty) { false + } else if (dynamicPartitionOverwrite) { + // For dynamic partition overwrite, do not delete partition directories ahead. + true } else { deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer) true @@ -126,7 +136,9 @@ case class InsertIntoHadoopFsRelationCommand( catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)), ifNotExists = true).run(sparkSession) } - if (mode == SaveMode.Overwrite) { + // For dynamic partition overwrite, we never remove partitions but only update existing + // ones. + if (mode == SaveMode.Overwrite && !dynamicPartitionOverwrite) { val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions if (deletedPartitions.nonEmpty) { AlterTableDropPartitionCommand( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala index 40825a1f724b1..39c594a9bc618 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala @@ -29,11 +29,15 @@ import org.apache.spark.sql.internal.SQLConf * A variant of [[HadoopMapReduceCommitProtocol]] that allows specifying the actual * Hadoop output committer using an option specified in SQLConf. */ -class SQLHadoopMapReduceCommitProtocol(jobId: String, path: String) - extends HadoopMapReduceCommitProtocol(jobId, path) with Serializable with Logging { +class SQLHadoopMapReduceCommitProtocol( + jobId: String, + path: String, + dynamicPartitionOverwrite: Boolean = false) + extends HadoopMapReduceCommitProtocol(jobId, path, dynamicPartitionOverwrite) + with Serializable with Logging { override protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { - var committer = context.getOutputFormatClass.newInstance().getOutputCommitter(context) + var committer = super.setupCommitter(context) val configuration = context.getConfiguration val clazz = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 8b7e2e5f45946..fef01c860db6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -21,6 +21,8 @@ import java.io.File import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -442,4 +444,80 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { assert(e.contains("Only Data Sources providing FileFormat are supported")) } } + + test("SPARK-20236: dynamic partition overwrite without catalog table") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTempPath { path => + Seq((1, 1, 1)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(1, 1, 1)) + + Seq((2, 1, 1)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").mode("overwrite").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(2, 1, 1)) + + Seq((2, 2, 2)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").mode("overwrite").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + } + } + } + + test("SPARK-20236: dynamic partition overwrite") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTable("t") { + sql( + """ + |create table t(i int, part1 int, part2 int) using parquet + |partitioned by (part1, part2) + """.stripMargin) + + sql("insert into t partition(part1=1, part2=1) select 1") + checkAnswer(spark.table("t"), Row(1, 1, 1)) + + sql("insert overwrite table t partition(part1=1, part2=1) select 2") + checkAnswer(spark.table("t"), Row(2, 1, 1)) + + sql("insert overwrite table t partition(part1=2, part2) select 2, 2") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2=2) select 3") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2) select 4, 1") + checkAnswer(spark.table("t"), Row(4, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + } + } + } + + test("SPARK-20236: dynamic partition overwrite with customer partition path") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTable("t") { + sql( + """ + |create table t(i int, part1 int, part2 int) using parquet + |partitioned by (part1, part2) + """.stripMargin) + + val path1 = Utils.createTempDir() + sql(s"alter table t add partition(part1=1, part2=1) location '$path1'") + sql(s"insert into t partition(part1=1, part2=1) select 1") + checkAnswer(spark.table("t"), Row(1, 1, 1)) + + sql("insert overwrite table t partition(part1=1, part2=1) select 2") + checkAnswer(spark.table("t"), Row(2, 1, 1)) + + sql("insert overwrite table t partition(part1=2, part2) select 2, 2") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + + val path2 = Utils.createTempDir() + sql(s"alter table t add partition(part1=1, part2=2) location '$path2'") + sql("insert overwrite table t partition(part1=1, part2=2) select 3") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2) select 4, 1") + checkAnswer(spark.table("t"), Row(4, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + } + } + } } From 9a2b65a3c0c36316aae0a53aa0f61c5044c2ceff Mon Sep 17 00:00:00 2001 From: chetkhatri Date: Wed, 3 Jan 2018 11:31:32 -0600 Subject: [PATCH 006/774] [SPARK-22896] Improvement in String interpolation ## What changes were proposed in this pull request? * String interpolation in ml pipeline example has been corrected as per scala standard. ## How was this patch tested? * manually tested. Author: chetkhatri Closes #20070 from chetkhatri/mllib-chetan-contrib. --- .../spark/examples/ml/JavaQuantileDiscretizerExample.java | 2 +- .../apache/spark/examples/SimpleSkewedGroupByTest.scala | 4 ---- .../org/apache/spark/examples/graphx/Analytics.scala | 6 ++++-- .../org/apache/spark/examples/graphx/SynthBenchmark.scala | 6 +++--- .../apache/spark/examples/ml/ChiSquareTestExample.scala | 6 +++--- .../org/apache/spark/examples/ml/CorrelationExample.scala | 4 ++-- .../org/apache/spark/examples/ml/DataFrameExample.scala | 4 ++-- .../examples/ml/DecisionTreeClassificationExample.scala | 4 ++-- .../spark/examples/ml/DecisionTreeRegressionExample.scala | 4 ++-- .../apache/spark/examples/ml/DeveloperApiExample.scala | 6 +++--- .../examples/ml/EstimatorTransformerParamExample.scala | 6 +++--- .../ml/GradientBoostedTreeClassifierExample.scala | 4 ++-- .../examples/ml/GradientBoostedTreeRegressorExample.scala | 4 ++-- ...ulticlassLogisticRegressionWithElasticNetExample.scala | 2 +- .../ml/MultilayerPerceptronClassifierExample.scala | 2 +- .../org/apache/spark/examples/ml/NaiveBayesExample.scala | 2 +- .../spark/examples/ml/QuantileDiscretizerExample.scala | 4 ++-- .../spark/examples/ml/RandomForestClassifierExample.scala | 4 ++-- .../spark/examples/ml/RandomForestRegressorExample.scala | 4 ++-- .../apache/spark/examples/ml/VectorIndexerExample.scala | 4 ++-- .../spark/examples/mllib/AssociationRulesExample.scala | 6 +++--- .../mllib/BinaryClassificationMetricsExample.scala | 4 ++-- .../mllib/DecisionTreeClassificationExample.scala | 4 ++-- .../examples/mllib/DecisionTreeRegressionExample.scala | 4 ++-- .../org/apache/spark/examples/mllib/FPGrowthExample.scala | 2 +- .../mllib/GradientBoostingClassificationExample.scala | 4 ++-- .../mllib/GradientBoostingRegressionExample.scala | 4 ++-- .../spark/examples/mllib/HypothesisTestingExample.scala | 2 +- .../spark/examples/mllib/IsotonicRegressionExample.scala | 2 +- .../org/apache/spark/examples/mllib/KMeansExample.scala | 2 +- .../org/apache/spark/examples/mllib/LBFGSExample.scala | 2 +- .../examples/mllib/LatentDirichletAllocationExample.scala | 8 +++++--- .../examples/mllib/LinearRegressionWithSGDExample.scala | 2 +- .../org/apache/spark/examples/mllib/PCAExample.scala | 4 ++-- .../spark/examples/mllib/PMMLModelExportExample.scala | 2 +- .../apache/spark/examples/mllib/PrefixSpanExample.scala | 4 ++-- .../mllib/RandomForestClassificationExample.scala | 4 ++-- .../examples/mllib/RandomForestRegressionExample.scala | 4 ++-- .../spark/examples/mllib/RecommendationExample.scala | 2 +- .../apache/spark/examples/mllib/SVMWithSGDExample.scala | 2 +- .../org/apache/spark/examples/mllib/SimpleFPGrowth.scala | 8 +++----- .../spark/examples/mllib/StratifiedSamplingExample.scala | 4 ++-- .../org/apache/spark/examples/mllib/TallSkinnyPCA.scala | 2 +- .../org/apache/spark/examples/mllib/TallSkinnySVD.scala | 2 +- .../apache/spark/examples/streaming/CustomReceiver.scala | 6 +++--- .../apache/spark/examples/streaming/RawNetworkGrep.scala | 2 +- .../examples/streaming/RecoverableNetworkWordCount.scala | 8 ++++---- .../streaming/clickstream/PageViewGenerator.scala | 4 ++-- .../examples/streaming/clickstream/PageViewStream.scala | 4 ++-- 49 files changed, 94 insertions(+), 96 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java index dd20cac621102..43cc30c1a899b 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java @@ -66,7 +66,7 @@ public static void main(String[] args) { .setNumBuckets(3); Dataset result = discretizer.fit(df).transform(df); - result.show(); + result.show(false); // $example off$ spark.stop(); } diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index e64dcbd182d94..2332a661f26a0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -60,10 +60,6 @@ object SimpleSkewedGroupByTest { pairs1.count println(s"RESULT: ${pairs1.groupByKey(numReducers).count}") - // Print how many keys each reducer got (for debugging) - // println("RESULT: " + pairs1.groupByKey(numReducers) - // .map{case (k,v) => (k, v.size)} - // .collectAsMap) spark.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index 92936bd30dbc0..815404d1218b7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -145,9 +145,11 @@ object Analytics extends Logging { // TriangleCount requires the graph to be partitioned .partitionBy(partitionStrategy.getOrElse(RandomVertexCut)).cache() val triangles = TriangleCount.run(graph) - println("Triangles: " + triangles.vertices.map { + val triangleTypes = triangles.vertices.map { case (vid, data) => data.toLong - }.reduce(_ + _) / 3) + }.reduce(_ + _) / 3 + + println(s"Triangles: ${triangleTypes}") sc.stop() case _ => diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala index 6d2228c8742aa..57b2edf992208 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala @@ -52,7 +52,7 @@ object SynthBenchmark { arg => arg.dropWhile(_ == '-').split('=') match { case Array(opt, v) => (opt -> v) - case _ => throw new IllegalArgumentException("Invalid argument: " + arg) + case _ => throw new IllegalArgumentException(s"Invalid argument: $arg") } } @@ -76,7 +76,7 @@ object SynthBenchmark { case ("sigma", v) => sigma = v.toDouble case ("degFile", v) => degFile = v case ("seed", v) => seed = v.toInt - case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt) + case (opt, _) => throw new IllegalArgumentException(s"Invalid option: $opt") } val conf = new SparkConf() @@ -86,7 +86,7 @@ object SynthBenchmark { val sc = new SparkContext(conf) // Create the graph - println(s"Creating graph...") + println("Creating graph...") val unpartitionedGraph = GraphGenerators.logNormalGraph(sc, numVertices, numEPart.getOrElse(sc.defaultParallelism), mu, sigma, seed) // Repartition the graph diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala index dcee1e427ce58..5146fd0316467 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala @@ -52,9 +52,9 @@ object ChiSquareTestExample { val df = data.toDF("label", "features") val chi = ChiSquareTest.test(df, "features", "label").head - println("pValues = " + chi.getAs[Vector](0)) - println("degreesOfFreedom = " + chi.getSeq[Int](1).mkString("[", ",", "]")) - println("statistics = " + chi.getAs[Vector](2)) + println(s"pValues = ${chi.getAs[Vector](0)}") + println(s"degreesOfFreedom ${chi.getSeq[Int](1).mkString("[", ",", "]")}") + println(s"statistics ${chi.getAs[Vector](2)}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala index 3f57dc342eb00..d7f1fc8ed74d7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala @@ -51,10 +51,10 @@ object CorrelationExample { val df = data.map(Tuple1.apply).toDF("features") val Row(coeff1: Matrix) = Correlation.corr(df, "features").head - println("Pearson correlation matrix:\n" + coeff1.toString) + println(s"Pearson correlation matrix:\n $coeff1") val Row(coeff2: Matrix) = Correlation.corr(df, "features", "spearman").head - println("Spearman correlation matrix:\n" + coeff2.toString) + println(s"Spearman correlation matrix:\n $coeff2") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index 0658bddf16961..ee4469faab3a0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -47,7 +47,7 @@ object DataFrameExample { val parser = new OptionParser[Params]("DataFrameExample") { head("DataFrameExample: an example app using DataFrame for ML.") opt[String]("input") - .text(s"input path to dataframe") + .text("input path to dataframe") .action((x, c) => c.copy(input = x)) checkConfig { params => success @@ -93,7 +93,7 @@ object DataFrameExample { // Load the records back. println(s"Loading Parquet file with UDT from $outputDir.") val newDF = spark.read.parquet(outputDir) - println(s"Schema from Parquet:") + println("Schema from Parquet:") newDF.printSchema() spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala index bc6d3275933ea..276cedab11abc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala @@ -83,10 +83,10 @@ object DecisionTreeClassificationExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test Error = " + (1.0 - accuracy)) + println(s"Test Error = ${(1.0 - accuracy)}") val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] - println("Learned classification tree model:\n" + treeModel.toDebugString) + println(s"Learned classification tree model:\n ${treeModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala index ee61200ad1d0c..aaaecaea47081 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala @@ -73,10 +73,10 @@ object DecisionTreeRegressionExample { .setPredictionCol("prediction") .setMetricName("rmse") val rmse = evaluator.evaluate(predictions) - println("Root Mean Squared Error (RMSE) on test data = " + rmse) + println(s"Root Mean Squared Error (RMSE) on test data = $rmse") val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel] - println("Learned regression tree model:\n" + treeModel.toDebugString) + println(s"Learned regression tree model:\n ${treeModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index d94d837d10e96..2dc11b07d88ef 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -53,7 +53,7 @@ object DeveloperApiExample { // Create a LogisticRegression instance. This instance is an Estimator. val lr = new MyLogisticRegression() // Print out the parameters, documentation, and any default values. - println("MyLogisticRegression parameters:\n" + lr.explainParams() + "\n") + println(s"MyLogisticRegression parameters:\n ${lr.explainParams()}") // We may set parameters using setter methods. lr.setMaxIter(10) @@ -169,10 +169,10 @@ private class MyLogisticRegressionModel( Vectors.dense(-margin, margin) } - /** Number of classes the label can take. 2 indicates binary classification. */ + // Number of classes the label can take. 2 indicates binary classification. override val numClasses: Int = 2 - /** Number of features the model was trained on. */ + // Number of features the model was trained on. override val numFeatures: Int = coefficients.size /** diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala index f18d86e1a6921..e5d91f132a3f2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala @@ -46,7 +46,7 @@ object EstimatorTransformerParamExample { // Create a LogisticRegression instance. This instance is an Estimator. val lr = new LogisticRegression() // Print out the parameters, documentation, and any default values. - println("LogisticRegression parameters:\n" + lr.explainParams() + "\n") + println(s"LogisticRegression parameters:\n ${lr.explainParams()}\n") // We may set parameters using setter methods. lr.setMaxIter(10) @@ -58,7 +58,7 @@ object EstimatorTransformerParamExample { // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this // LogisticRegression instance. - println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) + println(s"Model 1 was fit using parameters: ${model1.parent.extractParamMap}") // We may alternatively specify parameters using a ParamMap, // which supports several methods for specifying parameters. @@ -73,7 +73,7 @@ object EstimatorTransformerParamExample { // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. val model2 = lr.fit(training, paramMapCombined) - println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) + println(s"Model 2 was fit using parameters: ${model2.parent.extractParamMap}") // Prepare test data. val test = spark.createDataFrame(Seq( diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala index 3656773c8b817..ef78c0a1145ef 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala @@ -86,10 +86,10 @@ object GradientBoostedTreeClassifierExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test Error = " + (1.0 - accuracy)) + println(s"Test Error = ${1.0 - accuracy}") val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel] - println("Learned classification GBT model:\n" + gbtModel.toDebugString) + println(s"Learned classification GBT model:\n ${gbtModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala index e53aab7f326d3..3feb2343f6a85 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala @@ -73,10 +73,10 @@ object GradientBoostedTreeRegressorExample { .setPredictionCol("prediction") .setMetricName("rmse") val rmse = evaluator.evaluate(predictions) - println("Root Mean Squared Error (RMSE) on test data = " + rmse) + println(s"Root Mean Squared Error (RMSE) on test data = $rmse") val gbtModel = model.stages(1).asInstanceOf[GBTRegressionModel] - println("Learned regression GBT model:\n" + gbtModel.toDebugString) + println(s"Learned regression GBT model:\n ${gbtModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala index 42f0ace7a353d..3e61dbe628c20 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala @@ -48,7 +48,7 @@ object MulticlassLogisticRegressionWithElasticNetExample { // Print the coefficients and intercept for multinomial logistic regression println(s"Coefficients: \n${lrModel.coefficientMatrix}") - println(s"Intercepts: ${lrModel.interceptVector}") + println(s"Intercepts: \n${lrModel.interceptVector}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala index 6fce82d294f8d..646f46a925062 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala @@ -66,7 +66,7 @@ object MultilayerPerceptronClassifierExample { val evaluator = new MulticlassClassificationEvaluator() .setMetricName("accuracy") - println("Test set accuracy = " + evaluator.evaluate(predictionAndLabels)) + println(s"Test set accuracy = ${evaluator.evaluate(predictionAndLabels)}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala index bd9fcc420a66c..50c70c626b128 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala @@ -52,7 +52,7 @@ object NaiveBayesExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test set accuracy = " + accuracy) + println(s"Test set accuracy = $accuracy") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala index aedb9e7d3bb70..0fe16fb6dfa9f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala @@ -36,7 +36,7 @@ object QuantileDiscretizerExample { // Output of QuantileDiscretizer for such small datasets can depend on the number of // partitions. Here we force a single partition to ensure consistent results. // Note this is not necessary for normal use cases - .repartition(1) + .repartition(1) // $example on$ val discretizer = new QuantileDiscretizer() @@ -45,7 +45,7 @@ object QuantileDiscretizerExample { .setNumBuckets(3) val result = discretizer.fit(df).transform(df) - result.show() + result.show(false) // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala index 5eafda8ce4285..6265f83902528 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala @@ -85,10 +85,10 @@ object RandomForestClassifierExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test Error = " + (1.0 - accuracy)) + println(s"Test Error = ${(1.0 - accuracy)}") val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] - println("Learned classification forest model:\n" + rfModel.toDebugString) + println(s"Learned classification forest model:\n ${rfModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala index 9a0a001c26ef5..2679fcb353a8a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala @@ -72,10 +72,10 @@ object RandomForestRegressorExample { .setPredictionCol("prediction") .setMetricName("rmse") val rmse = evaluator.evaluate(predictions) - println("Root Mean Squared Error (RMSE) on test data = " + rmse) + println(s"Root Mean Squared Error (RMSE) on test data = $rmse") val rfModel = model.stages(1).asInstanceOf[RandomForestRegressionModel] - println("Learned regression forest model:\n" + rfModel.toDebugString) + println(s"Learned regression forest model:\n ${rfModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala index afa761aee0b98..96bb8ea2338af 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala @@ -41,8 +41,8 @@ object VectorIndexerExample { val indexerModel = indexer.fit(data) val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet - println(s"Chose ${categoricalFeatures.size} categorical features: " + - categoricalFeatures.mkString(", ")) + println(s"Chose ${categoricalFeatures.size} " + + s"categorical features: ${categoricalFeatures.mkString(", ")}") // Create new column "indexed" with categorical values transformed to indices val indexedData = indexerModel.transform(data) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala index ff44de56839e5..a07535bb5a38d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala @@ -42,9 +42,8 @@ object AssociationRulesExample { val results = ar.run(freqItemsets) results.collect().foreach { rule => - println("[" + rule.antecedent.mkString(",") - + "=>" - + rule.consequent.mkString(",") + "]," + rule.confidence) + println(s"[${rule.antecedent.mkString(",")}=>${rule.consequent.mkString(",")} ]" + + s" ${rule.confidence}") } // $example off$ @@ -53,3 +52,4 @@ object AssociationRulesExample { } // scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala index b9263ac6fcff6..c6312d71cc912 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala @@ -86,7 +86,7 @@ object BinaryClassificationMetricsExample { // AUPRC val auPRC = metrics.areaUnderPR - println("Area under precision-recall curve = " + auPRC) + println(s"Area under precision-recall curve = $auPRC") // Compute thresholds used in ROC and PR curves val thresholds = precision.map(_._1) @@ -96,7 +96,7 @@ object BinaryClassificationMetricsExample { // AUROC val auROC = metrics.areaUnderROC - println("Area under ROC = " + auROC) + println(s"Area under ROC = $auROC") // $example off$ sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala index b50b4592777ce..c2f89b72c9a2e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala @@ -55,8 +55,8 @@ object DecisionTreeClassificationExample { (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count() - println("Test Error = " + testErr) - println("Learned classification tree model:\n" + model.toDebugString) + println(s"Test Error = $testErr") + println(s"Learned classification tree model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myDecisionTreeClassificationModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala index 2af45afae3d5b..1ecf6426e1f95 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala @@ -54,8 +54,8 @@ object DecisionTreeRegressionExample { (point.label, prediction) } val testMSE = labelsAndPredictions.map{ case (v, p) => math.pow(v - p, 2) }.mean() - println("Test Mean Squared Error = " + testMSE) - println("Learned regression tree model:\n" + model.toDebugString) + println(s"Test Mean Squared Error = $testMSE") + println(s"Learned regression tree model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myDecisionTreeRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala index 6435abc127752..f724ee1030f04 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala @@ -74,7 +74,7 @@ object FPGrowthExample { println(s"Number of frequent itemsets: ${model.freqItemsets.count()}") model.freqItemsets.collect().foreach { itemset => - println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) + println(s"${itemset.items.mkString("[", ",", "]")}, ${itemset.freq}") } sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala index 00bb3348d2a36..3c56e1941aeca 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala @@ -54,8 +54,8 @@ object GradientBoostingClassificationExample { (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() - println("Test Error = " + testErr) - println("Learned classification GBT model:\n" + model.toDebugString) + println(s"Test Error = $testErr") + println(s"Learned classification GBT model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myGradientBoostingClassificationModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala index d8c263460839b..c288bf29bf255 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala @@ -53,8 +53,8 @@ object GradientBoostingRegressionExample { (point.label, prediction) } val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() - println("Test Mean Squared Error = " + testMSE) - println("Learned regression GBT model:\n" + model.toDebugString) + println(s"Test Mean Squared Error = $testMSE") + println(s"Learned regression GBT model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myGradientBoostingRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala index 0d391a3637c07..add1719739539 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala @@ -68,7 +68,7 @@ object HypothesisTestingExample { // against the label. val featureTestResults: Array[ChiSqTestResult] = Statistics.chiSqTest(obs) featureTestResults.zipWithIndex.foreach { case (k, v) => - println("Column " + (v + 1).toString + ":") + println(s"Column ${(v + 1)} :") println(k) } // summary of the test // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala index 4aee951f5b04c..a10d6f0dda880 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala @@ -56,7 +56,7 @@ object IsotonicRegressionExample { // Calculate mean squared error between predicted and real labels. val meanSquaredError = predictionAndLabel.map { case (p, l) => math.pow((p - l), 2) }.mean() - println("Mean Squared Error = " + meanSquaredError) + println(s"Mean Squared Error = $meanSquaredError") // Save and load model model.save(sc, "target/tmp/myIsotonicRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala index c4d71d862f375..b0a6f1671a898 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala @@ -43,7 +43,7 @@ object KMeansExample { // Evaluate clustering by computing Within Set Sum of Squared Errors val WSSSE = clusters.computeCost(parsedData) - println("Within Set Sum of Squared Errors = " + WSSSE) + println(s"Within Set Sum of Squared Errors = $WSSSE") // Save and load model clusters.save(sc, "target/org/apache/spark/KMeansExample/KMeansModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala index fedcefa098381..123782fa6b9cf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala @@ -82,7 +82,7 @@ object LBFGSExample { println("Loss of each step in training process") loss.foreach(println) - println("Area under ROC = " + auROC) + println(s"Area under ROC = $auROC") // $example off$ sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala index f2c8ec01439f1..d25962c5500ed 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala @@ -42,11 +42,13 @@ object LatentDirichletAllocationExample { val ldaModel = new LDA().setK(3).run(corpus) // Output topics. Each is a distribution over words (matching word count vectors) - println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize + " words):") + println(s"Learned topics (as distributions over vocab of ${ldaModel.vocabSize} words):") val topics = ldaModel.topicsMatrix for (topic <- Range(0, 3)) { - print("Topic " + topic + ":") - for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); } + print(s"Topic $topic :") + for (word <- Range(0, ldaModel.vocabSize)) { + print(s"${topics(word, topic)}") + } println() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala index d399618094487..449b725d1d173 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala @@ -52,7 +52,7 @@ object LinearRegressionWithSGDExample { (point.label, prediction) } val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2) }.mean() - println("training Mean Squared Error = " + MSE) + println(s"training Mean Squared Error $MSE") // Save and load model model.save(sc, "target/tmp/scalaLinearRegressionWithSGDModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala index eb36697d94ba1..eff2393cc3abe 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala @@ -65,8 +65,8 @@ object PCAExample { val MSE = valuesAndPreds.map { case (v, p) => math.pow((v - p), 2) }.mean() val MSE_pca = valuesAndPreds_pca.map { case (v, p) => math.pow((v - p), 2) }.mean() - println("Mean Squared Error = " + MSE) - println("PCA Mean Squared Error = " + MSE_pca) + println(s"Mean Squared Error = $MSE") + println(s"PCA Mean Squared Error = $MSE_pca") // $example off$ sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala index d74d74a37fb11..96deafd469bc7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala @@ -41,7 +41,7 @@ object PMMLModelExportExample { val clusters = KMeans.train(parsedData, numClusters, numIterations) // Export to PMML to a String in PMML format - println("PMML Model:\n" + clusters.toPMML) + println(s"PMML Model:\n ${clusters.toPMML}") // Export the model to a local file in PMML format clusters.toPMML("/tmp/kmeans.xml") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala index 69c72c4336576..8b789277774af 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala @@ -42,8 +42,8 @@ object PrefixSpanExample { val model = prefixSpan.run(sequences) model.freqSequences.collect().foreach { freqSequence => println( - freqSequence.sequence.map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]") + - ", " + freqSequence.freq) + s"${freqSequence.sequence.map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}," + + s" ${freqSequence.freq}") } // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala index f1ebdf1a733ed..246e71de25615 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala @@ -55,8 +55,8 @@ object RandomForestClassificationExample { (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() - println("Test Error = " + testErr) - println("Learned classification forest model:\n" + model.toDebugString) + println(s"Test Error = $testErr") + println(s"Learned classification forest model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myRandomForestClassificationModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala index 11d612e651b4b..770e30276bc30 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala @@ -55,8 +55,8 @@ object RandomForestRegressionExample { (point.label, prediction) } val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() - println("Test Mean Squared Error = " + testMSE) - println("Learned regression forest model:\n" + model.toDebugString) + println(s"Test Mean Squared Error = $testMSE") + println(s"Learned regression forest model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myRandomForestRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala index 6df742d737e70..0bb2b8c8c2b43 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala @@ -56,7 +56,7 @@ object RecommendationExample { val err = (r1 - r2) err * err }.mean() - println("Mean Squared Error = " + MSE) + println(s"Mean Squared Error = $MSE") // Save and load model model.save(sc, "target/tmp/myCollaborativeFilter") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala index b73fe9b2b3faa..285e2ce512639 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala @@ -57,7 +57,7 @@ object SVMWithSGDExample { val metrics = new BinaryClassificationMetrics(scoreAndLabels) val auROC = metrics.areaUnderROC() - println("Area under ROC = " + auROC) + println(s"Area under ROC = $auROC") // Save and load model model.save(sc, "target/tmp/scalaSVMWithSGDModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala index b5c3033bcba09..694c3bb18b045 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala @@ -42,15 +42,13 @@ object SimpleFPGrowth { val model = fpg.run(transactions) model.freqItemsets.collect().foreach { itemset => - println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) + println(s"${itemset.items.mkString("[", ",", "]")},${itemset.freq}") } val minConfidence = 0.8 model.generateAssociationRules(minConfidence).collect().foreach { rule => - println( - rule.antecedent.mkString("[", ",", "]") - + " => " + rule.consequent .mkString("[", ",", "]") - + ", " + rule.confidence) + println(s"${rule.antecedent.mkString("[", ",", "]")}=> " + + s"${rule.consequent .mkString("[", ",", "]")},${rule.confidence}") } // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala index 16b074ef60699..3d41bef0af88c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala @@ -41,10 +41,10 @@ object StratifiedSamplingExample { val exactSample = data.sampleByKeyExact(withReplacement = false, fractions = fractions) // $example off$ - println("approxSample size is " + approxSample.collect().size.toString) + println(s"approxSample size is ${approxSample.collect().size}") approxSample.collect().foreach(println) - println("exactSample its size is " + exactSample.collect().size.toString) + println(s"exactSample its size is ${exactSample.collect().size}") exactSample.collect().foreach(println) sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala index 03bc675299c5a..071d341b81614 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala @@ -54,7 +54,7 @@ object TallSkinnyPCA { // Compute principal components. val pc = mat.computePrincipalComponents(mat.numCols().toInt) - println("Principal components are:\n" + pc) + println(s"Principal components are:\n $pc") sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala index 067e49b9599e7..8ae6de16d80e7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala @@ -54,7 +54,7 @@ object TallSkinnySVD { // Compute SVD. val svd = mat.computeSVD(mat.numCols().toInt) - println("Singular values are " + svd.s) + println(s"Singular values are ${svd.s}") sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala index 43044d01b1204..25c7bf2871972 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala @@ -82,9 +82,9 @@ class CustomReceiver(host: String, port: Int) var socket: Socket = null var userInput: String = null try { - logInfo("Connecting to " + host + ":" + port) + logInfo(s"Connecting to $host : $port") socket = new Socket(host, port) - logInfo("Connected to " + host + ":" + port) + logInfo(s"Connected to $host : $port") val reader = new BufferedReader( new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8)) userInput = reader.readLine() @@ -98,7 +98,7 @@ class CustomReceiver(host: String, port: Int) restart("Trying to connect again") } catch { case e: java.net.ConnectException => - restart("Error connecting to " + host + ":" + port, e) + restart(s"Error connecting to $host : $port", e) case t: Throwable => restart("Error receiving data", t) } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala index 5322929d177b4..437ccf0898d7c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala @@ -54,7 +54,7 @@ object RawNetworkGrep { ssc.rawSocketStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray val union = ssc.union(rawStreams) union.filter(_.contains("the")).count().foreachRDD(r => - println("Grep count: " + r.collect().mkString)) + println(s"Grep count: ${r.collect().mkString}")) ssc.start() ssc.awaitTermination() } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 49c0427321133..f018f3a26d2e9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -130,10 +130,10 @@ object RecoverableNetworkWordCount { true } }.collect().mkString("[", ", ", "]") - val output = "Counts at time " + time + " " + counts + val output = s"Counts at time $time $counts" println(output) - println("Dropped " + droppedWordsCounter.value + " word(s) totally") - println("Appending to " + outputFile.getAbsolutePath) + println(s"Dropped ${droppedWordsCounter.value} word(s) totally") + println(s"Appending to ${outputFile.getAbsolutePath}") Files.append(output + "\n", outputFile, Charset.defaultCharset()) } ssc @@ -141,7 +141,7 @@ object RecoverableNetworkWordCount { def main(args: Array[String]) { if (args.length != 4) { - System.err.println("Your arguments were " + args.mkString("[", ", ", "]")) + System.err.println(s"Your arguments were ${args.mkString("[", ", ", "]")}") System.err.println( """ |Usage: RecoverableNetworkWordCount diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 0ddd065f0db2b..2108bc63edea2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -90,13 +90,13 @@ object PageViewGenerator { val viewsPerSecond = args(1).toFloat val sleepDelayMs = (1000.0 / viewsPerSecond).toInt val listener = new ServerSocket(port) - println("Listening on port: " + port) + println(s"Listening on port: $port") while (true) { val socket = listener.accept() new Thread() { override def run(): Unit = { - println("Got client connected from: " + socket.getInetAddress) + println(s"Got client connected from: ${socket.getInetAddress}") val out = new PrintWriter(socket.getOutputStream(), true) while (true) { diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index 1ba093f57b32c..b8e7c7e9e9152 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -104,8 +104,8 @@ object PageViewStream { .foreachRDD((rdd, time) => rdd.join(userList) .map(_._2._2) .take(10) - .foreach(u => println("Saw user %s at time %s".format(u, time)))) - case _ => println("Invalid metric entered: " + metric) + .foreach(u => println(s"Saw user $u at time $time"))) + case _ => println(s"Invalid metric entered: $metric") } ssc.start() From b297029130735316e1ac1144dee44761a12bfba7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 4 Jan 2018 07:28:53 +0800 Subject: [PATCH 007/774] [SPARK-20960][SQL] make ColumnVector public ## What changes were proposed in this pull request? move `ColumnVector` and related classes to `org.apache.spark.sql.vectorized`, and improve the document. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #20116 from cloud-fan/column-vector. --- .../VectorizedParquetRecordReader.java | 7 ++- .../vectorized/ColumnVectorUtils.java | 2 + .../vectorized/MutableColumnarRow.java | 4 ++ .../vectorized/WritableColumnVector.java | 7 ++- .../vectorized/ArrowColumnVector.java | 62 +------------------ .../vectorized/ColumnVector.java | 31 ++++++---- .../vectorized/ColumnarArray.java | 7 +-- .../vectorized/ColumnarBatch.java | 34 +++------- .../vectorized/ColumnarRow.java | 7 +-- .../sql/execution/ColumnarBatchScan.scala | 4 +- .../aggregate/HashAggregateExec.scala | 2 +- .../VectorizedHashMapGenerator.scala | 3 +- .../sql/execution/arrow/ArrowConverters.scala | 2 +- .../columnar/InMemoryTableScanExec.scala | 1 + .../execution/datasources/FileScanRDD.scala | 2 +- .../execution/python/ArrowPythonRunner.scala | 2 +- .../execution/arrow/ArrowWriterSuite.scala | 2 +- .../vectorized/ArrowColumnVectorSuite.scala | 1 + .../vectorized/ColumnVectorSuite.scala | 2 +- .../vectorized/ColumnarBatchSuite.scala | 6 +- 20 files changed, 63 insertions(+), 125 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ArrowColumnVector.java (94%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnVector.java (79%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnarArray.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnarBatch.java (73%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnarRow.java (96%) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 6c157e85d411f..cd745b1f0e4e3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -31,10 +31,10 @@ import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; -import org.apache.spark.sql.execution.vectorized.ColumnarBatch; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector; import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -248,7 +248,10 @@ public void enableReturningBatches() { * Advances to the next batch of rows. Returns false if there are no more. */ public boolean nextBatch() throws IOException { - columnarBatch.reset(); + for (WritableColumnVector vector : columnVectors) { + vector.reset(); + } + columnarBatch.setNumRows(0); if (rowsReturned >= totalRowCount) return false; checkEndOfRowGroup(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index bc62bc43484e5..b5cbe8e2839ba 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -28,6 +28,8 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 06602c147dfe9..70057a9def6c0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -23,6 +23,10 @@ import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarRow; +import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 5f6f125976e12..d2ae32b06f83b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; @@ -585,11 +586,11 @@ public final int appendArray(int length) { public final int appendStruct(boolean isNull) { if (isNull) { appendNull(); - for (ColumnVector c: childColumns) { + for (WritableColumnVector c: childColumns) { if (c.type instanceof StructType) { - ((WritableColumnVector) c).appendStruct(true); + c.appendStruct(true); } else { - ((WritableColumnVector) c).appendNull(); + c.appendNull(); } } } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java similarity index 94% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index af5673e26a501..708333213f3f1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.arrow.vector.*; import org.apache.arrow.vector.complex.*; @@ -34,11 +34,7 @@ public final class ArrowColumnVector extends ColumnVector { private ArrowColumnVector[] childColumns; private void ensureAccessible(int index) { - int valueCount = accessor.getValueCount(); - if (index < 0 || index >= valueCount) { - throw new IndexOutOfBoundsException( - String.format("index: %d, valueCount: %d", index, valueCount)); - } + ensureAccessible(index, 1); } private void ensureAccessible(int index, int count) { @@ -64,20 +60,12 @@ public void close() { accessor.close(); } - // - // APIs dealing with nulls - // - @Override public boolean isNullAt(int rowId) { ensureAccessible(rowId); return accessor.isNullAt(rowId); } - // - // APIs dealing with Booleans - // - @Override public boolean getBoolean(int rowId) { ensureAccessible(rowId); @@ -94,10 +82,6 @@ public boolean[] getBooleans(int rowId, int count) { return array; } - // - // APIs dealing with Bytes - // - @Override public byte getByte(int rowId) { ensureAccessible(rowId); @@ -114,10 +98,6 @@ public byte[] getBytes(int rowId, int count) { return array; } - // - // APIs dealing with Shorts - // - @Override public short getShort(int rowId) { ensureAccessible(rowId); @@ -134,10 +114,6 @@ public short[] getShorts(int rowId, int count) { return array; } - // - // APIs dealing with Ints - // - @Override public int getInt(int rowId) { ensureAccessible(rowId); @@ -154,10 +130,6 @@ public int[] getInts(int rowId, int count) { return array; } - // - // APIs dealing with Longs - // - @Override public long getLong(int rowId) { ensureAccessible(rowId); @@ -174,10 +146,6 @@ public long[] getLongs(int rowId, int count) { return array; } - // - // APIs dealing with floats - // - @Override public float getFloat(int rowId) { ensureAccessible(rowId); @@ -194,10 +162,6 @@ public float[] getFloats(int rowId, int count) { return array; } - // - // APIs dealing with doubles - // - @Override public double getDouble(int rowId) { ensureAccessible(rowId); @@ -214,10 +178,6 @@ public double[] getDoubles(int rowId, int count) { return array; } - // - // APIs dealing with Arrays - // - @Override public int getArrayLength(int rowId) { ensureAccessible(rowId); @@ -230,45 +190,27 @@ public int getArrayOffset(int rowId) { return accessor.getArrayOffset(rowId); } - // - // APIs dealing with Decimals - // - @Override public Decimal getDecimal(int rowId, int precision, int scale) { ensureAccessible(rowId); return accessor.getDecimal(rowId, precision, scale); } - // - // APIs dealing with UTF8Strings - // - @Override public UTF8String getUTF8String(int rowId) { ensureAccessible(rowId); return accessor.getUTF8String(rowId); } - // - // APIs dealing with Binaries - // - @Override public byte[] getBinary(int rowId) { ensureAccessible(rowId); return accessor.getBinary(rowId); } - /** - * Returns the data for the underlying array. - */ @Override public ArrowColumnVector arrayData() { return childColumns[0]; } - /** - * Returns the ordinal's child data column. - */ @Override public ArrowColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java similarity index 79% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index dc7c1269bedd9..d1196e1299fee 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.DataType; @@ -22,24 +22,31 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * This class represents in-memory values of a column and provides the main APIs to access the data. - * It supports all the types and contains get APIs as well as their batched versions. The batched - * versions are considered to be faster and preferable whenever possible. + * An interface representing in-memory columnar data in Spark. This interface defines the main APIs + * to access the data, as well as their batched versions. The batched versions are considered to be + * faster and preferable whenever possible. * - * To handle nested schemas, ColumnVector has two types: Arrays and Structs. In both cases these - * columns have child columns. All of the data are stored in the child columns and the parent column - * only contains nullability. In the case of Arrays, the lengths and offsets are saved in the child - * column and are encoded identically to INTs. + * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values + * in this ColumnVector. * - * Maps are just a special case of a two field struct. + * ColumnVector supports all the data types including nested types. To handle nested types, + * ColumnVector can have children and is a tree structure. For struct type, it stores the actual + * data of each field in the corresponding child ColumnVector, and only stores null information in + * the parent ColumnVector. For array type, it stores the actual array elements in the child + * ColumnVector, and stores null information, array offsets and lengths in the parent ColumnVector. * - * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values - * in the current batch. + * ColumnVector is expected to be reused during the entire data loading process, to avoid allocating + * memory again and again. + * + * ColumnVector is meant to maximize CPU efficiency but not to minimize storage footprint. + * Implementations should prefer computing efficiency over storage efficiency when design the + * format. Since it is expected to reuse the ColumnVector instance while loading data, the storage + * footprint is negligible. */ public abstract class ColumnVector implements AutoCloseable { /** - * Returns the data type of this column. + * Returns the data type of this column vector. */ public final DataType dataType() { return type; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index cbc39d1d0aec2..0d89a52e7a4fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; @@ -23,8 +23,7 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * Array abstraction in {@link ColumnVector}. The instance of this class is intended - * to be reused, callers should copy the data out if it needs to be stored. + * Array abstraction in {@link ColumnVector}. */ public final class ColumnarArray extends ArrayData { // The data for this array. This array contains elements from @@ -33,7 +32,7 @@ public final class ColumnarArray extends ArrayData { private final int offset; private final int length; - ColumnarArray(ColumnVector data, int offset, int length) { + public ColumnarArray(ColumnVector data, int offset, int length) { this.data = data; this.offset = offset; this.length = length; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java similarity index 73% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java index a9d09aa679726..9ae1c6d9993f0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java @@ -14,26 +14,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import java.util.*; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.vectorized.MutableColumnarRow; import org.apache.spark.sql.types.StructType; /** - * This class is the in memory representation of rows as they are streamed through operators. It - * is designed to maximize CPU efficiency and not storage footprint. Since it is expected that - * each operator allocates one of these objects, the storage footprint on the task is negligible. - * - * The layout is a columnar with values encoded in their native format. Each RowBatch contains - * a horizontal partitioning of the data, split into columns. - * - * The ColumnarBatch supports either on heap or offheap modes with (mostly) the identical API. - * - * TODO: - * - There are many TODOs for the existing APIs. They should throw a not implemented exception. - * - Compaction: The batch and columns should be able to compact based on a selection vector. + * This class wraps multiple ColumnVectors as a row-wise table. It provides a row view of this + * batch so that Spark can access the data row by row. Instance of it is meant to be reused during + * the entire data loading process. */ public final class ColumnarBatch { public static final int DEFAULT_BATCH_SIZE = 4 * 1024; @@ -57,7 +49,7 @@ public void close() { } /** - * Returns an iterator over the rows in this batch. This skips rows that are filtered out. + * Returns an iterator over the rows in this batch. */ public Iterator rowIterator() { final int maxRows = numRows; @@ -87,19 +79,7 @@ public void remove() { } /** - * Resets the batch for writing. - */ - public void reset() { - for (int i = 0; i < numCols(); ++i) { - if (columns[i] instanceof WritableColumnVector) { - ((WritableColumnVector) columns[i]).reset(); - } - } - this.numRows = 0; - } - - /** - * Sets the number of rows that are valid. + * Sets the number of rows in this batch. */ public void setNumRows(int numRows) { assert(numRows <= this.capacity); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java similarity index 96% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 8bb33ed5b78c0..3c6656dec77cd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; @@ -24,8 +24,7 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * Row abstraction in {@link ColumnVector}. The instance of this class is intended - * to be reused, callers should copy the data out if it needs to be stored. + * Row abstraction in {@link ColumnVector}. */ public final class ColumnarRow extends InternalRow { // The data for this row. @@ -34,7 +33,7 @@ public final class ColumnarRow extends InternalRow { private final int rowId; private final int numFields; - ColumnarRow(ColumnVector data, int rowId) { + public ColumnarRow(ColumnVector data, int rowId) { assert (data.dataType() instanceof StructType); this.data = data; this.rowId = rowId; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 782cec5e292ba..5617046e1396e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} /** * Helper trait for abstracting scan functionality using - * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]es. + * [[ColumnarBatch]]es. */ private[sql] trait ColumnarBatchScan extends CodegenSupport { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 9a6f1c6dfa6a9..ce3c68810f3b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.vectorized.{ColumnarRow, MutableColumnarRow} +import org.apache.spark.sql.execution.vectorized.MutableColumnarRow import org.apache.spark.sql.types.{DecimalType, StringType, StructType} import org.apache.spark.unsafe.KVIterator import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 0380ee8b09d63..0cf9b53ce1d5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext -import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, MutableColumnarRow, OnHeapColumnVector} +import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow, OnHeapColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch /** * This is a helper class to generate an append-only vectorized hash map that can act as a 'cache' diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index bcfc412430263..bcd1aa0890ba3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -32,8 +32,8 @@ import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 3e73393b12850..933b9753faa61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partition import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} case class InMemoryTableScanExec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 8731ee88f87f2..835ce98462477 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -26,7 +26,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{InputFileBlockHolder, RDD} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.vectorized.ColumnarBatch +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.NextIterator /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 5cc8ed3535654..dc5ba96e69aec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -30,8 +30,8 @@ import org.apache.spark._ import org.apache.spark.api.python._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.{ArrowUtils, ArrowWriter} -import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.Utils /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index 508c116aae92e..c42bc60a59d67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.arrow import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.execution.vectorized.ArrowColumnVector import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ArrowColumnVector import org.apache.spark.unsafe.types.UTF8String class ArrowWriterSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index 03490ad15a655..7304803a092c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -23,6 +23,7 @@ import org.apache.arrow.vector.complex._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ArrowColumnVector import org.apache.spark.unsafe.types.UTF8String class ArrowColumnVectorSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 54b31cee031f6..944240f3bade5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -21,10 +21,10 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow -import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.execution.columnar.ColumnAccessor import org.apache.spark.sql.execution.columnar.compression.ColumnBuilderHelper import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarArray import org.apache.spark.unsafe.types.UTF8String class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 7848ebdcab6d0..675f06b31b970 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.CalendarInterval @@ -918,10 +919,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(it.hasNext == false) // Reset and add 3 rows - batch.reset() - assert(batch.numRows() == 0) - assert(batch.rowIterator().hasNext == false) - + columns.foreach(_.reset()) // Add rows [NULL, 2.2, 2, "abc"], [3, NULL, 3, ""], [4, 4.4, 4, "world] columns(0).putNull(0) columns(1).putDouble(0, 2.2) From 7d045c5f00e2c7c67011830e2169a4e130c3ace8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 4 Jan 2018 13:14:52 +0800 Subject: [PATCH 008/774] [SPARK-22944][SQL] improve FoldablePropagation ## What changes were proposed in this pull request? `FoldablePropagation` is a little tricky as it needs to handle attributes that are miss-derived from children, e.g. outer join outputs. This rule does a kind of stop-able tree transform, to skip to apply this rule when hit a node which may have miss-derived attributes. Logically we should be able to apply this rule above the unsupported nodes, by just treating the unsupported nodes as leaf nodes. This PR improves this rule to not stop the tree transformation, but reduce the foldable expressions that we want to propagate. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #20139 from cloud-fan/foldable. --- .../sql/catalyst/optimizer/expressions.scala | 65 +++++++++++-------- .../optimizer/FoldablePropagationSuite.scala | 23 ++++++- 2 files changed, 58 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 7d830bbb7dc32..1c0b7bd806801 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -506,18 +506,21 @@ object NullPropagation extends Rule[LogicalPlan] { /** - * Propagate foldable expressions: * Replace attributes with aliases of the original foldable expressions if possible. - * Other optimizations will take advantage of the propagated foldable expressions. - * + * Other optimizations will take advantage of the propagated foldable expressions. For example, + * this rule can optimize * {{{ * SELECT 1.0 x, 'abc' y, Now() z ORDER BY x, y, 3 - * ==> SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now() * }}} + * to + * {{{ + * SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now() + * }}} + * and other rules can further optimize it and remove the ORDER BY operator. */ object FoldablePropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - val foldableMap = AttributeMap(plan.flatMap { + var foldableMap = AttributeMap(plan.flatMap { case Project(projectList, _) => projectList.collect { case a: Alias if a.child.foldable => (a.toAttribute, a) } @@ -530,38 +533,44 @@ object FoldablePropagation extends Rule[LogicalPlan] { if (foldableMap.isEmpty) { plan } else { - var stop = false CleanupAliases(plan.transformUp { - // A leaf node should not stop the folding process (note that we are traversing up the - // tree, starting at the leaf nodes); so we are allowing it. - case l: LeafNode => - l - // We can only propagate foldables for a subset of unary nodes. - case u: UnaryNode if !stop && canPropagateFoldables(u) => + case u: UnaryNode if foldableMap.nonEmpty && canPropagateFoldables(u) => u.transformExpressions(replaceFoldable) - // Allow inner joins. We do not allow outer join, although its output attributes are - // derived from its children, they are actually different attributes: the output of outer - // join is not always picked from its children, but can also be null. + // Join derives the output attributes from its child while they are actually not the + // same attributes. For example, the output of outer join is not always picked from its + // children, but can also be null. We should exclude these miss-derived attributes when + // propagating the foldable expressions. // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes // of outer join. - case j @ Join(_, _, Inner, _) if !stop => - j.transformExpressions(replaceFoldable) - - // We can fold the projections an expand holds. However expand changes the output columns - // and often reuses the underlying attributes; so we cannot assume that a column is still - // foldable after the expand has been applied. - // TODO(hvanhovell): Expand should use new attributes as the output attributes. - case expand: Expand if !stop => - val newExpand = expand.copy(projections = expand.projections.map { projection => + case j @ Join(left, right, joinType, _) if foldableMap.nonEmpty => + val newJoin = j.transformExpressions(replaceFoldable) + val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match { + case _: InnerLike | LeftExistence(_) => Nil + case LeftOuter => right.output + case RightOuter => left.output + case FullOuter => left.output ++ right.output + }) + foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot { + case (attr, _) => missDerivedAttrsSet.contains(attr) + }.toSeq) + newJoin + + // We can not replace the attributes in `Expand.output`. If there are other non-leaf + // operators that have the `output` field, we should put them here too. + case expand: Expand if foldableMap.nonEmpty => + expand.copy(projections = expand.projections.map { projection => projection.map(_.transform(replaceFoldable)) }) - stop = true - newExpand - case other => - stop = true + // For other plans, they are not safe to apply foldable propagation, and they should not + // propagate foldable expressions from children. + case other if foldableMap.nonEmpty => + val childrenOutputSet = AttributeSet(other.children.flatMap(_.output)) + foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot { + case (attr, _) => childrenOutputSet.contains(attr) + }.toSeq) other }) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala index dccb32f0379a8..c28844642aed0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala @@ -147,8 +147,8 @@ class FoldablePropagationSuite extends PlanTest { test("Propagate in expand") { val c1 = Literal(1).as('a) val c2 = Literal(2).as('b) - val a1 = c1.toAttribute.withNullability(true) - val a2 = c2.toAttribute.withNullability(true) + val a1 = c1.toAttribute.newInstance().withNullability(true) + val a2 = c2.toAttribute.newInstance().withNullability(true) val expand = Expand( Seq(Seq(Literal(null), 'b), Seq('a, Literal(null))), Seq(a1, a2), @@ -161,4 +161,23 @@ class FoldablePropagationSuite extends PlanTest { val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze comparePlans(optimized, correctAnswer) } + + test("Propagate above outer join") { + val left = LocalRelation('a.int).select('a, Literal(1).as('b)) + val right = LocalRelation('c.int).select('c, Literal(1).as('d)) + + val join = left.join( + right, + joinType = LeftOuter, + condition = Some('a === 'c && 'b === 'd)) + val query = join.select(('b + 3).as('res)).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = left.join( + right, + joinType = LeftOuter, + condition = Some('a === 'c && Literal(1) === Literal(1))) + .select((Literal(1) + 3).as('res)).analyze + comparePlans(optimized, correctAnswer) + } } From df95a908baf78800556636a76d58bba9b3dd943f Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Wed, 3 Jan 2018 21:43:14 -0800 Subject: [PATCH 009/774] [SPARK-22933][SPARKR] R Structured Streaming API for withWatermark, trigger, partitionBy ## What changes were proposed in this pull request? R Structured Streaming API for withWatermark, trigger, partitionBy ## How was this patch tested? manual, unit tests Author: Felix Cheung Closes #20129 from felixcheung/rwater. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 96 +++++++++++++++- R/pkg/R/SQLContext.R | 4 +- R/pkg/R/generics.R | 6 + R/pkg/tests/fulltests/test_streaming.R | 107 ++++++++++++++++++ python/pyspark/sql/streaming.py | 4 + .../sql/execution/streaming/Triggers.scala | 2 +- 7 files changed, 214 insertions(+), 6 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 3219c6f0cc47b..c51eb0f39c4b1 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -179,6 +179,7 @@ exportMethods("arrange", "with", "withColumn", "withColumnRenamed", + "withWatermark", "write.df", "write.jdbc", "write.json", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index fe238f6dd4eb0..9956f7eda91e6 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3661,7 +3661,8 @@ setMethod("getNumPartitions", #' isStreaming #' #' Returns TRUE if this SparkDataFrame contains one or more sources that continuously return data -#' as it arrives. +#' as it arrives. A dataset that reads data from a streaming source must be executed as a +#' \code{StreamingQuery} using \code{write.stream}. #' #' @param x A SparkDataFrame #' @return TRUE if this SparkDataFrame is from a streaming source @@ -3707,7 +3708,17 @@ setMethod("isStreaming", #' @param df a streaming SparkDataFrame. #' @param source a name for external data source. #' @param outputMode one of 'append', 'complete', 'update'. -#' @param ... additional argument(s) passed to the method. +#' @param partitionBy a name or a list of names of columns to partition the output by on the file +#' system. If specified, the output is laid out on the file system similar to Hive's +#' partitioning scheme. +#' @param trigger.processingTime a processing time interval as a string, e.g. '5 seconds', +#' '1 minute'. This is a trigger that runs a query periodically based on the processing +#' time. If value is '0 seconds', the query will run as fast as possible, this is the +#' default. Only one trigger can be set. +#' @param trigger.once a logical, must be set to \code{TRUE}. This is a trigger that processes only +#' one batch of data in a streaming query then terminates the query. Only one trigger can be +#' set. +#' @param ... additional external data source specific named options. #' #' @family SparkDataFrame functions #' @seealso \link{read.stream} @@ -3725,7 +3736,8 @@ setMethod("isStreaming", #' # console #' q <- write.stream(wordCounts, "console", outputMode = "complete") #' # text stream -#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp") +#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp" +#' partitionBy = c("year", "month"), trigger.processingTime = "30 seconds") #' # memory stream #' q <- write.stream(wordCounts, "memory", queryName = "outs", outputMode = "complete") #' head(sql("SELECT * from outs")) @@ -3737,7 +3749,8 @@ setMethod("isStreaming", #' @note experimental setMethod("write.stream", signature(df = "SparkDataFrame"), - function(df, source = NULL, outputMode = NULL, ...) { + function(df, source = NULL, outputMode = NULL, partitionBy = NULL, + trigger.processingTime = NULL, trigger.once = NULL, ...) { if (!is.null(source) && !is.character(source)) { stop("source should be character, NULL or omitted. It is the data source specified ", "in 'spark.sql.sources.default' configuration by default.") @@ -3748,12 +3761,43 @@ setMethod("write.stream", if (is.null(source)) { source <- getDefaultSqlSource() } + cols <- NULL + if (!is.null(partitionBy)) { + if (!all(sapply(partitionBy, function(c) { is.character(c) }))) { + stop("All partitionBy column names should be characters.") + } + cols <- as.list(partitionBy) + } + jtrigger <- NULL + if (!is.null(trigger.processingTime) && !is.na(trigger.processingTime)) { + if (!is.null(trigger.once)) { + stop("Multiple triggers not allowed.") + } + interval <- as.character(trigger.processingTime) + if (nchar(interval) == 0) { + stop("Value for trigger.processingTime must be a non-empty string.") + } + jtrigger <- handledCallJStatic("org.apache.spark.sql.streaming.Trigger", + "ProcessingTime", + interval) + } else if (!is.null(trigger.once) && !is.na(trigger.once)) { + if (!is.logical(trigger.once) || !trigger.once) { + stop("Value for trigger.once must be TRUE.") + } + jtrigger <- callJStatic("org.apache.spark.sql.streaming.Trigger", "Once") + } options <- varargsToStrEnv(...) write <- handledCallJMethod(df@sdf, "writeStream") write <- callJMethod(write, "format", source) if (!is.null(outputMode)) { write <- callJMethod(write, "outputMode", outputMode) } + if (!is.null(cols)) { + write <- callJMethod(write, "partitionBy", cols) + } + if (!is.null(jtrigger)) { + write <- callJMethod(write, "trigger", jtrigger) + } write <- callJMethod(write, "options", options) ssq <- handledCallJMethod(write, "start") streamingQuery(ssq) @@ -3967,3 +4011,47 @@ setMethod("broadcast", sdf <- callJStatic("org.apache.spark.sql.functions", "broadcast", x@sdf) dataFrame(sdf) }) + +#' withWatermark +#' +#' Defines an event time watermark for this streaming SparkDataFrame. A watermark tracks a point in +#' time before which we assume no more late data is going to arrive. +#' +#' Spark will use this watermark for several purposes: +#' \itemize{ +#' \item{-} To know when a given time window aggregation can be finalized and thus can be emitted +#' when using output modes that do not allow updates. +#' \item{-} To minimize the amount of state that we need to keep for on-going aggregations. +#' } +#' The current watermark is computed by looking at the \code{MAX(eventTime)} seen across +#' all of the partitions in the query minus a user specified \code{delayThreshold}. Due to the cost +#' of coordinating this value across partitions, the actual watermark used is only guaranteed +#' to be at least \code{delayThreshold} behind the actual event time. In some cases we may still +#' process records that arrive more than \code{delayThreshold} late. +#' +#' @param x a streaming SparkDataFrame +#' @param eventTime a string specifying the name of the Column that contains the event time of the +#' row. +#' @param delayThreshold a string specifying the minimum delay to wait to data to arrive late, +#' relative to the latest record that has been processed in the form of an +#' interval (e.g. "1 minute" or "5 hours"). NOTE: This should not be negative. +#' @return a SparkDataFrame. +#' @aliases withWatermark,SparkDataFrame,character,character-method +#' @family SparkDataFrame functions +#' @rdname withWatermark +#' @name withWatermark +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' schema <- structType(structField("time", "timestamp"), structField("value", "double")) +#' df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) +#' df <- withWatermark(df, "time", "10 minutes") +#' } +#' @note withWatermark since 2.3.0 +setMethod("withWatermark", + signature(x = "SparkDataFrame", eventTime = "character", delayThreshold = "character"), + function(x, eventTime, delayThreshold) { + sdf <- callJMethod(x@sdf, "withWatermark", eventTime, delayThreshold) + dataFrame(sdf) + }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 3b7f71bbbffb8..9d0a2d5e074e4 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -727,7 +727,9 @@ read.jdbc <- function(url, tableName, #' @param schema The data schema defined in structType or a DDL-formatted string, this is #' required for file-based streaming data source #' @param ... additional external data source specific named options, for instance \code{path} for -#' file-based streaming data source +#' file-based streaming data source. \code{timeZone} to indicate a timezone to be used to +#' parse timestamps in the JSON/CSV data sources or partition values; If it isn't set, it +#' uses the default value, session local timezone. #' @return SparkDataFrame #' @rdname read.stream #' @name read.stream diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 5369c32544e5e..e0dde3339fabc 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -799,6 +799,12 @@ setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn setGeneric("withColumnRenamed", function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) +#' @rdname withWatermark +#' @export +setGeneric("withWatermark", function(x, eventTime, delayThreshold) { + standardGeneric("withWatermark") +}) + #' @rdname write.df #' @export setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.df") }) diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R index 54f40bbd5f517..a354d50c6b54e 100644 --- a/R/pkg/tests/fulltests/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -172,6 +172,113 @@ test_that("Terminated by error", { stopQuery(q) }) +test_that("PartitionBy", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + checkpointPath <- tempfile(pattern = "sparkr-test", fileext = ".checkpoint") + textPath <- tempfile(pattern = "sparkr-test", fileext = ".text") + df <- read.df(jsonPath, "json", stringSchema) + write.df(df, parquetPath, "parquet", "overwrite") + + df <- read.stream(path = parquetPath, schema = stringSchema) + + expect_error(write.stream(df, "json", path = textPath, checkpointLocation = "append", + partitionBy = c(1, 2)), + "All partitionBy column names should be characters") + + q <- write.stream(df, "json", path = textPath, checkpointLocation = "append", + partitionBy = "name") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + dirs <- list.files(textPath) + expect_equal(length(dirs[substring(dirs, 1, nchar("name=")) == "name="]), 3) + + unlink(checkpointPath) + unlink(textPath) + unlink(parquetPath) +}) + +test_that("Watermark", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + schema <- structType(structField("value", "string")) + t <- Sys.time() + df <- as.DataFrame(lapply(list(t), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + df <- read.stream(path = parquetPath, schema = "value STRING") + df <- withColumn(df, "eventTime", cast(df$value, "timestamp")) + df <- withWatermark(df, "eventTime", "10 seconds") + counts <- count(group_by(df, "eventTime")) + q <- write.stream(counts, "memory", queryName = "times", outputMode = "append") + + # first events + df <- as.DataFrame(lapply(list(t + 1, t, t + 2), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # advance watermark to 15 + df <- as.DataFrame(lapply(list(t + 25), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # old events, should be dropped + df <- as.DataFrame(lapply(list(t), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # evict events less than previous watermark + df <- as.DataFrame(lapply(list(t + 25), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + times <- collect(sql("SELECT * FROM times")) + # looks like write timing can affect the first bucket; but it should be t + expect_equal(times[order(times$eventTime),][1, 2], 2) + + stopQuery(q) + unlink(parquetPath) +}) + +test_that("Trigger", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + schema <- structType(structField("value", "string")) + df <- as.DataFrame(lapply(list(Sys.time()), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + df <- read.stream(path = parquetPath, schema = "value STRING") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = "", trigger.once = ""), "Multiple triggers not allowed.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = ""), + "Value for trigger.processingTime must be a non-empty string.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = "invalid"), "illegal argument") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.once = ""), "Value for trigger.once must be TRUE.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.once = FALSE), "Value for trigger.once must be TRUE.") + + q <- write.stream(df, "memory", queryName = "times", outputMode = "append", trigger.once = TRUE) + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + df <- as.DataFrame(lapply(list(Sys.time()), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + expect_equal(nrow(collect(sql("SELECT * FROM times"))), 1) + + stopQuery(q) + unlink(parquetPath) +}) + unlink(jsonPath) unlink(jsonPathNa) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index fb228f99ba7ab..24ae3776a217b 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -793,6 +793,10 @@ def trigger(self, processingTime=None, once=None): .. note:: Evolving. :param processingTime: a processing time interval as a string, e.g. '5 seconds', '1 minute'. + Set a trigger that runs a query periodically based on the processing + time. Only one trigger can be set. + :param once: if set to True, set a trigger that processes only one batch of data in a + streaming query then terminates the query. Only one trigger can be set. >>> # trigger the query for execution every 5 seconds >>> writer = sdf.writeStream.trigger(processingTime='5 seconds') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala index 271bc4da99c08..19e3e55cb2829 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.streaming.Trigger /** - * A [[Trigger]] that process only one batch of data in a streaming query then terminates + * A [[Trigger]] that processes only one batch of data in a streaming query then terminates * the query. */ @Experimental From 9fa703e89318922393bae03c0db4575f4f4b4c56 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 4 Jan 2018 19:10:10 +0800 Subject: [PATCH 010/774] [SPARK-22950][SQL] Handle ChildFirstURLClassLoader's parent ## What changes were proposed in this pull request? ChildFirstClassLoader's parent is set to null, so we can't get jars from its parent. This will cause ClassNotFoundException during HiveClient initialization with builtin hive jars, where we may should use spark context loader instead. ## How was this patch tested? add new ut cc cloud-fan gatorsmile Author: Kent Yao Closes #20145 from yaooqinn/SPARK-22950. --- .../org/apache/spark/sql/hive/HiveUtils.scala | 4 +++- .../spark/sql/hive/HiveUtilsSuite.scala | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index c489690af8cd1..c7717d70c996f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.internal.StaticSQLConf.{CATALOG_IMPLEMENTATION, WAREHOUSE_PATH} import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ChildFirstURLClassLoader, Utils} private[spark] object HiveUtils extends Logging { @@ -312,6 +312,8 @@ private[spark] object HiveUtils extends Logging { // starting from the given classLoader. def allJars(classLoader: ClassLoader): Array[URL] = classLoader match { case null => Array.empty[URL] + case childFirst: ChildFirstURLClassLoader => + childFirst.getURLs() ++ allJars(Utils.getSparkClassLoader) case urlClassLoader: URLClassLoader => urlClassLoader.getURLs ++ allJars(urlClassLoader.getParent) case other => allJars(other.getParent) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala index fdbfcf1a68440..8697d47e89e89 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala @@ -17,11 +17,16 @@ package org.apache.spark.sql.hive +import java.net.URL + import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.QueryTest import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader} class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { @@ -42,4 +47,19 @@ class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton assert(hiveConf("foo") === "bar") } } + + test("ChildFirstURLClassLoader's parent is null, get spark classloader instead") { + val conf = new SparkConf + val contextClassLoader = Thread.currentThread().getContextClassLoader + val loader = new ChildFirstURLClassLoader(Array(), contextClassLoader) + try { + Thread.currentThread().setContextClassLoader(loader) + HiveUtils.newClientForMetadata( + conf, + SparkHadoopUtil.newConfiguration(conf), + HiveUtils.newTemporaryConfiguration(useInMemoryDerby = true)) + } finally { + Thread.currentThread().setContextClassLoader(contextClassLoader) + } + } } From d5861aba9d80ca15ad3f22793b79822e470d6913 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 4 Jan 2018 19:17:22 +0800 Subject: [PATCH 011/774] [SPARK-22945][SQL] add java UDF APIs in the functions object ## What changes were proposed in this pull request? Currently Scala users can use UDF like ``` val foo = udf((i: Int) => Math.random() + i).asNondeterministic df.select(foo('a)) ``` Python users can also do it with similar APIs. However Java users can't do it, we should add Java UDF APIs in the functions object. ## How was this patch tested? new tests Author: Wenchen Fan Closes #20141 from cloud-fan/udf. --- .../apache/spark/sql/UDFRegistration.scala | 90 ++--- .../sql/expressions/UserDefinedFunction.scala | 1 + .../org/apache/spark/sql/functions.scala | 313 ++++++++++++++---- .../apache/spark/sql/JavaDataFrameSuite.java | 11 + .../scala/org/apache/spark/sql/UDFSuite.scala | 12 +- 5 files changed, 315 insertions(+), 112 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index dc2468a721e41..f94baef39dfad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.lang.reflect.{ParameterizedType, Type} +import java.lang.reflect.ParameterizedType import scala.reflect.runtime.universe.TypeTag import scala.util.Try @@ -110,29 +110,29 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends /* register 0-22 were generated by this script - (0 to 22).map { x => + (0 to 22).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) - val typeTags = (1 to x).map(i => s"A${i}: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor[A$i].dataType :: $s"}) println(s""" - /** - * Registers a deterministic Scala closure of ${x} arguments as user-defined function (UDF). - * @tparam RT return type of UDF. - * @since 1.3.0 - */ - def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try($inputTypes).toOption - def builder(e: Seq[Expression]) = if (e.length == $x) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) - } else { - throw new AnalysisException("Invalid number of arguments for function " + name + - ". Expected: $x; Found: " + e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) - if (nullable) udf else udf.asNonNullable() - }""") + |/** + | * Registers a deterministic Scala closure of $x arguments as user-defined function (UDF). + | * @tparam RT return type of UDF. + | * @since 1.3.0 + | */ + |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { + | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + | val inputTypes = Try($inputTypes).toOption + | def builder(e: Seq[Expression]) = if (e.length == $x) { + | ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + | } else { + | throw new AnalysisException("Invalid number of arguments for function " + name + + | ". Expected: $x; Found: " + e.length) + | } + | functionRegistry.createOrReplaceTempFunction(name, builder) + | val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + | if (nullable) udf else udf.asNonNullable() + |}""".stripMargin) } (0 to 22).foreach { i => @@ -144,7 +144,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val funcCall = if (i == 0) "() => func" else "func" println(s""" |/** - | * Register a user-defined function with ${i} arguments. + | * Register a deterministic Java UDF$i instance as user-defined function (UDF). | * @since $version | */ |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { @@ -689,7 +689,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 0 arguments. + * Register a deterministic Java UDF0 instance as user-defined function (UDF). * @since 2.3.0 */ def register(name: String, f: UDF0[_], returnType: DataType): Unit = { @@ -704,7 +704,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 1 arguments. + * Register a deterministic Java UDF1 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { @@ -719,7 +719,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 2 arguments. + * Register a deterministic Java UDF2 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { @@ -734,7 +734,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 3 arguments. + * Register a deterministic Java UDF3 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { @@ -749,7 +749,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 4 arguments. + * Register a deterministic Java UDF4 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { @@ -764,7 +764,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 5 arguments. + * Register a deterministic Java UDF5 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { @@ -779,7 +779,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 6 arguments. + * Register a deterministic Java UDF6 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -794,7 +794,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 7 arguments. + * Register a deterministic Java UDF7 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -809,7 +809,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 8 arguments. + * Register a deterministic Java UDF8 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -824,7 +824,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 9 arguments. + * Register a deterministic Java UDF9 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -839,7 +839,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 10 arguments. + * Register a deterministic Java UDF10 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -854,7 +854,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 11 arguments. + * Register a deterministic Java UDF11 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -869,7 +869,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 12 arguments. + * Register a deterministic Java UDF12 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -884,7 +884,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 13 arguments. + * Register a deterministic Java UDF13 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -899,7 +899,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 14 arguments. + * Register a deterministic Java UDF14 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -914,7 +914,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 15 arguments. + * Register a deterministic Java UDF15 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -929,7 +929,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 16 arguments. + * Register a deterministic Java UDF16 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -944,7 +944,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 17 arguments. + * Register a deterministic Java UDF17 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -959,7 +959,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 18 arguments. + * Register a deterministic Java UDF18 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -974,7 +974,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 19 arguments. + * Register a deterministic Java UDF19 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -989,7 +989,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 20 arguments. + * Register a deterministic Java UDF20 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -1004,7 +1004,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 21 arguments. + * Register a deterministic Java UDF21 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -1019,7 +1019,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 22 arguments. + * Register a deterministic Java UDF22 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 03b654f830520..40a058d2cadd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -66,6 +66,7 @@ case class UserDefinedFunction protected[sql] ( * * @since 1.3.0 */ + @scala.annotation.varargs def apply(exprs: Column*): Column = { Column(ScalaUDF( f, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 530a525a01dec..0d11682d80a3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -24,6 +24,7 @@ import scala.util.Try import scala.util.control.NonFatal import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -32,7 +33,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.UserDefinedFunction -import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -3254,42 +3254,66 @@ object functions { */ def map_values(e: Column): Column = withExpr { MapValues(e.expr) } - ////////////////////////////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////////////////////////////// - // scalastyle:off line.size.limit // scalastyle:off parameter.number /* Use the following code to generate: - (0 to 10).map { x => + + (0 to 10).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]).dataType :: $s"}) println(s""" - /** - * Defines a deterministic user-defined function of ${x} arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. - * - * @group udf_funcs - * @since 1.3.0 - */ - def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try($inputTypes).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) - if (nullable) udf else udf.asNonNullable() - }""") + |/** + | * Defines a Scala closure of $x arguments as user-defined function (UDF). + | * The data types are automatically inferred based on the Scala closure's + | * signature. By default the returned UDF is deterministic. To change it to + | * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + | * + | * @group udf_funcs + | * @since 1.3.0 + | */ + |def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { + | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + | val inputTypes = Try($inputTypes).toOption + | val udf = UserDefinedFunction(f, dataType, inputTypes) + | if (nullable) udf else udf.asNonNullable() + |}""".stripMargin) + } + + (0 to 10).foreach { i => + val extTypeArgs = (0 to i).map(_ => "_").mkString(", ") + val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ") + val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]" + val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") + val funcCall = if (i == 0) "() => func" else "func" + println(s""" + |/** + | * Defines a Java UDF$i instance as user-defined function (UDF). + | * The caller must specify the output data type, and there is no automatic input type coercion. + | * By default the returned UDF is deterministic. To change it to nondeterministic, call the + | * API `UserDefinedFunction.asNondeterministic()`. + | * + | * @group udf_funcs + | * @since 2.3.0 + | */ + |def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = { + | val func = f$anyCast.call($anyParams) + | UserDefinedFunction($funcCall, returnType, inputTypes = None) + |}""".stripMargin) } */ + ////////////////////////////////////////////////////////////////////////////////////////////// + // Scala UDF functions + ////////////////////////////////////////////////////////////////////////////////////////////// + /** - * Defines a deterministic user-defined function of 0 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 0 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3302,10 +3326,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 1 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 1 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3318,10 +3342,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 2 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 2 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3334,10 +3358,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 3 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 3 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3350,10 +3374,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 4 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 4 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3366,10 +3390,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 5 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 5 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3382,10 +3406,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 6 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 6 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3398,10 +3422,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 7 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 7 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3414,10 +3438,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 8 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 8 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3430,10 +3454,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 9 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 9 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3446,10 +3470,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 10 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 10 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3461,13 +3485,172 @@ object functions { if (nullable) udf else udf.asNonNullable() } + ////////////////////////////////////////////////////////////////////////////////////////////// + // Java UDF functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Defines a Java UDF0 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF0[Any]].call() + UserDefinedFunction(() => func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF1 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF2 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF3 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF4 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF5 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF6 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF7 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF8 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF9 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF10 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + // scalastyle:on parameter.number // scalastyle:on line.size.limit /** * Defines a deterministic user-defined function (UDF) using a Scala closure. For this variant, * the caller must specify the output data type, and there is no automatic input type coercion. - * To change a UDF to nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. * * @param f A closure in Scala * @param dataType The output data type of the UDF diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index b007093dad84b..4f8a31f185724 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -36,6 +36,7 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.expressions.UserDefinedFunction; import org.apache.spark.sql.test.TestSparkSession; import org.apache.spark.sql.types.*; import org.apache.spark.util.sketch.BloomFilter; @@ -455,4 +456,14 @@ public void testCircularReferenceBean() { CircularReference1Bean bean = new CircularReference1Bean(); spark.createDataFrame(Arrays.asList(bean), CircularReference1Bean.class); } + + @Test + public void testUDF() { + UserDefinedFunction foo = udf((Integer i, String s) -> i.toString() + s, DataTypes.StringType); + Dataset df = spark.table("testData").select(foo.apply(col("key"), col("value"))); + String[] result = df.collectAsList().stream().map(row -> row.getString(0)).toArray(String[]::new); + String[] expected = spark.table("testData").collectAsList().stream() + .map(row -> row.get(0).toString() + row.getString(1)).toArray(String[]::new); + Assert.assertArrayEquals(expected, result); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 7f1c009ca6e7a..db37be68e42e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql +import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.command.ExplainCommand -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ -import org.apache.spark.sql.types.DataTypes +import org.apache.spark.sql.types.{DataTypes, DoubleType} private case class FunctionResult(f1: String, f2: String) @@ -128,6 +129,13 @@ class UDFSuite extends QueryTest with SharedSQLContext { val df2 = testData.select(bar()) assert(df2.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) assert(df2.head().getDouble(0) >= 0.0) + + val javaUdf = udf(new UDF0[Double] { + override def call(): Double = Math.random() + }, DoubleType).asNondeterministic() + val df3 = testData.select(javaUdf()) + assert(df3.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) + assert(df3.head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { From 5aadbc929cb194e06dbd3bab054a161569289af5 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 4 Jan 2018 21:07:31 +0800 Subject: [PATCH 012/774] [SPARK-22939][PYSPARK] Support Spark UDF in registerFunction ## What changes were proposed in this pull request? ```Python import random from pyspark.sql.functions import udf from pyspark.sql.types import IntegerType, StringType random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic() spark.catalog.registerFunction("random_udf", random_udf, StringType()) spark.sql("SELECT random_udf()").collect() ``` We will get the following error. ``` Py4JError: An error occurred while calling o29.__getnewargs__. Trace: py4j.Py4JException: Method __getnewargs__([]) does not exist at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318) at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326) at py4j.Gateway.invoke(Gateway.java:274) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:214) at java.lang.Thread.run(Thread.java:745) ``` This PR is to support it. ## How was this patch tested? WIP Author: gatorsmile Closes #20137 from gatorsmile/registerFunction. --- python/pyspark/sql/catalog.py | 27 +++++++++++++++---- python/pyspark/sql/context.py | 16 +++++++++--- python/pyspark/sql/tests.py | 49 +++++++++++++++++++++++++---------- python/pyspark/sql/udf.py | 21 ++++++++++----- 4 files changed, 84 insertions(+), 29 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 659bc65701a0c..156603128d063 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -227,15 +227,15 @@ def dropGlobalTempView(self, viewName): @ignore_unicode_prefix @since(2.0) def registerFunction(self, name, f, returnType=StringType()): - """Registers a python function (including lambda function) as a UDF - so it can be used in SQL statements. + """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` + as a UDF. The registered UDF can be used in SQL statement. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. :param name: name of the UDF - :param f: python function + :param f: a Python function, or a wrapped/native UserDefinedFunction :param returnType: a :class:`pyspark.sql.types.DataType` object :return: a wrapped :class:`UserDefinedFunction` @@ -255,9 +255,26 @@ def registerFunction(self, name, f, returnType=StringType()): >>> _ = spark.udf.register("stringLengthInt", len, IntegerType()) >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] + + >>> import random + >>> from pyspark.sql.functions import udf + >>> from pyspark.sql.types import IntegerType, StringType + >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() + >>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType()) + >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP + [Row(random_udf()=u'82')] + >>> spark.range(1).select(newRandom_udf()).collect() # doctest: +SKIP + [Row(random_udf()=u'62')] """ - udf = UserDefinedFunction(f, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF) + + # This is to check whether the input function is a wrapped/native UserDefinedFunction + if hasattr(f, 'asNondeterministic'): + udf = UserDefinedFunction(f.func, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF, + deterministic=f.deterministic) + else: + udf = UserDefinedFunction(f, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF) self._jsparkSession.udf().registerPython(name, udf._judf) return udf._wrapped() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index b1e723cdecef3..b8d86cc098e94 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -175,15 +175,15 @@ def range(self, start, end=None, step=1, numPartitions=None): @ignore_unicode_prefix @since(1.2) def registerFunction(self, name, f, returnType=StringType()): - """Registers a python function (including lambda function) as a UDF - so it can be used in SQL statements. + """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` + as a UDF. The registered UDF can be used in SQL statement. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. :param name: name of the UDF - :param f: python function + :param f: a Python function, or a wrapped/native UserDefinedFunction :param returnType: a :class:`pyspark.sql.types.DataType` object :return: a wrapped :class:`UserDefinedFunction` @@ -203,6 +203,16 @@ def registerFunction(self, name, f, returnType=StringType()): >>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] + + >>> import random + >>> from pyspark.sql.functions import udf + >>> from pyspark.sql.types import IntegerType, StringType + >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() + >>> newRandom_udf = sqlContext.registerFunction("random_udf", random_udf, StringType()) + >>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP + [Row(random_udf()=u'82')] + >>> sqlContext.range(1).select(newRandom_udf()).collect() # doctest: +SKIP + [Row(random_udf()=u'62')] """ return self.sparkSession.catalog.registerFunction(name, f, returnType) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 67bdb3d72d93b..6dc767f9ec46e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -378,6 +378,41 @@ def test_udf2(self): [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) + def test_udf3(self): + twoargs = self.spark.catalog.registerFunction( + "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType()) + self.assertEqual(twoargs.deterministic, True) + [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_nondeterministic_udf(self): + from pyspark.sql.functions import udf + import random + udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() + self.assertEqual(udf_random_col.deterministic, False) + df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) + udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) + [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() + self.assertEqual(row[0] + 10, row[1]) + + def test_nondeterministic_udf2(self): + import random + from pyspark.sql.functions import udf + random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic() + self.assertEqual(random_udf.deterministic, False) + random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType()) + self.assertEqual(random_udf1.deterministic, False) + [row] = self.spark.sql("SELECT randInt()").collect() + self.assertEqual(row[0], "6") + [row] = self.spark.range(1).select(random_udf1()).collect() + self.assertEqual(row[0], "6") + [row] = self.spark.range(1).select(random_udf()).collect() + self.assertEqual(row[0], 6) + # render_doc() reproduces the help() exception without printing output + pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType())) + pydoc.render_doc(random_udf) + pydoc.render_doc(random_udf1) + def test_chained_udf(self): self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) [row] = self.spark.sql("SELECT double(1)").collect() @@ -435,15 +470,6 @@ def test_udf_with_array_type(self): self.assertEqual(list(range(3)), l1) self.assertEqual(1, l2) - def test_nondeterministic_udf(self): - from pyspark.sql.functions import udf - import random - udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() - df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) - udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) - [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() - self.assertEqual(row[0] + 10, row[1]) - def test_broadcast_in_udf(self): bar = {"a": "aa", "b": "bb", "c": "abc"} foo = self.sc.broadcast(bar) @@ -567,7 +593,6 @@ def test_read_multiple_orc_file(self): def test_udf_with_input_file_name(self): from pyspark.sql.functions import udf, input_file_name - from pyspark.sql.types import StringType sourceFile = udf(lambda path: path, StringType()) filePath = "python/test_support/sql/people1.json" row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first() @@ -575,7 +600,6 @@ def test_udf_with_input_file_name(self): def test_udf_with_input_file_name_for_hadooprdd(self): from pyspark.sql.functions import udf, input_file_name - from pyspark.sql.types import StringType def filename(path): return path @@ -635,7 +659,6 @@ def test_udf_with_string_return_type(self): def test_udf_shouldnt_accept_noncallable_object(self): from pyspark.sql.functions import UserDefinedFunction - from pyspark.sql.types import StringType non_callable = None self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType()) @@ -1299,7 +1322,6 @@ def test_between_function(self): df.filter(df.a.between(df.b, df.c)).collect()) def test_struct_type(self): - from pyspark.sql.types import StructType, StringType, StructField struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) struct2 = StructType([StructField("f1", StringType(), True), StructField("f2", StringType(), True, None)]) @@ -1368,7 +1390,6 @@ def test_parse_datatype_string(self): _parse_datatype_string("a INT, c DOUBLE")) def test_metadata_null(self): - from pyspark.sql.types import StructType, StringType, StructField schema = StructType([StructField("f1", StringType(), True, None), StructField("f2", StringType(), True, {'a': None})]) rdd = self.sc.parallelize([["a", "b"], ["c", "d"]]) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 54b5a8656e1c8..5e75eb6545333 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -56,7 +56,8 @@ def _create_udf(f, returnType, evalType): ) # Set the name of the UserDefinedFunction object to be the name of function f - udf_obj = UserDefinedFunction(f, returnType=returnType, name=None, evalType=evalType) + udf_obj = UserDefinedFunction( + f, returnType=returnType, name=None, evalType=evalType, deterministic=True) return udf_obj._wrapped() @@ -67,8 +68,10 @@ class UserDefinedFunction(object): .. versionadded:: 1.3 """ def __init__(self, func, - returnType=StringType(), name=None, - evalType=PythonEvalType.SQL_BATCHED_UDF): + returnType=StringType(), + name=None, + evalType=PythonEvalType.SQL_BATCHED_UDF, + deterministic=True): if not callable(func): raise TypeError( "Invalid function: not a function or callable (__call__ is not defined): " @@ -92,7 +95,7 @@ def __init__(self, func, func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) self.evalType = evalType - self._deterministic = True + self.deterministic = deterministic @property def returnType(self): @@ -130,7 +133,7 @@ def _create_judf(self): wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self.evalType, self._deterministic) + self._name, wrapped_func, jdt, self.evalType, self.deterministic) return judf def __call__(self, *cols): @@ -138,6 +141,9 @@ def __call__(self, *cols): sc = SparkContext._active_spark_context return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) + # This function is for improving the online help system in the interactive interpreter. + # For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and + # argument annotation. (See: SPARK-19161) def _wrapped(self): """ Wrap this udf with a function and attach docstring from func @@ -162,7 +168,8 @@ def wrapper(*args): wrapper.func = self.func wrapper.returnType = self.returnType wrapper.evalType = self.evalType - wrapper.asNondeterministic = self.asNondeterministic + wrapper.deterministic = self.deterministic + wrapper.asNondeterministic = lambda: self.asNondeterministic()._wrapped() return wrapper @@ -172,5 +179,5 @@ def asNondeterministic(self): .. versionadded:: 2.3 """ - self._deterministic = False + self.deterministic = False return self From 6f68316e98fad72b171df422566e1fc9a7bbfcde Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 4 Jan 2018 21:15:10 +0800 Subject: [PATCH 013/774] [SPARK-22771][SQL] Add a missing return statement in Concat.checkInputDataTypes ## What changes were proposed in this pull request? This pr is a follow-up to fix a bug left in #19977. ## How was this patch tested? Added tests in `StringExpressionsSuite`. Author: Takeshi Yamamuro Closes #20149 from maropu/SPARK-22771-FOLLOWUP. --- .../sql/catalyst/expressions/stringExpressions.scala | 2 +- .../expressions/StringExpressionsSuite.scala | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index b0da55a4a961b..41dc762154a4c 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -58,7 +58,7 @@ case class Concat(children: Seq[Expression]) extends Expression { } else { val childTypes = children.map(_.dataType) if (childTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { - TypeCheckResult.TypeCheckFailure( + return TypeCheckResult.TypeCheckFailure( s"input to function $prettyName should have StringType or BinaryType, but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 54cde77176e27..97ddbeba2c5ca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -51,6 +51,18 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Concat(strs.map(Literal.create(_, StringType))), strs.mkString, EmptyRow) } + test("SPARK-22771 Check Concat.checkInputDataTypes results") { + assert(Concat(Seq.empty[Expression]).checkInputDataTypes().isSuccess) + assert(Concat(Literal.create("a") :: Literal.create("b") :: Nil) + .checkInputDataTypes().isSuccess) + assert(Concat(Literal.create("a".getBytes) :: Literal.create("b".getBytes) :: Nil) + .checkInputDataTypes().isSuccess) + assert(Concat(Literal.create(1) :: Literal.create(2) :: Nil) + .checkInputDataTypes().isFailure) + assert(Concat(Literal.create("a") :: Literal.create("b".getBytes) :: Nil) + .checkInputDataTypes().isFailure) + } + test("concat_ws") { def testConcatWs(expected: String, sep: String, inputs: Any*): Unit = { val inputExprs = inputs.map { From 93f92c0ed7442a4382e97254307309977ff676f8 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 4 Jan 2018 11:39:42 -0800 Subject: [PATCH 014/774] [SPARK-21475][CORE][2ND ATTEMPT] Change to use NIO's Files API for external shuffle service ## What changes were proposed in this pull request? This PR is the second attempt of #18684 , NIO's Files API doesn't override `skip` method for `InputStream`, so it will bring in performance issue (mentioned in #20119). But using `FileInputStream`/`FileOutputStream` will also bring in memory issue (https://dzone.com/articles/fileinputstream-fileoutputstream-considered-harmful), which is severe for long running external shuffle service. So here in this proposal, only fixing the external shuffle service related code. ## How was this patch tested? Existing tests. Author: jerryshao Closes #20144 from jerryshao/SPARK-21475-v2. --- .../apache/spark/network/buffer/FileSegmentManagedBuffer.java | 3 ++- .../apache/spark/network/shuffle/ShuffleIndexInformation.java | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index c20fab83c3460..8b8f9892847c3 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -24,6 +24,7 @@ import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; +import java.nio.file.StandardOpenOption; import com.google.common.base.Objects; import com.google.common.io.ByteStreams; @@ -132,7 +133,7 @@ public Object convertToNetty() throws IOException { if (conf.lazyFileDescriptor()) { return new DefaultFileRegion(file, offset, length); } else { - FileChannel fileChannel = new FileInputStream(file).getChannel(); + FileChannel fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ); return new DefaultFileRegion(fileChannel, offset, length); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java index eacf485344b76..386738ece51a6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java @@ -19,10 +19,10 @@ import java.io.DataInputStream; import java.io.File; -import java.io.FileInputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.LongBuffer; +import java.nio.file.Files; /** * Keeps the index information for a particular map output @@ -39,7 +39,7 @@ public ShuffleIndexInformation(File indexFile) throws IOException { offsets = buffer.asLongBuffer(); DataInputStream dis = null; try { - dis = new DataInputStream(new FileInputStream(indexFile)); + dis = new DataInputStream(Files.newInputStream(indexFile.toPath())); dis.readFully(buffer.array()); } finally { if (dis != null) { From d2cddc88eac32f26b18ec26bb59e85c6f09a8c88 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 4 Jan 2018 16:19:00 -0600 Subject: [PATCH 015/774] [SPARK-22850][CORE] Ensure queued events are delivered to all event queues. The code in LiveListenerBus was queueing events before start in the queues themselves; so in situations like the following: bus.post(someEvent) bus.addToEventLogQueue(listener) bus.start() "someEvent" would not be delivered to "listener" if that was the first listener in the queue, because the queue wouldn't exist when the event was posted. This change buffers the events before starting the bus in the bus itself, so that they can be delivered to all registered queues when the bus is started. Also tweaked the unit tests to cover the behavior above. Author: Marcelo Vanzin Closes #20039 from vanzin/SPARK-22850. --- .../spark/scheduler/LiveListenerBus.scala | 45 ++++++++++++++++--- .../spark/scheduler/SparkListenerSuite.scala | 21 +++++---- 2 files changed, 52 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 23121402b1025..ba6387a8f08ad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -62,6 +62,9 @@ private[spark] class LiveListenerBus(conf: SparkConf) { private val queues = new CopyOnWriteArrayList[AsyncEventQueue]() + // Visible for testing. + @volatile private[scheduler] var queuedEvents = new mutable.ListBuffer[SparkListenerEvent]() + /** Add a listener to queue shared by all non-internal listeners. */ def addToSharedQueue(listener: SparkListenerInterface): Unit = { addToQueue(listener, SHARED_QUEUE) @@ -125,13 +128,39 @@ private[spark] class LiveListenerBus(conf: SparkConf) { /** Post an event to all queues. */ def post(event: SparkListenerEvent): Unit = { - if (!stopped.get()) { - metrics.numEventsPosted.inc() - val it = queues.iterator() - while (it.hasNext()) { - it.next().post(event) + if (stopped.get()) { + return + } + + metrics.numEventsPosted.inc() + + // If the event buffer is null, it means the bus has been started and we can avoid + // synchronization and post events directly to the queues. This should be the most + // common case during the life of the bus. + if (queuedEvents == null) { + postToQueues(event) + return + } + + // Otherwise, need to synchronize to check whether the bus is started, to make sure the thread + // calling start() picks up the new event. + synchronized { + if (!started.get()) { + queuedEvents += event + return } } + + // If the bus was already started when the check above was made, just post directly to the + // queues. + postToQueues(event) + } + + private def postToQueues(event: SparkListenerEvent): Unit = { + val it = queues.iterator() + while (it.hasNext()) { + it.next().post(event) + } } /** @@ -149,7 +178,11 @@ private[spark] class LiveListenerBus(conf: SparkConf) { } this.sparkContext = sc - queues.asScala.foreach(_.start(sc)) + queues.asScala.foreach { q => + q.start(sc) + queuedEvents.foreach(q.post) + } + queuedEvents = null metricsSystem.registerSource(metrics) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 1beb36afa95f0..da6ecb82c7e42 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -48,7 +48,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match bus.metrics.metricRegistry.counter(s"queue.$SHARED_QUEUE.numDroppedEvents").getCount } - private def queueSize(bus: LiveListenerBus): Int = { + private def sharedQueueSize(bus: LiveListenerBus): Int = { bus.metrics.metricRegistry.getGauges().get(s"queue.$SHARED_QUEUE.size").getValue() .asInstanceOf[Int] } @@ -73,12 +73,11 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val conf = new SparkConf() val counter = new BasicJobCounter val bus = new LiveListenerBus(conf) - bus.addToSharedQueue(counter) // Metrics are initially empty. assert(bus.metrics.numEventsPosted.getCount === 0) assert(numDroppedEvents(bus) === 0) - assert(queueSize(bus) === 0) + assert(bus.queuedEvents.size === 0) assert(eventProcessingTimeCount(bus) === 0) // Post five events: @@ -87,7 +86,10 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Five messages should be marked as received and queued, but no messages should be posted to // listeners yet because the the listener bus hasn't been started. assert(bus.metrics.numEventsPosted.getCount === 5) - assert(queueSize(bus) === 5) + assert(bus.queuedEvents.size === 5) + + // Add the counter to the bus after messages have been queued for later delivery. + bus.addToSharedQueue(counter) assert(counter.count === 0) // Starting listener bus should flush all buffered events @@ -95,9 +97,12 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match Mockito.verify(mockMetricsSystem).registerSource(bus.metrics) bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(counter.count === 5) - assert(queueSize(bus) === 0) + assert(sharedQueueSize(bus) === 0) assert(eventProcessingTimeCount(bus) === 5) + // After the bus is started, there should be no more queued events. + assert(bus.queuedEvents === null) + // After listener bus has stopped, posting events should not increment counter bus.stop() (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } @@ -188,18 +193,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Post a message to the listener bus and wait for processing to begin: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) listenerStarted.acquire() - assert(queueSize(bus) === 0) + assert(sharedQueueSize(bus) === 0) assert(numDroppedEvents(bus) === 0) // If we post an additional message then it should remain in the queue because the listener is // busy processing the first event: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - assert(queueSize(bus) === 1) + assert(sharedQueueSize(bus) === 1) assert(numDroppedEvents(bus) === 0) // The queue is now full, so any additional events posted to the listener will be dropped: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - assert(queueSize(bus) === 1) + assert(sharedQueueSize(bus) === 1) assert(numDroppedEvents(bus) === 1) // Allow the the remaining events to be processed so we can stop the listener bus: From 95f9659abe8845f9f3f42fd7ababd79e55c52489 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 4 Jan 2018 15:00:09 -0800 Subject: [PATCH 016/774] [SPARK-22948][K8S] Move SparkPodInitContainer to correct package. Author: Marcelo Vanzin Closes #20156 from vanzin/SPARK-22948. --- dev/sparktestsupport/modules.py | 2 +- .../spark/deploy/{rest => }/k8s/SparkPodInitContainer.scala | 2 +- .../deploy/{rest => }/k8s/SparkPodInitContainerSuite.scala | 2 +- .../docker/src/main/dockerfiles/init-container/Dockerfile | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) rename resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/{rest => }/k8s/SparkPodInitContainer.scala (99%) rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/{rest => }/k8s/SparkPodInitContainerSuite.scala (98%) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index f834563da9dda..7164180a6a7b0 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -539,7 +539,7 @@ def __hash__(self): kubernetes = Module( name="kubernetes", dependencies=[], - source_file_regexes=["resource-managers/kubernetes/core"], + source_file_regexes=["resource-managers/kubernetes"], build_profile_flags=["-Pkubernetes"], sbt_test_goals=["kubernetes/test"] ) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala similarity index 99% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala index 4a4b628aedbbf..c0f08786b76a1 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.rest.k8s +package org.apache.spark.deploy.k8s import java.io.File import java.util.concurrent.TimeUnit diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala similarity index 98% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala index 6c557ec4a7c9a..e0f29ecd0fb53 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.rest.k8s +package org.apache.spark.deploy.k8s import java.io.File import java.util.UUID diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile index 055493188fcb7..047056ab2633b 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile @@ -21,4 +21,4 @@ FROM spark-base # command should be invoked from the top level directory of the Spark distribution. E.g.: # docker build -t spark-init:latest -f kubernetes/dockerfiles/init-container/Dockerfile . -ENTRYPOINT [ "/opt/entrypoint.sh", "/opt/spark/bin/spark-class", "org.apache.spark.deploy.rest.k8s.SparkPodInitContainer" ] +ENTRYPOINT [ "/opt/entrypoint.sh", "/opt/spark/bin/spark-class", "org.apache.spark.deploy.k8s.SparkPodInitContainer" ] From e288fc87a027ec1e1a21401d1f151df20dbfecf3 Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Thu, 4 Jan 2018 15:35:20 -0800 Subject: [PATCH 017/774] [SPARK-22953][K8S] Avoids adding duplicated secret volumes when init-container is used ## What changes were proposed in this pull request? User-specified secrets are mounted into both the main container and init-container (when it is used) in a Spark driver/executor pod, using the `MountSecretsBootstrap`. Because `MountSecretsBootstrap` always adds new secret volumes for the secrets to the pod, the same secret volumes get added twice, one when mounting the secrets to the main container, and the other when mounting the secrets to the init-container. This PR fixes the issue by separating `MountSecretsBootstrap.mountSecrets` out into two methods: `addSecretVolumes` for adding secret volumes to a pod and `mountSecrets` for mounting secret volumes to a container, respectively. `addSecretVolumes` is only called once for each pod, whereas `mountSecrets` is called individually for the main container and the init-container (if it is used). Ref: https://github.com/apache-spark-on-k8s/spark/issues/594. ## How was this patch tested? Unit tested and manually tested. vanzin This replaces https://github.com/apache/spark/pull/20148. hex108 foxish kimoonkim Author: Yinan Li Closes #20159 from liyinan926/master. --- .../deploy/k8s/MountSecretsBootstrap.scala | 30 ++++++++++++------- .../k8s/submit/DriverConfigOrchestrator.scala | 16 +++++----- .../steps/BasicDriverConfigurationStep.scala | 2 +- .../submit/steps/DriverMountSecretsStep.scala | 4 +-- .../InitContainerMountSecretsStep.scala | 11 +++---- .../cluster/k8s/ExecutorPodFactory.scala | 6 ++-- .../k8s/{submit => }/SecretVolumeUtils.scala | 18 +++++------ .../BasicDriverConfigurationStepSuite.scala | 4 +-- .../steps/DriverMountSecretsStepSuite.scala | 4 +-- .../InitContainerMountSecretsStepSuite.scala | 7 +---- .../cluster/k8s/ExecutorPodFactorySuite.scala | 14 +++++---- 11 files changed, 61 insertions(+), 55 deletions(-) rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/{submit => }/SecretVolumeUtils.scala (71%) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala index 8286546ce0641..c35e7db51d407 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala @@ -24,26 +24,36 @@ import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBui private[spark] class MountSecretsBootstrap(secretNamesToMountPaths: Map[String, String]) { /** - * Mounts Kubernetes secrets as secret volumes into the given container in the given pod. + * Add new secret volumes for the secrets specified in secretNamesToMountPaths into the given pod. * * @param pod the pod into which the secret volumes are being added. - * @param container the container into which the secret volumes are being mounted. - * @return the updated pod and container with the secrets mounted. + * @return the updated pod with the secret volumes added. */ - def mountSecrets(pod: Pod, container: Container): (Pod, Container) = { + def addSecretVolumes(pod: Pod): Pod = { var podBuilder = new PodBuilder(pod) secretNamesToMountPaths.keys.foreach { name => podBuilder = podBuilder .editOrNewSpec() .addNewVolume() - .withName(secretVolumeName(name)) - .withNewSecret() - .withSecretName(name) - .endSecret() - .endVolume() + .withName(secretVolumeName(name)) + .withNewSecret() + .withSecretName(name) + .endSecret() + .endVolume() .endSpec() } + podBuilder.build() + } + + /** + * Mounts Kubernetes secret volumes of the secrets specified in secretNamesToMountPaths into the + * given container. + * + * @param container the container into which the secret volumes are being mounted. + * @return the updated container with the secrets mounted. + */ + def mountSecrets(container: Container): Container = { var containerBuilder = new ContainerBuilder(container) secretNamesToMountPaths.foreach { case (name, path) => containerBuilder = containerBuilder @@ -53,7 +63,7 @@ private[spark] class MountSecretsBootstrap(secretNamesToMountPaths: Map[String, .endVolumeMount() } - (podBuilder.build(), containerBuilder.build()) + containerBuilder.build() } private def secretVolumeName(secretName: String): String = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala index 00c9c4ee49177..c9cc300d65569 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala @@ -127,6 +127,12 @@ private[spark] class DriverConfigOrchestrator( Nil } + val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { + Seq(new DriverMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) + } else { + Nil + } + val initContainerBootstrapStep = if (existNonContainerLocalFiles(sparkJars ++ sparkFiles)) { val orchestrator = new InitContainerConfigOrchestrator( sparkJars, @@ -147,19 +153,13 @@ private[spark] class DriverConfigOrchestrator( Nil } - val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { - Seq(new DriverMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) - } else { - Nil - } - Seq( initialSubmissionStep, serviceBootstrapStep, kubernetesCredentialsStep) ++ dependencyResolutionStep ++ - initContainerBootstrapStep ++ - mountSecretsStep + mountSecretsStep ++ + initContainerBootstrapStep } private def existNonContainerLocalFiles(files: Seq[String]): Boolean = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index b7a69a7dfd472..eca46b84c6066 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -119,7 +119,7 @@ private[spark] class BasicDriverConfigurationStep( .endEnv() .addNewEnv() .withName(ENV_DRIVER_ARGS) - .withValue(appArgs.map(arg => "\"" + arg + "\"").mkString(" ")) + .withValue(appArgs.mkString(" ")) .endEnv() .addNewEnv() .withName(ENV_DRIVER_BIND_ADDRESS) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala index f872e0f4b65d1..91e9a9f211335 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala @@ -28,8 +28,8 @@ private[spark] class DriverMountSecretsStep( bootstrap: MountSecretsBootstrap) extends DriverConfigurationStep { override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val (pod, container) = bootstrap.mountSecrets( - driverSpec.driverPod, driverSpec.driverContainer) + val pod = bootstrap.addSecretVolumes(driverSpec.driverPod) + val container = bootstrap.mountSecrets(driverSpec.driverContainer) driverSpec.copy( driverPod = pod, driverContainer = container diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala index c0e7bb20cce8c..0daa7b95e8aae 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala @@ -28,12 +28,9 @@ private[spark] class InitContainerMountSecretsStep( bootstrap: MountSecretsBootstrap) extends InitContainerConfigurationStep { override def configureInitContainer(spec: InitContainerSpec) : InitContainerSpec = { - val (driverPod, initContainer) = bootstrap.mountSecrets( - spec.driverPod, - spec.initContainer) - spec.copy( - driverPod = driverPod, - initContainer = initContainer - ) + // Mount the secret volumes given that the volumes have already been added to the driver pod + // when mounting the secrets into the main driver container. + val initContainer = bootstrap.mountSecrets(spec.initContainer) + spec.copy(initContainer = initContainer) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index ba5d891f4c77e..066d7e9f70ca5 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -214,7 +214,7 @@ private[spark] class ExecutorPodFactory( val (maybeSecretsMountedPod, maybeSecretsMountedContainer) = mountSecretsBootstrap.map { bootstrap => - bootstrap.mountSecrets(executorPod, containerWithLimitCores) + (bootstrap.addSecretVolumes(executorPod), bootstrap.mountSecrets(containerWithLimitCores)) }.getOrElse((executorPod, containerWithLimitCores)) val (bootstrappedPod, bootstrappedContainer) = @@ -227,7 +227,9 @@ private[spark] class ExecutorPodFactory( val (pod, mayBeSecretsMountedInitContainer) = initContainerMountSecretsBootstrap.map { bootstrap => - bootstrap.mountSecrets(podWithInitContainer.pod, podWithInitContainer.initContainer) + // Mount the secret volumes given that the volumes have already been added to the + // executor pod when mounting the secrets into the main executor container. + (podWithInitContainer.pod, bootstrap.mountSecrets(podWithInitContainer.initContainer)) }.getOrElse((podWithInitContainer.pod, podWithInitContainer.initContainer)) val bootstrappedPod = KubernetesUtils.appendInitContainer( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala similarity index 71% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala index 8388c16ded268..16780584a674a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit +package org.apache.spark.deploy.k8s import scala.collection.JavaConverters._ @@ -22,15 +22,15 @@ import io.fabric8.kubernetes.api.model.{Container, Pod} private[spark] object SecretVolumeUtils { - def podHasVolume(driverPod: Pod, volumeName: String): Boolean = { - driverPod.getSpec.getVolumes.asScala.exists(volume => volume.getName == volumeName) + def podHasVolume(pod: Pod, volumeName: String): Boolean = { + pod.getSpec.getVolumes.asScala.exists { volume => + volume.getName == volumeName + } } - def containerHasVolume( - driverContainer: Container, - volumeName: String, - mountPath: String): Boolean = { - driverContainer.getVolumeMounts.asScala.exists(volumeMount => - volumeMount.getName == volumeName && volumeMount.getMountPath == mountPath) + def containerHasVolume(container: Container, volumeName: String, mountPath: String): Boolean = { + container.getVolumeMounts.asScala.exists { volumeMount => + volumeMount.getName == volumeName && volumeMount.getMountPath == mountPath + } } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index e864c6a16eeb1..8ee629ac8ddc1 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -33,7 +33,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" private val APP_NAME = "spark-test" private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" - private val APP_ARGS = Array("arg1", "arg2", "arg 3") + private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") private val CUSTOM_ANNOTATION_KEY = "customAnnotation" private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" private val DRIVER_CUSTOM_ENV_KEY1 = "customDriverEnv1" @@ -82,7 +82,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { assert(envs(ENV_SUBMIT_EXTRA_CLASSPATH) === "/opt/spark/spark-examples.jar") assert(envs(ENV_DRIVER_MEMORY) === "256M") assert(envs(ENV_DRIVER_MAIN_CLASS) === MAIN_CLASS) - assert(envs(ENV_DRIVER_ARGS) === "\"arg1\" \"arg2\" \"arg 3\"") + assert(envs(ENV_DRIVER_ARGS) === "arg1 arg2 \"arg 3\"") assert(envs(DRIVER_CUSTOM_ENV_KEY1) === "customDriverEnv1") assert(envs(DRIVER_CUSTOM_ENV_KEY2) === "customDriverEnv2") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala index 9ec0cb55de5aa..960d0bda1d011 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.deploy.k8s.submit.steps import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.MountSecretsBootstrap -import org.apache.spark.deploy.k8s.submit.{KubernetesDriverSpec, SecretVolumeUtils} +import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec class DriverMountSecretsStepSuite extends SparkFunSuite { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala index eab4e17659456..7ac0bde80dfe6 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.deploy.k8s.submit.steps.initcontainer import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder} import org.apache.spark.SparkFunSuite -import org.apache.spark.deploy.k8s.MountSecretsBootstrap -import org.apache.spark.deploy.k8s.submit.SecretVolumeUtils +import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} class InitContainerMountSecretsStepSuite extends SparkFunSuite { @@ -44,12 +43,8 @@ class InitContainerMountSecretsStepSuite extends SparkFunSuite { val initContainerMountSecretsStep = new InitContainerMountSecretsStep(mountSecretsBootstrap) val configuredInitContainerSpec = initContainerMountSecretsStep.configureInitContainer( baseInitContainerSpec) - - val podWithSecretsMounted = configuredInitContainerSpec.driverPod val initContainerWithSecretsMounted = configuredInitContainerSpec.initContainer - Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => - assert(SecretVolumeUtils.podHasVolume(podWithSecretsMounted, volumeName))) Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => assert(SecretVolumeUtils.containerHasVolume( initContainerWithSecretsMounted, volumeName, SECRET_MOUNT_PATH))) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 7121a802c69c1..884da8aabd880 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -25,7 +25,7 @@ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, MountSecretsBootstrap, PodWithDetachedInitContainer} +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, MountSecretsBootstrap, PodWithDetachedInitContainer, SecretVolumeUtils} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ @@ -165,17 +165,19 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef val factory = new ExecutorPodFactory( conf, - None, + Some(secretsBootstrap), Some(initContainerBootstrap), Some(secretsBootstrap)) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + assert(executor.getSpec.getVolumes.size() === 1) + assert(SecretVolumeUtils.podHasVolume(executor, "secret1-volume")) + assert(SecretVolumeUtils.containerHasVolume( + executor.getSpec.getContainers.get(0), "secret1-volume", "/var/secret1")) assert(executor.getSpec.getInitContainers.size() === 1) - assert(executor.getSpec.getInitContainers.get(0).getVolumeMounts.get(0).getName - === "secret1-volume") - assert(executor.getSpec.getInitContainers.get(0).getVolumeMounts.get(0) - .getMountPath === "/var/secret1") + assert(SecretVolumeUtils.containerHasVolume( + executor.getSpec.getInitContainers.get(0), "secret1-volume", "/var/secret1")) checkOwnerReferences(executor, driverPodUid) } From 0428368c2c5e135f99f62be20877bbbda43be310 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 4 Jan 2018 16:34:56 -0800 Subject: [PATCH 018/774] [SPARK-22960][K8S] Make build-push-docker-images.sh more dev-friendly. - Make it possible to build images from a git clone. - Make it easy to use minikube to test things. Also fixed what seemed like a bug: the base image wasn't getting the tag provided in the command line. Adding the tag allows users to use multiple Spark builds in the same kubernetes cluster. Tested by deploying images on minikube and running spark-submit from a dev environment; also by building the images with different tags and verifying "docker images" in minikube. Author: Marcelo Vanzin Closes #20154 from vanzin/SPARK-22960. --- docs/running-on-kubernetes.md | 9 +- .../src/main/dockerfiles/driver/Dockerfile | 3 +- .../src/main/dockerfiles/executor/Dockerfile | 3 +- .../dockerfiles/init-container/Dockerfile | 3 +- .../main/dockerfiles/spark-base/Dockerfile | 7 +- sbin/build-push-docker-images.sh | 120 +++++++++++++++--- 6 files changed, 117 insertions(+), 28 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index e491329136a3c..2d69f636472ae 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -16,6 +16,9 @@ Kubernetes scheduler that has been added to Spark. you may setup a test cluster on your local machine using [minikube](https://kubernetes.io/docs/getting-started-guides/minikube/). * We recommend using the latest release of minikube with the DNS addon enabled. + * Be aware that the default minikube configuration is not enough for running Spark applications. + We recommend 3 CPUs and 4g of memory to be able to start a simple Spark application with a single + executor. * You must have appropriate permissions to list, create, edit and delete [pods](https://kubernetes.io/docs/user-guide/pods/) in your cluster. You can verify that you can list these resources by running `kubectl auth can-i pods`. @@ -197,7 +200,7 @@ kubectl port-forward 4040:4040 Then, the Spark driver UI can be accessed on `http://localhost:4040`. -### Debugging +### Debugging There may be several kinds of failures. If the Kubernetes API server rejects the request made from spark-submit, or the connection is refused for a different reason, the submission logic should indicate the error encountered. However, if there @@ -215,8 +218,8 @@ If the pod has encountered a runtime error, the status can be probed further usi kubectl logs ``` -Status and logs of failed executor pods can be checked in similar ways. Finally, deleting the driver pod will clean up the entire spark -application, includling all executors, associated service, etc. The driver pod can be thought of as the Kubernetes representation of +Status and logs of failed executor pods can be checked in similar ways. Finally, deleting the driver pod will clean up the entire spark +application, including all executors, associated service, etc. The driver pod can be thought of as the Kubernetes representation of the Spark application. ## Kubernetes Features diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile index 45fbcd9cd0deb..ff5289e10c21e 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile @@ -15,7 +15,8 @@ # limitations under the License. # -FROM spark-base +ARG base_image +FROM ${base_image} # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile index 0f806cf7e148e..3eabb42d4d852 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile @@ -15,7 +15,8 @@ # limitations under the License. # -FROM spark-base +ARG base_image +FROM ${base_image} # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile index 047056ab2633b..e0a249e0ac71f 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile @@ -15,7 +15,8 @@ # limitations under the License. # -FROM spark-base +ARG base_image +FROM ${base_image} # If this docker file is being used in the context of building your images from a Spark distribution, the docker build # command should be invoked from the top level directory of the Spark distribution. E.g.: diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile index 222e777db3a82..da1d6b9e161cc 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile @@ -17,6 +17,9 @@ FROM openjdk:8-alpine +ARG spark_jars +ARG img_path + # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. # If this docker file is being used in the context of building your images from a Spark @@ -34,11 +37,11 @@ RUN set -ex && \ ln -sv /bin/bash /bin/sh && \ chgrp root /etc/passwd && chmod ug+rw /etc/passwd -COPY jars /opt/spark/jars +COPY ${spark_jars} /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin COPY conf /opt/spark/conf -COPY kubernetes/dockerfiles/spark-base/entrypoint.sh /opt/ +COPY ${img_path}/spark-base/entrypoint.sh /opt/ ENV SPARK_HOME /opt/spark diff --git a/sbin/build-push-docker-images.sh b/sbin/build-push-docker-images.sh index b3137598692d8..bb8806dd33f37 100755 --- a/sbin/build-push-docker-images.sh +++ b/sbin/build-push-docker-images.sh @@ -19,29 +19,94 @@ # This script builds and pushes docker images when run from a release of Spark # with Kubernetes support. -declare -A path=( [spark-driver]=kubernetes/dockerfiles/driver/Dockerfile \ - [spark-executor]=kubernetes/dockerfiles/executor/Dockerfile \ - [spark-init]=kubernetes/dockerfiles/init-container/Dockerfile ) +function error { + echo "$@" 1>&2 + exit 1 +} + +# Detect whether this is a git clone or a Spark distribution and adjust paths +# accordingly. +if [ -z "${SPARK_HOME}" ]; then + SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi +. "${SPARK_HOME}/bin/load-spark-env.sh" + +if [ -f "$SPARK_HOME/RELEASE" ]; then + IMG_PATH="kubernetes/dockerfiles" + SPARK_JARS="jars" +else + IMG_PATH="resource-managers/kubernetes/docker/src/main/dockerfiles" + SPARK_JARS="assembly/target/scala-$SPARK_SCALA_VERSION/jars" +fi + +if [ ! -d "$IMG_PATH" ]; then + error "Cannot find docker images. This script must be run from a runnable distribution of Apache Spark." +fi + +declare -A path=( [spark-driver]="$IMG_PATH/driver/Dockerfile" \ + [spark-executor]="$IMG_PATH/executor/Dockerfile" \ + [spark-init]="$IMG_PATH/init-container/Dockerfile" ) + +function image_ref { + local image="$1" + local add_repo="${2:-1}" + if [ $add_repo = 1 ] && [ -n "$REPO" ]; then + image="$REPO/$image" + fi + if [ -n "$TAG" ]; then + image="$image:$TAG" + fi + echo "$image" +} function build { - docker build -t spark-base -f kubernetes/dockerfiles/spark-base/Dockerfile . + local base_image="$(image_ref spark-base 0)" + docker build --build-arg "spark_jars=$SPARK_JARS" \ + --build-arg "img_path=$IMG_PATH" \ + -t "$base_image" \ + -f "$IMG_PATH/spark-base/Dockerfile" . for image in "${!path[@]}"; do - docker build -t ${REPO}/$image:${TAG} -f ${path[$image]} . + docker build --build-arg "base_image=$base_image" -t "$(image_ref $image)" -f ${path[$image]} . done } - function push { for image in "${!path[@]}"; do - docker push ${REPO}/$image:${TAG} + docker push "$(image_ref $image)" done } function usage { - echo "This script must be run from a runnable distribution of Apache Spark." - echo "Usage: ./sbin/build-push-docker-images.sh -r -t build" - echo " ./sbin/build-push-docker-images.sh -r -t push" - echo "for example: ./sbin/build-push-docker-images.sh -r docker.io/myrepo -t v2.3.0 push" + cat </dev/null; then + error "Cannot find minikube." + fi + eval $(minikube docker-env) + ;; esac done -if [ -z "$REPO" ] || [ -z "$TAG" ]; then +case "${@: -1}" in + build) + build + ;; + push) + if [ -z "$REPO" ]; then + usage + exit 1 + fi + push + ;; + *) usage -else - case "${@: -1}" in - build) build;; - push) push;; - *) usage;; - esac -fi + exit 1 + ;; +esac From df7fc3ef3899cadd252d2837092bebe3442d6523 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 5 Jan 2018 10:16:34 +0800 Subject: [PATCH 019/774] [SPARK-22957] ApproxQuantile breaks if the number of rows exceeds MaxInt ## What changes were proposed in this pull request? 32bit Int was used for row rank. That overflowed in a dataframe with more than 2B rows. ## How was this patch tested? Added test, but ignored, as it takes 4 minutes. Author: Juliusz Sompolski Closes #20152 from juliuszsompolski/SPARK-22957. --- .../aggregate/ApproximatePercentile.scala | 12 ++++++------ .../spark/sql/catalyst/util/QuantileSummaries.scala | 8 ++++---- .../org/apache/spark/sql/DataFrameStatSuite.scala | 8 ++++++++ 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 149ac265e6ed5..a45854a3b5146 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -296,8 +296,8 @@ object ApproximatePercentile { Ints.BYTES + Doubles.BYTES + Longs.BYTES + // length of summary.sampled Ints.BYTES + - // summary.sampled, Array[Stat(value: Double, g: Int, delta: Int)] - summaries.sampled.length * (Doubles.BYTES + Ints.BYTES + Ints.BYTES) + // summary.sampled, Array[Stat(value: Double, g: Long, delta: Long)] + summaries.sampled.length * (Doubles.BYTES + Longs.BYTES + Longs.BYTES) } final def serialize(obj: PercentileDigest): Array[Byte] = { @@ -312,8 +312,8 @@ object ApproximatePercentile { while (i < summary.sampled.length) { val stat = summary.sampled(i) buffer.putDouble(stat.value) - buffer.putInt(stat.g) - buffer.putInt(stat.delta) + buffer.putLong(stat.g) + buffer.putLong(stat.delta) i += 1 } buffer.array() @@ -330,8 +330,8 @@ object ApproximatePercentile { var i = 0 while (i < sampledLength) { val value = buffer.getDouble() - val g = buffer.getInt() - val delta = buffer.getInt() + val g = buffer.getLong() + val delta = buffer.getLong() sampled(i) = Stats(value, g, delta) i += 1 } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index eb7941cf9e6af..b013add9c9778 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -105,7 +105,7 @@ class QuantileSummaries( if (newSamples.isEmpty || (sampleIdx == sampled.length && opsIdx == sorted.length - 1)) { 0 } else { - math.floor(2 * relativeError * currentCount).toInt + math.floor(2 * relativeError * currentCount).toLong } val tuple = Stats(currentSample, 1, delta) @@ -192,10 +192,10 @@ class QuantileSummaries( } // Target rank - val rank = math.ceil(quantile * count).toInt + val rank = math.ceil(quantile * count).toLong val targetError = relativeError * count // Minimum rank at current sample - var minRank = 0 + var minRank = 0L var i = 0 while (i < sampled.length - 1) { val curSample = sampled(i) @@ -235,7 +235,7 @@ object QuantileSummaries { * @param g the minimum rank jump from the previous value's minimum rank * @param delta the maximum span of the rank. */ - case class Stats(value: Double, g: Int, delta: Int) + case class Stats(value: Double, g: Long, delta: Long) private def compressImmut( currentSamples: IndexedSeq[Stats], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 46b21c3b64a2e..5169d2b5fc6b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -260,6 +260,14 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(res2(1).isEmpty) } + // SPARK-22957: check for 32bit overflow when computing rank. + // ignored - takes 4 minutes to run. + ignore("approx quantile 4: test for Int overflow") { + val res = spark.range(3000000000L).stat.approxQuantile("id", Array(0.8, 0.9), 0.05) + assert(res(0) > 2200000000.0) + assert(res(1) > 2200000000.0) + } + test("crosstab") { withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { val rng = new Random() From 52fc5c17d9d784b846149771b398e741621c0b5c Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 5 Jan 2018 14:02:21 +0800 Subject: [PATCH 020/774] [SPARK-22825][SQL] Fix incorrect results of Casting Array to String ## What changes were proposed in this pull request? This pr fixed the issue when casting arrays into strings; ``` scala> val df = spark.range(10).select('id.cast("integer")).agg(collect_list('id).as('ids)) scala> df.write.saveAsTable("t") scala> sql("SELECT cast(ids as String) FROM t").show(false) +------------------------------------------------------------------+ |ids | +------------------------------------------------------------------+ |org.apache.spark.sql.catalyst.expressions.UnsafeArrayData8bc285df| +------------------------------------------------------------------+ ``` This pr modified the result into; ``` +------------------------------+ |ids | +------------------------------+ |[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]| +------------------------------+ ``` ## How was this patch tested? Added tests in `CastSuite` and `SQLQuerySuite`. Author: Takeshi Yamamuro Closes #20024 from maropu/SPARK-22825. --- .../codegen/UTF8StringBuilder.java | 78 +++++++++++++++++++ .../spark/sql/catalyst/expressions/Cast.scala | 68 ++++++++++++++++ .../sql/catalyst/expressions/CastSuite.scala | 25 ++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 2 - 4 files changed, 171 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java new file mode 100644 index 0000000000000..f0f66bae245fd --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A helper class to write {@link UTF8String}s to an internal buffer and build the concatenated + * {@link UTF8String} at the end. + */ +public class UTF8StringBuilder { + + private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; + + private byte[] buffer; + private int cursor = Platform.BYTE_ARRAY_OFFSET; + + public UTF8StringBuilder() { + // Since initial buffer size is 16 in `StringBuilder`, we set the same size here + this.buffer = new byte[16]; + } + + // Grows the buffer by at least `neededSize` + private void grow(int neededSize) { + if (neededSize > ARRAY_MAX - totalSize()) { + throw new UnsupportedOperationException( + "Cannot grow internal buffer by size " + neededSize + " because the size after growing " + + "exceeds size limitation " + ARRAY_MAX); + } + final int length = totalSize() + neededSize; + if (buffer.length < length) { + int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX; + final byte[] tmp = new byte[newLength]; + Platform.copyMemory( + buffer, + Platform.BYTE_ARRAY_OFFSET, + tmp, + Platform.BYTE_ARRAY_OFFSET, + totalSize()); + buffer = tmp; + } + } + + private int totalSize() { + return cursor - Platform.BYTE_ARRAY_OFFSET; + } + + public void append(UTF8String value) { + grow(value.numBytes()); + value.writeToMemory(buffer, cursor); + cursor += value.numBytes(); + } + + public void append(String value) { + append(UTF8String.fromString(value)); + } + + public UTF8String build() { + return UTF8String.fromBytes(buffer, 0, totalSize()); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 274d8813f16db..d4fc5e0f168a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -206,6 +206,28 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d))) case TimestampType => buildCast[Long](_, t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone))) + case ArrayType(et, _) => + buildCast[ArrayData](_, array => { + val builder = new UTF8StringBuilder + builder.append("[") + if (array.numElements > 0) { + val toUTF8String = castToString(et) + if (!array.isNullAt(0)) { + builder.append(toUTF8String(array.get(0, et)).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < array.numElements) { + builder.append(",") + if (!array.isNullAt(i)) { + builder.append(" ") + builder.append(toUTF8String(array.get(i, et)).asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -597,6 +619,41 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """ } + private def writeArrayToStringBuilder( + et: DataType, + array: String, + buffer: String, + ctx: CodegenContext): String = { + val elementToStringCode = castToStringCode(et, ctx) + val funcName = ctx.freshName("elementToString") + val elementToStringFunc = ctx.addNewFunction(funcName, + s""" + |private UTF8String $funcName(${ctx.javaType(et)} element) { + | UTF8String elementStr = null; + | ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)} + | return elementStr; + |} + """.stripMargin) + + val loopIndex = ctx.freshName("loopIndex") + s""" + |$buffer.append("["); + |if ($array.numElements() > 0) { + | if (!$array.isNullAt(0)) { + | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")})); + | } + | for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) { + | $buffer.append(","); + | if (!$array.isNullAt($loopIndex)) { + | $buffer.append(" "); + | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)})); + | } + | } + |} + |$buffer.append("]"); + """.stripMargin + } + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => @@ -608,6 +665,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val tz = ctx.addReferenceObj("timeZone", timeZone) (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));""" + case ArrayType(et, _) => + (c, evPrim, evNull) => { + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx) + s""" + |$bufferClass $buffer = new $bufferClass(); + |$writeArrayElemCode; + |$evPrim = $buffer.build(); + """.stripMargin + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1dd040e4696a1..e3ed7171defd8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -853,4 +853,29 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { cast("2", LongType).genCode(ctx) assert(ctx.inlinedMutableStates.length == 0) } + + test("SPARK-22825 Cast array to string") { + val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType) + checkEvaluation(ret1, "[1, 2, 3, 4, 5]") + val ret2 = cast(Literal.create(Array("ab", "cde", "f")), StringType) + checkEvaluation(ret2, "[ab, cde, f]") + val ret3 = cast(Literal.create(Array("ab", null, "c")), StringType) + checkEvaluation(ret3, "[ab,, c]") + val ret4 = cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType) + checkEvaluation(ret4, "[ab, cde, f]") + val ret5 = cast( + Literal.create(Array("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)), + StringType) + checkEvaluation(ret5, "[2014-12-03, 2014-12-04, 2014-12-06]") + val ret6 = cast( + Literal.create(Array("2014-12-03 13:01:00", "2014-12-04 15:05:00").map(Timestamp.valueOf)), + StringType) + checkEvaluation(ret6, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]") + val ret7 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType) + checkEvaluation(ret7, "[[1, 2, 3], [4, 5]]") + val ret8 = cast( + Literal.create(Array(Array(Array("a"), Array("b", "c")), Array(Array("d")))), + StringType) + checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5e077285ade55..96bf65fce9c4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -28,8 +28,6 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} -import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf From cf0aa65576acbe0209c67f04c029058fd73555c1 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Thu, 4 Jan 2018 22:45:15 -0800 Subject: [PATCH 021/774] [SPARK-22949][ML] Apply CrossValidator approach to Driver/Distributed memory tradeoff for TrainValidationSplit ## What changes were proposed in this pull request? Avoid holding all models in memory for `TrainValidationSplit`. ## How was this patch tested? Existing tests. Author: Bago Amirbekian Closes #20143 from MrBago/trainValidMemoryFix. --- .../spark/ml/tuning/CrossValidator.scala | 4 +++- .../spark/ml/tuning/TrainValidationSplit.scala | 18 ++++-------------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 095b54c0fe83f..a0b507d2e718c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -160,8 +160,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) } (executionContext) } - // Wait for metrics to be calculated before unpersisting validation dataset + // Wait for metrics to be calculated val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) + + // Unpersist training & validation set once all metrics have been produced trainingDataset.unpersist() validationDataset.unpersist() foldMetrics diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index c73bd18475475..8826ef3271bc1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -143,24 +143,13 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St // Fit models in a Future for training in parallel logDebug(s"Train split with multiple sets of parameters.") - val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => - Future[Model[_]] { + val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => + Future[Double] { val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] if (collectSubModelsParam) { subModels.get(paramIndex) = model } - model - } (executionContext) - } - - // Unpersist training data only when all models have trained - Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext) - .onComplete { _ => trainingDataset.unpersist() } (executionContext) - - // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up - val metricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) => - modelFuture.map { model => // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) logDebug(s"Got metric $metric for model trained with $paramMap.") @@ -171,7 +160,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St // Wait for all metrics to be calculated val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) - // Unpersist validation set once all metrics have been produced + // Unpersist training & validation set once all metrics have been produced + trainingDataset.unpersist() validationDataset.unpersist() logInfo(s"Train validation split metrics: ${metrics.toSeq}") From 6cff7d19f6a905fe425bd6892fe7ca014c0e696b Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Thu, 4 Jan 2018 23:23:41 -0800 Subject: [PATCH 022/774] [SPARK-22757][K8S] Enable spark.jars and spark.files in KUBERNETES mode ## What changes were proposed in this pull request? We missed enabling `spark.files` and `spark.jars` in https://github.com/apache/spark/pull/19954. The result is that remote dependencies specified through `spark.files` or `spark.jars` are not included in the list of remote dependencies to be downloaded by the init-container. This PR fixes it. ## How was this patch tested? Manual tests. vanzin This replaces https://github.com/apache/spark/pull/20157. foxish Author: Yinan Li Closes #20160 from liyinan926/SPARK-22757. --- .../src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index cbe1f2c3e08a1..1e381965c52ba 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -584,10 +584,11 @@ object SparkSubmit extends CommandLineUtils with Logging { confKey = "spark.executor.memory"), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, confKey = "spark.cores.max"), - OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, + OptionAssigner(args.files, LOCAL | STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, confKey = "spark.files"), OptionAssigner(args.jars, LOCAL, CLIENT, confKey = "spark.jars"), - OptionAssigner(args.jars, STANDALONE | MESOS, ALL_DEPLOY_MODES, confKey = "spark.jars"), + OptionAssigner(args.jars, STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, + confKey = "spark.jars"), OptionAssigner(args.driverMemory, STANDALONE | MESOS | YARN | KUBERNETES, CLUSTER, confKey = "spark.driver.memory"), OptionAssigner(args.driverCores, STANDALONE | MESOS | YARN | KUBERNETES, CLUSTER, From 51c33bd0d402af9e0284c6cbc0111f926446bfba Mon Sep 17 00:00:00 2001 From: Adrian Ionescu Date: Fri, 5 Jan 2018 21:32:39 +0800 Subject: [PATCH 023/774] [SPARK-22961][REGRESSION] Constant columns should generate QueryPlanConstraints ## What changes were proposed in this pull request? #19201 introduced the following regression: given something like `df.withColumn("c", lit(2))`, we're no longer picking up `c === 2` as a constraint and infer filters from it when joins are involved, which may lead to noticeable performance degradation. This patch re-enables this optimization by picking up Aliases of Literals in Projection lists as constraints and making sure they're not treated as aliased columns. ## How was this patch tested? Unit test was added. Author: Adrian Ionescu Closes #20155 from adrian-ionescu/constant_constraints. --- .../sql/catalyst/plans/logical/LogicalPlan.scala | 2 ++ .../plans/logical/QueryPlanConstraints.scala | 2 +- .../InferFiltersFromConstraintsSuite.scala | 13 +++++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index a38458add7b5e..ff2a0ec588567 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -247,6 +247,8 @@ abstract class UnaryNode extends LogicalPlan { protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = { var allConstraints = child.constraints.asInstanceOf[Set[Expression]] projectList.foreach { + case a @ Alias(l: Literal, _) => + allConstraints += EqualTo(a.toAttribute, l) case a @ Alias(e, _) => // For every alias in `projectList`, replace the reference in constraints by its attribute. allConstraints ++= allConstraints.map(_ transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index b0f611fd38dea..9c0a30a47f839 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -98,7 +98,7 @@ trait QueryPlanConstraints { self: LogicalPlan => // we may avoid producing recursive constraints. private lazy val aliasMap: AttributeMap[Expression] = AttributeMap( expressions.collect { - case a: Alias => (a.toAttribute, a.child) + case a: Alias if !a.child.isInstanceOf[Literal] => (a.toAttribute, a.child) } ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints].aliasMap)) // Note: the explicit cast is necessary, since Scala compiler fails to infer the type. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 5580f8604ec72..a0708bf7eee9a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -236,4 +236,17 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(optimized, originalQuery) } } + + test("constraints should be inferred from aliased literals") { + val originalLeft = testRelation.subquery('left).as("left") + val optimizedLeft = testRelation.subquery('left).where(IsNotNull('a) && 'a === 2).as("left") + + val right = Project(Seq(Literal(2).as("two")), testRelation.subquery('right)).as("right") + val condition = Some("left.a".attr === "right.two".attr) + + val original = originalLeft.join(right, Inner, condition) + val correct = optimizedLeft.join(right, Inner, condition) + + comparePlans(Optimize.execute(original.analyze), correct.analyze) + } } From c0b7424ecacb56d3e7a18acc11ba3d5e7be57c43 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 5 Jan 2018 09:58:28 -0800 Subject: [PATCH 024/774] [SPARK-22940][SQL] HiveExternalCatalogVersionsSuite should succeed on platforms that don't have wget ## What changes were proposed in this pull request? Modified HiveExternalCatalogVersionsSuite.scala to use Utils.doFetchFile to download different versions of Spark binaries rather than launching wget as an external process. On platforms that don't have wget installed, this suite fails with an error. cloud-fan : would you like to check this change? ## How was this patch tested? 1) test-only of HiveExternalCatalogVersionsSuite on several platforms. Tested bad mirror, read timeout, and redirects. 2) ./dev/run-tests Author: Bruce Robbins Closes #20147 from bersprockets/SPARK-22940-alt. --- .../HiveExternalCatalogVersionsSuite.scala | 48 ++++++++++++++++--- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index a3d5b941a6761..ae4aeb7b4ce4a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -18,11 +18,14 @@ package org.apache.spark.sql.hive import java.io.File -import java.nio.file.Files +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Paths} import scala.sys.process._ -import org.apache.spark.TestUtils +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SecurityManager, SparkConf, TestUtils} import org.apache.spark.sql.{QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTableType @@ -55,14 +58,19 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private def tryDownloadSpark(version: String, path: String): Unit = { // Try mirrors a few times until one succeeds for (i <- 0 until 3) { + // we don't retry on a failure to get mirror url. If we can't get a mirror url, + // the test fails (getStringFromUrl will throw an exception) val preferredMirror = - Seq("wget", "https://www.apache.org/dyn/closer.lua?preferred=true", "-q", "-O", "-").!!.trim - val url = s"$preferredMirror/spark/spark-$version/spark-$version-bin-hadoop2.7.tgz" + getStringFromUrl("https://www.apache.org/dyn/closer.lua?preferred=true") + val filename = s"spark-$version-bin-hadoop2.7.tgz" + val url = s"$preferredMirror/spark/spark-$version/$filename" logInfo(s"Downloading Spark $version from $url") - if (Seq("wget", url, "-q", "-P", path).! == 0) { + try { + getFileFromUrl(url, path, filename) return + } catch { + case ex: Exception => logWarning(s"Failed to download Spark $version from $url", ex) } - logWarning(s"Failed to download Spark $version from $url") } fail(s"Unable to download Spark $version") } @@ -85,6 +93,34 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { new File(tmpDataDir, name).getCanonicalPath } + private def getFileFromUrl(urlString: String, targetDir: String, filename: String): Unit = { + val conf = new SparkConf + // if the caller passes the name of an existing file, we want doFetchFile to write over it with + // the contents from the specified url. + conf.set("spark.files.overwrite", "true") + val securityManager = new SecurityManager(conf) + val hadoopConf = new Configuration + + val outDir = new File(targetDir) + if (!outDir.exists()) { + outDir.mkdirs() + } + + // propagate exceptions up to the caller of getFileFromUrl + Utils.doFetchFile(urlString, outDir, filename, conf, securityManager, hadoopConf) + } + + private def getStringFromUrl(urlString: String): String = { + val contentFile = File.createTempFile("string-", ".txt") + contentFile.deleteOnExit() + + // exceptions will propagate to the caller of getStringFromUrl + getFileFromUrl(urlString, contentFile.getParent, contentFile.getName) + + val contentPath = Paths.get(contentFile.toURI) + new String(Files.readAllBytes(contentPath), StandardCharsets.UTF_8) + } + override def beforeAll(): Unit = { super.beforeAll() From 930b90a84871e2504b57ed50efa7b8bb52d3ba44 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 5 Jan 2018 11:51:25 -0800 Subject: [PATCH 025/774] [SPARK-13030][ML] Follow-up cleanups for OneHotEncoderEstimator ## What changes were proposed in this pull request? Follow-up cleanups for the OneHotEncoderEstimator PR. See some discussion in the original PR: https://github.com/apache/spark/pull/19527 or read below for what this PR includes: * configedCategorySize: I reverted this to return an Array. I realized the original setup (which I had recommended in the original PR) caused the whole model to be serialized in the UDF. * encoder: I reorganized the logic to show what I meant in the comment in the previous PR. I think it's simpler but am open to suggestions. I also made some small style cleanups based on IntelliJ warnings. ## How was this patch tested? Existing unit tests Author: Joseph K. Bradley Closes #20132 from jkbradley/viirya-SPARK-13030. --- .../ml/feature/OneHotEncoderEstimator.scala | 92 ++++++++++--------- 1 file changed, 49 insertions(+), 43 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index 074622d41e28d..bd1e3426c8780 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -30,24 +30,27 @@ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions.{col, lit, udf} -import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType} +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** Private trait for params and common methods for OneHotEncoderEstimator and OneHotEncoderModel */ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid with HasInputCols with HasOutputCols { /** - * Param for how to handle invalid data. + * Param for how to handle invalid data during transform(). * Options are 'keep' (invalid data presented as an extra categorical feature) or * 'error' (throw an error). + * Note that this Param is only used during transform; during fitting, invalid data + * will result in an error. * Default: "error" * @group param */ @Since("2.3.0") override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", - "How to handle invalid data " + + "How to handle invalid data during transform(). " + "Options are 'keep' (invalid data presented as an extra categorical feature) " + - "or error (throw an error).", + "or error (throw an error). Note that this Param is only used during transform; " + + "during fitting, invalid data will result in an error.", ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids)) setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID) @@ -66,10 +69,11 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid def getDropLast: Boolean = $(dropLast) protected def validateAndTransformSchema( - schema: StructType, dropLast: Boolean, keepInvalid: Boolean): StructType = { + schema: StructType, + dropLast: Boolean, + keepInvalid: Boolean): StructType = { val inputColNames = $(inputCols) val outputColNames = $(outputCols) - val existingFields = schema.fields require(inputColNames.length == outputColNames.length, s"The number of input columns ${inputColNames.length} must be the same as the number of " + @@ -197,6 +201,10 @@ object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimat override def load(path: String): OneHotEncoderEstimator = super.load(path) } +/** + * @param categorySizes Original number of categories for each feature being encoded. + * The array contains one value for each input column, in order. + */ @Since("2.3.0") class OneHotEncoderModel private[ml] ( @Since("2.3.0") override val uid: String, @@ -205,60 +213,58 @@ class OneHotEncoderModel private[ml] ( import OneHotEncoderModel._ - // Returns the category size for a given index with `dropLast` and `handleInvalid` + // Returns the category size for each index with `dropLast` and `handleInvalid` // taken into account. - private def configedCategorySize(orgCategorySize: Int, idx: Int): Int = { + private def getConfigedCategorySizes: Array[Int] = { val dropLast = getDropLast val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID if (!dropLast && keepInvalid) { // When `handleInvalid` is "keep", an extra category is added as last category // for invalid data. - orgCategorySize + 1 + categorySizes.map(_ + 1) } else if (dropLast && !keepInvalid) { // When `dropLast` is true, the last category is removed. - orgCategorySize - 1 + categorySizes.map(_ - 1) } else { // When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid // data is removed. Thus, it is the same as the plain number of categories. - orgCategorySize + categorySizes } } private def encoder: UserDefinedFunction = { - val oneValue = Array(1.0) - val emptyValues = Array.empty[Double] - val emptyIndices = Array.empty[Int] - val dropLast = getDropLast - val handleInvalid = getHandleInvalid - val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID + val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID + val configedSizes = getConfigedCategorySizes + val localCategorySizes = categorySizes // The udf performed on input data. The first parameter is the input value. The second - // parameter is the index of input. - udf { (label: Double, idx: Int) => - val plainNumCategories = categorySizes(idx) - val size = configedCategorySize(plainNumCategories, idx) - - if (label < 0) { - throw new SparkException(s"Negative value: $label. Input can't be negative.") - } else if (label == size && dropLast && !keepInvalid) { - // When `dropLast` is true and `handleInvalid` is not "keep", - // the last category is removed. - Vectors.sparse(size, emptyIndices, emptyValues) - } else if (label >= plainNumCategories && keepInvalid) { - // When `handleInvalid` is "keep", encodes invalid data to last category (and removed - // if `dropLast` is true) - if (dropLast) { - Vectors.sparse(size, emptyIndices, emptyValues) + // parameter is the index in inputCols of the column being encoded. + udf { (label: Double, colIdx: Int) => + val origCategorySize = localCategorySizes(colIdx) + // idx: index in vector of the single 1-valued element + val idx = if (label >= 0 && label < origCategorySize) { + label + } else { + if (keepInvalid) { + origCategorySize } else { - Vectors.sparse(size, Array(size - 1), oneValue) + if (label < 0) { + throw new SparkException(s"Negative value: $label. Input can't be negative. " + + s"To handle invalid values, set Param handleInvalid to " + + s"${OneHotEncoderEstimator.KEEP_INVALID}") + } else { + throw new SparkException(s"Unseen value: $label. To handle unseen values, " + + s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.") + } } - } else if (label < plainNumCategories) { - Vectors.sparse(size, Array(label.toInt), oneValue) + } + + val size = configedSizes(colIdx) + if (idx < size) { + Vectors.sparse(size, Array(idx.toInt), Array(1.0)) } else { - assert(handleInvalid == OneHotEncoderEstimator.ERROR_INVALID) - throw new SparkException(s"Unseen value: $label. To handle unseen values, " + - s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.") + Vectors.sparse(size, Array.empty[Int], Array.empty[Double]) } } } @@ -282,7 +288,6 @@ class OneHotEncoderModel private[ml] ( @Since("2.3.0") override def transformSchema(schema: StructType): StructType = { val inputColNames = $(inputCols) - val outputColNames = $(outputCols) require(inputColNames.length == categorySizes.length, s"The number of input columns ${inputColNames.length} must be the same as the number of " + @@ -300,6 +305,7 @@ class OneHotEncoderModel private[ml] ( * account. Mismatched numbers will cause exception. */ private def verifyNumOfValues(schema: StructType): StructType = { + val configedSizes = getConfigedCategorySizes $(outputCols).zipWithIndex.foreach { case (outputColName, idx) => val inputColName = $(inputCols)(idx) val attrGroup = AttributeGroup.fromStructField(schema(outputColName)) @@ -308,9 +314,9 @@ class OneHotEncoderModel private[ml] ( // comparing with expected category number with `handleInvalid` and // `dropLast` taken into account. if (attrGroup.attributes.nonEmpty) { - val numCategories = configedCategorySize(categorySizes(idx), idx) + val numCategories = configedSizes(idx) require(attrGroup.size == numCategories, "OneHotEncoderModel expected " + - s"$numCategories categorical values for input column ${inputColName}, " + + s"$numCategories categorical values for input column $inputColName, " + s"but the input column had metadata specifying ${attrGroup.size} values.") } } @@ -322,7 +328,7 @@ class OneHotEncoderModel private[ml] ( val transformedSchema = transformSchema(dataset.schema, logging = true) val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID - val encodedColumns = (0 until $(inputCols).length).map { idx => + val encodedColumns = $(inputCols).indices.map { idx => val inputColName = $(inputCols)(idx) val outputColName = $(outputCols)(idx) From ea956833017fcbd8ed2288368bfa2e417a2251c5 Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Fri, 5 Jan 2018 17:25:28 -0800 Subject: [PATCH 026/774] [SPARK-22914][DEPLOY] Register history.ui.port ## What changes were proposed in this pull request? Register spark.history.ui.port as a known spark conf to be used in substitution expressions even if it's not set explicitly. ## How was this patch tested? Added unit test to demonstrate the issue Author: Gera Shegalov Author: Gera Shegalov Closes #20098 from gerashegalov/gera/register-SHS-port-conf. --- .../spark/deploy/history/HistoryServer.scala | 3 +- .../apache/spark/deploy/history/config.scala | 5 +++ .../spark/deploy/yarn/ApplicationMaster.scala | 17 +++++--- .../deploy/yarn/ApplicationMasterSuite.scala | 43 +++++++++++++++++++ 4 files changed, 62 insertions(+), 6 deletions(-) create mode 100644 resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 75484f5c9f30f..0ec4afad0308c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -28,6 +28,7 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.history.config.HISTORY_SERVER_UI_PORT import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, UIRoot} @@ -276,7 +277,7 @@ object HistoryServer extends Logging { .newInstance(conf) .asInstanceOf[ApplicationHistoryProvider] - val port = conf.getInt("spark.history.ui.port", 18080) + val port = conf.get(HISTORY_SERVER_UI_PORT) val server = new HistoryServer(conf, provider, securityManager, port) server.bind() diff --git a/core/src/main/scala/org/apache/spark/deploy/history/config.scala b/core/src/main/scala/org/apache/spark/deploy/history/config.scala index 22b6d49d8e2a4..efdbf672bb52f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/config.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/config.scala @@ -44,4 +44,9 @@ private[spark] object config { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("10g") + val HISTORY_SERVER_UI_PORT = ConfigBuilder("spark.history.ui.port") + .doc("Web UI port to bind Spark History Server") + .intConf + .createWithDefault(18080) + } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index b2576b0d72633..4d5e3bb043671 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -427,11 +427,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends uiAddress: Option[String]) = { val appId = client.getAttemptId().getApplicationId().toString() val attemptId = client.getAttemptId().getAttemptId().toString() - val historyAddress = - _sparkConf.get(HISTORY_SERVER_ADDRESS) - .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) } - .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } - .getOrElse("") + val historyAddress = ApplicationMaster + .getHistoryServerAddress(_sparkConf, yarnConf, appId, attemptId) val driverUrl = RpcEndpointAddress( _sparkConf.get("spark.driver.host"), @@ -834,6 +831,16 @@ object ApplicationMaster extends Logging { master.getAttemptId } + private[spark] def getHistoryServerAddress( + sparkConf: SparkConf, + yarnConf: YarnConfiguration, + appId: String, + attemptId: String): String = { + sparkConf.get(HISTORY_SERVER_ADDRESS) + .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) } + .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } + .getOrElse("") + } } /** diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala new file mode 100644 index 0000000000000..695a82f3583e6 --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import org.apache.hadoop.yarn.conf.YarnConfiguration + +import org.apache.spark.{SparkConf, SparkFunSuite} + +class ApplicationMasterSuite extends SparkFunSuite { + + test("history url with hadoop and spark substitutions") { + val host = "rm.host.com" + val port = 18080 + val sparkConf = new SparkConf() + + sparkConf.set("spark.yarn.historyServer.address", + "http://${hadoopconf-yarn.resourcemanager.hostname}:${spark.history.ui.port}") + val yarnConf = new YarnConfiguration() + yarnConf.set("yarn.resourcemanager.hostname", host) + val appId = "application_123_1" + val attemptId = appId + "_1" + + val shsAddr = ApplicationMaster + .getHistoryServerAddress(sparkConf, yarnConf, appId, attemptId) + + assert(shsAddr === s"http://${host}:${port}/history/${appId}/${attemptId}") + } +} From e8af7e8aeca15a6107248f358d9514521ffdc6d3 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 6 Jan 2018 09:26:03 +0800 Subject: [PATCH 027/774] [SPARK-22937][SQL] SQL elt output binary for binary inputs ## What changes were proposed in this pull request? This pr modified `elt` to output binary for binary inputs. `elt` in the current master always output data as a string. But, in some databases (e.g., MySQL), if all inputs are binary, `elt` also outputs binary (Also, this might be a small surprise). This pr is related to #19977. ## How was this patch tested? Added tests in `SQLQueryTestSuite` and `TypeCoercionSuite`. Author: Takeshi Yamamuro Closes #20135 from maropu/SPARK-22937. --- docs/sql-programming-guide.md | 2 + .../sql/catalyst/analysis/TypeCoercion.scala | 29 +++++ .../expressions/stringExpressions.scala | 46 ++++--- .../apache/spark/sql/internal/SQLConf.scala | 8 ++ .../catalyst/analysis/TypeCoercionSuite.scala | 54 ++++++++ .../inputs/typeCoercion/native/elt.sql | 44 +++++++ .../results/typeCoercion/native/elt.sql.out | 115 ++++++++++++++++++ 7 files changed, 281 insertions(+), 17 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index dc3e384008d27..b50f9360b866c 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1783,6 +1783,8 @@ options. - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. + - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. + ## Upgrading From Spark SQL 2.1 to 2.2 - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e9436367c7e2e..e8669c4637d06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -54,6 +54,7 @@ object TypeCoercion { BooleanEquality :: FunctionArgumentConversion :: ConcatCoercion(conf) :: + EltCoercion(conf) :: CaseWhenCoercion :: IfCoercion :: StackCoercion :: @@ -684,6 +685,34 @@ object TypeCoercion { } } + /** + * Coerces the types of [[Elt]] children to expected ones. + * + * If `spark.sql.function.eltOutputAsString` is false and all children types are binary, + * the expected types are binary. Otherwise, the expected ones are strings. + */ + case class EltCoercion(conf: SQLConf) extends TypeCoercionRule { + + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => + p transformExpressionsUp { + // Skip nodes if unresolved or not enough children + case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c + case c @ Elt(children) => + val index = children.head + val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) + val newInputs = if (conf.eltOutputAsString || + !children.tail.map(_.dataType).forall(_ == BinaryType)) { + children.tail.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + } else { + children.tail + } + c.copy(children = newIndex +: newInputs) + } + } + } + /** * Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType * to TimeAdd/TimeSub diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 41dc762154a4c..e004bfc6af473 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -271,33 +271,45 @@ case class ConcatWs(children: Seq[Expression]) } } +/** + * An expression that returns the `n`-th input in given inputs. + * If all inputs are binary, `elt` returns an output as binary. Otherwise, it returns as string. + * If any input is null, `elt` returns null. + */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(n, str1, str2, ...) - Returns the `n`-th string, e.g., returns `str2` when `n` is 2.", + usage = "_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.", examples = """ Examples: > SELECT _FUNC_(1, 'scala', 'java'); scala """) // scalastyle:on line.size.limit -case class Elt(children: Seq[Expression]) - extends Expression with ImplicitCastInputTypes { +case class Elt(children: Seq[Expression]) extends Expression { private lazy val indexExpr = children.head - private lazy val stringExprs = children.tail.toArray + private lazy val inputExprs = children.tail.toArray /** This expression is always nullable because it returns null if index is out of range. */ override def nullable: Boolean = true - override def dataType: DataType = StringType - - override def inputTypes: Seq[DataType] = IntegerType +: Seq.fill(children.size - 1)(StringType) + override def dataType: DataType = inputExprs.map(_.dataType).headOption.getOrElse(StringType) override def checkInputDataTypes(): TypeCheckResult = { if (children.size < 2) { TypeCheckResult.TypeCheckFailure("elt function requires at least two arguments") } else { - super[ImplicitCastInputTypes].checkInputDataTypes() + val (indexType, inputTypes) = (indexExpr.dataType, inputExprs.map(_.dataType)) + if (indexType != IntegerType) { + return TypeCheckResult.TypeCheckFailure(s"first input to function $prettyName should " + + s"have IntegerType, but it's $indexType") + } + if (inputTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { + return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have StringType or BinaryType, but it's " + + inputTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(inputTypes, s"function $prettyName") } } @@ -307,27 +319,27 @@ case class Elt(children: Seq[Expression]) null } else { val index = indexObj.asInstanceOf[Int] - if (index <= 0 || index > stringExprs.length) { + if (index <= 0 || index > inputExprs.length) { null } else { - stringExprs(index - 1).eval(input) + inputExprs(index - 1).eval(input) } } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val index = indexExpr.genCode(ctx) - val strings = stringExprs.map(_.genCode(ctx)) + val inputs = inputExprs.map(_.genCode(ctx)) val indexVal = ctx.freshName("index") val indexMatched = ctx.freshName("eltIndexMatched") - val stringVal = ctx.addMutableState(ctx.javaType(dataType), "stringVal") + val inputVal = ctx.addMutableState(ctx.javaType(dataType), "inputVal") - val assignStringValue = strings.zipWithIndex.map { case (eval, index) => + val assignInputValue = inputs.zipWithIndex.map { case (eval, index) => s""" |if ($indexVal == ${index + 1}) { | ${eval.code} - | $stringVal = ${eval.isNull} ? null : ${eval.value}; + | $inputVal = ${eval.isNull} ? null : ${eval.value}; | $indexMatched = true; | continue; |} @@ -335,7 +347,7 @@ case class Elt(children: Seq[Expression]) } val codes = ctx.splitExpressionsWithCurrentInputs( - expressions = assignStringValue, + expressions = assignInputValue, funcName = "eltFunc", extraArguments = ("int", indexVal) :: Nil, returnType = ctx.JAVA_BOOLEAN, @@ -361,11 +373,11 @@ case class Elt(children: Seq[Expression]) |${index.code} |final int $indexVal = ${index.value}; |${ctx.JAVA_BOOLEAN} $indexMatched = false; - |$stringVal = null; + |$inputVal = null; |do { | $codes |} while (false); - |final UTF8String ${ev.value} = $stringVal; + |final ${ctx.javaType(dataType)} ${ev.value} = $inputVal; |final boolean ${ev.isNull} = ${ev.value} == null; """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5d6edf6b8abec..80b8965e084a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1052,6 +1052,12 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ELT_OUTPUT_AS_STRING = buildConf("spark.sql.function.eltOutputAsString") + .doc("When this option is set to false and all inputs are binary, `elt` returns " + + "an output as binary. Otherwise, it returns as a string. ") + .booleanConf + .createWithDefault(false) + val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE = buildConf("spark.sql.streaming.continuous.executorQueueSize") .internal() @@ -1412,6 +1418,8 @@ class SQLConf extends Serializable with Logging { def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) + def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) + def partitionOverwriteMode: PartitionOverwriteMode.Value = PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 3661530cd622b..52a7ebdafd7c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -923,6 +923,60 @@ class TypeCoercionSuite extends AnalysisTest { } } + test("type coercion for Elt") { + val rule = TypeCoercion.EltCoercion(conf) + + ruleTest(rule, + Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))), + Elt(Seq(Literal(1), Literal("ab"), Literal("cde")))) + ruleTest(rule, + Elt(Seq(Literal(1.toShort), Literal("ab"), Literal("cde"))), + Elt(Seq(Cast(Literal(1.toShort), IntegerType), Literal("ab"), Literal("cde")))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(null), Literal("abc"))), + Elt(Seq(Literal(2), Cast(Literal(null), StringType), Literal("abc")))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(1), Literal("234"))), + Elt(Seq(Literal(2), Cast(Literal(1), StringType), Literal("234")))) + ruleTest(rule, + Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1))), + Elt(Seq(Literal(3), Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType), + Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort))), + Elt(Seq(Literal(2), Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType), + Cast(Literal(3.toShort), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(1L), Literal(0.1))), + Elt(Seq(Literal(1), Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(Decimal(10)))), + Elt(Seq(Literal(1), Cast(Literal(Decimal(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10)))), + Elt(Seq(Literal(1), Cast(Literal(BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10)))), + Elt(Seq(Literal(1), Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), + Elt(Seq(Literal(2), Cast(Literal(new java.sql.Date(0)), StringType), + Cast(Literal(new Timestamp(0)), StringType)))) + + withSQLConf("spark.sql.function.eltOutputAsString" -> "true") { + ruleTest(rule, + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), + Elt(Seq(Literal(1), Cast(Literal("123".getBytes), StringType), + Cast(Literal("456".getBytes), StringType)))) + } + + withSQLConf("spark.sql.function.eltOutputAsString" -> "false") { + ruleTest(rule, + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes)))) + } + } + test("BooleanEquality type cast") { val be = TypeCoercion.BooleanEquality // Use something more than a literal to avoid triggering the simplification rules. diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql new file mode 100644 index 0000000000000..717616f91db05 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql @@ -0,0 +1,44 @@ +-- Mixed inputs (output type is string) +SELECT elt(2, col1, col2, col3, col4, col5) col +FROM ( + SELECT + 'prefix_' col1, + id col2, + string(id + 1) col3, + encode(string(id + 2), 'utf-8') col4, + CAST(id AS DOUBLE) col5 + FROM range(10) +); + +SELECT elt(3, col1, col2, col3, col4) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +); + +-- turn on eltOutputAsString +set spark.sql.function.eltOutputAsString=true; + +SELECT elt(1, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +); + +-- turn off eltOutputAsString +set spark.sql.function.eltOutputAsString=false; + +-- Elt binary inputs (output type is binary) +SELECT elt(2, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +); diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out new file mode 100644 index 0000000000000..b62e1b6826045 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out @@ -0,0 +1,115 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +SELECT elt(2, col1, col2, col3, col4, col5) col +FROM ( + SELECT + 'prefix_' col1, + id col2, + string(id + 1) col3, + encode(string(id + 2), 'utf-8') col4, + CAST(id AS DOUBLE) col5 + FROM range(10) +) +-- !query 0 schema +struct +-- !query 0 output +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 1 +SELECT elt(3, col1, col2, col3, col4) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) +-- !query 1 schema +struct +-- !query 1 output +10 +11 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 2 +set spark.sql.function.eltOutputAsString=true +-- !query 2 schema +struct +-- !query 2 output +spark.sql.function.eltOutputAsString true + + +-- !query 3 +SELECT elt(1, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +) +-- !query 3 schema +struct +-- !query 3 output +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 4 +set spark.sql.function.eltOutputAsString=false +-- !query 4 schema +struct +-- !query 4 output +spark.sql.function.eltOutputAsString false + + +-- !query 5 +SELECT elt(2, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +) +-- !query 5 schema +struct +-- !query 5 output +1 +10 +2 +3 +4 +5 +6 +7 +8 +9 From bf65cd3cda46d5480bfcd13110975c46ca631972 Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Fri, 5 Jan 2018 17:29:27 -0800 Subject: [PATCH 028/774] [SPARK-22960][K8S] Revert use of ARG base_image in images ## What changes were proposed in this pull request? This PR reverts the `ARG base_image` before `FROM` in the images of driver, executor, and init-container, introduced in https://github.com/apache/spark/pull/20154. The reason is Docker versions before 17.06 do not support this use (`ARG` before `FROM`). ## How was this patch tested? Tested manually. vanzin foxish kimoonkim Author: Yinan Li Closes #20170 from liyinan926/master. --- .../docker/src/main/dockerfiles/driver/Dockerfile | 3 +-- .../docker/src/main/dockerfiles/executor/Dockerfile | 3 +-- .../docker/src/main/dockerfiles/init-container/Dockerfile | 3 +-- sbin/build-push-docker-images.sh | 8 ++++---- 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile index ff5289e10c21e..45fbcd9cd0deb 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile @@ -15,8 +15,7 @@ # limitations under the License. # -ARG base_image -FROM ${base_image} +FROM spark-base # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile index 3eabb42d4d852..0f806cf7e148e 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile @@ -15,8 +15,7 @@ # limitations under the License. # -ARG base_image -FROM ${base_image} +FROM spark-base # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile index e0a249e0ac71f..047056ab2633b 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile @@ -15,8 +15,7 @@ # limitations under the License. # -ARG base_image -FROM ${base_image} +FROM spark-base # If this docker file is being used in the context of building your images from a Spark distribution, the docker build # command should be invoked from the top level directory of the Spark distribution. E.g.: diff --git a/sbin/build-push-docker-images.sh b/sbin/build-push-docker-images.sh index bb8806dd33f37..b9532597419a5 100755 --- a/sbin/build-push-docker-images.sh +++ b/sbin/build-push-docker-images.sh @@ -60,13 +60,13 @@ function image_ref { } function build { - local base_image="$(image_ref spark-base 0)" - docker build --build-arg "spark_jars=$SPARK_JARS" \ + docker build \ + --build-arg "spark_jars=$SPARK_JARS" \ --build-arg "img_path=$IMG_PATH" \ - -t "$base_image" \ + -t spark-base \ -f "$IMG_PATH/spark-base/Dockerfile" . for image in "${!path[@]}"; do - docker build --build-arg "base_image=$base_image" -t "$(image_ref $image)" -f ${path[$image]} . + docker build -t "$(image_ref $image)" -f ${path[$image]} . done } From f2dd8b923759e8771b0e5f59bfa7ae4ad7e6a339 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Sat, 6 Jan 2018 16:11:20 +0800 Subject: [PATCH 029/774] [SPARK-22930][PYTHON][SQL] Improve the description of Vectorized UDFs for non-deterministic cases ## What changes were proposed in this pull request? Add tests for using non deterministic UDFs in aggregate. Update pandas_udf docstring w.r.t to determinism. ## How was this patch tested? test_nondeterministic_udf_in_aggregate Author: Li Jin Closes #20142 from icexelloss/SPARK-22930-pandas-udf-deterministic. --- python/pyspark/sql/functions.py | 12 +++++++- python/pyspark/sql/tests.py | 52 +++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a4ed562ad48b4..733e32bd825b0 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2214,7 +2214,17 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. seealso:: :meth:`pyspark.sql.GroupedData.apply` - .. note:: The user-defined function must be deterministic. + .. note:: The user-defined functions are considered deterministic by default. Due to + optimization, duplicate invocations may be eliminated or the function may even be invoked + more times than it is present in the query. If your function is not deterministic, call + `asNondeterministic` on the user defined function. E.g.: + + >>> @pandas_udf('double', PandasUDFType.SCALAR) # doctest: +SKIP + ... def random(v): + ... import numpy as np + ... import pandas as pd + ... return pd.Series(np.random.randn(len(v)) + >>> random = random.asNondeterministic() # doctest: +SKIP .. note:: The user-defined functions do not support conditional expressions or short curcuiting in boolean expressions and it ends up with being executed all internally. If the functions diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6dc767f9ec46e..689736d8e6456 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -386,6 +386,7 @@ def test_udf3(self): self.assertEqual(row[0], 5) def test_nondeterministic_udf(self): + # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations from pyspark.sql.functions import udf import random udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() @@ -413,6 +414,18 @@ def test_nondeterministic_udf2(self): pydoc.render_doc(random_udf) pydoc.render_doc(random_udf1) + def test_nondeterministic_udf_in_aggregate(self): + from pyspark.sql.functions import udf, sum + import random + udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic() + df = self.spark.range(10) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): + df.groupby('id').agg(sum(udf_random_col())).collect() + with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): + df.agg(sum(udf_random_col())).collect() + def test_chained_udf(self): self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) [row] = self.spark.sql("SELECT double(1)").collect() @@ -3567,6 +3580,18 @@ def tearDownClass(cls): time.tzset() ReusedSQLTestCase.tearDownClass() + @property + def random_udf(self): + from pyspark.sql.functions import pandas_udf + + @pandas_udf('double') + def random_udf(v): + import pandas as pd + import numpy as np + return pd.Series(np.random.random(len(v))) + random_udf = random_udf.asNondeterministic() + return random_udf + def test_vectorized_udf_basic(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10).select( @@ -3950,6 +3975,33 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self): finally: self.spark.conf.set("spark.sql.session.timeZone", orig_tz) + def test_nondeterministic_udf(self): + # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations + from pyspark.sql.functions import udf, pandas_udf, col + + @pandas_udf('double') + def plus_ten(v): + return v + 10 + random_udf = self.random_udf + + df = self.spark.range(10).withColumn('rand', random_udf(col('id'))) + result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas() + + self.assertEqual(random_udf.deterministic, False) + self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10)) + + def test_nondeterministic_udf_in_aggregate(self): + from pyspark.sql.functions import pandas_udf, sum + + df = self.spark.range(10) + random_udf = self.random_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): + df.groupby(df.id).agg(sum(random_udf(df.id))).collect() + with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): + df.agg(sum(random_udf(df.id))).collect() + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedSQLTestCase): From be9a804f2ef77a5044d3da7d9374976daf59fc16 Mon Sep 17 00:00:00 2001 From: zuotingbing Date: Sat, 6 Jan 2018 18:07:45 +0800 Subject: [PATCH 030/774] [SPARK-22793][SQL] Memory leak in Spark Thrift Server # What changes were proposed in this pull request? 1. Start HiveThriftServer2. 2. Connect to thriftserver through beeline. 3. Close the beeline. 4. repeat step2 and step 3 for many times. we found there are many directories never be dropped under the path `hive.exec.local.scratchdir` and `hive.exec.scratchdir`, as we know the scratchdir has been added to deleteOnExit when it be created. So it means that the cache size of FileSystem `deleteOnExit` will keep increasing until JVM terminated. In addition, we use `jmap -histo:live [PID]` to printout the size of objects in HiveThriftServer2 Process, we can find the object `org.apache.spark.sql.hive.client.HiveClientImpl` and `org.apache.hadoop.hive.ql.session.SessionState` keep increasing even though we closed all the beeline connections, which may caused the leak of Memory. # How was this patch tested? manual tests This PR follw-up the https://github.com/apache/spark/pull/19989 Author: zuotingbing Closes #20029 from zuotingbing/SPARK-22793. --- .../org/apache/spark/sql/hive/HiveSessionStateBuilder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 92cb4ef11c9e3..dc92ad3b0c1ac 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -42,7 +42,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session * Create a Hive aware resource loader. */ override protected lazy val resourceLoader: HiveSessionResourceLoader = { - val client: HiveClient = externalCatalog.client.newSession() + val client: HiveClient = externalCatalog.client new HiveSessionResourceLoader(session, client) } From 7b78041423b6ee330def2336dfd1ff9ae8469c59 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Sat, 6 Jan 2018 18:19:57 +0800 Subject: [PATCH 031/774] [SPARK-21786][SQL] When acquiring 'compressionCodecClassName' in 'ParquetOptions', `parquet.compression` needs to be considered. [SPARK-21786][SQL] When acquiring 'compressionCodecClassName' in 'ParquetOptions', `parquet.compression` needs to be considered. ## What changes were proposed in this pull request? Since Hive 1.1, Hive allows users to set parquet compression codec via table-level properties parquet.compression. See the JIRA: https://issues.apache.org/jira/browse/HIVE-7858 . We do support orc.compression for ORC. Thus, for external users, it is more straightforward to support both. See the stackflow question: https://stackoverflow.com/questions/36941122/spark-sql-ignores-parquet-compression-propertie-specified-in-tblproperties In Spark side, our table-level compression conf compression was added by #11464 since Spark 2.0. We need to support both table-level conf. Users might also use session-level conf spark.sql.parquet.compression.codec. The priority rule will be like If other compression codec configuration was found through hive or parquet, the precedence would be compression, parquet.compression, spark.sql.parquet.compression.codec. Acceptable values include: none, uncompressed, snappy, gzip, lzo. The rule for Parquet is consistent with the ORC after the change. Changes: 1.Increased acquiring 'compressionCodecClassName' from `parquet.compression`,and the precedence order is `compression`,`parquet.compression`,`spark.sql.parquet.compression.codec`, just like what we do in `OrcOptions`. 2.Change `spark.sql.parquet.compression.codec` to support "none".Actually in `ParquetOptions`,we do support "none" as equivalent to "uncompressed", but it does not allowed to configured to "none". 3.Change `compressionCode` to `compressionCodecClassName`. ## How was this patch tested? Add test. Author: fjh100456 Closes #20076 from fjh100456/ParquetOptionIssue. --- docs/sql-programming-guide.md | 6 +- .../apache/spark/sql/internal/SQLConf.scala | 14 +- .../datasources/parquet/ParquetOptions.scala | 12 +- ...rquetCompressionCodecPrecedenceSuite.scala | 122 ++++++++++++++++++ 4 files changed, 145 insertions(+), 9 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index b50f9360b866c..3ccaaf4d5b1fa 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -953,8 +953,10 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession spark.sql.parquet.compression.codec snappy - Sets the compression codec use when writing Parquet files. Acceptable values include: - uncompressed, snappy, gzip, lzo. + Sets the compression codec used when writing Parquet files. If either `compression` or + `parquet.compression` is specified in the table-specific options/properties, the precedence would be + `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`. Acceptable values include: + none, uncompressed, snappy, gzip, lzo. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 80b8965e084a2..7d1217de254a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -325,11 +325,13 @@ object SQLConf { .createWithDefault(false) val PARQUET_COMPRESSION = buildConf("spark.sql.parquet.compression.codec") - .doc("Sets the compression codec use when writing Parquet files. Acceptable values include: " + - "uncompressed, snappy, gzip, lzo.") + .doc("Sets the compression codec used when writing Parquet files. If either `compression` or" + + "`parquet.compression` is specified in the table-specific options/properties, the precedence" + + "would be `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`." + + "Acceptable values include: none, uncompressed, snappy, gzip, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) - .checkValues(Set("uncompressed", "snappy", "gzip", "lzo")) + .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo")) .createWithDefault("snappy") val PARQUET_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.filterPushdown") @@ -366,8 +368,10 @@ object SQLConf { .createWithDefault(true) val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") - .doc("Sets the compression codec use when writing ORC files. Acceptable values include: " + - "none, uncompressed, snappy, zlib, lzo.") + .doc("Sets the compression codec used when writing ORC files. If either `compression` or" + + "`orc.compress` is specified in the table-specific options/properties, the precedence" + + "would be `compression`, `orc.compress`, `spark.sql.orc.compression.codec`." + + "Acceptable values include: none, uncompressed, snappy, zlib, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) .checkValues(Set("none", "uncompressed", "snappy", "zlib", "lzo")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index 772d4565de548..ef67ea7d17cea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.util.Locale +import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -42,8 +43,15 @@ private[parquet] class ParquetOptions( * Acceptable values are defined in [[shortParquetCompressionCodecNames]]. */ val compressionCodecClassName: String = { - val codecName = parameters.getOrElse("compression", - sqlConf.parquetCompressionCodec).toLowerCase(Locale.ROOT) + // `compression`, `parquet.compression`(i.e., ParquetOutputFormat.COMPRESSION), and + // `spark.sql.parquet.compression.codec` + // are in order of precedence from highest to lowest. + val parquetCompressionConf = parameters.get(ParquetOutputFormat.COMPRESSION) + val codecName = parameters + .get("compression") + .orElse(parquetCompressionConf) + .getOrElse(sqlConf.parquetCompressionCodec) + .toLowerCase(Locale.ROOT) if (!shortParquetCompressionCodecNames.contains(codecName)) { val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala new file mode 100644 index 0000000000000..ed8fd2b453456 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetOutputFormat + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetCompressionCodecPrecedenceSuite extends ParquetTest with SharedSQLContext { + test("Test `spark.sql.parquet.compression.codec` config") { + Seq("NONE", "UNCOMPRESSED", "SNAPPY", "GZIP", "LZO").foreach { c => + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> c) { + val expected = if (c == "NONE") "UNCOMPRESSED" else c + val option = new ParquetOptions(Map.empty[String, String], spark.sessionState.conf) + assert(option.compressionCodecClassName == expected) + } + } + } + + test("[SPARK-21786] Test Acquiring 'compressionCodecClassName' for parquet in right order.") { + // When "compression" is configured, it should be the first choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map("compression" -> "uncompressed", ParquetOutputFormat.COMPRESSION -> "gzip") + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "UNCOMPRESSED") + } + + // When "compression" is not configured, "parquet.compression" should be the preferred choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map(ParquetOutputFormat.COMPRESSION -> "gzip") + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "GZIP") + } + + // When both "compression" and "parquet.compression" are not configured, + // spark.sql.parquet.compression.codec should be the right choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map.empty[String, String] + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "SNAPPY") + } + } + + private def getTableCompressionCodec(path: String): Seq[String] = { + val hadoopConf = spark.sessionState.newHadoopConf() + val codecs = for { + footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConf) + block <- footer.getParquetMetadata.getBlocks.asScala + column <- block.getColumns.asScala + } yield column.getCodec.name() + codecs.distinct + } + + private def createTableWithCompression( + tableName: String, + isPartitioned: Boolean, + compressionCodec: String, + rootDir: File): Unit = { + val options = + s""" + |OPTIONS('path'='${rootDir.toURI.toString.stripSuffix("/")}/$tableName', + |'parquet.compression'='$compressionCodec') + """.stripMargin + val partitionCreate = if (isPartitioned) "PARTITIONED BY (p)" else "" + sql( + s""" + |CREATE TABLE $tableName USING Parquet $options $partitionCreate + |AS SELECT 1 AS col1, 2 AS p + """.stripMargin) + } + + private def checkCompressionCodec(compressionCodec: String, isPartitioned: Boolean): Unit = { + withTempDir { tmpDir => + val tempTableName = "TempParquetTable" + withTable(tempTableName) { + createTableWithCompression(tempTableName, isPartitioned, compressionCodec, tmpDir) + val partitionPath = if (isPartitioned) "p=2" else "" + val path = s"${tmpDir.getPath.stripSuffix("/")}/$tempTableName/$partitionPath" + val realCompressionCodecs = getTableCompressionCodec(path) + assert(realCompressionCodecs.forall(_ == compressionCodec)) + } + } + } + + test("Create parquet table with compression") { + Seq(true, false).foreach { isPartitioned => + Seq("UNCOMPRESSED", "SNAPPY", "GZIP").foreach { compressionCodec => + checkCompressionCodec(compressionCodec, isPartitioned) + } + } + } + + test("Create table with unknown compression") { + Seq(true, false).foreach { isPartitioned => + val exception = intercept[IllegalArgumentException] { + checkCompressionCodec("aa", isPartitioned) + } + assert(exception.getMessage.contains("Codec [aa] is not available")) + } + } +} From 993f21567a1dd33e43ef9a626e0ddfbe46f83f93 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 6 Jan 2018 23:08:26 +0800 Subject: [PATCH 032/774] [SPARK-22901][PYTHON][FOLLOWUP] Adds the doc for asNondeterministic for wrapped UDF function ## What changes were proposed in this pull request? This PR wraps the `asNondeterministic` attribute in the wrapped UDF function to set the docstring properly. ```python from pyspark.sql.functions import udf help(udf(lambda x: x).asNondeterministic) ``` Before: ``` Help on function in module pyspark.sql.udf: lambda (END ``` After: ``` Help on function asNondeterministic in module pyspark.sql.udf: asNondeterministic() Updates UserDefinedFunction to nondeterministic. .. versionadded:: 2.3 (END) ``` ## How was this patch tested? Manually tested and a simple test was added. Author: hyukjinkwon Closes #20173 from HyukjinKwon/SPARK-22901-followup. --- python/pyspark/sql/tests.py | 1 + python/pyspark/sql/udf.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 689736d8e6456..122a65b83aef9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -413,6 +413,7 @@ def test_nondeterministic_udf2(self): pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType())) pydoc.render_doc(random_udf) pydoc.render_doc(random_udf1) + pydoc.render_doc(udf(lambda x: x).asNondeterministic) def test_nondeterministic_udf_in_aggregate(self): from pyspark.sql.functions import udf, sum diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 5e75eb6545333..5e80ab9165867 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -169,8 +169,8 @@ def wrapper(*args): wrapper.returnType = self.returnType wrapper.evalType = self.evalType wrapper.deterministic = self.deterministic - wrapper.asNondeterministic = lambda: self.asNondeterministic()._wrapped() - + wrapper.asNondeterministic = functools.wraps( + self.asNondeterministic)(lambda: self.asNondeterministic()._wrapped()) return wrapper def asNondeterministic(self): From 9a7048b2889bd0fd66e68a0ce3e07e466315a051 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 7 Jan 2018 00:19:21 +0800 Subject: [PATCH 033/774] [HOTFIX] Fix style checking failure ## What changes were proposed in this pull request? This PR is to fix the style checking failure. ## How was this patch tested? N/A Author: gatorsmile Closes #20175 from gatorsmile/stylefix. --- .../org/apache/spark/sql/internal/SQLConf.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7d1217de254a2..5c61f10bb71ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -325,10 +325,11 @@ object SQLConf { .createWithDefault(false) val PARQUET_COMPRESSION = buildConf("spark.sql.parquet.compression.codec") - .doc("Sets the compression codec used when writing Parquet files. If either `compression` or" + - "`parquet.compression` is specified in the table-specific options/properties, the precedence" + - "would be `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`." + - "Acceptable values include: none, uncompressed, snappy, gzip, lzo.") + .doc("Sets the compression codec used when writing Parquet files. If either `compression` or " + + "`parquet.compression` is specified in the table-specific options/properties, the " + + "precedence would be `compression`, `parquet.compression`, " + + "`spark.sql.parquet.compression.codec`. Acceptable values include: none, uncompressed, " + + "snappy, gzip, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo")) @@ -368,8 +369,8 @@ object SQLConf { .createWithDefault(true) val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") - .doc("Sets the compression codec used when writing ORC files. If either `compression` or" + - "`orc.compress` is specified in the table-specific options/properties, the precedence" + + .doc("Sets the compression codec used when writing ORC files. If either `compression` or " + + "`orc.compress` is specified in the table-specific options/properties, the precedence " + "would be `compression`, `orc.compress`, `spark.sql.orc.compression.codec`." + "Acceptable values include: none, uncompressed, snappy, zlib, lzo.") .stringConf From 18e94149992618a2b4e6f0fd3b3f4594e1745224 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 7 Jan 2018 13:42:01 +0800 Subject: [PATCH 034/774] [SPARK-22973][SQL] Fix incorrect results of Casting Map to String ## What changes were proposed in this pull request? This pr fixed the issue when casting maps into strings; ``` scala> Seq(Map(1 -> "a", 2 -> "b")).toDF("a").write.saveAsTable("t") scala> sql("SELECT cast(a as String) FROM t").show(false) +----------------------------------------------------------------+ |a | +----------------------------------------------------------------+ |org.apache.spark.sql.catalyst.expressions.UnsafeMapData38bdd75d| +----------------------------------------------------------------+ ``` This pr modified the result into; ``` +----------------+ |a | +----------------+ |[1 -> a, 2 -> b]| +----------------+ ``` ## How was this patch tested? Added tests in `CastSuite`. Author: Takeshi Yamamuro Closes #20166 from maropu/SPARK-22973. --- .../spark/sql/catalyst/expressions/Cast.scala | 89 +++++++++++++++++++ .../sql/catalyst/expressions/CastSuite.scala | 28 ++++++ 2 files changed, 117 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d4fc5e0f168a7..f2de4c8e30bec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -228,6 +228,37 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String builder.append("]") builder.build() }) + case MapType(kt, vt, _) => + buildCast[MapData](_, map => { + val builder = new UTF8StringBuilder + builder.append("[") + if (map.numElements > 0) { + val keyArray = map.keyArray() + val valueArray = map.valueArray() + val keyToUTF8String = castToString(kt) + val valueToUTF8String = castToString(vt) + builder.append(keyToUTF8String(keyArray.get(0, kt)).asInstanceOf[UTF8String]) + builder.append(" ->") + if (!valueArray.isNullAt(0)) { + builder.append(" ") + builder.append(valueToUTF8String(valueArray.get(0, vt)).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < map.numElements) { + builder.append(", ") + builder.append(keyToUTF8String(keyArray.get(i, kt)).asInstanceOf[UTF8String]) + builder.append(" ->") + if (!valueArray.isNullAt(i)) { + builder.append(" ") + builder.append(valueToUTF8String(valueArray.get(i, vt)) + .asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -654,6 +685,53 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """.stripMargin } + private def writeMapToStringBuilder( + kt: DataType, + vt: DataType, + map: String, + buffer: String, + ctx: CodegenContext): String = { + + def dataToStringFunc(func: String, dataType: DataType) = { + val funcName = ctx.freshName(func) + val dataToStringCode = castToStringCode(dataType, ctx) + ctx.addNewFunction(funcName, + s""" + |private UTF8String $funcName(${ctx.javaType(dataType)} data) { + | UTF8String dataStr = null; + | ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)} + | return dataStr; + |} + """.stripMargin) + } + + val keyToStringFunc = dataToStringFunc("keyToString", kt) + val valueToStringFunc = dataToStringFunc("valueToString", vt) + val loopIndex = ctx.freshName("loopIndex") + s""" + |$buffer.append("["); + |if ($map.numElements() > 0) { + | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, "0")})); + | $buffer.append(" ->"); + | if (!$map.valueArray().isNullAt(0)) { + | $buffer.append(" "); + | $buffer.append($valueToStringFunc(${ctx.getValue(s"$map.valueArray()", vt, "0")})); + | } + | for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) { + | $buffer.append(", "); + | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, loopIndex)})); + | $buffer.append(" ->"); + | if (!$map.valueArray().isNullAt($loopIndex)) { + | $buffer.append(" "); + | $buffer.append($valueToStringFunc( + | ${ctx.getValue(s"$map.valueArray()", vt, loopIndex)})); + | } + | } + |} + |$buffer.append("]"); + """.stripMargin + } + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => @@ -676,6 +754,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$evPrim = $buffer.build(); """.stripMargin } + case MapType(kt, vt, _) => + (c, evPrim, evNull) => { + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val writeMapElemCode = writeMapToStringBuilder(kt, vt, c, buffer, ctx) + s""" + |$bufferClass $buffer = new $bufferClass(); + |$writeMapElemCode; + |$evPrim = $buffer.build(); + """.stripMargin + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index e3ed7171defd8..1445bb8a97d40 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -878,4 +878,32 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StringType) checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]") } + + test("SPARK-22973 Cast map to string") { + val ret1 = cast(Literal.create(Map(1 -> "a", 2 -> "b", 3 -> "c")), StringType) + checkEvaluation(ret1, "[1 -> a, 2 -> b, 3 -> c]") + val ret2 = cast( + Literal.create(Map("1" -> "a".getBytes, "2" -> null, "3" -> "c".getBytes)), + StringType) + checkEvaluation(ret2, "[1 -> a, 2 ->, 3 -> c]") + val ret3 = cast( + Literal.create(Map( + 1 -> Date.valueOf("2014-12-03"), + 2 -> Date.valueOf("2014-12-04"), + 3 -> Date.valueOf("2014-12-05"))), + StringType) + checkEvaluation(ret3, "[1 -> 2014-12-03, 2 -> 2014-12-04, 3 -> 2014-12-05]") + val ret4 = cast( + Literal.create(Map( + 1 -> Timestamp.valueOf("2014-12-03 13:01:00"), + 2 -> Timestamp.valueOf("2014-12-04 15:05:00"))), + StringType) + checkEvaluation(ret4, "[1 -> 2014-12-03 13:01:00, 2 -> 2014-12-04 15:05:00]") + val ret5 = cast( + Literal.create(Map( + 1 -> Array(1, 2, 3), + 2 -> Array(4, 5, 6))), + StringType) + checkEvaluation(ret5, "[1 -> [1, 2, 3], 2 -> [4, 5, 6]]") + } } From 71d65a32158a55285be197bec4e41fedc9225b94 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 8 Jan 2018 11:39:45 +0800 Subject: [PATCH 035/774] [SPARK-22985] Fix argument escaping bug in from_utc_timestamp / to_utc_timestamp codegen ## What changes were proposed in this pull request? This patch adds additional escaping in `from_utc_timestamp` / `to_utc_timestamp` expression codegen in order to a bug where invalid timezones which contain special characters could cause generated code to fail to compile. ## How was this patch tested? New regression tests in `DateExpressionsSuite`. Author: Josh Rosen Closes #20182 from JoshRosen/SPARK-22985-fix-utc-timezone-function-escaping-bugs. --- .../catalyst/expressions/datetimeExpressions.scala | 12 ++++++++---- .../catalyst/expressions/DateExpressionsSuite.scala | 6 ++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 7a674ea7f4d76..424871f2047e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -23,6 +23,8 @@ import java.util.{Calendar, TimeZone} import scala.util.control.NonFatal +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -1008,7 +1010,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (right.foldable) { - val tz = right.eval() + val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { ev.copy(code = s""" |boolean ${ev.isNull} = true; @@ -1017,8 +1019,9 @@ case class FromUTCTimestamp(left: Expression, right: Expression) } else { val tzClass = classOf[TimeZone].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val escapedTz = StringEscapeUtils.escapeJava(tz.toString) val tzTerm = ctx.addMutableState(tzClass, "tz", - v => s"""$v = $dtu.getTimeZone("$tz");""") + v => s"""$v = $dtu.getTimeZone("$escapedTz");""") val utcTerm = "tzUTC" ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") @@ -1185,7 +1188,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (right.foldable) { - val tz = right.eval() + val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { ev.copy(code = s""" |boolean ${ev.isNull} = true; @@ -1194,8 +1197,9 @@ case class ToUTCTimestamp(left: Expression, right: Expression) } else { val tzClass = classOf[TimeZone].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val escapedTz = StringEscapeUtils.escapeJava(tz.toString) val tzTerm = ctx.addMutableState(tzClass, "tz", - v => s"""$v = $dtu.getTimeZone("$tz");""") + v => s"""$v = $dtu.getTimeZone("$escapedTz");""") val utcTerm = "tzUTC" ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 63f6ceeb21b96..786266a2c13c0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -22,6 +22,7 @@ import java.text.SimpleDateFormat import java.util.{Calendar, Locale, TimeZone} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT @@ -791,6 +792,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test(null, "UTC", null) test("2015-07-24 00:00:00", null, null) test(null, null, null) + // Test escaping of timezone + GenerateUnsafeProjection.generate( + ToUTCTimestamp(Literal(Timestamp.valueOf("2015-07-24 00:00:00")), Literal("\"quote")) :: Nil) } test("from_utc_timestamp") { @@ -811,5 +815,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test(null, "UTC", null) test("2015-07-24 00:00:00", null, null) test(null, null, null) + // Test escaping of timezone + GenerateUnsafeProjection.generate(FromUTCTimestamp(Literal(0), Literal("\"quote")) :: Nil) } } From 3e40eb3f1ffac3d2f49459a801e3ce171ed34091 Mon Sep 17 00:00:00 2001 From: Guilherme Berger Date: Mon, 8 Jan 2018 14:32:05 +0900 Subject: [PATCH 036/774] [SPARK-22566][PYTHON] Better error message for `_merge_type` in Pandas to Spark DF conversion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? It provides a better error message when doing `spark_session.createDataFrame(pandas_df)` with no schema and an error occurs in the schema inference due to incompatible types. The Pandas column names are propagated down and the error message mentions which column had the merging error. https://issues.apache.org/jira/browse/SPARK-22566 ## How was this patch tested? Manually in the `./bin/pyspark` console, and with new tests: `./python/run-tests` screen shot 2017-11-21 at 13 29 49 I state that the contribution is my original work and that I license the work to the Apache Spark project under the project’s open source license. Author: Guilherme Berger Closes #19792 from gberger/master. --- python/pyspark/sql/session.py | 17 +++--- python/pyspark/sql/tests.py | 100 ++++++++++++++++++++++++++++++++++ python/pyspark/sql/types.py | 28 +++++++--- 3 files changed, 129 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 6e5eec48e8aca..6052fa9e84096 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -325,11 +325,12 @@ def range(self, start, end=None, step=1, numPartitions=None): return DataFrame(jdf, self._wrapped) - def _inferSchemaFromList(self, data): + def _inferSchemaFromList(self, data, names=None): """ Infer schema from list of Row or tuple. :param data: list of Row or tuple + :param names: list of column names :return: :class:`pyspark.sql.types.StructType` """ if not data: @@ -338,12 +339,12 @@ def _inferSchemaFromList(self, data): if type(first) is dict: warnings.warn("inferring schema from dict is deprecated," "please use pyspark.sql.Row instead") - schema = reduce(_merge_type, map(_infer_schema, data)) + schema = reduce(_merge_type, (_infer_schema(row, names) for row in data)) if _has_nulltype(schema): raise ValueError("Some of types cannot be determined after inferring") return schema - def _inferSchema(self, rdd, samplingRatio=None): + def _inferSchema(self, rdd, samplingRatio=None, names=None): """ Infer schema from an RDD of Row or tuple. @@ -360,10 +361,10 @@ def _inferSchema(self, rdd, samplingRatio=None): "Use pyspark.sql.Row instead") if samplingRatio is None: - schema = _infer_schema(first) + schema = _infer_schema(first, names=names) if _has_nulltype(schema): for row in rdd.take(100)[1:]: - schema = _merge_type(schema, _infer_schema(row)) + schema = _merge_type(schema, _infer_schema(row, names=names)) if not _has_nulltype(schema): break else: @@ -372,7 +373,7 @@ def _inferSchema(self, rdd, samplingRatio=None): else: if samplingRatio < 0.99: rdd = rdd.sample(False, float(samplingRatio)) - schema = rdd.map(_infer_schema).reduce(_merge_type) + schema = rdd.map(lambda row: _infer_schema(row, names)).reduce(_merge_type) return schema def _createFromRDD(self, rdd, schema, samplingRatio): @@ -380,7 +381,7 @@ def _createFromRDD(self, rdd, schema, samplingRatio): Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. """ if schema is None or isinstance(schema, (list, tuple)): - struct = self._inferSchema(rdd, samplingRatio) + struct = self._inferSchema(rdd, samplingRatio, names=schema) converter = _create_converter(struct) rdd = rdd.map(converter) if isinstance(schema, (list, tuple)): @@ -406,7 +407,7 @@ def _createFromLocal(self, data, schema): data = list(data) if schema is None or isinstance(schema, (list, tuple)): - struct = self._inferSchemaFromList(data) + struct = self._inferSchemaFromList(data, names=schema) converter = _create_converter(struct) data = map(converter, data) if isinstance(schema, (list, tuple)): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 122a65b83aef9..13576ff57001b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -68,6 +68,7 @@ from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings +from pyspark.sql.types import _merge_type from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window @@ -898,6 +899,15 @@ def test_infer_schema(self): result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'") self.assertEqual(1, result.head()[0]) + def test_infer_schema_not_enough_names(self): + df = self.spark.createDataFrame([["a", "b"]], ["col1"]) + self.assertEqual(df.columns, ['col1', '_2']) + + def test_infer_schema_fails(self): + with self.assertRaisesRegexp(TypeError, 'field a'): + self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]), + schema=["a", "b"], samplingRatio=0.99) + def test_infer_nested_schema(self): NestedRow = Row("f1", "f2") nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), @@ -918,6 +928,10 @@ def test_infer_nested_schema(self): df = self.spark.createDataFrame(rdd) self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) + def test_create_dataframe_from_dict_respects_schema(self): + df = self.spark.createDataFrame([{'a': 1}], ["b"]) + self.assertEqual(df.columns, ['b']) + def test_create_dataframe_from_objects(self): data = [MyObject(1, "1"), MyObject(2, "2")] df = self.spark.createDataFrame(data) @@ -1772,6 +1786,92 @@ def test_infer_long_type(self): self.assertEqual(_infer_type(2**61), LongType()) self.assertEqual(_infer_type(2**71), LongType()) + def test_merge_type(self): + self.assertEqual(_merge_type(LongType(), NullType()), LongType()) + self.assertEqual(_merge_type(NullType(), LongType()), LongType()) + + self.assertEqual(_merge_type(LongType(), LongType()), LongType()) + + self.assertEqual(_merge_type( + ArrayType(LongType()), + ArrayType(LongType()) + ), ArrayType(LongType())) + with self.assertRaisesRegexp(TypeError, 'element in array'): + _merge_type(ArrayType(LongType()), ArrayType(DoubleType())) + + self.assertEqual(_merge_type( + MapType(StringType(), LongType()), + MapType(StringType(), LongType()) + ), MapType(StringType(), LongType())) + with self.assertRaisesRegexp(TypeError, 'key of map'): + _merge_type( + MapType(StringType(), LongType()), + MapType(DoubleType(), LongType())) + with self.assertRaisesRegexp(TypeError, 'value of map'): + _merge_type( + MapType(StringType(), LongType()), + MapType(StringType(), DoubleType())) + + self.assertEqual(_merge_type( + StructType([StructField("f1", LongType()), StructField("f2", StringType())]), + StructType([StructField("f1", LongType()), StructField("f2", StringType())]) + ), StructType([StructField("f1", LongType()), StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, 'field f1'): + _merge_type( + StructType([StructField("f1", LongType()), StructField("f2", StringType())]), + StructType([StructField("f1", DoubleType()), StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]), + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]) + ), StructType([StructField("f1", StructType([StructField("f2", LongType())]))])) + with self.assertRaisesRegexp(TypeError, 'field f2 in field f1'): + _merge_type( + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]), + StructType([StructField("f1", StructType([StructField("f2", StringType())]))])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]), + StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]) + ), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, 'element in array field f1'): + _merge_type( + StructType([ + StructField("f1", ArrayType(LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", ArrayType(DoubleType())), + StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]) + ), StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, 'value of map field f1'): + _merge_type( + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", MapType(StringType(), DoubleType())), + StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]) + ), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])) + with self.assertRaisesRegexp(TypeError, 'key of map element in array field f1'): + _merge_type( + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), + StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))]) + ) + def test_filter_with_datetime(self): time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000) date = time.date() diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 146e673ae9756..0dc5823f72a3c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1073,7 +1073,7 @@ def _infer_type(obj): raise TypeError("not supported type: %s" % type(obj)) -def _infer_schema(row): +def _infer_schema(row, names=None): """Infer the schema from dict/namedtuple/object""" if isinstance(row, dict): items = sorted(row.items()) @@ -1084,7 +1084,10 @@ def _infer_schema(row): elif hasattr(row, "_fields"): # namedtuple items = zip(row._fields, tuple(row)) else: - names = ['_%d' % i for i in range(1, len(row) + 1)] + if names is None: + names = ['_%d' % i for i in range(1, len(row) + 1)] + elif len(names) < len(row): + names.extend('_%d' % i for i in range(len(names) + 1, len(row) + 1)) items = zip(names, row) elif hasattr(row, "__dict__"): # object @@ -1109,19 +1112,27 @@ def _has_nulltype(dt): return isinstance(dt, NullType) -def _merge_type(a, b): +def _merge_type(a, b, name=None): + if name is None: + new_msg = lambda msg: msg + new_name = lambda n: "field %s" % n + else: + new_msg = lambda msg: "%s: %s" % (name, msg) + new_name = lambda n: "field %s in %s" % (n, name) + if isinstance(a, NullType): return b elif isinstance(b, NullType): return a elif type(a) is not type(b): # TODO: type cast (such as int -> long) - raise TypeError("Can not merge type %s and %s" % (type(a), type(b))) + raise TypeError(new_msg("Can not merge type %s and %s" % (type(a), type(b)))) # same type if isinstance(a, StructType): nfs = dict((f.name, f.dataType) for f in b.fields) - fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) + fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()), + name=new_name(f.name))) for f in a.fields] names = set([f.name for f in fields]) for n in nfs: @@ -1130,11 +1141,12 @@ def _merge_type(a, b): return StructType(fields) elif isinstance(a, ArrayType): - return ArrayType(_merge_type(a.elementType, b.elementType), True) + return ArrayType(_merge_type(a.elementType, b.elementType, + name='element in array %s' % name), True) elif isinstance(a, MapType): - return MapType(_merge_type(a.keyType, b.keyType), - _merge_type(a.valueType, b.valueType), + return MapType(_merge_type(a.keyType, b.keyType, name='key of map %s' % name), + _merge_type(a.valueType, b.valueType, name='value of map %s' % name), True) else: return a From 8fdeb4b9946bd9be045abb919da2e531708b3bd4 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 8 Jan 2018 13:59:08 +0800 Subject: [PATCH 037/774] [SPARK-22979][PYTHON][SQL] Avoid per-record type dispatch in Python data conversion (EvaluatePython.fromJava) ## What changes were proposed in this pull request? Seems we can avoid type dispatch for each value when Java objection (from Pyrolite) -> Spark's internal data format because we know the schema ahead. I manually performed the benchmark as below: ```scala test("EvaluatePython.fromJava / EvaluatePython.makeFromJava") { val numRows = 1000 * 1000 val numFields = 30 val random = new Random(System.nanoTime()) val types = Array( BooleanType, ByteType, FloatType, DoubleType, IntegerType, LongType, ShortType, DecimalType.ShortDecimal, DecimalType.IntDecimal, DecimalType.ByteDecimal, DecimalType.FloatDecimal, DecimalType.LongDecimal, new DecimalType(5, 2), new DecimalType(12, 2), new DecimalType(30, 10), CalendarIntervalType) val schema = RandomDataGenerator.randomSchema(random, numFields, types) val rows = mutable.ArrayBuffer.empty[Array[Any]] var i = 0 while (i < numRows) { val row = RandomDataGenerator.randomRow(random, schema) rows += row.toSeq.toArray i += 1 } val benchmark = new Benchmark("EvaluatePython.fromJava / EvaluatePython.makeFromJava", numRows) benchmark.addCase("Before - EvaluatePython.fromJava", 3) { _ => var i = 0 while (i < numRows) { EvaluatePython.fromJava(rows(i), schema) i += 1 } } benchmark.addCase("After - EvaluatePython.makeFromJava", 3) { _ => val fromJava = EvaluatePython.makeFromJava(schema) var i = 0 while (i < numRows) { fromJava(rows(i)) i += 1 } } benchmark.run() } ``` ``` EvaluatePython.fromJava / EvaluatePython.makeFromJava: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Before - EvaluatePython.fromJava 1265 / 1346 0.8 1264.8 1.0X After - EvaluatePython.makeFromJava 571 / 649 1.8 570.8 2.2X ``` If the structure is nested, I think the advantage should be larger than this. ## How was this patch tested? Existing tests should cover this. Also, I manually checked if the values from before / after are actually same via `assert` when performing the benchmarks. Author: hyukjinkwon Closes #20172 from HyukjinKwon/type-dispatch-python-eval. --- .../org/apache/spark/sql/SparkSession.scala | 5 +- .../python/BatchEvalPythonExec.scala | 7 +- .../sql/execution/python/EvaluatePython.scala | 166 ++++++++++++------ 3 files changed, 118 insertions(+), 60 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 272eb844226d4..734573ba31f71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -742,7 +742,10 @@ class SparkSession private( private[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], schema: StructType): DataFrame = { - val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow]) + val rowRdd = rdd.mapPartitions { iter => + val fromJava = python.EvaluatePython.makeFromJava(schema) + iter.map(r => fromJava(r).asInstanceOf[InternalRow]) + } internalCreateDataFrame(rowRdd, schema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 26ee25f633ea4..f4d83e8dc7c2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -79,16 +79,19 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi } else { StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) } + + val fromJava = EvaluatePython.makeFromJava(resultType) + outputIterator.flatMap { pickedResult => val unpickledBatch = unpickle.loads(pickedResult) unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala }.map { result => if (udfs.length == 1) { // fast path for single UDF - mutableRow(0) = EvaluatePython.fromJava(result, resultType) + mutableRow(0) = fromJava(result) mutableRow } else { - EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] + fromJava(result).asInstanceOf[InternalRow] } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 9bbfa6018ba77..520afad287648 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -83,82 +83,134 @@ object EvaluatePython { } /** - * Converts `obj` to the type specified by the data type, or returns null if the type of obj is - * unexpected. Because Python doesn't enforce the type. + * Make a converter that converts `obj` to the type specified by the data type, or returns + * null if the type of obj is unexpected. Because Python doesn't enforce the type. */ - def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { - case (null, _) => null - - case (c: Boolean, BooleanType) => c + def makeFromJava(dataType: DataType): Any => Any = dataType match { + case BooleanType => (obj: Any) => nullSafeConvert(obj) { + case b: Boolean => b + } - case (c: Byte, ByteType) => c - case (c: Short, ByteType) => c.toByte - case (c: Int, ByteType) => c.toByte - case (c: Long, ByteType) => c.toByte + case ByteType => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c + case c: Short => c.toByte + case c: Int => c.toByte + case c: Long => c.toByte + } - case (c: Byte, ShortType) => c.toShort - case (c: Short, ShortType) => c - case (c: Int, ShortType) => c.toShort - case (c: Long, ShortType) => c.toShort + case ShortType => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c.toShort + case c: Short => c + case c: Int => c.toShort + case c: Long => c.toShort + } - case (c: Byte, IntegerType) => c.toInt - case (c: Short, IntegerType) => c.toInt - case (c: Int, IntegerType) => c - case (c: Long, IntegerType) => c.toInt + case IntegerType => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c.toInt + case c: Short => c.toInt + case c: Int => c + case c: Long => c.toInt + } - case (c: Byte, LongType) => c.toLong - case (c: Short, LongType) => c.toLong - case (c: Int, LongType) => c.toLong - case (c: Long, LongType) => c + case LongType => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c.toLong + case c: Short => c.toLong + case c: Int => c.toLong + case c: Long => c + } - case (c: Float, FloatType) => c - case (c: Double, FloatType) => c.toFloat + case FloatType => (obj: Any) => nullSafeConvert(obj) { + case c: Float => c + case c: Double => c.toFloat + } - case (c: Float, DoubleType) => c.toDouble - case (c: Double, DoubleType) => c + case DoubleType => (obj: Any) => nullSafeConvert(obj) { + case c: Float => c.toDouble + case c: Double => c + } - case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, dt.precision, dt.scale) + case dt: DecimalType => (obj: Any) => nullSafeConvert(obj) { + case c: java.math.BigDecimal => Decimal(c, dt.precision, dt.scale) + } - case (c: Int, DateType) => c + case DateType => (obj: Any) => nullSafeConvert(obj) { + case c: Int => c + } - case (c: Long, TimestampType) => c - // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs - case (c: Int, TimestampType) => c.toLong + case TimestampType => (obj: Any) => nullSafeConvert(obj) { + case c: Long => c + // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs + case c: Int => c.toLong + } - case (c, StringType) => UTF8String.fromString(c.toString) + case StringType => (obj: Any) => nullSafeConvert(obj) { + case _ => UTF8String.fromString(obj.toString) + } - case (c: String, BinaryType) => c.getBytes(StandardCharsets.UTF_8) - case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c + case BinaryType => (obj: Any) => nullSafeConvert(obj) { + case c: String => c.getBytes(StandardCharsets.UTF_8) + case c if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c + } - case (c: java.util.List[_], ArrayType(elementType, _)) => - new GenericArrayData(c.asScala.map { e => fromJava(e, elementType)}.toArray) + case ArrayType(elementType, _) => + val elementFromJava = makeFromJava(elementType) - case (c, ArrayType(elementType, _)) if c.getClass.isArray => - new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) + (obj: Any) => nullSafeConvert(obj) { + case c: java.util.List[_] => + new GenericArrayData(c.asScala.map { e => elementFromJava(e) }.toArray) + case c if c.getClass.isArray => + new GenericArrayData(c.asInstanceOf[Array[_]].map(e => elementFromJava(e))) + } - case (javaMap: java.util.Map[_, _], MapType(keyType, valueType, _)) => - ArrayBasedMapData( - javaMap, - (key: Any) => fromJava(key, keyType), - (value: Any) => fromJava(value, valueType)) + case MapType(keyType, valueType, _) => + val keyFromJava = makeFromJava(keyType) + val valueFromJava = makeFromJava(valueType) + + (obj: Any) => nullSafeConvert(obj) { + case javaMap: java.util.Map[_, _] => + ArrayBasedMapData( + javaMap, + (key: Any) => keyFromJava(key), + (value: Any) => valueFromJava(value)) + } - case (c, StructType(fields)) if c.getClass.isArray => - val array = c.asInstanceOf[Array[_]] - if (array.length != fields.length) { - throw new IllegalStateException( - s"Input row doesn't have expected number of values required by the schema. " + - s"${fields.length} fields are required while ${array.length} values are provided." - ) + case StructType(fields) => + val fieldsFromJava = fields.map(f => makeFromJava(f.dataType)).toArray + + (obj: Any) => nullSafeConvert(obj) { + case c if c.getClass.isArray => + val array = c.asInstanceOf[Array[_]] + if (array.length != fields.length) { + throw new IllegalStateException( + s"Input row doesn't have expected number of values required by the schema. " + + s"${fields.length} fields are required while ${array.length} values are provided." + ) + } + + val row = new GenericInternalRow(fields.length) + var i = 0 + while (i < fields.length) { + row(i) = fieldsFromJava(i)(array(i)) + i += 1 + } + row } - new GenericInternalRow(array.zip(fields).map { - case (e, f) => fromJava(e, f.dataType) - }) - case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType) + case udt: UserDefinedType[_] => makeFromJava(udt.sqlType) + + case other => (obj: Any) => nullSafeConvert(other)(PartialFunction.empty) + } - // all other unexpected type should be null, or we will have runtime exception - // TODO(davies): we could improve this by try to cast the object to expected type - case (c, _) => null + private def nullSafeConvert(input: Any)(f: PartialFunction[Any, Any]): Any = { + if (input == null) { + null + } else { + f.applyOrElse(input, { + // all other unexpected type should be null, or we will have runtime exception + // TODO(davies): we could improve this by try to cast the object to expected type + _: Any => null + }) + } } private val module = "pyspark.sql.types" From 2c73d2a948bdde798aaf0f87c18846281deb05fd Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 8 Jan 2018 16:04:03 +0800 Subject: [PATCH 038/774] [SPARK-22983] Don't push filters beneath aggregates with empty grouping expressions ## What changes were proposed in this pull request? The following SQL query should return zero rows, but in Spark it actually returns one row: ``` SELECT 1 from ( SELECT 1 AS z, MIN(a.x) FROM (select 1 as x) a WHERE false ) b where b.z != b.z ``` The problem stems from the `PushDownPredicate` rule: when this rule encounters a filter on top of an Aggregate operator, e.g. `Filter(Agg(...))`, it removes the original filter and adds a new filter onto Aggregate's child, e.g. `Agg(Filter(...))`. This is sometimes okay, but the case above is a counterexample: because there is no explicit `GROUP BY`, we are implicitly computing a global aggregate over the entire table so the original filter was not acting like a `HAVING` clause filtering the number of groups: if we push this filter then it fails to actually reduce the cardinality of the Aggregate output, leading to the wrong answer. In 2016 I fixed a similar problem involving invalid pushdowns of data-independent filters (filters which reference no columns of the filtered relation). There was additional discussion after my fix was merged which pointed out that my patch was an incomplete fix (see #15289), but it looks I must have either misunderstood the comment or forgot to follow up on the additional points raised there. This patch fixes the problem by choosing to never push down filters in cases where there are no grouping expressions. Since there are no grouping keys, the only columns are aggregate columns and we can't push filters defined over aggregate results, so this change won't cause us to miss out on any legitimate pushdown opportunities. ## How was this patch tested? New regression tests in `SQLQueryTestSuite` and `FilterPushdownSuite`. Author: Josh Rosen Closes #20180 from JoshRosen/SPARK-22983-dont-push-filters-beneath-aggs-with-empty-grouping-expressions. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 3 ++- .../catalyst/optimizer/FilterPushdownSuite.scala | 13 +++++++++++++ .../test/resources/sql-tests/inputs/group-by.sql | 9 +++++++++ .../resources/sql-tests/results/group-by.sql.out | 16 +++++++++++++++- 4 files changed, 39 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0d4b02c6e7d8a..df0af8264a329 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -795,7 +795,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) case filter @ Filter(condition, aggregate: Aggregate) - if aggregate.aggregateExpressions.forall(_.deterministic) => + if aggregate.aggregateExpressions.forall(_.deterministic) + && aggregate.groupingExpressions.nonEmpty => // Find all the aliased expressions in the aggregate list that don't include any actual // AggregateExpression, and create a map from the alias to the expression val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 85a5e979f6021..82a10254d846d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -809,6 +809,19 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("aggregate: don't push filters if the aggregate has no grouping expressions") { + val originalQuery = LocalRelation.apply(testRelation.output, Seq.empty) + .select('a, 'b) + .groupBy()(count(1)) + .where(false) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = originalQuery.analyze + + comparePlans(optimized, correctAnswer) + } + test("broadcast hint") { val originalQuery = ResolvedHint(testRelation) .where('a === 2L && 'b + Rand(10).as("rnd") === 3) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 1e1384549a410..c5070b734d521 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -60,3 +60,12 @@ SELECT a, COUNT(1) FROM testData WHERE false GROUP BY a; -- Aggregate with empty input and empty GroupBy expressions. SELECT COUNT(1) FROM testData WHERE false; SELECT 1 FROM (SELECT COUNT(1) FROM testData WHERE false) t; + +-- Aggregate with empty GroupBy expressions and filter on top +SELECT 1 from ( + SELECT 1 AS z, + MIN(a.x) + FROM (select 1 as x) a + WHERE false +) b +where b.z != b.z diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 986bb01c13fe4..c1abc6dff754b 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 25 +-- Number of queries: 26 -- !query 0 @@ -227,3 +227,17 @@ SELECT 1 FROM (SELECT COUNT(1) FROM testData WHERE false) t struct<1:int> -- !query 24 output 1 + + +-- !query 25 +SELECT 1 from ( + SELECT 1 AS z, + MIN(a.x) + FROM (select 1 as x) a + WHERE false +) b +where b.z != b.z +-- !query 25 schema +struct<1:int> +-- !query 25 output + From eb45b52e826ea9cea48629760db35ef87f91fea0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 8 Jan 2018 19:41:41 +0800 Subject: [PATCH 039/774] [SPARK-21865][SQL] simplify the distribution semantic of Spark SQL ## What changes were proposed in this pull request? **The current shuffle planning logic** 1. Each operator specifies the distribution requirements for its children, via the `Distribution` interface. 2. Each operator specifies its output partitioning, via the `Partitioning` interface. 3. `Partitioning.satisfy` determines whether a `Partitioning` can satisfy a `Distribution`. 4. For each operator, check each child of it, add a shuffle node above the child if the child partitioning can not satisfy the required distribution. 5. For each operator, check if its children's output partitionings are compatible with each other, via the `Partitioning.compatibleWith`. 6. If the check in 5 failed, add a shuffle above each child. 7. try to eliminate the shuffles added in 6, via `Partitioning.guarantees`. This design has a major problem with the definition of "compatible". `Partitioning.compatibleWith` is not well defined, ideally a `Partitioning` can't know if it's compatible with other `Partitioning`, without more information from the operator. For example, `t1 join t2 on t1.a = t2.b`, `HashPartitioning(a, 10)` should be compatible with `HashPartitioning(b, 10)` under this case, but the partitioning itself doesn't know it. As a result, currently `Partitioning.compatibleWith` always return false except for literals, which make it almost useless. This also means, if an operator has distribution requirements for multiple children, Spark always add shuffle nodes to all the children(although some of them can be eliminated). However, there is no guarantee that the children's output partitionings are compatible with each other after adding these shuffles, we just assume that the operator will only specify `ClusteredDistribution` for multiple children. I think it's very hard to guarantee children co-partition for all kinds of operators, and we can not even give a clear definition about co-partition between distributions like `ClusteredDistribution(a,b)` and `ClusteredDistribution(c)`. I think we should drop the "compatible" concept in the distribution model, and let the operator achieve the co-partition requirement by special distribution requirements. **Proposed shuffle planning logic after this PR** (The first 4 are same as before) 1. Each operator specifies the distribution requirements for its children, via the `Distribution` interface. 2. Each operator specifies its output partitioning, via the `Partitioning` interface. 3. `Partitioning.satisfy` determines whether a `Partitioning` can satisfy a `Distribution`. 4. For each operator, check each child of it, add a shuffle node above the child if the child partitioning can not satisfy the required distribution. 5. For each operator, check if its children's output partitionings have the same number of partitions. 6. If the check in 5 failed, pick the max number of partitions from children's output partitionings, and add shuffle to child whose number of partitions doesn't equal to the max one. The new distribution model is very simple, we only have one kind of relationship, which is `Partitioning.satisfy`. For multiple children, Spark only guarantees they have the same number of partitions, and it's the operator's responsibility to leverage this guarantee to achieve more complicated requirements. For example, non-broadcast joins can use the newly added `HashPartitionedDistribution` to achieve co-partition. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #19080 from cloud-fan/exchange. --- .../plans/physical/partitioning.scala | 286 +++++++----------- .../sql/catalyst/PartitioningSuite.scala | 55 ---- .../spark/sql/execution/SparkPlan.scala | 16 +- .../exchange/EnsureRequirements.scala | 120 +++----- .../joins/ShuffledHashJoinExec.scala | 2 +- .../execution/joins/SortMergeJoinExec.scala | 2 +- .../apache/spark/sql/execution/objects.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 81 ++--- 8 files changed, 194 insertions(+), 370 deletions(-) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index e57c842ce2a36..0189bd73c56bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -30,18 +30,43 @@ import org.apache.spark.sql.types.{DataType, IntegerType} * - Intra-partition ordering of data: In this case the distribution describes guarantees made * about how tuples are distributed within a single partition. */ -sealed trait Distribution +sealed trait Distribution { + /** + * The required number of partitions for this distribution. If it's None, then any number of + * partitions is allowed for this distribution. + */ + def requiredNumPartitions: Option[Int] + + /** + * Creates a default partitioning for this distribution, which can satisfy this distribution while + * matching the given number of partitions. + */ + def createPartitioning(numPartitions: Int): Partitioning +} /** * Represents a distribution where no promises are made about co-location of data. */ -case object UnspecifiedDistribution extends Distribution +case object UnspecifiedDistribution extends Distribution { + override def requiredNumPartitions: Option[Int] = None + + override def createPartitioning(numPartitions: Int): Partitioning = { + throw new IllegalStateException("UnspecifiedDistribution does not have default partitioning.") + } +} /** * Represents a distribution that only has a single partition and all tuples of the dataset * are co-located. */ -case object AllTuples extends Distribution +case object AllTuples extends Distribution { + override def requiredNumPartitions: Option[Int] = Some(1) + + override def createPartitioning(numPartitions: Int): Partitioning = { + assert(numPartitions == 1, "The default partitioning of AllTuples can only have 1 partition.") + SinglePartition + } +} /** * Represents data where tuples that share the same values for the `clustering` @@ -51,12 +76,41 @@ case object AllTuples extends Distribution */ case class ClusteredDistribution( clustering: Seq[Expression], - numPartitions: Option[Int] = None) extends Distribution { + requiredNumPartitions: Option[Int] = None) extends Distribution { require( clustering != Nil, "The clustering expressions of a ClusteredDistribution should not be Nil. " + "An AllTuples should be used to represent a distribution that only has " + "a single partition.") + + override def createPartitioning(numPartitions: Int): Partitioning = { + assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions, + s"This ClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " + + s"the actual number of partitions is $numPartitions.") + HashPartitioning(clustering, numPartitions) + } +} + +/** + * Represents data where tuples have been clustered according to the hash of the given + * `expressions`. The hash function is defined as `HashPartitioning.partitionIdExpression`, so only + * [[HashPartitioning]] can satisfy this distribution. + * + * This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the + * number of partitions, this distribution strictly requires which partition the tuple should be in. + */ +case class HashClusteredDistribution(expressions: Seq[Expression]) extends Distribution { + require( + expressions != Nil, + "The expressions for hash of a HashPartitionedDistribution should not be Nil. " + + "An AllTuples should be used to represent a distribution that only has " + + "a single partition.") + + override def requiredNumPartitions: Option[Int] = None + + override def createPartitioning(numPartitions: Int): Partitioning = { + HashPartitioning(expressions, numPartitions) + } } /** @@ -73,46 +127,31 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { "An AllTuples should be used to represent a distribution that only has " + "a single partition.") - // TODO: This is not really valid... - def clustering: Set[Expression] = ordering.map(_.child).toSet + override def requiredNumPartitions: Option[Int] = None + + override def createPartitioning(numPartitions: Int): Partitioning = { + RangePartitioning(ordering, numPartitions) + } } /** * Represents data where tuples are broadcasted to every node. It is quite common that the * entire set of tuples is transformed into different data structure. */ -case class BroadcastDistribution(mode: BroadcastMode) extends Distribution +case class BroadcastDistribution(mode: BroadcastMode) extends Distribution { + override def requiredNumPartitions: Option[Int] = Some(1) + + override def createPartitioning(numPartitions: Int): Partitioning = { + assert(numPartitions == 1, + "The default partitioning of BroadcastDistribution can only have 1 partition.") + BroadcastPartitioning(mode) + } +} /** - * Describes how an operator's output is split across partitions. The `compatibleWith`, - * `guarantees`, and `satisfies` methods describe relationships between child partitionings, - * target partitionings, and [[Distribution]]s. These relations are described more precisely in - * their individual method docs, but at a high level: - * - * - `satisfies` is a relationship between partitionings and distributions. - * - `compatibleWith` is relationships between an operator's child output partitionings. - * - `guarantees` is a relationship between a child's existing output partitioning and a target - * output partitioning. - * - * Diagrammatically: - * - * +--------------+ - * | Distribution | - * +--------------+ - * ^ - * | - * satisfies - * | - * +--------------+ +--------------+ - * | Child | | Target | - * +----| Partitioning |----guarantees--->| Partitioning | - * | +--------------+ +--------------+ - * | ^ - * | | - * | compatibleWith - * | | - * +------------+ - * + * Describes how an operator's output is split across partitions. It has 2 major properties: + * 1. number of partitions. + * 2. if it can satisfy a given distribution. */ sealed trait Partitioning { /** Returns the number of partitions that the data is split across */ @@ -123,113 +162,35 @@ sealed trait Partitioning { * to satisfy the partitioning scheme mandated by the `required` [[Distribution]], * i.e. the current dataset does not need to be re-partitioned for the `required` * Distribution (it is possible that tuples within a partition need to be reorganized). - */ - def satisfies(required: Distribution): Boolean - - /** - * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] - * guarantees the same partitioning scheme described by `other`. - * - * Compatibility of partitionings is only checked for operators that have multiple children - * and that require a specific child output [[Distribution]], such as joins. - * - * Intuitively, partitionings are compatible if they route the same partitioning key to the same - * partition. For instance, two hash partitionings are only compatible if they produce the same - * number of output partitionings and hash records according to the same hash function and - * same partitioning key schema. - * - * Put another way, two partitionings are compatible with each other if they satisfy all of the - * same distribution guarantees. - */ - def compatibleWith(other: Partitioning): Boolean - - /** - * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] guarantees - * the same partitioning scheme described by `other`. If a `A.guarantees(B)`, then repartitioning - * the child's output according to `B` will be unnecessary. `guarantees` is used as a performance - * optimization to allow the exchange planner to avoid redundant repartitionings. By default, - * a partitioning only guarantees partitionings that are equal to itself (i.e. the same number - * of partitions, same strategy (range or hash), etc). - * - * In order to enable more aggressive optimization, this strict equality check can be relaxed. - * For example, say that the planner needs to repartition all of an operator's children so that - * they satisfy the [[AllTuples]] distribution. One way to do this is to repartition all children - * to have the [[SinglePartition]] partitioning. If one of the operator's children already happens - * to be hash-partitioned with a single partition then we do not need to re-shuffle this child; - * this repartitioning can be avoided if a single-partition [[HashPartitioning]] `guarantees` - * [[SinglePartition]]. - * - * The SinglePartition example given above is not particularly interesting; guarantees' real - * value occurs for more advanced partitioning strategies. SPARK-7871 will introduce a notion - * of null-safe partitionings, under which partitionings can specify whether rows whose - * partitioning keys contain null values will be grouped into the same partition or whether they - * will have an unknown / random distribution. If a partitioning does not require nulls to be - * clustered then a partitioning which _does_ cluster nulls will guarantee the null clustered - * partitioning. The converse is not true, however: a partitioning which clusters nulls cannot - * be guaranteed by one which does not cluster them. Thus, in general `guarantees` is not a - * symmetric relation. * - * Another way to think about `guarantees`: if `A.guarantees(B)`, then any partitioning of rows - * produced by `A` could have also been produced by `B`. + * By default a [[Partitioning]] can satisfy [[UnspecifiedDistribution]], and [[AllTuples]] if + * the [[Partitioning]] only have one partition. Implementations can overwrite this method with + * special logic. */ - def guarantees(other: Partitioning): Boolean = this == other -} - -object Partitioning { - def allCompatible(partitionings: Seq[Partitioning]): Boolean = { - // Note: this assumes transitivity - partitionings.sliding(2).map { - case Seq(a) => true - case Seq(a, b) => - if (a.numPartitions != b.numPartitions) { - assert(!a.compatibleWith(b) && !b.compatibleWith(a)) - false - } else { - a.compatibleWith(b) && b.compatibleWith(a) - } - }.forall(_ == true) - } -} - -case class UnknownPartitioning(numPartitions: Int) extends Partitioning { - override def satisfies(required: Distribution): Boolean = required match { + def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true + case AllTuples => numPartitions == 1 case _ => false } - - override def compatibleWith(other: Partitioning): Boolean = false - - override def guarantees(other: Partitioning): Boolean = false } +case class UnknownPartitioning(numPartitions: Int) extends Partitioning + /** * Represents a partitioning where rows are distributed evenly across output partitions * by starting from a random target partition number and distributing rows in a round-robin * fashion. This partitioning is used when implementing the DataFrame.repartition() operator. */ -case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning { - override def satisfies(required: Distribution): Boolean = required match { - case UnspecifiedDistribution => true - case _ => false - } - - override def compatibleWith(other: Partitioning): Boolean = false - - override def guarantees(other: Partitioning): Boolean = false -} +case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning case object SinglePartition extends Partitioning { val numPartitions = 1 override def satisfies(required: Distribution): Boolean = required match { case _: BroadcastDistribution => false - case ClusteredDistribution(_, desiredPartitions) => desiredPartitions.forall(_ == 1) + case ClusteredDistribution(_, Some(requiredNumPartitions)) => requiredNumPartitions == 1 case _ => true } - - override def compatibleWith(other: Partitioning): Boolean = other.numPartitions == 1 - - override def guarantees(other: Partitioning): Boolean = other.numPartitions == 1 } /** @@ -244,22 +205,19 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = required match { - case UnspecifiedDistribution => true - case ClusteredDistribution(requiredClustering, desiredPartitions) => - expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && - desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true - case _ => false - } - - override def compatibleWith(other: Partitioning): Boolean = other match { - case o: HashPartitioning => this.semanticEquals(o) - case _ => false - } - - override def guarantees(other: Partitioning): Boolean = other match { - case o: HashPartitioning => this.semanticEquals(o) - case _ => false + override def satisfies(required: Distribution): Boolean = { + super.satisfies(required) || { + required match { + case h: HashClusteredDistribution => + expressions.length == h.expressions.length && expressions.zip(h.expressions).forall { + case (l, r) => l.semanticEquals(r) + } + case ClusteredDistribution(requiredClustering, requiredNumPartitions) => + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && + (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) + case _ => false + } + } } /** @@ -288,25 +246,18 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = required match { - case UnspecifiedDistribution => true - case OrderedDistribution(requiredOrdering) => - val minSize = Seq(requiredOrdering.size, ordering.size).min - requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering, desiredPartitions) => - ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && - desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true - case _ => false - } - - override def compatibleWith(other: Partitioning): Boolean = other match { - case o: RangePartitioning => this.semanticEquals(o) - case _ => false - } - - override def guarantees(other: Partitioning): Boolean = other match { - case o: RangePartitioning => this.semanticEquals(o) - case _ => false + override def satisfies(required: Distribution): Boolean = { + super.satisfies(required) || { + required match { + case OrderedDistribution(requiredOrdering) => + val minSize = Seq(requiredOrdering.size, ordering.size).min + requiredOrdering.take(minSize) == ordering.take(minSize) + case ClusteredDistribution(requiredClustering, requiredNumPartitions) => + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && + (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) + case _ => false + } + } } } @@ -347,20 +298,6 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) override def satisfies(required: Distribution): Boolean = partitionings.exists(_.satisfies(required)) - /** - * Returns true if any `partitioning` of this collection is compatible with - * the given [[Partitioning]]. - */ - override def compatibleWith(other: Partitioning): Boolean = - partitionings.exists(_.compatibleWith(other)) - - /** - * Returns true if any `partitioning` of this collection guarantees - * the given [[Partitioning]]. - */ - override def guarantees(other: Partitioning): Boolean = - partitionings.exists(_.guarantees(other)) - override def toString: String = { partitionings.map(_.toString).mkString("(", " or ", ")") } @@ -377,9 +314,4 @@ case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning { case BroadcastDistribution(m) if m == mode => true case _ => false } - - override def compatibleWith(other: Partitioning): Boolean = other match { - case BroadcastPartitioning(m) if m == mode => true - case _ => false - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala deleted file mode 100644 index 5b802ccc637dd..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal} -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning} - -class PartitioningSuite extends SparkFunSuite { - test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") { - val expressions = Seq(Literal(2), Literal(3)) - // Consider two HashPartitionings that have the same _set_ of hash expressions but which are - // created with different orderings of those expressions: - val partitioningA = HashPartitioning(expressions, 100) - val partitioningB = HashPartitioning(expressions.reverse, 100) - // These partitionings are not considered equal: - assert(partitioningA != partitioningB) - // However, they both satisfy the same clustered distribution: - val distribution = ClusteredDistribution(expressions) - assert(partitioningA.satisfies(distribution)) - assert(partitioningB.satisfies(distribution)) - // These partitionings compute different hashcodes for the same input row: - def computeHashCode(partitioning: HashPartitioning): Int = { - val hashExprProj = new InterpretedMutableProjection(partitioning.expressions, Seq.empty) - hashExprProj.apply(InternalRow.empty).hashCode() - } - assert(computeHashCode(partitioningA) != computeHashCode(partitioningB)) - // Thus, these partitionings are incompatible: - assert(!partitioningA.compatibleWith(partitioningB)) - assert(!partitioningB.compatibleWith(partitioningA)) - assert(!partitioningA.guarantees(partitioningB)) - assert(!partitioningB.guarantees(partitioningA)) - - // Just to be sure that we haven't cheated by having these methods always return false, - // check that identical partitionings are still compatible with and guarantee each other: - assert(partitioningA === partitioningA) - assert(partitioningA.guarantees(partitioningA)) - assert(partitioningA.compatibleWith(partitioningA)) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 787c1cfbfb3d8..82300efc01632 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -94,7 +94,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Specifies how data is partitioned across different nodes in the cluster. */ def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH! - /** Specifies any partition requirements on the input data for this operator. */ + /** + * Specifies the data distribution requirements of all the children for this operator. By default + * it's [[UnspecifiedDistribution]] for each child, which means each child can have any + * distribution. + * + * If an operator overwrites this method, and specifies distribution requirements(excluding + * [[UnspecifiedDistribution]] and [[BroadcastDistribution]]) for more than one child, Spark + * guarantees that the outputs of these children will have same number of partitions, so that the + * operator can safely zip partitions of these children's result RDDs. Some operators can leverage + * this guarantee to satisfy some interesting requirement, e.g., non-broadcast joins can specify + * HashClusteredDistribution(a,b) for its left child, and specify HashClusteredDistribution(c,d) + * for its right child, then it's guaranteed that left and right child are co-partitioned by + * a,b/c,d, which means tuples of same value are in the partitions of same index, e.g., + * (a=1,b=2) and (c=1,d=2) are both in the second partition of left and right child. + */ def requiredChildDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index c8e236be28b42..e3d28388c5470 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -46,23 +46,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None } - /** - * Given a required distribution, returns a partitioning that satisfies that distribution. - * @param requiredDistribution The distribution that is required by the operator - * @param numPartitions Used when the distribution doesn't require a specific number of partitions - */ - private def createPartitioning( - requiredDistribution: Distribution, - numPartitions: Int): Partitioning = { - requiredDistribution match { - case AllTuples => SinglePartition - case ClusteredDistribution(clustering, desiredPartitions) => - HashPartitioning(clustering, desiredPartitions.getOrElse(numPartitions)) - case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) - case dist => sys.error(s"Do not know how to satisfy distribution $dist") - } - } - /** * Adds [[ExchangeCoordinator]] to [[ShuffleExchangeExec]]s if adaptive query execution is enabled * and partitioning schemes of these [[ShuffleExchangeExec]]s support [[ExchangeCoordinator]]. @@ -88,8 +71,9 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // shuffle data when we have more than one children because data generated by // these children may not be partitioned in the same way. // Please see the comment in withCoordinator for more details. - val supportsDistribution = - requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution]) + val supportsDistribution = requiredChildDistributions.forall { dist => + dist.isInstanceOf[ClusteredDistribution] || dist.isInstanceOf[HashClusteredDistribution] + } children.length > 1 && supportsDistribution } @@ -142,8 +126,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // // It will be great to introduce a new Partitioning to represent the post-shuffle // partitions when one post-shuffle partition includes multiple pre-shuffle partitions. - val targetPartitioning = - createPartitioning(distribution, defaultNumPreShufflePartitions) + val targetPartitioning = distribution.createPartitioning(defaultNumPreShufflePartitions) assert(targetPartitioning.isInstanceOf[HashPartitioning]) ShuffleExchangeExec(targetPartitioning, child, Some(coordinator)) } @@ -162,71 +145,56 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { assert(requiredChildDistributions.length == children.length) assert(requiredChildOrderings.length == children.length) - // Ensure that the operator's children satisfy their output distribution requirements: + // Ensure that the operator's children satisfy their output distribution requirements. children = children.zip(requiredChildDistributions).map { case (child, distribution) if child.outputPartitioning.satisfies(distribution) => child case (child, BroadcastDistribution(mode)) => BroadcastExchangeExec(mode, child) case (child, distribution) => - ShuffleExchangeExec(createPartitioning(distribution, defaultNumPreShufflePartitions), child) + val numPartitions = distribution.requiredNumPartitions + .getOrElse(defaultNumPreShufflePartitions) + ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child) } - // If the operator has multiple children and specifies child output distributions (e.g. join), - // then the children's output partitionings must be compatible: - def requireCompatiblePartitioning(distribution: Distribution): Boolean = distribution match { - case UnspecifiedDistribution => false - case BroadcastDistribution(_) => false + // Get the indexes of children which have specified distribution requirements and need to have + // same number of partitions. + val childrenIndexes = requiredChildDistributions.zipWithIndex.filter { + case (UnspecifiedDistribution, _) => false + case (_: BroadcastDistribution, _) => false case _ => true - } - if (children.length > 1 - && requiredChildDistributions.exists(requireCompatiblePartitioning) - && !Partitioning.allCompatible(children.map(_.outputPartitioning))) { + }.map(_._2) - // First check if the existing partitions of the children all match. This means they are - // partitioned by the same partitioning into the same number of partitions. In that case, - // don't try to make them match `defaultPartitions`, just use the existing partitioning. - val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max - val useExistingPartitioning = children.zip(requiredChildDistributions).forall { - case (child, distribution) => - child.outputPartitioning.guarantees( - createPartitioning(distribution, maxChildrenNumPartitions)) + val childrenNumPartitions = + childrenIndexes.map(children(_).outputPartitioning.numPartitions).toSet + + if (childrenNumPartitions.size > 1) { + // Get the number of partitions which is explicitly required by the distributions. + val requiredNumPartitions = { + val numPartitionsSet = childrenIndexes.flatMap { + index => requiredChildDistributions(index).requiredNumPartitions + }.toSet + assert(numPartitionsSet.size <= 1, + s"$operator have incompatible requirements of the number of partitions for its children") + numPartitionsSet.headOption } - children = if (useExistingPartitioning) { - // We do not need to shuffle any child's output. - children - } else { - // We need to shuffle at least one child's output. - // Now, we will determine the number of partitions that will be used by created - // partitioning schemes. - val numPartitions = { - // Let's see if we need to shuffle all child's outputs when we use - // maxChildrenNumPartitions. - val shufflesAllChildren = children.zip(requiredChildDistributions).forall { - case (child, distribution) => - !child.outputPartitioning.guarantees( - createPartitioning(distribution, maxChildrenNumPartitions)) - } - // If we need to shuffle all children, we use defaultNumPreShufflePartitions as the - // number of partitions. Otherwise, we use maxChildrenNumPartitions. - if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions - } + val targetNumPartitions = requiredNumPartitions.getOrElse(childrenNumPartitions.max) - children.zip(requiredChildDistributions).map { - case (child, distribution) => - val targetPartitioning = createPartitioning(distribution, numPartitions) - if (child.outputPartitioning.guarantees(targetPartitioning)) { - child - } else { - child match { - // If child is an exchange, we replace it with - // a new one having targetPartitioning. - case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(targetPartitioning, c) - case _ => ShuffleExchangeExec(targetPartitioning, child) - } + children = children.zip(requiredChildDistributions).zipWithIndex.map { + case ((child, distribution), index) if childrenIndexes.contains(index) => + if (child.outputPartitioning.numPartitions == targetNumPartitions) { + child + } else { + val defaultPartitioning = distribution.createPartitioning(targetNumPartitions) + child match { + // If child is an exchange, we replace it with a new one having defaultPartitioning. + case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(defaultPartitioning, c) + case _ => ShuffleExchangeExec(defaultPartitioning, child) + } } - } + + case ((child, _), _) => child } } @@ -324,10 +292,10 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case operator @ ShuffleExchangeExec(partitioning, child, _) => - child.children match { - case ShuffleExchangeExec(childPartitioning, baseChild, _)::Nil => - if (childPartitioning.guarantees(partitioning)) child else operator + // TODO: remove this after we create a physical operator for `RepartitionByExpression`. + case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) => + child.outputPartitioning match { + case lower: HashPartitioning if upper.semanticEquals(lower) => child case _ => operator } case operator: SparkPlan => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 66e8031bb5191..897a4dae39f32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -46,7 +46,7 @@ case class ShuffledHashJoinExec( "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe")) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { val buildDataSize = longMetric("buildDataSize") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 94405410cce90..2de2f30eb05d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -78,7 +78,7 @@ case class SortMergeJoinExec( } override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil override def outputOrdering: Seq[SortOrder] = joinType match { // For inner join, orders of both sides keys should be kept. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index d1bd8a7076863..03d1bbf2ab882 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -456,7 +456,7 @@ case class CoGroupExec( right: SparkPlan) extends BinaryExecNode with ObjectProducerExec { override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil + HashClusteredDistribution(leftGroup) :: HashClusteredDistribution(rightGroup) :: Nil override def requiredChildOrdering: Seq[Seq[SortOrder]] = leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index b50642d275ba8..f8b26f5b28cc7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -260,11 +260,16 @@ class PlannerSuite extends SharedSQLContext { // do they satisfy the distribution requirements? As a result, we need at least four test cases. private def assertDistributionRequirementsAreSatisfied(outputPlan: SparkPlan): Unit = { - if (outputPlan.children.length > 1 - && outputPlan.requiredChildDistribution.toSet != Set(UnspecifiedDistribution)) { - val childPartitionings = outputPlan.children.map(_.outputPartitioning) - if (!Partitioning.allCompatible(childPartitionings)) { - fail(s"Partitionings are not compatible: $childPartitionings") + if (outputPlan.children.length > 1) { + val childPartitionings = outputPlan.children.zip(outputPlan.requiredChildDistribution) + .filter { + case (_, UnspecifiedDistribution) => false + case (_, _: BroadcastDistribution) => false + case _ => true + }.map(_._1.outputPartitioning) + + if (childPartitionings.map(_.numPartitions).toSet.size > 1) { + fail(s"Partitionings doesn't have same number of partitions: $childPartitionings") } } outputPlan.children.zip(outputPlan.requiredChildDistribution).foreach { @@ -274,40 +279,7 @@ class PlannerSuite extends SharedSQLContext { } } - test("EnsureRequirements with incompatible child partitionings which satisfy distribution") { - // Consider an operator that requires inputs that are clustered by two expressions (e.g. - // sort merge join where there are multiple columns in the equi-join condition) - val clusteringA = Literal(1) :: Nil - val clusteringB = Literal(2) :: Nil - val distribution = ClusteredDistribution(clusteringA ++ clusteringB) - // Say that the left and right inputs are each partitioned by _one_ of the two join columns: - val leftPartitioning = HashPartitioning(clusteringA, 1) - val rightPartitioning = HashPartitioning(clusteringB, 1) - // Individually, each input's partitioning satisfies the clustering distribution: - assert(leftPartitioning.satisfies(distribution)) - assert(rightPartitioning.satisfies(distribution)) - // However, these partitionings are not compatible with each other, so we still need to - // repartition both inputs prior to performing the join: - assert(!leftPartitioning.compatibleWith(rightPartitioning)) - assert(!rightPartitioning.compatibleWith(leftPartitioning)) - val inputPlan = DummySparkPlan( - children = Seq( - DummySparkPlan(outputPartitioning = leftPartitioning), - DummySparkPlan(outputPartitioning = rightPartitioning) - ), - requiredChildDistribution = Seq(distribution, distribution), - requiredChildOrdering = Seq(Seq.empty, Seq.empty) - ) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) - assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchangeExec => true }.isEmpty) { - fail(s"Exchange should have been added:\n$outputPlan") - } - } - test("EnsureRequirements with child partitionings with different numbers of output partitions") { - // This is similar to the previous test, except it checks that partitionings are not compatible - // unless they produce the same number of partitions. val clustering = Literal(1) :: Nil val distribution = ClusteredDistribution(clustering) val inputPlan = DummySparkPlan( @@ -386,18 +358,15 @@ class PlannerSuite extends SharedSQLContext { } } - test("EnsureRequirements eliminates Exchange if child has Exchange with same partitioning") { + test("EnsureRequirements eliminates Exchange if child has same partitioning") { val distribution = ClusteredDistribution(Literal(1) :: Nil) - val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) - val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) - assert(!childPartitioning.satisfies(distribution)) - val inputPlan = ShuffleExchangeExec(finalPartitioning, - DummySparkPlan( - children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, - requiredChildDistribution = Seq(distribution), - requiredChildOrdering = Seq(Seq.empty)), - None) + val partitioning = HashPartitioning(Literal(1) :: Nil, 5) + assert(partitioning.satisfies(distribution)) + val inputPlan = ShuffleExchangeExec( + partitioning, + DummySparkPlan(outputPartitioning = partitioning), + None) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 2) { @@ -407,17 +376,13 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements does not eliminate Exchange with different partitioning") { val distribution = ClusteredDistribution(Literal(1) :: Nil) - // Number of partitions differ - val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 8) - val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) - assert(!childPartitioning.satisfies(distribution)) - val inputPlan = ShuffleExchangeExec(finalPartitioning, - DummySparkPlan( - children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, - requiredChildDistribution = Seq(distribution), - requiredChildOrdering = Seq(Seq.empty)), - None) + val partitioning = HashPartitioning(Literal(2) :: Nil, 5) + assert(!partitioning.satisfies(distribution)) + val inputPlan = ShuffleExchangeExec( + partitioning, + DummySparkPlan(outputPartitioning = partitioning), + None) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) { From 40b983c3b44b6771f07302ce87987fa4716b5ebf Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Mon, 8 Jan 2018 23:49:07 +0800 Subject: [PATCH 040/774] [SPARK-22952][CORE] Deprecate stageAttemptId in favour of stageAttemptNumber ## What changes were proposed in this pull request? 1. Deprecate attemptId in StageInfo and add `def attemptNumber() = attemptId` 2. Replace usage of stageAttemptId with stageAttemptNumber ## How was this patch tested? I manually checked the compiler warning info Author: Xianjin YE Closes #20178 from advancedxy/SPARK-22952. --- .../apache/spark/scheduler/DAGScheduler.scala | 15 +++--- .../apache/spark/scheduler/StageInfo.scala | 4 +- .../spark/scheduler/StatsReportListener.scala | 2 +- .../spark/status/AppStatusListener.scala | 7 +-- .../org/apache/spark/status/LiveEntity.scala | 4 +- .../spark/ui/scope/RDDOperationGraph.scala | 2 +- .../org/apache/spark/util/JsonProtocol.scala | 2 +- .../spark/status/AppStatusListenerSuite.scala | 54 ++++++++++--------- .../execution/ui/SQLAppStatusListener.scala | 2 +- 9 files changed, 51 insertions(+), 41 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c2498d4808e91..199937b8c27af 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -815,7 +815,8 @@ class DAGScheduler( private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo) { // Note that there is a chance that this task is launched after the stage is cancelled. // In that case, we wouldn't have the stage anymore in stageIdToStage. - val stageAttemptId = stageIdToStage.get(task.stageId).map(_.latestInfo.attemptId).getOrElse(-1) + val stageAttemptId = + stageIdToStage.get(task.stageId).map(_.latestInfo.attemptNumber).getOrElse(-1) listenerBus.post(SparkListenerTaskStart(task.stageId, stageAttemptId, taskInfo)) } @@ -1050,7 +1051,7 @@ class DAGScheduler( val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) stage.pendingPartitions += id - new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, + new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId), Option(sc.applicationId), sc.applicationAttemptId) } @@ -1060,7 +1061,7 @@ class DAGScheduler( val p: Int = stage.partitions(id) val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) - new ResultTask(stage.id, stage.latestInfo.attemptId, + new ResultTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, id, properties, serializedTaskMetrics, Option(jobId), Option(sc.applicationId), sc.applicationAttemptId) } @@ -1076,7 +1077,7 @@ class DAGScheduler( logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " + s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})") taskScheduler.submitTasks(new TaskSet( - tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties)) + tasks.toArray, stage.id, stage.latestInfo.attemptNumber, jobId, properties)) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark // the stage as completed here in case there are no tasks to run @@ -1245,7 +1246,7 @@ class DAGScheduler( val status = event.result.asInstanceOf[MapStatus] val execId = status.location.executorId logDebug("ShuffleMapTask finished on " + execId) - if (stageIdToStage(task.stageId).latestInfo.attemptId == task.stageAttemptId) { + if (stageIdToStage(task.stageId).latestInfo.attemptNumber == task.stageAttemptId) { // This task was for the currently running attempt of the stage. Since the task // completed successfully from the perspective of the TaskSetManager, mark it as // no longer pending (the TaskSetManager may consider the task complete even @@ -1324,10 +1325,10 @@ class DAGScheduler( val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleIdToMapStage(shuffleId) - if (failedStage.latestInfo.attemptId != task.stageAttemptId) { + if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) { logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" + s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + - s"(attempt ID ${failedStage.latestInfo.attemptId}) running") + s"(attempt ${failedStage.latestInfo.attemptNumber}) running") } else { // It is likely that we receive multiple FetchFailed for a single stage (because we have // multiple tasks running concurrently on different executors). In that case, it is diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index c513ed36d1680..903e25b7986f2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -30,7 +30,7 @@ import org.apache.spark.storage.RDDInfo @DeveloperApi class StageInfo( val stageId: Int, - val attemptId: Int, + @deprecated("Use attemptNumber instead", "2.3.0") val attemptId: Int, val name: String, val numTasks: Int, val rddInfos: Seq[RDDInfo], @@ -56,6 +56,8 @@ class StageInfo( completionTime = Some(System.currentTimeMillis) } + def attemptNumber(): Int = attemptId + private[spark] def getStatusString: String = { if (completionTime.isDefined) { if (failureReason.isDefined) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala index 3c8cab7504c17..3c7af4f6146fa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala @@ -79,7 +79,7 @@ class StatsReportListener extends SparkListener with Logging { x => info.completionTime.getOrElse(System.currentTimeMillis()) - x ).getOrElse("-") - s"Stage(${info.stageId}, ${info.attemptId}); Name: '${info.name}'; " + + s"Stage(${info.stageId}, ${info.attemptNumber}); Name: '${info.name}'; " + s"Status: ${info.getStatusString}$failureReason; numTasks: ${info.numTasks}; " + s"Took: $timeTaken msec" } diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 487a782e865e8..88b75ddd5993a 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -529,7 +529,8 @@ private[spark] class AppStatusListener( } override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { - val maybeStage = Option(liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptId))) + val maybeStage = + Option(liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptNumber))) maybeStage.foreach { stage => val now = System.nanoTime() stage.info = event.stageInfo @@ -785,7 +786,7 @@ private[spark] class AppStatusListener( } private def getOrCreateStage(info: StageInfo): LiveStage = { - val stage = liveStages.computeIfAbsent((info.stageId, info.attemptId), + val stage = liveStages.computeIfAbsent((info.stageId, info.attemptNumber), new Function[(Int, Int), LiveStage]() { override def apply(key: (Int, Int)): LiveStage = new LiveStage() }) @@ -912,7 +913,7 @@ private[spark] class AppStatusListener( private def cleanupTasks(stage: LiveStage): Unit = { val countToDelete = calculateNumberToRemove(stage.savedTasks.get(), maxTasksPerStage).toInt if (countToDelete > 0) { - val stageKey = Array(stage.info.stageId, stage.info.attemptId) + val stageKey = Array(stage.info.stageId, stage.info.attemptNumber) val view = kvstore.view(classOf[TaskDataWrapper]).index("stage").first(stageKey) .last(stageKey) diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 52e83f250d34e..305c2fafa6aac 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -412,14 +412,14 @@ private class LiveStage extends LiveEntity { def executorSummary(executorId: String): LiveExecutorStageSummary = { executorSummaries.getOrElseUpdate(executorId, - new LiveExecutorStageSummary(info.stageId, info.attemptId, executorId)) + new LiveExecutorStageSummary(info.stageId, info.attemptNumber, executorId)) } def toApi(): v1.StageData = { new v1.StageData( status, info.stageId, - info.attemptId, + info.attemptNumber, info.numTasks, activeTasks, diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index 827a8637b9bd2..948858224d724 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -116,7 +116,7 @@ private[spark] object RDDOperationGraph extends Logging { // Use a special prefix here to differentiate this cluster from other operation clusters val stageClusterId = STAGE_CLUSTER_PREFIX + stage.stageId val stageClusterName = s"Stage ${stage.stageId}" + - { if (stage.attemptId == 0) "" else s" (attempt ${stage.attemptId})" } + { if (stage.attemptNumber == 0) "" else s" (attempt ${stage.attemptNumber})" } val rootCluster = new RDDOperationCluster(stageClusterId, stageClusterName) var rootNodeCount = 0 diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 5e60218c5740b..ff83301d631c4 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -263,7 +263,7 @@ private[spark] object JsonProtocol { val completionTime = stageInfo.completionTime.map(JInt(_)).getOrElse(JNothing) val failureReason = stageInfo.failureReason.map(JString(_)).getOrElse(JNothing) ("Stage ID" -> stageInfo.stageId) ~ - ("Stage Attempt ID" -> stageInfo.attemptId) ~ + ("Stage Attempt ID" -> stageInfo.attemptNumber) ~ ("Stage Name" -> stageInfo.name) ~ ("Number of Tasks" -> stageInfo.numTasks) ~ ("RDD Info" -> rddInfo) ~ diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 997c7de8dd02b..b8c84e24c2c3f 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -195,7 +195,9 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val s1Tasks = createTasks(4, execIds) s1Tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, + stages.head.attemptNumber, + task)) } assert(store.count(classOf[TaskDataWrapper]) === s1Tasks.size) @@ -213,10 +215,11 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { check[TaskDataWrapper](task.taskId) { wrapper => assert(wrapper.info.taskId === task.taskId) assert(wrapper.stageId === stages.head.stageId) - assert(wrapper.stageAttemptId === stages.head.attemptId) - assert(Arrays.equals(wrapper.stage, Array(stages.head.stageId, stages.head.attemptId))) + assert(wrapper.stageAttemptId === stages.head.attemptNumber) + assert(Arrays.equals(wrapper.stage, Array(stages.head.stageId, stages.head.attemptNumber))) - val runtime = Array[AnyRef](stages.head.stageId: JInteger, stages.head.attemptId: JInteger, + val runtime = Array[AnyRef](stages.head.stageId: JInteger, + stages.head.attemptNumber: JInteger, -1L: JLong) assert(Arrays.equals(wrapper.runtime, runtime)) @@ -237,7 +240,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { Some(1L), None, true, false, None) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate( task.executorId, - Seq((task.taskId, stages.head.stageId, stages.head.attemptId, Seq(accum))))) + Seq((task.taskId, stages.head.stageId, stages.head.attemptNumber, Seq(accum))))) } check[StageDataWrapper](key(stages.head)) { stage => @@ -254,12 +257,12 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Fail one of the tasks, re-start it. time += 1 s1Tasks.head.markFinished(TaskState.FAILED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber, "taskType", TaskResultLost, s1Tasks.head, null)) time += 1 val reattempt = newAttempt(s1Tasks.head, nextTaskId()) - listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptNumber, reattempt)) assert(store.count(classOf[TaskDataWrapper]) === s1Tasks.size + 1) @@ -289,7 +292,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val killed = s1Tasks.drop(1).head killed.finishTime = time killed.failed = true - listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber, "taskType", TaskKilled("killed"), killed, null)) check[JobDataWrapper](1) { job => @@ -311,13 +314,13 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 val denied = newAttempt(killed, nextTaskId()) val denyReason = TaskCommitDenied(1, 1, 1) - listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptNumber, denied)) time += 1 denied.finishTime = time denied.failed = true - listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber, "taskType", denyReason, denied, null)) check[JobDataWrapper](1) { job => @@ -337,7 +340,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Start a new attempt. val reattempt2 = newAttempt(denied, nextTaskId()) - listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptNumber, reattempt2)) // Succeed all tasks in stage 1. @@ -350,7 +353,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 pending.foreach { task => task.markFinished(TaskState.FINISHED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber, "taskType", Success, task, s1Metrics)) } @@ -414,13 +417,15 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 val s2Tasks = createTasks(4, execIds) s2Tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(stages.last.stageId, stages.last.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(stages.last.stageId, + stages.last.attemptNumber, + task)) } time += 1 s2Tasks.foreach { task => task.markFinished(TaskState.FAILED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stages.last.stageId, stages.last.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.last.stageId, stages.last.attemptNumber, "taskType", TaskResultLost, task, null)) } @@ -455,7 +460,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // - Re-submit stage 2, all tasks, and succeed them and the stage. val oldS2 = stages.last - val newS2 = new StageInfo(oldS2.stageId, oldS2.attemptId + 1, oldS2.name, oldS2.numTasks, + val newS2 = new StageInfo(oldS2.stageId, oldS2.attemptNumber + 1, oldS2.name, oldS2.numTasks, oldS2.rddInfos, oldS2.parentIds, oldS2.details, oldS2.taskMetrics) time += 1 @@ -466,14 +471,14 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val newS2Tasks = createTasks(4, execIds) newS2Tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(newS2.stageId, newS2.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(newS2.stageId, newS2.attemptNumber, task)) } time += 1 newS2Tasks.foreach { task => task.markFinished(TaskState.FINISHED, time) - listener.onTaskEnd(SparkListenerTaskEnd(newS2.stageId, newS2.attemptId, "taskType", Success, - task, null)) + listener.onTaskEnd(SparkListenerTaskEnd(newS2.stageId, newS2.attemptNumber, "taskType", + Success, task, null)) } time += 1 @@ -522,14 +527,15 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val j2s2Tasks = createTasks(4, execIds) j2s2Tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(j2Stages.last.stageId, j2Stages.last.attemptId, + listener.onTaskStart(SparkListenerTaskStart(j2Stages.last.stageId, + j2Stages.last.attemptNumber, task)) } time += 1 j2s2Tasks.foreach { task => task.markFinished(TaskState.FINISHED, time) - listener.onTaskEnd(SparkListenerTaskEnd(j2Stages.last.stageId, j2Stages.last.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(j2Stages.last.stageId, j2Stages.last.attemptNumber, "taskType", Success, task, null)) } @@ -919,13 +925,13 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 val tasks = createTasks(2, Array("1")) tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptNumber, task)) } assert(store.count(classOf[TaskDataWrapper]) === 2) // Start a 3rd task. The finished tasks should be deleted. createTasks(1, Array("1")).foreach { task => - listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptNumber, task)) } assert(store.count(classOf[TaskDataWrapper]) === 2) intercept[NoSuchElementException] { @@ -934,7 +940,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Start a 4th task. The first task should be deleted, even if it's still running. createTasks(1, Array("1")).foreach { task => - listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptNumber, task)) } assert(store.count(classOf[TaskDataWrapper]) === 2) intercept[NoSuchElementException] { @@ -960,7 +966,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } - private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptId) + private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptNumber) private def check[T: ClassTag](key: Any)(fn: T => Unit): Unit = { val value = store.read(classTag[T].runtimeClass, key).asInstanceOf[T] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index d8adbe7bee13e..73a105266e1c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -99,7 +99,7 @@ class SQLAppStatusListener( // Reset the metrics tracking object for the new attempt. Option(stageMetrics.get(event.stageInfo.stageId)).foreach { metrics => metrics.taskMetrics.clear() - metrics.attemptId = event.stageInfo.attemptId + metrics.attemptId = event.stageInfo.attemptNumber } } From eed82a0b211352215316ec70dc48aefc013ad0b2 Mon Sep 17 00:00:00 2001 From: foxish Date: Mon, 8 Jan 2018 13:01:45 -0800 Subject: [PATCH 041/774] [SPARK-22992][K8S] Remove assumption of the DNS domain ## What changes were proposed in this pull request? Remove the use of FQDN to access the driver because it assumes that it's set up in a DNS zone - `cluster.local` which is common but not ubiquitous Note that we already access the in-cluster API server through `kubernetes.default.svc`, so, by extension, this should work as well. The alternative is to introduce DNS zones for both of those addresses. ## How was this patch tested? Unit tests cc vanzin liyinan926 mridulm mccheah Author: foxish Closes #20187 from foxish/cluster.local. --- .../deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala | 2 +- .../k8s/submit/steps/DriverServiceBootstrapStepSuite.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala index eb594e4f16ec0..34af7cde6c1a9 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala @@ -83,7 +83,7 @@ private[spark] class DriverServiceBootstrapStep( .build() val namespace = sparkConf.get(KUBERNETES_NAMESPACE) - val driverHostname = s"${driverService.getMetadata.getName}.$namespace.svc.cluster.local" + val driverHostname = s"${driverService.getMetadata.getName}.$namespace.svc" val resolvedSparkConf = driverSpec.driverSparkConf.clone() .set(DRIVER_HOST_KEY, driverHostname) .set("spark.driver.port", driverPort.toString) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala index 006ce2668f8a0..78c8c3ba1afbd 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala @@ -85,7 +85,7 @@ class DriverServiceBootstrapStepSuite extends SparkFunSuite with BeforeAndAfter val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX - val expectedHostName = s"$expectedServiceName.my-namespace.svc.cluster.local" + val expectedHostName = s"$expectedServiceName.my-namespace.svc" verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName) } @@ -120,7 +120,7 @@ class DriverServiceBootstrapStepSuite extends SparkFunSuite with BeforeAndAfter val driverService = resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service] val expectedServiceName = s"spark-10000${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}" assert(driverService.getMetadata.getName === expectedServiceName) - val expectedHostName = s"$expectedServiceName.my-namespace.svc.cluster.local" + val expectedHostName = s"$expectedServiceName.my-namespace.svc" verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName) } From 4f7e75883436069c2d9028c4cd5daa78e8d59560 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 8 Jan 2018 13:24:08 -0800 Subject: [PATCH 042/774] [SPARK-22912] v2 data source support in MicroBatchExecution ## What changes were proposed in this pull request? Support for v2 data sources in microbatch streaming. ## How was this patch tested? A very basic new unit test on the toy v2 implementation of rate source. Once we have a v1 source fully migrated to v2, we'll need to do more detailed compatibility testing. Author: Jose Torres Closes #20097 from jose-torres/v2-impl. --- ...pache.spark.sql.sources.DataSourceRegister | 1 + .../datasources/v2/DataSourceV2Relation.scala | 10 ++ .../streaming/MicroBatchExecution.scala | 112 ++++++++++++++---- .../streaming/ProgressReporter.scala | 6 +- .../streaming/RateSourceProvider.scala | 10 +- .../execution/streaming/StreamExecution.scala | 4 +- .../streaming/StreamingRelation.scala | 4 +- .../continuous/ContinuousExecution.scala | 4 +- .../ContinuousRateStreamSource.scala | 17 +-- .../sources/RateStreamSourceV2.scala | 31 ++++- .../sql/streaming/DataStreamReader.scala | 25 +++- .../sql/streaming/StreamingQueryManager.scala | 24 ++-- .../streaming/RateSourceV2Suite.scala | 68 +++++++++-- .../spark/sql/streaming/StreamTest.scala | 2 +- .../continuous/ContinuousSuite.scala | 2 +- 15 files changed, 241 insertions(+), 79 deletions(-) diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 6cdfe2fae5642..0259c774bbf4a 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -7,3 +7,4 @@ org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider org.apache.spark.sql.execution.streaming.TextSocketSourceProvider org.apache.spark.sql.execution.streaming.RateSourceProvider +org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 7eb99a645001a..cba20dd902007 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -35,6 +35,16 @@ case class DataSourceV2Relation( } } +/** + * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical + * to the non-streaming relation. + */ +class StreamingDataSourceV2Relation( + fullOutput: Seq[AttributeReference], + reader: DataSourceV2Reader) extends DataSourceV2Relation(fullOutput, reader) { + override def isStreaming: Boolean = true +} + object DataSourceV2Relation { def apply(reader: DataSourceV2Reader): DataSourceV2Relation = { new DataSourceV2Relation(reader.readSchema().toAttributes, reader) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 9a7a13fcc5806..42240eeb58d4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.streaming +import java.util.Optional + +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} import org.apache.spark.sql.{Dataset, SparkSession} @@ -24,7 +27,10 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.sources.v2.streaming.MicroBatchReadSupport +import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.streaming.{MicroBatchReadSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} @@ -33,10 +39,11 @@ class MicroBatchExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: Sink, + sink: BaseStreamingSink, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, + extraOptions: Map[String, String], deleteCheckpointOnStop: Boolean) extends StreamExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, @@ -57,6 +64,13 @@ class MicroBatchExecution( var nextSourceId = 0L val toExecutionRelationMap = MutableMap[StreamingRelation, StreamingExecutionRelation]() val v2ToExecutionRelationMap = MutableMap[StreamingRelationV2, StreamingExecutionRelation]() + // We transform each distinct streaming relation into a StreamingExecutionRelation, keeping a + // map as we go to ensure each identical relation gets the same StreamingExecutionRelation + // object. For each microbatch, the StreamingExecutionRelation will be replaced with a logical + // plan for the data within that batch. + // Note that we have to use the previous `output` as attributes in StreamingExecutionRelation, + // since the existing logical plan has already used those attributes. The per-microbatch + // transformation is responsible for replacing attributes with their final values. val _logicalPlan = analyzedPlan.transform { case streamingRelation@StreamingRelation(dataSource, _, output) => toExecutionRelationMap.getOrElseUpdate(streamingRelation, { @@ -64,19 +78,26 @@ class MicroBatchExecution( val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" val source = dataSource.createSource(metadataPath) nextSourceId += 1 - // We still need to use the previous `output` instead of `source.schema` as attributes in - // "df.logicalPlan" has already used attributes of the previous `output`. StreamingExecutionRelation(source, output)(sparkSession) }) - case s @ StreamingRelationV2(v2DataSource, _, _, output, v1DataSource) - if !v2DataSource.isInstanceOf[MicroBatchReadSupport] => + case s @ StreamingRelationV2(source: MicroBatchReadSupport, _, options, output, _) => + v2ToExecutionRelationMap.getOrElseUpdate(s, { + // Materialize source to avoid creating it in every batch + val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" + val reader = source.createMicroBatchReader( + Optional.empty(), // user specified schema + metadataPath, + new DataSourceV2Options(options.asJava)) + nextSourceId += 1 + StreamingExecutionRelation(reader, output)(sparkSession) + }) + case s @ StreamingRelationV2(_, _, _, output, v1Relation) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val source = v1DataSource.createSource(metadataPath) + assert(v1Relation.isDefined, "v2 execution didn't match but v1 was unavailable") + val source = v1Relation.get.dataSource.createSource(metadataPath) nextSourceId += 1 - // We still need to use the previous `output` instead of `source.schema` as attributes in - // "df.logicalPlan" has already used attributes of the previous `output`. StreamingExecutionRelation(source, output)(sparkSession) }) } @@ -192,7 +213,8 @@ class MicroBatchExecution( source.getBatch(start, end) } case nonV1Tuple => - throw new IllegalStateException(s"Unexpected V2 source in $nonV1Tuple") + // The V2 API does not have the same edge case requiring getBatch to be called + // here, so we do nothing here. } currentBatchId = latestCommittedBatchId + 1 committedOffsets ++= availableOffsets @@ -236,14 +258,27 @@ class MicroBatchExecution( val hasNewData = { awaitProgressLock.lock() try { - val latestOffsets: Map[Source, Option[Offset]] = uniqueSources.map { + // Generate a map from each unique source to the next available offset. + val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map { case s: Source => updateStatusMessage(s"Getting offsets from $s") reportTimeTaken("getOffset") { (s, s.getOffset) } + case s: MicroBatchReader => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("getOffset") { + // Once v1 streaming source execution is gone, we can refactor this away. + // For now, we set the range here to get the source to infer the available end offset, + // get that offset, and then set the range again when we later execute. + s.setOffsetRange( + toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), + Optional.empty()) + + (s, Some(s.getEndOffset)) + } }.toMap - availableOffsets ++= latestOffsets.filter { case (s, o) => o.nonEmpty }.mapValues(_.get) + availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) if (dataAvailable) { true @@ -317,6 +352,8 @@ class MicroBatchExecution( if (prevBatchOff.isDefined) { prevBatchOff.get.toStreamProgress(sources).foreach { case (src: Source, off) => src.commit(off) + case (reader: MicroBatchReader, off) => + reader.commit(reader.deserializeOffset(off.json)) } } else { throw new IllegalStateException(s"batch $currentBatchId doesn't exist") @@ -357,7 +394,16 @@ class MicroBatchExecution( s"DataFrame returned by getBatch from $source did not have isStreaming=true\n" + s"${batch.queryExecution.logical}") logDebug(s"Retrieving data from $source: $current -> $available") - Some(source -> batch) + Some(source -> batch.logicalPlan) + case (reader: MicroBatchReader, available) + if committedOffsets.get(reader).map(_ != available).getOrElse(true) => + val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json)) + reader.setOffsetRange( + toJava(current), + Optional.of(available.asInstanceOf[OffsetV2])) + logDebug(s"Retrieving data from $reader: $current -> $available") + Some(reader -> + new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader)) case _ => None } } @@ -365,15 +411,14 @@ class MicroBatchExecution( // A list of attributes that will need to be updated. val replacements = new ArrayBuffer[(Attribute, Attribute)] // Replace sources in the logical plan with data that has arrived since the last batch. - val withNewSources = logicalPlan transform { + val newBatchesPlan = logicalPlan transform { case StreamingExecutionRelation(source, output) => - newData.get(source).map { data => - val newPlan = data.logicalPlan - assert(output.size == newPlan.output.size, + newData.get(source).map { dataPlan => + assert(output.size == dataPlan.output.size, s"Invalid batch: ${Utils.truncatedString(output, ",")} != " + - s"${Utils.truncatedString(newPlan.output, ",")}") - replacements ++= output.zip(newPlan.output) - newPlan + s"${Utils.truncatedString(dataPlan.output, ",")}") + replacements ++= output.zip(dataPlan.output) + dataPlan }.getOrElse { LocalRelation(output, isStreaming = true) } @@ -381,7 +426,7 @@ class MicroBatchExecution( // Rewire the plan to use the new attributes that were returned by the source. val replacementMap = AttributeMap(replacements) - val triggerLogicalPlan = withNewSources transformAllExpressions { + val newAttributePlan = newBatchesPlan transformAllExpressions { case a: Attribute if replacementMap.contains(a) => replacementMap(a).withMetadata(a.metadata) case ct: CurrentTimestamp => @@ -392,6 +437,20 @@ class MicroBatchExecution( cd.dataType, cd.timeZoneId) } + val triggerLogicalPlan = sink match { + case _: Sink => newAttributePlan + case s: MicroBatchWriteSupport => + val writer = s.createMicroBatchWriter( + s"$runId", + currentBatchId, + newAttributePlan.schema, + outputMode, + new DataSourceV2Options(extraOptions.asJava)) + assert(writer.isPresent, "microbatch writer must always be present") + WriteToDataSourceV2(writer.get, newAttributePlan) + case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") + } + reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( sparkSessionToRunBatch, @@ -409,7 +468,12 @@ class MicroBatchExecution( reportTimeTaken("addBatch") { SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { - sink.addBatch(currentBatchId, nextBatch) + sink match { + case s: Sink => s.addBatch(currentBatchId, nextBatch) + case s: MicroBatchWriteSupport => + // This doesn't accumulate any data - it just forces execution of the microbatch writer. + nextBatch.collect() + } } } @@ -421,4 +485,8 @@ class MicroBatchExecution( awaitProgressLock.unlock() } } + + private def toJava(scalaOption: Option[OffsetV2]): Optional[OffsetV2] = { + Optional.ofNullable(scalaOption.orNull) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 1c9043613cb69..d1e5be9c12762 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -53,7 +53,7 @@ trait ProgressReporter extends Logging { protected def triggerClock: Clock protected def logicalPlan: LogicalPlan protected def lastExecution: QueryExecution - protected def newData: Map[BaseStreamingSource, DataFrame] + protected def newData: Map[BaseStreamingSource, LogicalPlan] protected def availableOffsets: StreamProgress protected def committedOffsets: StreamProgress protected def sources: Seq[BaseStreamingSource] @@ -225,8 +225,8 @@ trait ProgressReporter extends Logging { // // 3. For each source, we sum the metrics of the associated execution plan leaves. // - val logicalPlanLeafToSource = newData.flatMap { case (source, df) => - df.logicalPlan.collectLeaves().map { leaf => leaf -> source } + val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) => + logicalPlan.collectLeaves().map { leaf => leaf -> source } } val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index d02cf882b61ac..66eb0169ac1ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -29,12 +29,12 @@ import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.streaming.continuous.ContinuousRateStreamReader -import org.apache.spark.sql.execution.streaming.sources.RateStreamV2Reader +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.execution.streaming.sources.RateStreamMicroBatchReader import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport -import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader} import org.apache.spark.sql.types._ import org.apache.spark.util.{ManualClock, SystemClock} @@ -112,7 +112,7 @@ class RateSourceProvider extends StreamSourceProvider with DataSourceRegister schema: Optional[StructType], checkpointLocation: String, options: DataSourceV2Options): ContinuousReader = { - new ContinuousRateStreamReader(options) + new RateStreamContinuousReader(options) } override def shortName(): String = "rate" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 3e76bf7b7ca8f..24a8b000df0c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -163,7 +163,7 @@ abstract class StreamExecution( var lastExecution: IncrementalExecution = _ /** Holds the most recent input data for each source. */ - protected var newData: Map[BaseStreamingSource, DataFrame] = _ + protected var newData: Map[BaseStreamingSource, LogicalPlan] = _ @volatile protected var streamDeathCause: StreamingQueryException = null @@ -418,7 +418,7 @@ abstract class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. */ - private[sql] def awaitOffset(source: Source, newOffset: Offset): Unit = { + private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = { assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index a9d50e3a112e7..a0ee683a895d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -61,7 +61,7 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. */ case class StreamingExecutionRelation( - source: Source, + source: BaseStreamingSource, output: Seq[Attribute])(session: SparkSession) extends LeafNode { @@ -92,7 +92,7 @@ case class StreamingRelationV2( sourceName: String, extraOptions: Map[String, String], output: Seq[Attribute], - v1DataSource: DataSource)(session: SparkSession) + v1Relation: Option[StreamingRelation])(session: SparkSession) extends LeafNode { override def isStreaming: Boolean = true override def toString: String = sourceName diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 2843ab13bde2b..9657b5e26d770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.DataSourceV2Options import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} @@ -174,7 +174,7 @@ class ContinuousExecution( val loggedOffset = offsets.offsets(0) val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) reader.setOffset(java.util.Optional.ofNullable(realOffset.orNull)) - DataSourceV2Relation(newOutput, reader) + new StreamingDataSourceV2Relation(newOutput, reader) } // Rewire the plan to use the new attributes that were returned by the source. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index c9aa78a5a2e28..b4b21e7d2052f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -32,10 +32,10 @@ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} -case class ContinuousRateStreamPartitionOffset( +case class RateStreamPartitionOffset( partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset -class ContinuousRateStreamReader(options: DataSourceV2Options) +class RateStreamContinuousReader(options: DataSourceV2Options) extends ContinuousReader { implicit val defaultFormats: DefaultFormats = DefaultFormats @@ -48,7 +48,7 @@ class ContinuousRateStreamReader(options: DataSourceV2Options) override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { assert(offsets.length == numPartitions) val tuples = offsets.map { - case ContinuousRateStreamPartitionOffset(i, currVal, nextRead) => + case RateStreamPartitionOffset(i, currVal, nextRead) => (i, ValueRunTimeMsPair(currVal, nextRead)) } RateStreamOffset(Map(tuples: _*)) @@ -86,7 +86,7 @@ class ContinuousRateStreamReader(options: DataSourceV2Options) val start = partitionStartMap(i) // Have each partition advance by numPartitions each row, with starting points staggered // by their partition index. - RateStreamReadTask( + RateStreamContinuousReadTask( start.value, start.runTimeMs, i, @@ -101,7 +101,7 @@ class ContinuousRateStreamReader(options: DataSourceV2Options) } -case class RateStreamReadTask( +case class RateStreamContinuousReadTask( startValue: Long, startTimeMs: Long, partitionIndex: Int, @@ -109,10 +109,11 @@ case class RateStreamReadTask( rowsPerSecond: Double) extends ReadTask[Row] { override def createDataReader(): DataReader[Row] = - new RateStreamDataReader(startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) + new RateStreamContinuousDataReader( + startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) } -class RateStreamDataReader( +class RateStreamContinuousDataReader( startValue: Long, startTimeMs: Long, partitionIndex: Int, @@ -151,5 +152,5 @@ class RateStreamDataReader( override def close(): Unit = {} override def getOffset(): PartitionOffset = - ContinuousRateStreamPartitionOffset(partitionIndex, currentValue, nextReadTime) + RateStreamPartitionOffset(partitionIndex, currentValue, nextReadTime) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index 97bada08bcd2b..c0ed12cec25ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -28,17 +28,38 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.MicroBatchReadSupport import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} -import org.apache.spark.util.SystemClock +import org.apache.spark.util.{ManualClock, SystemClock} -class RateStreamV2Reader(options: DataSourceV2Options) +/** + * This is a temporary register as we build out v2 migration. Microbatch read support should + * be implemented in the same register as v1. + */ +class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with DataSourceRegister { + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceV2Options): MicroBatchReader = { + new RateStreamMicroBatchReader(options) + } + + override def shortName(): String = "ratev2" +} + +class RateStreamMicroBatchReader(options: DataSourceV2Options) extends MicroBatchReader { implicit val defaultFormats: DefaultFormats = DefaultFormats - val clock = new SystemClock + val clock = { + // The option to use a manual clock is provided only for unit testing purposes. + if (options.get("useManualClock").orElse("false").toBoolean) new ManualClock + else new SystemClock + } private val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt @@ -111,7 +132,7 @@ class RateStreamV2Reader(options: DataSourceV2Options) val packedRows = mutable.ListBuffer[(Long, Long)]() var outVal = startVal + numPartitions - var outTimeMs = startTimeMs + msPerPartitionBetweenRows + var outTimeMs = startTimeMs while (outVal <= endVal) { packedRows.append((outTimeMs, outVal)) outVal += numPartitions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 2e92beecf2c17..52f2e2639cd86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import java.util.Locale +import java.util.{Locale, Optional} import scala.collection.JavaConverters._ @@ -27,8 +27,9 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} +import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -166,19 +167,31 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo userSpecifiedSchema = userSpecifiedSchema, className = source, options = extraOptions.toMap) + val v1Relation = ds match { + case _: StreamSourceProvider => Some(StreamingRelation(v1DataSource)) + case _ => None + } ds match { + case s: MicroBatchReadSupport => + val tempReader = s.createMicroBatchReader( + Optional.ofNullable(userSpecifiedSchema.orNull), + Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, + options) + Dataset.ofRows( + sparkSession, + StreamingRelationV2( + s, source, extraOptions.toMap, + tempReader.readSchema().toAttributes, v1Relation)(sparkSession)) case s: ContinuousReadSupport => val tempReader = s.createContinuousReader( - java.util.Optional.ofNullable(userSpecifiedSchema.orNull), + Optional.ofNullable(userSpecifiedSchema.orNull), Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, options) - // Generate the V1 node to catch errors thrown within generation. - StreamingRelation(v1DataSource) Dataset.ofRows( sparkSession, StreamingRelationV2( s, source, extraOptions.toMap, - tempReader.readSchema().toAttributes, v1DataSource)(sparkSession)) + tempReader.readSchema().toAttributes, v1Relation)(sparkSession)) case _ => // Code path for data source v1. Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index b508f4406138f..4b27e0d4ef47b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -29,10 +29,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger} import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport +import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -240,31 +240,35 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo "is not supported in streaming DataFrames/Datasets and will be disabled.") } - sink match { - case v1Sink: Sink => - new StreamingQueryWrapper(new MicroBatchExecution( + (sink, trigger) match { + case (v2Sink: ContinuousWriteSupport, trigger: ContinuousTrigger) => + UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) + new StreamingQueryWrapper(new ContinuousExecution( sparkSession, userSpecifiedName.orNull, checkpointLocation, analyzedPlan, - v1Sink, + v2Sink, trigger, triggerClock, outputMode, + extraOptions, deleteCheckpointOnStop)) - case v2Sink: ContinuousWriteSupport => - UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) - new StreamingQueryWrapper(new ContinuousExecution( + case (_: MicroBatchWriteSupport, _) | (_: Sink, _) => + new StreamingQueryWrapper(new MicroBatchExecution( sparkSession, userSpecifiedName.orNull, checkpointLocation, analyzedPlan, - v2Sink, + sink, trigger, triggerClock, outputMode, extraOptions, deleteCheckpointOnStop)) + case (_: ContinuousWriteSupport, t) if !t.isInstanceOf[ContinuousTrigger] => + throw new AnalysisException( + "Sink only supports continuous writes, but a continuous trigger was not specified.") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index e11705a227f48..85085d43061bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -18,20 +18,64 @@ package org.apache.spark.sql.execution.streaming import java.util.Optional +import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamSourceV2, RateStreamV2Reader} +import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.util.ManualClock class RateSourceV2Suite extends StreamTest { + import testImplicits._ + + case class AdvanceRateManualClock(seconds: Long) extends AddData { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + assert(query.nonEmpty) + val rateSource = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source + }.head + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) + rateSource.setOffsetRange(Optional.empty(), Optional.empty()) + (rateSource, rateSource.getEndOffset()) + } + } + + test("microbatch in registry") { + DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceV2Options.empty()) + assert(reader.isInstanceOf[RateStreamMicroBatchReader]) + case _ => + throw new IllegalStateException("Could not find v2 read support for rate") + } + } + + test("basic microbatch execution") { + val input = spark.readStream + .format("rateV2") + .option("numPartitions", "1") + .option("rowsPerSecond", "10") + .option("useManualClock", "true") + .load() + testStream(input, useV2Sink = true)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + StopStream, + StartStream(), + // Advance 2 seconds because creating a new RateSource will also create a new ManualClock + AdvanceRateManualClock(seconds = 2), + CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + test("microbatch - numPartitions propagated") { - val reader = new RateStreamV2Reader( + val reader = new RateStreamMicroBatchReader( new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) reader.setOffsetRange(Optional.empty(), Optional.empty()) val tasks = reader.createReadTasks() @@ -39,7 +83,7 @@ class RateSourceV2Suite extends StreamTest { } test("microbatch - set offset") { - val reader = new RateStreamV2Reader(DataSourceV2Options.empty()) + val reader = new RateStreamMicroBatchReader(DataSourceV2Options.empty()) val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000)))) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) @@ -48,7 +92,7 @@ class RateSourceV2Suite extends StreamTest { } test("microbatch - infer offsets") { - val reader = new RateStreamV2Reader( + val reader = new RateStreamMicroBatchReader( new DataSourceV2Options(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100) reader.setOffsetRange(Optional.empty(), Optional.empty()) @@ -69,7 +113,7 @@ class RateSourceV2Suite extends StreamTest { } test("microbatch - predetermined batch size") { - val reader = new RateStreamV2Reader( + val reader = new RateStreamMicroBatchReader( new DataSourceV2Options(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) @@ -80,7 +124,7 @@ class RateSourceV2Suite extends StreamTest { } test("microbatch - data read") { - val reader = new RateStreamV2Reader( + val reader = new RateStreamMicroBatchReader( new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { @@ -107,14 +151,14 @@ class RateSourceV2Suite extends StreamTest { DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { case ds: ContinuousReadSupport => val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceV2Options.empty()) - assert(reader.isInstanceOf[ContinuousRateStreamReader]) + assert(reader.isInstanceOf[RateStreamContinuousReader]) case _ => throw new IllegalStateException("Could not find v2 read support for rate") } } test("continuous data") { - val reader = new ContinuousRateStreamReader( + val reader = new RateStreamContinuousReader( new DataSourceV2Options(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) reader.setOffset(Optional.empty()) val tasks = reader.createReadTasks() @@ -122,17 +166,17 @@ class RateSourceV2Suite extends StreamTest { val data = scala.collection.mutable.ListBuffer[Row]() tasks.asScala.foreach { - case t: RateStreamReadTask => + case t: RateStreamContinuousReadTask => val startTimeMs = reader.getStartOffset() .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamDataReader] + val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] for (rowIndex <- 0 to 9) { r.next() data.append(r.get()) assert(r.getOffset() == - ContinuousRateStreamPartitionOffset( + RateStreamPartitionOffset( t.partitionIndex, t.partitionIndex + rowIndex * 2, startTimeMs + (rowIndex + 1) * 100)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 4b7f0fbe97d4e..d46461fa9bf6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -105,7 +105,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be * the active query, and then return the source object the data was added, as well as the * offset of added data. */ - def addData(query: Option[StreamExecution]): (Source, Offset) + def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) } /** A trait that can be extended when testing a source. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index eda0d8ad48313..9562c10feafe9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -61,7 +61,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, r: ContinuousRateStreamReader) => r + case DataSourceV2ScanExec(_, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 From 68ce792b5857f0291154f524ac651036db868bb9 Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Tue, 9 Jan 2018 10:15:01 +0800 Subject: [PATCH 043/774] [SPARK-22972] Couldn't find corresponding Hive SerDe for data source provider org.apache.spark.sql.hive.orc ## What changes were proposed in this pull request? Fix the warning: Couldn't find corresponding Hive SerDe for data source provider org.apache.spark.sql.hive.orc. ## How was this patch tested? test("SPARK-22972: hive orc source") assert(HiveSerDe.sourceToSerDe("org.apache.spark.sql.hive.orc") .equals(HiveSerDe.sourceToSerDe("orc"))) Author: xubo245 <601450868@qq.com> Closes #20165 from xubo245/HiveSerDe. --- .../apache/spark/sql/internal/HiveSerDe.scala | 1 + .../sql/hive/orc/HiveOrcSourceSuite.scala | 29 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala index b9515ec7bca2a..dac463641cfab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala @@ -73,6 +73,7 @@ object HiveSerDe { val key = source.toLowerCase(Locale.ROOT) match { case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet" case s if s.startsWith("org.apache.spark.sql.orc") => "orc" + case s if s.startsWith("org.apache.spark.sql.hive.orc") => "orc" case s if s.equals("orcfile") => "orc" case s if s.equals("parquetfile") => "parquet" case s if s.equals("avrofile") => "avro" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index 17b7d8cfe127e..d556a030e2186 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.hive.orc import java.io.File import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.orc.OrcSuite import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.util.Utils class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { @@ -62,6 +64,33 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { """.stripMargin) } + test("SPARK-22972: hive orc source") { + val tableName = "normal_orc_as_source_hive" + withTable(tableName) { + sql( + s""" + |CREATE TABLE $tableName + |USING org.apache.spark.sql.hive.orc + |OPTIONS ( + | PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}' + |) + """.stripMargin) + + val tableMetadata = spark.sessionState.catalog.getTableMetadata( + TableIdentifier(tableName)) + assert(tableMetadata.storage.inputFormat == + Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) + assert(tableMetadata.storage.outputFormat == + Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + assert(tableMetadata.storage.serde == + Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + assert(HiveSerDe.sourceToSerDe("org.apache.spark.sql.hive.orc") + .equals(HiveSerDe.sourceToSerDe("orc"))) + assert(HiveSerDe.sourceToSerDe("org.apache.spark.sql.orc") + .equals(HiveSerDe.sourceToSerDe("orc"))) + } + } + test("SPARK-19459/SPARK-18220: read char/varchar column written by Hive") { val location = Utils.createTempDir() val uri = location.toURI From 849043ce1d28a976659278d29368da0799329db8 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Tue, 9 Jan 2018 10:44:21 +0800 Subject: [PATCH 044/774] [SPARK-22990][CORE] Fix method isFairScheduler in JobsTab and StagesTab ## What changes were proposed in this pull request? In current implementation, the function `isFairScheduler` is always false, since it is comparing String with `SchedulingMode` Author: Wang Gengliang Closes #20186 from gengliangwang/isFairScheduler. --- .../src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala | 8 ++++---- .../main/scala/org/apache/spark/ui/jobs/StagesTab.scala | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index 99eab1b2a27d8..ff1b75e5c5065 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -34,10 +34,10 @@ private[ui] class JobsTab(parent: SparkUI, store: AppStatusStore) val killEnabled = parent.killEnabled def isFairScheduler: Boolean = { - store.environmentInfo().sparkProperties.toMap - .get("spark.scheduler.mode") - .map { mode => mode == SchedulingMode.FAIR } - .getOrElse(false) + store + .environmentInfo() + .sparkProperties + .contains(("spark.scheduler.mode", SchedulingMode.FAIR.toString)) } def getSparkUser: String = parent.getSparkUser diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index be05a963f0e68..10b032084ce4f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -37,10 +37,10 @@ private[ui] class StagesTab(val parent: SparkUI, val store: AppStatusStore) attachPage(new PoolPage(this)) def isFairScheduler: Boolean = { - store.environmentInfo().sparkProperties.toMap - .get("spark.scheduler.mode") - .map { mode => mode == SchedulingMode.FAIR } - .getOrElse(false) + store + .environmentInfo() + .sparkProperties + .contains(("spark.scheduler.mode", SchedulingMode.FAIR.toString)) } def handleKillRequest(request: HttpServletRequest): Unit = { From f20131dd35939734fe16b0005a086aa72400893b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 9 Jan 2018 11:49:10 +0800 Subject: [PATCH 045/774] [SPARK-22984] Fix incorrect bitmap copying and offset adjustment in GenerateUnsafeRowJoiner ## What changes were proposed in this pull request? This PR fixes a longstanding correctness bug in `GenerateUnsafeRowJoiner`. This class was introduced in https://github.com/apache/spark/pull/7821 (July 2015 / Spark 1.5.0+) and is used to combine pairs of UnsafeRows in TungstenAggregationIterator, CartesianProductExec, and AppendColumns. ### Bugs fixed by this patch 1. **Incorrect combining of null-tracking bitmaps**: when concatenating two UnsafeRows, the implementation "Concatenate the two bitsets together into a single one, taking padding into account". If one row has no columns then it has a bitset size of 0, but the code was incorrectly assuming that if the left row had a non-zero number of fields then the right row would also have at least one field, so it was copying invalid bytes and and treating them as part of the bitset. I'm not sure whether this bug was also present in the original implementation or whether it was introduced in https://github.com/apache/spark/pull/7892 (which fixed another bug in this code). 2. **Incorrect updating of data offsets for null variable-length fields**: after updating the bitsets and copying fixed-length and variable-length data, we need to perform adjustments to the offsets pointing the start of variable length fields's data. The existing code was _conditionally_ adding a fixed offset to correct for the new length of the combined row, but it is unsafe to do this if the variable-length field has a null value: we always represent nulls by storing `0` in the fixed-length slot, but this code was incorrectly incrementing those values. This bug was present since the original version of `GenerateUnsafeRowJoiner`. ### Why this bug remained latent for so long The PR which introduced `GenerateUnsafeRowJoiner` features several randomized tests, including tests of the cases where one side of the join has no fields and where string-valued fields are null. However, the existing assertions were too weak to uncover this bug: - If a null field has a non-zero value in its fixed-length data slot then this will not cause problems for field accesses because the null-tracking bitmap should still be correct and we will not try to use the incorrect offset for anything. - If the null tracking bitmap is corrupted by joining against a row with no fields then the corruption occurs in field numbers past the actual field numbers contained in the row. Thus valid `isNullAt()` calls will not read the incorrectly-set bits. The existing `GenerateUnsafeRowJoinerSuite` tests only exercised `.get()` and `isNullAt()`, but didn't actually check the UnsafeRows for bit-for-bit equality, preventing these bugs from failing assertions. It turns out that there was even a [GenerateUnsafeRowJoinerBitsetSuite](https://github.com/apache/spark/blob/03377d2522776267a07b7d6ae9bddf79a4e0f516/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala) but it looks like it also didn't catch this problem because it only tested the bitsets in an end-to-end fashion by accessing them through the `UnsafeRow` interface instead of actually comparing the bitsets' bytes. ### Impact of these bugs - This bug will cause `equals()` and `hashCode()` to be incorrect for these rows, which will be problematic in case`GenerateUnsafeRowJoiner`'s results are used as join or grouping keys. - Chained / repeated invocations of `GenerateUnsafeRowJoiner` may result in reads from invalid null bitmap positions causing fields to incorrectly become NULL (see the end-to-end example below). - It looks like this generally only happens in `CartesianProductExec`, which our query optimizer often avoids executing (usually we try to plan a `BroadcastNestedLoopJoin` instead). ### End-to-end test case demonstrating the problem The following query demonstrates how this bug may result in incorrect query results: ```sql set spark.sql.autoBroadcastJoinThreshold=-1; -- Needed to trigger CartesianProductExec create table a as select * from values 1; create table b as select * from values 2; SELECT t3.col1, t1.col1 FROM a t1 CROSS JOIN b t2 CROSS JOIN b t3 ``` This should return `(2, 1)` but instead was returning `(null, 1)`. Column pruning ends up trimming off all columns from `t2`, so when `t2` joins with another table this triggers the bitmap-copying bug. This incorrect bitmap is subsequently copied again when performing the final join, causing the final output to have an incorrectly-set null bit for the first field. ## How was this patch tested? Strengthened the assertions in existing tests in GenerateUnsafeRowJoinerSuite. Also verified that the end-to-end test case which uncovered this now passes. Author: Josh Rosen Closes #20181 from JoshRosen/SPARK-22984-fix-generate-unsaferow-joiner-bitmap-bugs. --- .../codegen/GenerateUnsafeRowJoiner.scala | 52 +++++++++- .../GenerateUnsafeRowJoinerSuite.scala | 95 ++++++++++++++++++- 2 files changed, 138 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index be5f5a73b5d47..febf7b0c96c2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -70,7 +70,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U // --------------------- copy bitset from row 1 and row 2 --------------------------- // val copyBitset = Seq.tabulate(outputBitsetWords) { i => - val bits = if (bitset1Remainder > 0) { + val bits = if (bitset1Remainder > 0 && bitset2Words != 0) { if (i < bitset1Words - 1) { s"$getLong(obj1, offset1 + ${i * 8})" } else if (i == bitset1Words - 1) { @@ -152,7 +152,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U } else { // Number of bytes to increase for the offset. Note that since in UnsafeRow we store the // offset in the upper 32 bit of the words, we can just shift the offset to the left by - // 32 and increment that amount in place. + // 32 and increment that amount in place. However, we need to handle the important special + // case of a null field, in which case the offset should be zero and should not have a + // shift added to it. val shift = if (i < schema1.size) { s"${(outputBitsetWords - bitset1Words + schema2.size) * 8}L" @@ -160,14 +162,55 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1)" } val cursor = offset + outputBitsetWords * 8 + i * 8 - s"$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32));\n" + // UnsafeRow is a little underspecified, so in what follows we'll treat UnsafeRowWriter's + // output as a de-facto specification for the internal layout of data. + // + // Null-valued fields will always have a data offset of 0 because + // UnsafeRowWriter.setNullAt(ordinal) sets the null bit and stores 0 to in field's + // position in the fixed-length section of the row. As a result, we must NOT add + // `shift` to the offset for null fields. + // + // We could perform a null-check here by inspecting the null-tracking bitmap, but doing + // so could be expensive and will add significant bloat to the generated code. Instead, + // we'll rely on the invariant "stored offset == 0 for variable-length data type implies + // that the field's value is null." + // + // To establish that this invariant holds, we'll prove that a non-null field can never + // have a stored offset of 0. There are two cases to consider: + // + // 1. The non-null field's data is of non-zero length: reading this field's value + // must read data from the variable-length section of the row, so the stored offset + // will actually be used in address calculation and must be correct. The offsets + // count bytes from the start of the UnsafeRow so these offsets will always be + // non-zero because the storage of the offsets themselves takes up space at the + // start of the row. + // 2. The non-null field's data is of zero length (i.e. its data is empty). In this + // case, we have to worry about the possibility that an arbitrary offset value was + // stored because we never actually read any bytes using this offset and therefore + // would not crash if it was incorrect. The variable-sized data writing paths in + // UnsafeRowWriter unconditionally calls setOffsetAndSize(ordinal, numBytes) with + // no special handling for the case where `numBytes == 0`. Internally, + // setOffsetAndSize computes the offset without taking the size into account. Thus + // the stored offset is the same non-zero offset that would be used if the field's + // dataSize was non-zero (and in (1) above we've shown that case behaves as we + // expect). + // + // Thus it is safe to perform `existingOffset != 0` checks here in the place of + // more expensive null-bit checks. + s""" + |existingOffset = $getLong(buf, $cursor); + |if (existingOffset != 0) { + | $putLong(buf, $cursor, existingOffset + ($shift << 32)); + |} + """.stripMargin } } val updateOffsets = ctx.splitExpressions( expressions = updateOffset, funcName = "copyBitsetFunc", - arguments = ("long", "numBytesVariableRow1") :: Nil) + arguments = ("long", "numBytesVariableRow1") :: Nil, + makeSplitFunction = (s: String) => "long existingOffset;\n" + s) // ------------------------ Finally, put everything together --------------------------- // val codeBody = s""" @@ -200,6 +243,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | $copyFixedLengthRow2 | $copyVariableLengthRow1 | $copyVariableLengthRow2 + | long existingOffset; | $updateOffsets | | out.pointTo(buf, sizeInBytes); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala index f203f25ad10d4..75c6beeb32150 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala @@ -22,8 +22,10 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{JoinedRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Test suite for [[GenerateUnsafeRowJoiner]]. @@ -45,6 +47,32 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite { testConcat(64, 64, fixed) } + test("rows with all empty strings") { + val schema = StructType(Seq( + StructField("f1", StringType), StructField("f2", StringType))) + val row: UnsafeRow = UnsafeProjection.create(schema).apply( + InternalRow(UTF8String.EMPTY_UTF8, UTF8String.EMPTY_UTF8)) + testConcat(schema, row, schema, row) + } + + test("rows with all empty int arrays") { + val schema = StructType(Seq( + StructField("f1", ArrayType(IntegerType)), StructField("f2", ArrayType(IntegerType)))) + val emptyIntArray = + ExpressionEncoder[Array[Int]]().resolveAndBind().toRow(Array.emptyIntArray).getArray(0) + val row: UnsafeRow = UnsafeProjection.create(schema).apply( + InternalRow(emptyIntArray, emptyIntArray)) + testConcat(schema, row, schema, row) + } + + test("alternating empty and non-empty strings") { + val schema = StructType(Seq( + StructField("f1", StringType), StructField("f2", StringType))) + val row: UnsafeRow = UnsafeProjection.create(schema).apply( + InternalRow(UTF8String.EMPTY_UTF8, UTF8String.fromString("foo"))) + testConcat(schema, row, schema, row) + } + test("randomized fix width types") { for (i <- 0 until 20) { testConcatOnce(Random.nextInt(100), Random.nextInt(100), fixed) @@ -94,27 +122,84 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite { val extRow2 = RandomDataGenerator.forType(schema2, nullable = false).get.apply() val row1 = converter1.apply(internalConverter1.apply(extRow1).asInstanceOf[InternalRow]) val row2 = converter2.apply(internalConverter2.apply(extRow2).asInstanceOf[InternalRow]) + testConcat(schema1, row1, schema2, row2) + } + + private def testConcat( + schema1: StructType, + row1: UnsafeRow, + schema2: StructType, + row2: UnsafeRow) { // Run the joiner. val mergedSchema = StructType(schema1 ++ schema2) val concater = GenerateUnsafeRowJoiner.create(schema1, schema2) - val output = concater.join(row1, row2) + val output: UnsafeRow = concater.join(row1, row2) + + // We'll also compare to an UnsafeRow produced with JoinedRow + UnsafeProjection. This ensures + // that unused space in the row (e.g. leftover bits in the null-tracking bitmap) is written + // correctly. + val expectedOutput: UnsafeRow = { + val joinedRowProjection = UnsafeProjection.create(mergedSchema) + val joined = new JoinedRow() + joinedRowProjection.apply(joined.apply(row1, row2)) + } // Test everything equals ... for (i <- mergedSchema.indices) { + val dataType = mergedSchema(i).dataType if (i < schema1.size) { assert(output.isNullAt(i) === row1.isNullAt(i)) if (!output.isNullAt(i)) { - assert(output.get(i, mergedSchema(i).dataType) === row1.get(i, mergedSchema(i).dataType)) + assert(output.get(i, dataType) === row1.get(i, dataType)) + assert(output.get(i, dataType) === expectedOutput.get(i, dataType)) } } else { assert(output.isNullAt(i) === row2.isNullAt(i - schema1.size)) if (!output.isNullAt(i)) { - assert(output.get(i, mergedSchema(i).dataType) === - row2.get(i - schema1.size, mergedSchema(i).dataType)) + assert(output.get(i, dataType) === row2.get(i - schema1.size, dataType)) + assert(output.get(i, dataType) === expectedOutput.get(i, dataType)) } } } + + + assert( + expectedOutput.getSizeInBytes == output.getSizeInBytes, + "output isn't same size in bytes as slow path") + + // Compare the UnsafeRows byte-by-byte so that we can print more useful debug information in + // case this assertion fails: + val actualBytes = output.getBaseObject.asInstanceOf[Array[Byte]] + .take(output.getSizeInBytes) + val expectedBytes = expectedOutput.getBaseObject.asInstanceOf[Array[Byte]] + .take(expectedOutput.getSizeInBytes) + + val bitsetWidth = UnsafeRow.calculateBitSetWidthInBytes(expectedOutput.numFields()) + val actualBitset = actualBytes.take(bitsetWidth) + val expectedBitset = expectedBytes.take(bitsetWidth) + assert(actualBitset === expectedBitset, "bitsets were not equal") + + val fixedLengthSize = expectedOutput.numFields() * 8 + val actualFixedLength = actualBytes.slice(bitsetWidth, bitsetWidth + fixedLengthSize) + val expectedFixedLength = expectedBytes.slice(bitsetWidth, bitsetWidth + fixedLengthSize) + if (actualFixedLength !== expectedFixedLength) { + actualFixedLength.grouped(8) + .zip(expectedFixedLength.grouped(8)) + .zip(mergedSchema.fields.toIterator) + .foreach { + case ((actual, expected), field) => + assert(actual === expected, s"Fixed length sections are not equal for field $field") + } + fail("Fixed length sections were not equal") + } + + val variableLengthStart = bitsetWidth + fixedLengthSize + val actualVariableLength = actualBytes.drop(variableLengthStart) + val expectedVariableLength = expectedBytes.drop(variableLengthStart) + assert(actualVariableLength === expectedVariableLength, "fixed length sections were not equal") + + assert(output.hashCode() == expectedOutput.hashCode(), "hash codes were not equal") } } From 8486ad419d8f1779e277ec71c39e1516673a83ab Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 8 Jan 2018 21:58:26 -0800 Subject: [PATCH 046/774] [SPARK-21292][DOCS] refreshtable example ## What changes were proposed in this pull request? doc update Author: Felix Cheung Closes #20198 from felixcheung/rrefreshdoc. --- docs/sql-programming-guide.md | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3ccaaf4d5b1fa..72f79d6909ecc 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -915,6 +915,14 @@ spark.catalog.refreshTable("my_table") +
+ +{% highlight r %} +refreshTable("my_table") +{% endhighlight %} + +
+
{% highlight sql %} @@ -1498,10 +1506,10 @@ that these options will be deprecated in future release as more optimizations ar ## Broadcast Hint for SQL Queries The `BROADCAST` hint guides Spark to broadcast each specified table when joining them with another table or view. -When Spark deciding the join methods, the broadcast hash join (i.e., BHJ) is preferred, +When Spark deciding the join methods, the broadcast hash join (i.e., BHJ) is preferred, even if the statistics is above the configuration `spark.sql.autoBroadcastJoinThreshold`. When both sides of a join are specified, Spark broadcasts the one having the lower statistics. -Note Spark does not guarantee BHJ is always chosen, since not all cases (e.g. full outer join) +Note Spark does not guarantee BHJ is always chosen, since not all cases (e.g. full outer join) support BHJ. When the broadcast nested loop join is selected, we still respect the hint.
@@ -1780,7 +1788,7 @@ options. Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other combinations of scales and precisions because currently we only infer decimal type like `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type. - In PySpark, now we need Pandas 0.19.2 or upper if you want to use Pandas related functionalities, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc. - In PySpark, the behavior of timestamp values for Pandas related functionalities was changed to respect session timezone. If you want to use the old behavior, you need to set a configuration `spark.sql.execution.pandas.respectSessionTimeZone` to `False`. See [SPARK-22395](https://issues.apache.org/jira/browse/SPARK-22395) for details. - + - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489). - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. @@ -2167,7 +2175,7 @@ Not all the APIs of the Hive UDF/UDTF/UDAF are supported by Spark SQL. Below are Spark SQL currently does not support the reuse of aggregation. * `getWindowingEvaluator` (`GenericUDAFEvaluator`) is a function to optimize aggregation by evaluating an aggregate over a fixed window. - + ### Incompatible Hive UDF Below are the scenarios in which Hive and Spark generate different results: From 02214b094390e913f52e71d55c9bb8a81c9e7ef9 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 8 Jan 2018 22:08:19 -0800 Subject: [PATCH 047/774] [SPARK-21293][SPARKR][DOCS] structured streaming doc update ## What changes were proposed in this pull request? doc update Author: Felix Cheung Closes #20197 from felixcheung/rwadoc. --- R/pkg/vignettes/sparkr-vignettes.Rmd | 2 +- docs/sparkr.md | 2 +- .../structured-streaming-programming-guide.md | 32 +++++++++++++++++-- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 2e662424b25f2..feca617c2554c 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -1042,7 +1042,7 @@ unlink(modelPath) ## Structured Streaming -SparkR supports the Structured Streaming API (experimental). +SparkR supports the Structured Streaming API. You can check the Structured Streaming Programming Guide for [an introduction](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#programming-model) to its programming model and basic concepts. diff --git a/docs/sparkr.md b/docs/sparkr.md index 997ea60fb6cf0..6685b585a393a 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -596,7 +596,7 @@ The following example shows how to save/load a MLlib model by SparkR. # Structured Streaming -SparkR supports the Structured Streaming API (experimental). Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. For more information see the R API on the [Structured Streaming Programming Guide](structured-streaming-programming-guide.html) +SparkR supports the Structured Streaming API. Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. For more information see the R API on the [Structured Streaming Programming Guide](structured-streaming-programming-guide.html) # R Function Name Conflicts diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 31fcfabb9cacc..de13e281916db 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -827,8 +827,8 @@ df.isStreaming() {% endhighlight %}
-{% highlight bash %} -Not available. +{% highlight r %} +isStreaming(df) {% endhighlight %}
@@ -885,6 +885,19 @@ windowedCounts = words.groupBy( ).count() {% endhighlight %} + +
+{% highlight r %} +words <- ... # streaming DataFrame of schema { timestamp: Timestamp, word: String } + +# Group the data by window and word and compute the count of each group +windowedCounts <- count( + groupBy( + words, + window(words$timestamp, "10 minutes", "5 minutes"), + words$word)) +{% endhighlight %} +
@@ -959,6 +972,21 @@ windowedCounts = words \ .count() {% endhighlight %} + +
+{% highlight r %} +words <- ... # streaming DataFrame of schema { timestamp: Timestamp, word: String } + +# Group the data by window and word and compute the count of each group + +words <- withWatermark(words, "timestamp", "10 minutes") +windowedCounts <- count( + groupBy( + words, + window(words$timestamp, "10 minutes", "5 minutes"), + words$word)) +{% endhighlight %} +
From 0959aa581a399279be3f94214bcdffc6a1b6d60a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 9 Jan 2018 16:31:20 +0800 Subject: [PATCH 048/774] [SPARK-23000] Fix Flaky test suite DataSourceWithHiveMetastoreCatalogSuite in Spark 2.3 ## What changes were proposed in this pull request? https://amplab.cs.berkeley.edu/jenkins/job/spark-branch-2.3-test-sbt-hadoop-2.6/ The test suite DataSourceWithHiveMetastoreCatalogSuite of Branch 2.3 always failed in hadoop 2.6 The table `t` exists in `default`, but `runSQLHive` reported the table does not exist. Obviously, Hive client's default database is different. The fix is to clean the environment and use `DEFAULT` as the database. ``` org.apache.spark.sql.execution.QueryExecutionException: FAILED: SemanticException [Error 10001]: Line 1:14 Table not found 't' Stacktrace sbt.ForkMain$ForkError: org.apache.spark.sql.execution.QueryExecutionException: FAILED: SemanticException [Error 10001]: Line 1:14 Table not found 't' at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$runHive$1.apply(HiveClientImpl.scala:699) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$runHive$1.apply(HiveClientImpl.scala:683) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$withHiveState$1.apply(HiveClientImpl.scala:272) at org.apache.spark.sql.hive.client.HiveClientImpl.liftedTree1$1(HiveClientImpl.scala:210) at org.apache.spark.sql.hive.client.HiveClientImpl.retryLocked(HiveClientImpl.scala:209) at org.apache.spark.sql.hive.client.HiveClientImpl.withHiveState(HiveClientImpl.scala:255) at org.apache.spark.sql.hive.client.HiveClientImpl.runHive(HiveClientImpl.scala:683) at org.apache.spark.sql.hive.client.HiveClientImpl.runSqlHive(HiveClientImpl.scala:673) ``` ## How was this patch tested? N/A Author: gatorsmile Closes #20196 from gatorsmile/testFix. --- .../org/apache/spark/sql/hive/client/HiveClientImpl.scala | 6 +++++- .../apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala | 5 +++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 7b7f4e0f10210..102f40bacc985 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -823,7 +823,8 @@ private[hive] class HiveClientImpl( } def reset(): Unit = withHiveState { - client.getAllTables("default").asScala.foreach { t => + try { + client.getAllTables("default").asScala.foreach { t => logDebug(s"Deleting table $t") val table = client.getTable("default", t) client.getIndexes("default", t, 255).asScala.foreach { index => @@ -837,6 +838,9 @@ private[hive] class HiveClientImpl( logDebug(s"Dropping Database: $db") client.dropDatabase(db, true, false, true) } + } finally { + runSqlHive("USE default") + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 18137e7ea1d63..cf4ce83124d88 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -146,6 +146,11 @@ class DataSourceWithHiveMetastoreCatalogSuite 'id cast StringType as 'd2 ).coalesce(1) + override def beforeAll(): Unit = { + super.beforeAll() + sparkSession.metadataHive.reset() + } + Seq( "parquet" -> (( "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", From 6a4206ff04746481d7c8e307dfd0d31ff1402555 Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Tue, 9 Jan 2018 01:32:48 -0800 Subject: [PATCH 049/774] [SPARK-22998][K8S] Set missing value for SPARK_MOUNTED_CLASSPATH in the executors ## What changes were proposed in this pull request? The environment variable `SPARK_MOUNTED_CLASSPATH` is referenced in the executor's Dockerfile, where its value is added to the classpath of the executor. However, the scheduler backend code missed setting it when creating the executor pods. This PR fixes it. ## How was this patch tested? Unit tested. vanzin Can you help take a look? Thanks! foxish Author: Yinan Li Closes #20193 from liyinan926/master. --- .../spark/scheduler/cluster/k8s/ExecutorPodFactory.scala | 5 ++++- .../scheduler/cluster/k8s/ExecutorPodFactorySuite.scala | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index 066d7e9f70ca5..bcacb3934d36a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -94,6 +94,8 @@ private[spark] class ExecutorPodFactory( private val executorCores = sparkConf.getDouble("spark.executor.cores", 1) private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) + private val executorJarsDownloadDir = sparkConf.get(JARS_DOWNLOAD_LOCATION) + /** * Configure and construct an executor pod with the given parameters. */ @@ -145,7 +147,8 @@ private[spark] class ExecutorPodFactory( (ENV_EXECUTOR_CORES, math.ceil(executorCores).toInt.toString), (ENV_EXECUTOR_MEMORY, executorMemoryString), (ENV_APPLICATION_ID, applicationId), - (ENV_EXECUTOR_ID, executorId)) ++ executorEnvs) + (ENV_EXECUTOR_ID, executorId), + (ENV_MOUNTED_CLASSPATH, s"$executorJarsDownloadDir/*")) ++ executorEnvs) .map(env => new EnvVarBuilder() .withName(env._1) .withValue(env._2) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 884da8aabd880..7cfbe54c95390 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -197,7 +197,8 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef ENV_EXECUTOR_CORES -> "1", ENV_EXECUTOR_MEMORY -> "1g", ENV_APPLICATION_ID -> "dummy", - ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars + ENV_EXECUTOR_POD_IP -> null, + ENV_MOUNTED_CLASSPATH -> "/var/spark-data/spark-jars/*") ++ additionalEnvVars assert(executor.getSpec.getContainers.size() === 1) assert(executor.getSpec.getContainers.get(0).getEnv.size() === defaultEnvs.size) From f44ba910f58083458e1133502e193a9d6f2bf766 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 9 Jan 2018 21:48:14 +0800 Subject: [PATCH 050/774] [SPARK-16060][SQL] Support Vectorized ORC Reader ## What changes were proposed in this pull request? This PR adds an ORC columnar-batch reader to native `OrcFileFormat`. Since both Spark `ColumnarBatch` and ORC `RowBatch` are used together, it is faster than the current Spark implementation. This replaces the prior PR, #17924. Also, this PR adds `OrcReadBenchmark` to show the performance improvement. ## How was this patch tested? Pass the existing test cases. Author: Dongjoon Hyun Closes #19943 from dongjoon-hyun/SPARK-16060. --- .../apache/spark/sql/internal/SQLConf.scala | 7 + .../orc/OrcColumnarBatchReader.java | 523 ++++++++++++++++++ .../datasources/orc/OrcFileFormat.scala | 75 ++- .../execution/datasources/orc/OrcUtils.scala | 7 +- .../spark/sql/hive/orc/OrcReadBenchmark.scala | 435 +++++++++++++++ 5 files changed, 1022 insertions(+), 25 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5c61f10bb71ad..74949db883f7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -386,6 +386,11 @@ object SQLConf { .checkValues(Set("hive", "native")) .createWithDefault("native") + val ORC_VECTORIZED_READER_ENABLED = buildConf("spark.sql.orc.enableVectorizedReader") + .doc("Enables vectorized orc decoding.") + .booleanConf + .createWithDefault(true) + val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf @@ -1183,6 +1188,8 @@ class SQLConf extends Serializable with Logging { def orcCompressionCodec: String = getConf(ORC_COMPRESSION) + def orcVectorizedReaderEnabled: Boolean = getConf(ORC_VECTORIZED_READER_ENABLED) + def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java new file mode 100644 index 0000000000000..5c28d0e6e507a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -0,0 +1,523 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import java.io.IOException; +import java.util.stream.IntStream; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.RecordReader; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.hadoop.mapreduce.lib.input.FileSplit; +import org.apache.orc.OrcConf; +import org.apache.orc.OrcFile; +import org.apache.orc.Reader; +import org.apache.orc.TypeDescription; +import org.apache.orc.mapred.OrcInputFormat; +import org.apache.orc.storage.common.type.HiveDecimal; +import org.apache.orc.storage.ql.exec.vector.*; +import org.apache.orc.storage.serde2.io.HiveDecimalWritable; + +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; +import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector; +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.execution.vectorized.WritableColumnVector; +import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarBatch; + + +/** + * To support vectorization in WholeStageCodeGen, this reader returns ColumnarBatch. + * After creating, `initialize` and `initBatch` should be called sequentially. + */ +public class OrcColumnarBatchReader extends RecordReader { + + /** + * The default size of batch. We use this value for both ORC and Spark consistently + * because they have different default values like the following. + * + * - ORC's VectorizedRowBatch.DEFAULT_SIZE = 1024 + * - Spark's ColumnarBatch.DEFAULT_BATCH_SIZE = 4 * 1024 + */ + public static final int DEFAULT_SIZE = 4 * 1024; + + // ORC File Reader + private Reader reader; + + // Vectorized ORC Row Batch + private VectorizedRowBatch batch; + + /** + * The column IDs of the physical ORC file schema which are required by this reader. + * -1 means this required column doesn't exist in the ORC file. + */ + private int[] requestedColIds; + + // Record reader from ORC row batch. + private org.apache.orc.RecordReader recordReader; + + private StructField[] requiredFields; + + // The result columnar batch for vectorized execution by whole-stage codegen. + private ColumnarBatch columnarBatch; + + // Writable column vectors of the result columnar batch. + private WritableColumnVector[] columnVectors; + + /** + * The memory mode of the columnarBatch + */ + private final MemoryMode MEMORY_MODE; + + public OrcColumnarBatchReader(boolean useOffHeap) { + MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; + } + + + @Override + public Void getCurrentKey() throws IOException, InterruptedException { + return null; + } + + @Override + public ColumnarBatch getCurrentValue() throws IOException, InterruptedException { + return columnarBatch; + } + + @Override + public float getProgress() throws IOException, InterruptedException { + return recordReader.getProgress(); + } + + @Override + public boolean nextKeyValue() throws IOException, InterruptedException { + return nextBatch(); + } + + @Override + public void close() throws IOException { + if (columnarBatch != null) { + columnarBatch.close(); + columnarBatch = null; + } + if (recordReader != null) { + recordReader.close(); + recordReader = null; + } + } + + /** + * Initialize ORC file reader and batch record reader. + * Please note that `initBatch` is needed to be called after this. + */ + @Override + public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) + throws IOException, InterruptedException { + FileSplit fileSplit = (FileSplit)inputSplit; + Configuration conf = taskAttemptContext.getConfiguration(); + reader = OrcFile.createReader( + fileSplit.getPath(), + OrcFile.readerOptions(conf) + .maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf)) + .filesystem(fileSplit.getPath().getFileSystem(conf))); + + Reader.Options options = + OrcInputFormat.buildOptions(conf, reader, fileSplit.getStart(), fileSplit.getLength()); + recordReader = reader.rows(options); + } + + /** + * Initialize columnar batch by setting required schema and partition information. + * With this information, this creates ColumnarBatch with the full schema. + */ + public void initBatch( + TypeDescription orcSchema, + int[] requestedColIds, + StructField[] requiredFields, + StructType partitionSchema, + InternalRow partitionValues) { + batch = orcSchema.createRowBatch(DEFAULT_SIZE); + assert(!batch.selectedInUse); // `selectedInUse` should be initialized with `false`. + + this.requiredFields = requiredFields; + this.requestedColIds = requestedColIds; + assert(requiredFields.length == requestedColIds.length); + + StructType resultSchema = new StructType(requiredFields); + for (StructField f : partitionSchema.fields()) { + resultSchema = resultSchema.add(f); + } + + int capacity = DEFAULT_SIZE; + if (MEMORY_MODE == MemoryMode.OFF_HEAP) { + columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); + } else { + columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); + } + columnarBatch = new ColumnarBatch(resultSchema, columnVectors, capacity); + + if (partitionValues.numFields() > 0) { + int partitionIdx = requiredFields.length; + for (int i = 0; i < partitionValues.numFields(); i++) { + ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i); + columnVectors[i + partitionIdx].setIsConstant(); + } + } + + // Initialize the missing columns once. + for (int i = 0; i < requiredFields.length; i++) { + if (requestedColIds[i] == -1) { + columnVectors[i].putNulls(0, columnarBatch.capacity()); + columnVectors[i].setIsConstant(); + } + } + } + + /** + * Return true if there exists more data in the next batch. If exists, prepare the next batch + * by copying from ORC VectorizedRowBatch columns to Spark ColumnarBatch columns. + */ + private boolean nextBatch() throws IOException { + for (WritableColumnVector vector : columnVectors) { + vector.reset(); + } + columnarBatch.setNumRows(0); + + recordReader.nextBatch(batch); + int batchSize = batch.size; + if (batchSize == 0) { + return false; + } + columnarBatch.setNumRows(batchSize); + for (int i = 0; i < requiredFields.length; i++) { + StructField field = requiredFields[i]; + WritableColumnVector toColumn = columnVectors[i]; + + if (requestedColIds[i] >= 0) { + ColumnVector fromColumn = batch.cols[requestedColIds[i]]; + + if (fromColumn.isRepeating) { + putRepeatingValues(batchSize, field, fromColumn, toColumn); + } else if (fromColumn.noNulls) { + putNonNullValues(batchSize, field, fromColumn, toColumn); + } else { + putValues(batchSize, field, fromColumn, toColumn); + } + } + } + return true; + } + + private void putRepeatingValues( + int batchSize, + StructField field, + ColumnVector fromColumn, + WritableColumnVector toColumn) { + if (fromColumn.isNull[0]) { + toColumn.putNulls(0, batchSize); + } else { + DataType type = field.dataType(); + if (type instanceof BooleanType) { + toColumn.putBooleans(0, batchSize, ((LongColumnVector)fromColumn).vector[0] == 1); + } else if (type instanceof ByteType) { + toColumn.putBytes(0, batchSize, (byte)((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof ShortType) { + toColumn.putShorts(0, batchSize, (short)((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof IntegerType || type instanceof DateType) { + toColumn.putInts(0, batchSize, (int)((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof LongType) { + toColumn.putLongs(0, batchSize, ((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof TimestampType) { + toColumn.putLongs(0, batchSize, + fromTimestampColumnVector((TimestampColumnVector)fromColumn, 0)); + } else if (type instanceof FloatType) { + toColumn.putFloats(0, batchSize, (float)((DoubleColumnVector)fromColumn).vector[0]); + } else if (type instanceof DoubleType) { + toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector[0]); + } else if (type instanceof StringType || type instanceof BinaryType) { + BytesColumnVector data = (BytesColumnVector)fromColumn; + WritableColumnVector arrayData = toColumn.getChildColumn(0); + int size = data.vector[0].length; + arrayData.reserve(size); + arrayData.putBytes(0, size, data.vector[0], 0); + for (int index = 0; index < batchSize; index++) { + toColumn.putArray(index, 0, size); + } + } else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType)type; + putDecimalWritables( + toColumn, + batchSize, + decimalType.precision(), + decimalType.scale(), + ((DecimalColumnVector)fromColumn).vector[0]); + } else { + throw new UnsupportedOperationException("Unsupported Data Type: " + type); + } + } + } + + private void putNonNullValues( + int batchSize, + StructField field, + ColumnVector fromColumn, + WritableColumnVector toColumn) { + DataType type = field.dataType(); + if (type instanceof BooleanType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putBoolean(index, data[index] == 1); + } + } else if (type instanceof ByteType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putByte(index, (byte)data[index]); + } + } else if (type instanceof ShortType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putShort(index, (short)data[index]); + } + } else if (type instanceof IntegerType || type instanceof DateType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putInt(index, (int)data[index]); + } + } else if (type instanceof LongType) { + toColumn.putLongs(0, batchSize, ((LongColumnVector)fromColumn).vector, 0); + } else if (type instanceof TimestampType) { + TimestampColumnVector data = ((TimestampColumnVector)fromColumn); + for (int index = 0; index < batchSize; index++) { + toColumn.putLong(index, fromTimestampColumnVector(data, index)); + } + } else if (type instanceof FloatType) { + double[] data = ((DoubleColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putFloat(index, (float)data[index]); + } + } else if (type instanceof DoubleType) { + toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector, 0); + } else if (type instanceof StringType || type instanceof BinaryType) { + BytesColumnVector data = ((BytesColumnVector)fromColumn); + WritableColumnVector arrayData = toColumn.getChildColumn(0); + int totalNumBytes = IntStream.of(data.length).sum(); + arrayData.reserve(totalNumBytes); + for (int index = 0, pos = 0; index < batchSize; pos += data.length[index], index++) { + arrayData.putBytes(pos, data.length[index], data.vector[index], data.start[index]); + toColumn.putArray(index, pos, data.length[index]); + } + } else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType)type; + DecimalColumnVector data = ((DecimalColumnVector)fromColumn); + if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) { + WritableColumnVector arrayData = toColumn.getChildColumn(0); + arrayData.reserve(batchSize * 16); + } + for (int index = 0; index < batchSize; index++) { + putDecimalWritable( + toColumn, + index, + decimalType.precision(), + decimalType.scale(), + data.vector[index]); + } + } else { + throw new UnsupportedOperationException("Unsupported Data Type: " + type); + } + } + + private void putValues( + int batchSize, + StructField field, + ColumnVector fromColumn, + WritableColumnVector toColumn) { + DataType type = field.dataType(); + if (type instanceof BooleanType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putBoolean(index, vector[index] == 1); + } + } + } else if (type instanceof ByteType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putByte(index, (byte)vector[index]); + } + } + } else if (type instanceof ShortType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putShort(index, (short)vector[index]); + } + } + } else if (type instanceof IntegerType || type instanceof DateType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putInt(index, (int)vector[index]); + } + } + } else if (type instanceof LongType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putLong(index, vector[index]); + } + } + } else if (type instanceof TimestampType) { + TimestampColumnVector vector = ((TimestampColumnVector)fromColumn); + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putLong(index, fromTimestampColumnVector(vector, index)); + } + } + } else if (type instanceof FloatType) { + double[] vector = ((DoubleColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putFloat(index, (float)vector[index]); + } + } + } else if (type instanceof DoubleType) { + double[] vector = ((DoubleColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putDouble(index, vector[index]); + } + } + } else if (type instanceof StringType || type instanceof BinaryType) { + BytesColumnVector vector = (BytesColumnVector)fromColumn; + WritableColumnVector arrayData = toColumn.getChildColumn(0); + int totalNumBytes = IntStream.of(vector.length).sum(); + arrayData.reserve(totalNumBytes); + for (int index = 0, pos = 0; index < batchSize; pos += vector.length[index], index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + arrayData.putBytes(pos, vector.length[index], vector.vector[index], vector.start[index]); + toColumn.putArray(index, pos, vector.length[index]); + } + } + } else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType)type; + HiveDecimalWritable[] vector = ((DecimalColumnVector)fromColumn).vector; + if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) { + WritableColumnVector arrayData = toColumn.getChildColumn(0); + arrayData.reserve(batchSize * 16); + } + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + putDecimalWritable( + toColumn, + index, + decimalType.precision(), + decimalType.scale(), + vector[index]); + } + } + } else { + throw new UnsupportedOperationException("Unsupported Data Type: " + type); + } + } + + /** + * Returns the number of micros since epoch from an element of TimestampColumnVector. + */ + private static long fromTimestampColumnVector(TimestampColumnVector vector, int index) { + return vector.time[index] * 1000L + vector.nanos[index] / 1000L; + } + + /** + * Put a `HiveDecimalWritable` to a `WritableColumnVector`. + */ + private static void putDecimalWritable( + WritableColumnVector toColumn, + int index, + int precision, + int scale, + HiveDecimalWritable decimalWritable) { + HiveDecimal decimal = decimalWritable.getHiveDecimal(); + Decimal value = + Decimal.apply(decimal.bigDecimalValue(), decimal.precision(), decimal.scale()); + value.changePrecision(precision, scale); + + if (precision <= Decimal.MAX_INT_DIGITS()) { + toColumn.putInt(index, (int) value.toUnscaledLong()); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + toColumn.putLong(index, value.toUnscaledLong()); + } else { + byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray(); + WritableColumnVector arrayData = toColumn.getChildColumn(0); + arrayData.putBytes(index * 16, bytes.length, bytes, 0); + toColumn.putArray(index, index * 16, bytes.length); + } + } + + /** + * Put `HiveDecimalWritable`s to a `WritableColumnVector`. + */ + private static void putDecimalWritables( + WritableColumnVector toColumn, + int size, + int precision, + int scale, + HiveDecimalWritable decimalWritable) { + HiveDecimal decimal = decimalWritable.getHiveDecimal(); + Decimal value = + Decimal.apply(decimal.bigDecimalValue(), decimal.precision(), decimal.scale()); + value.changePrecision(precision, scale); + + if (precision <= Decimal.MAX_INT_DIGITS()) { + toColumn.putInts(0, size, (int) value.toUnscaledLong()); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + toColumn.putLongs(0, size, value.toUnscaledLong()); + } else { + byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray(); + WritableColumnVector arrayData = toColumn.getChildColumn(0); + arrayData.reserve(bytes.length); + arrayData.putBytes(0, bytes.length, bytes, 0); + for (int index = 0; index < size; index++) { + toColumn.putArray(index, 0, bytes.length); + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index f7471cd7debce..b8bacfa1838ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -118,6 +118,13 @@ class OrcFileFormat } } + override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { + val conf = sparkSession.sessionState.conf + conf.orcVectorizedReaderEnabled && conf.wholeStageEnabled && + schema.length <= conf.wholeStageMaxNumFields && + schema.forall(_.dataType.isInstanceOf[AtomicType]) + } + override def isSplitable( sparkSession: SparkSession, options: Map[String, String], @@ -139,6 +146,11 @@ class OrcFileFormat } } + val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields) + val sqlConf = sparkSession.sessionState.conf + val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled + val enableVectorizedReader = supportBatch(sparkSession, resultSchema) + val broadcastedConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis @@ -146,8 +158,14 @@ class OrcFileFormat (file: PartitionedFile) => { val conf = broadcastedConf.value.value + val filePath = new Path(new URI(file.filePath)) + + val fs = filePath.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + val reader = OrcFile.createReader(filePath, readerOptions) + val requestedColIdsOrEmptyFile = OrcUtils.requestedColumnIds( - isCaseSensitive, dataSchema, requiredSchema, new Path(new URI(file.filePath)), conf) + isCaseSensitive, dataSchema, requiredSchema, reader, conf) if (requestedColIdsOrEmptyFile.isEmpty) { Iterator.empty @@ -155,29 +173,46 @@ class OrcFileFormat val requestedColIds = requestedColIdsOrEmptyFile.get assert(requestedColIds.length == requiredSchema.length, "[BUG] requested column IDs do not match required schema") - conf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, + val taskConf = new Configuration(conf) + taskConf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, requestedColIds.filter(_ != -1).sorted.mkString(",")) - val fileSplit = - new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) + val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) - val taskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) - - val orcRecordReader = new OrcInputFormat[OrcStruct] - .createRecordReader(fileSplit, taskAttemptContext) - val iter = new RecordReaderIterator[OrcStruct](orcRecordReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) - - val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes - val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) - val deserializer = new OrcDeserializer(dataSchema, requiredSchema, requestedColIds) - - if (partitionSchema.length == 0) { - iter.map(value => unsafeProjection(deserializer.deserialize(value))) + val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) + + val taskContext = Option(TaskContext.get()) + if (enableVectorizedReader) { + val batchReader = + new OrcColumnarBatchReader(enableOffHeapColumnVector && taskContext.isDefined) + batchReader.initialize(fileSplit, taskAttemptContext) + batchReader.initBatch( + reader.getSchema, + requestedColIds, + requiredSchema.fields, + partitionSchema, + file.partitionValues) + + val iter = new RecordReaderIterator(batchReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + iter.asInstanceOf[Iterator[InternalRow]] } else { - val joinedRow = new JoinedRow() - iter.map(value => - unsafeProjection(joinedRow(deserializer.deserialize(value), file.partitionValues))) + val orcRecordReader = new OrcInputFormat[OrcStruct] + .createRecordReader(fileSplit, taskAttemptContext) + val iter = new RecordReaderIterator[OrcStruct](orcRecordReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + + val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes + val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) + val deserializer = new OrcDeserializer(dataSchema, requiredSchema, requestedColIds) + + if (partitionSchema.length == 0) { + iter.map(value => unsafeProjection(deserializer.deserialize(value))) + } else { + val joinedRow = new JoinedRow() + iter.map(value => + unsafeProjection(joinedRow(deserializer.deserialize(value), file.partitionValues))) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index b03ee06d04a16..13a23996f4ade 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.orc.{OrcFile, TypeDescription} +import org.apache.orc.{OrcFile, Reader, TypeDescription} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging @@ -80,11 +80,8 @@ object OrcUtils extends Logging { isCaseSensitive: Boolean, dataSchema: StructType, requiredSchema: StructType, - file: Path, + reader: Reader, conf: Configuration): Option[Array[Int]] = { - val fs = file.getFileSystem(conf) - val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) - val reader = OrcFile.createReader(file, readerOptions) val orcFieldNames = reader.getSchema.getFieldNames.asScala if (orcFieldNames.isEmpty) { // SPARK-8501: Some old empty ORC files always have an empty schema stored in their footer. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala new file mode 100644 index 0000000000000..37ed846acd1eb --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala @@ -0,0 +1,435 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File + +import scala.util.{Random, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.util.{Benchmark, Utils} + + +/** + * Benchmark to measure ORC read performance. + * + * This is in `sql/hive` module in order to compare `sql/core` and `sql/hive` ORC data sources. + */ +// scalastyle:off line.size.limit +object OrcReadBenchmark { + val conf = new SparkConf() + conf.set("orc.compression", "snappy") + + private val spark = SparkSession.builder() + .master("local[1]") + .appName("OrcReadBenchmark") + .config(conf) + .getOrCreate() + + // Set default configs. Individual cases will change them if necessary. + spark.conf.set(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key, "true") + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + + private val NATIVE_ORC_FORMAT = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName + private val HIVE_ORC_FORMAT = classOf[org.apache.spark.sql.hive.orc.OrcFileFormat].getCanonicalName + + private def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = { + val dirORC = dir.getCanonicalPath + + if (partition.isDefined) { + df.write.partitionBy(partition.get).orc(dirORC) + } else { + df.write.orc(dirORC) + } + + spark.read.format(NATIVE_ORC_FORMAT).load(dirORC).createOrReplaceTempView("nativeOrcTable") + spark.read.format(HIVE_ORC_FORMAT).load(dirORC).createOrReplaceTempView("hiveOrcTable") + } + + def numericScanBenchmark(values: Int, dataType: DataType): Unit = { + val sqlBenchmark = new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM t1")) + + sqlBenchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + } + + sqlBenchmark.addCase("Native ORC Vectorized") { _ => + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + + sqlBenchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT sum(id) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1192 / 1221 13.2 75.8 1.0X + Native ORC Vectorized 161 / 170 97.5 10.3 7.4X + Hive built-in ORC 1399 / 1413 11.2 89.0 0.9X + + SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1287 / 1333 12.2 81.8 1.0X + Native ORC Vectorized 164 / 172 95.6 10.5 7.8X + Hive built-in ORC 1629 / 1650 9.7 103.6 0.8X + + SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1304 / 1388 12.1 82.9 1.0X + Native ORC Vectorized 227 / 240 69.3 14.4 5.7X + Hive built-in ORC 1866 / 1867 8.4 118.6 0.7X + + SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1331 / 1357 11.8 84.6 1.0X + Native ORC Vectorized 289 / 297 54.4 18.4 4.6X + Hive built-in ORC 1922 / 1929 8.2 122.2 0.7X + + SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1410 / 1428 11.2 89.7 1.0X + Native ORC Vectorized 328 / 335 48.0 20.8 4.3X + Hive built-in ORC 1929 / 2012 8.2 122.6 0.7X + + SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1467 / 1485 10.7 93.3 1.0X + Native ORC Vectorized 402 / 411 39.1 25.6 3.6X + Hive built-in ORC 2023 / 2042 7.8 128.6 0.7X + */ + sqlBenchmark.run() + } + } + } + + def intStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Int and String Scan", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql("SELECT CAST(value AS INT) AS c1, CAST(value as STRING) AS c2 FROM t1")) + + benchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(c1), sum(length(c2)) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Native ORC Vectorized") { _ => + spark.sql("SELECT sum(c1), sum(length(c2)) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT sum(c1), sum(length(c2)) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 2729 / 2744 3.8 260.2 1.0X + Native ORC Vectorized 1318 / 1344 8.0 125.7 2.1X + Hive built-in ORC 3731 / 3782 2.8 355.8 0.7X + */ + benchmark.run() + } + } + } + + def partitionTableScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Partitioned Table", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT value % 2 AS p, value AS id FROM t1"), Some("p")) + + benchmark.addCase("Read data column - Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Read data column - Native ORC Vectorized") { _ => + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Read data column - Hive built-in ORC") { _ => + spark.sql("SELECT sum(id) FROM hiveOrcTable").collect() + } + + benchmark.addCase("Read partition column - Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Read partition column - Native ORC Vectorized") { _ => + spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Read partition column - Hive built-in ORC") { _ => + spark.sql("SELECT sum(p) FROM hiveOrcTable").collect() + } + + benchmark.addCase("Read both columns - Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Read both columns - Native ORC Vectorized") { _ => + spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Read both columns - Hive built-in ORC") { _ => + spark.sql("SELECT sum(p), sum(id) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Read data column - Native ORC MR 1531 / 1536 10.3 97.4 1.0X + Read data column - Native ORC Vectorized 295 / 298 53.3 18.8 5.2X + Read data column - Hive built-in ORC 2125 / 2126 7.4 135.1 0.7X + Read partition column - Native ORC MR 1049 / 1062 15.0 66.7 1.5X + Read partition column - Native ORC Vectorized 54 / 57 290.1 3.4 28.2X + Read partition column - Hive built-in ORC 1282 / 1291 12.3 81.5 1.2X + Read both columns - Native ORC MR 1594 / 1598 9.9 101.3 1.0X + Read both columns - Native ORC Vectorized 332 / 336 47.4 21.1 4.6X + Read both columns - Hive built-in ORC 2145 / 2187 7.3 136.4 0.7X + */ + benchmark.run() + } + } + } + + def repeatedStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Repeated String", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + spark.range(values).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT CAST((id % 200) + 10000 as STRING) AS c1 FROM t1")) + + benchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(length(c1)) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Native ORC Vectorized") { _ => + spark.sql("SELECT sum(length(c1)) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT sum(length(c1)) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1325 / 1328 7.9 126.4 1.0X + Native ORC Vectorized 320 / 330 32.8 30.5 4.1X + Hive built-in ORC 1971 / 1972 5.3 188.0 0.7X + */ + benchmark.run() + } + } + } + + def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + spark.range(values).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql( + s"SELECT IF(RAND(1) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c1, " + + s"IF(RAND(2) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c2 FROM t1")) + + val benchmark = new Benchmark(s"String with Nulls Scan ($fractionOfNulls%)", values) + + benchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT SUM(LENGTH(c2)) FROM nativeOrcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + } + + benchmark.addCase("Native ORC Vectorized") { _ => + spark.sql("SELECT SUM(LENGTH(c2)) FROM nativeOrcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT SUM(LENGTH(c2)) FROM hiveOrcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 2553 / 2554 4.1 243.4 1.0X + Native ORC Vectorized 953 / 954 11.0 90.9 2.7X + Hive built-in ORC 3875 / 3898 2.7 369.6 0.7X + + String with Nulls Scan (0.5%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 2389 / 2408 4.4 227.8 1.0X + Native ORC Vectorized 1208 / 1209 8.7 115.2 2.0X + Hive built-in ORC 2940 / 2952 3.6 280.4 0.8X + + String with Nulls Scan (0.95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1295 / 1311 8.1 123.5 1.0X + Native ORC Vectorized 449 / 457 23.4 42.8 2.9X + Hive built-in ORC 1649 / 1660 6.4 157.3 0.8X + */ + benchmark.run() + } + } + } + + def columnsBenchmark(values: Int, width: Int): Unit = { + val sqlBenchmark = new Benchmark(s"SQL Single Column Scan from $width columns", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + import spark.implicits._ + val middle = width / 2 + val selectExpr = (1 to width).map(i => s"value as c$i") + spark.range(values).map(_ => Random.nextLong).toDF() + .selectExpr(selectExpr: _*).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT * FROM t1")) + + sqlBenchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() + } + } + + sqlBenchmark.addCase("Native ORC Vectorized") { _ => + spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() + } + + sqlBenchmark.addCase("Hive built-in ORC") { _ => + spark.sql(s"SELECT sum(c$middle) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + SQL Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1103 / 1124 1.0 1052.0 1.0X + Native ORC Vectorized 92 / 100 11.4 87.9 12.0X + Hive built-in ORC 383 / 390 2.7 365.4 2.9X + + SQL Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 2245 / 2250 0.5 2141.0 1.0X + Native ORC Vectorized 157 / 165 6.7 150.2 14.3X + Hive built-in ORC 587 / 593 1.8 559.4 3.8X + + SQL Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 3343 / 3350 0.3 3188.3 1.0X + Native ORC Vectorized 265 / 280 3.9 253.2 12.6X + Hive built-in ORC 828 / 842 1.3 789.8 4.0X + */ + sqlBenchmark.run() + } + } + } + + def main(args: Array[String]): Unit = { + Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { dataType => + numericScanBenchmark(1024 * 1024 * 15, dataType) + } + intStringScanBenchmark(1024 * 1024 * 10) + partitionTableScanBenchmark(1024 * 1024 * 15) + repeatedStringScanBenchmark(1024 * 1024 * 10) + for (fractionOfNulls <- List(0.0, 0.50, 0.95)) { + stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls) + } + columnsBenchmark(1024 * 1024 * 1, 100) + columnsBenchmark(1024 * 1024 * 1, 200) + columnsBenchmark(1024 * 1024 * 1, 300) + } +} +// scalastyle:on line.size.limit From 2250cb75b99d257e698fe5418a51d8cddb4d5104 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 9 Jan 2018 21:58:55 +0800 Subject: [PATCH 051/774] [SPARK-22981][SQL] Fix incorrect results of Casting Struct to String ## What changes were proposed in this pull request? This pr fixed the issue when casting structs into strings; ``` scala> val df = Seq(((1, "a"), 0), ((2, "b"), 0)).toDF("a", "b") scala> df.write.saveAsTable("t") scala> sql("SELECT CAST(a AS STRING) FROM t").show +-------------------+ | a| +-------------------+ |[0,1,1800000001,61]| |[0,2,1800000001,62]| +-------------------+ ``` This pr modified the result into; ``` +------+ | a| +------+ |[1, a]| |[2, b]| +------+ ``` ## How was this patch tested? Added tests in `CastSuite`. Author: Takeshi Yamamuro Closes #20176 from maropu/SPARK-22981. --- .../spark/sql/catalyst/expressions/Cast.scala | 71 +++++++++++++++++++ .../sql/catalyst/expressions/CastSuite.scala | 16 +++++ 2 files changed, 87 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f2de4c8e30bec..f21aa1e9e3135 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -259,6 +259,29 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String builder.append("]") builder.build() }) + case StructType(fields) => + buildCast[InternalRow](_, row => { + val builder = new UTF8StringBuilder + builder.append("[") + if (row.numFields > 0) { + val st = fields.map(_.dataType) + val toUTF8StringFuncs = st.map(castToString) + if (!row.isNullAt(0)) { + builder.append(toUTF8StringFuncs(0)(row.get(0, st(0))).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < row.numFields) { + builder.append(",") + if (!row.isNullAt(i)) { + builder.append(" ") + builder.append(toUTF8StringFuncs(i)(row.get(i, st(i))).asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -732,6 +755,41 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """.stripMargin } + private def writeStructToStringBuilder( + st: Seq[DataType], + row: String, + buffer: String, + ctx: CodegenContext): String = { + val structToStringCode = st.zipWithIndex.map { case (ft, i) => + val fieldToStringCode = castToStringCode(ft, ctx) + val field = ctx.freshName("field") + val fieldStr = ctx.freshName("fieldStr") + s""" + |${if (i != 0) s"""$buffer.append(",");""" else ""} + |if (!$row.isNullAt($i)) { + | ${if (i != 0) s"""$buffer.append(" ");""" else ""} + | + | // Append $i field into the string buffer + | ${ctx.javaType(ft)} $field = ${ctx.getValue(row, ft, s"$i")}; + | UTF8String $fieldStr = null; + | ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)} + | $buffer.append($fieldStr); + |} + """.stripMargin + } + + val writeStructCode = ctx.splitExpressions( + expressions = structToStringCode, + funcName = "fieldToString", + arguments = ("InternalRow", row) :: (classOf[UTF8StringBuilder].getName, buffer) :: Nil) + + s""" + |$buffer.append("["); + |$writeStructCode + |$buffer.append("]"); + """.stripMargin + } + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => @@ -765,6 +823,19 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$evPrim = $buffer.build(); """.stripMargin } + case StructType(fields) => + (c, evPrim, evNull) => { + val row = ctx.freshName("row") + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val writeStructCode = writeStructToStringBuilder(fields.map(_.dataType), row, buffer, ctx) + s""" + |InternalRow $row = $c; + |$bufferClass $buffer = new $bufferClass(); + |$writeStructCode + |$evPrim = $buffer.build(); + """.stripMargin + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1445bb8a97d40..5b25bdf907c3a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -906,4 +906,20 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StringType) checkEvaluation(ret5, "[1 -> [1, 2, 3], 2 -> [4, 5, 6]]") } + + test("SPARK-22981 Cast struct to string") { + val ret1 = cast(Literal.create((1, "a", 0.1)), StringType) + checkEvaluation(ret1, "[1, a, 0.1]") + val ret2 = cast(Literal.create(Tuple3[Int, String, String](1, null, "a")), StringType) + checkEvaluation(ret2, "[1,, a]") + val ret3 = cast(Literal.create( + (Date.valueOf("2014-12-03"), Timestamp.valueOf("2014-12-03 15:05:00"))), StringType) + checkEvaluation(ret3, "[2014-12-03, 2014-12-03 15:05:00]") + val ret4 = cast(Literal.create(((1, "a"), 5, 0.1)), StringType) + checkEvaluation(ret4, "[[1, a], 5, 0.1]") + val ret5 = cast(Literal.create((Seq(1, 2, 3), "a", 0.1)), StringType) + checkEvaluation(ret5, "[[1, 2, 3], a, 0.1]") + val ret6 = cast(Literal.create((1, Map(1 -> "a", 2 -> "b", 3 -> "c"))), StringType) + checkEvaluation(ret6, "[1, [1 -> a, 2 -> b, 3 -> c]]") + } } From 96ba217a06fbe1dad703447d7058cb7841653861 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Wed, 10 Jan 2018 10:15:27 +0800 Subject: [PATCH 052/774] [SPARK-23005][CORE] Improve RDD.take on small number of partitions ## What changes were proposed in this pull request? In current implementation of RDD.take, we overestimate the number of partitions we need to try by 50%: `(1.5 * num * partsScanned / buf.size).toInt` However, when the number is small, the result of `.toInt` is not what we want. E.g, 2.9 will become 2, which should be 3. Use Math.ceil to fix the problem. Also clean up the code in RDD.scala. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #20200 from gengliangwang/Take. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 27 +++++++++---------- .../spark/sql/execution/SparkPlan.scala | 5 ++-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 8798dfc925362..7859781e98223 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -150,7 +150,7 @@ abstract class RDD[T: ClassTag]( val id: Int = sc.newRddId() /** A friendly name for this RDD */ - @transient var name: String = null + @transient var name: String = _ /** Assign a name to this RDD */ def setName(_name: String): this.type = { @@ -224,8 +224,8 @@ abstract class RDD[T: ClassTag]( // Our dependencies and partitions will be gotten by calling subclass's methods below, and will // be overwritten when we're checkpointed - private var dependencies_ : Seq[Dependency[_]] = null - @transient private var partitions_ : Array[Partition] = null + private var dependencies_ : Seq[Dependency[_]] = _ + @transient private var partitions_ : Array[Partition] = _ /** An Option holding our checkpoint RDD, if we are checkpointed */ private def checkpointRDD: Option[CheckpointRDD[T]] = checkpointData.flatMap(_.checkpointRDD) @@ -297,7 +297,7 @@ abstract class RDD[T: ClassTag]( private[spark] def getNarrowAncestors: Seq[RDD[_]] = { val ancestors = new mutable.HashSet[RDD[_]] - def visit(rdd: RDD[_]) { + def visit(rdd: RDD[_]): Unit = { val narrowDependencies = rdd.dependencies.filter(_.isInstanceOf[NarrowDependency[_]]) val narrowParents = narrowDependencies.map(_.rdd) val narrowParentsNotVisited = narrowParents.filterNot(ancestors.contains) @@ -449,7 +449,7 @@ abstract class RDD[T: ClassTag]( if (shuffle) { /** Distributes elements evenly across output partitions, starting from a random partition. */ val distributePartition = (index: Int, items: Iterator[T]) => { - var position = (new Random(hashing.byteswap32(index))).nextInt(numPartitions) + var position = new Random(hashing.byteswap32(index)).nextInt(numPartitions) items.map { t => // Note that the hash code of the key will just be the key itself. The HashPartitioner // will mod it with the number of total partitions. @@ -951,7 +951,7 @@ abstract class RDD[T: ClassTag]( def collectPartition(p: Int): Array[T] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p)).head } - (0 until partitions.length).iterator.flatMap(i => collectPartition(i)) + partitions.indices.iterator.flatMap(i => collectPartition(i)) } /** @@ -1338,6 +1338,7 @@ abstract class RDD[T: ClassTag]( // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1L + val left = num - buf.size if (partsScanned > 0) { // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate @@ -1345,13 +1346,12 @@ abstract class RDD[T: ClassTag]( if (buf.isEmpty) { numPartsToTry = partsScanned * scaleUpFactor } else { - // the left side of max is >=1 whenever partsScanned >= 2 - numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1) + // As left > 0, numPartsToTry is always >= 1 + numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.size).toInt numPartsToTry = Math.min(numPartsToTry, partsScanned * scaleUpFactor) } } - val left = num - buf.size val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p) @@ -1677,8 +1677,7 @@ abstract class RDD[T: ClassTag]( // an RDD and its parent in every batch, in which case the parent may never be checkpointed // and its lineage never truncated, leading to OOMs in the long run (SPARK-6847). private val checkpointAllMarkedAncestors = - Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)) - .map(_.toBoolean).getOrElse(false) + Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)).exists(_.toBoolean) /** Returns the first parent RDD */ protected[spark] def firstParent[U: ClassTag]: RDD[U] = { @@ -1686,7 +1685,7 @@ abstract class RDD[T: ClassTag]( } /** Returns the jth parent RDD: e.g. rdd.parent[T](0) is equivalent to rdd.firstParent[T] */ - protected[spark] def parent[U: ClassTag](j: Int) = { + protected[spark] def parent[U: ClassTag](j: Int): RDD[U] = { dependencies(j).rdd.asInstanceOf[RDD[U]] } @@ -1754,7 +1753,7 @@ abstract class RDD[T: ClassTag]( * collected. Subclasses of RDD may override this method for implementing their own cleaning * logic. See [[org.apache.spark.rdd.UnionRDD]] for an example. */ - protected def clearDependencies() { + protected def clearDependencies(): Unit = { dependencies_ = null } @@ -1790,7 +1789,7 @@ abstract class RDD[T: ClassTag]( val lastDepStrings = debugString(lastDep.rdd, prefix, lastDep.isInstanceOf[ShuffleDependency[_, _, _]], true) - (frontDepStrings ++ lastDepStrings) + frontDepStrings ++ lastDepStrings } } // The first RDD in the dependency stack has no parents, so no need for a +- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 82300efc01632..398758a3331b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -351,8 +351,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ if (buf.isEmpty) { numPartsToTry = partsScanned * limitScaleUpFactor } else { - // the left side of max is >=1 whenever partsScanned >= 2 - numPartsToTry = Math.max((1.5 * n * partsScanned / buf.size).toInt - partsScanned, 1) + val left = n - buf.size + // As left > 0, numPartsToTry is always >= 1 + numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.size).toInt numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor) } } From 6f169ca9e1444fe8fd1ab6f3fbf0a8be1670f1b5 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 10 Jan 2018 10:20:34 +0800 Subject: [PATCH 053/774] [MINOR] fix a typo in BroadcastJoinSuite ## What changes were proposed in this pull request? `BroadcastNestedLoopJoinExec` should be `BroadcastHashJoinExec` ## How was this patch tested? N/A Author: Wenchen Fan Closes #20202 from cloud-fan/typo. --- .../apache/spark/sql/execution/joins/BroadcastJoinSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 6da46ea3480b3..0bcd54e1fceab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -318,7 +318,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { case b: BroadcastNestedLoopJoinExec => assert(b.getClass.getSimpleName === joinMethod) assert(b.buildSide === buildSide) - case b: BroadcastNestedLoopJoinExec => + case b: BroadcastHashJoinExec => assert(b.getClass.getSimpleName === joinMethod) assert(b.buildSide === buildSide) case w: WholeStageCodegenExec => From 7bcc2666810cefc85dfa0d6679ac7a0de9e23154 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 10 Jan 2018 14:00:07 +0900 Subject: [PATCH 054/774] [SPARK-23018][PYTHON] Fix createDataFrame from Pandas timestamp series assignment ## What changes were proposed in this pull request? This fixes createDataFrame from Pandas to only assign modified timestamp series back to a copied version of the Pandas DataFrame. Previously, if the Pandas DataFrame was only a reference (e.g. a slice of another) each series will still get assigned back to the reference even if it is not a modified timestamp column. This caused the following warning "SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame." ## How was this patch tested? existing tests Author: Bryan Cutler Closes #20213 from BryanCutler/pyspark-createDataFrame-copy-slice-warn-SPARK-23018. --- python/pyspark/sql/session.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 6052fa9e84096..3e4574729a631 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -459,21 +459,23 @@ def _convert_from_pandas(self, pdf, schema, timezone): # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if isinstance(field.dataType, TimestampType): s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone) - if not copied and s is not pdf[field.name]: - # Copy once if the series is modified to prevent the original Pandas - # DataFrame from being updated - pdf = pdf.copy() - copied = True - pdf[field.name] = s + if s is not pdf[field.name]: + if not copied: + # Copy once if the series is modified to prevent the original + # Pandas DataFrame from being updated + pdf = pdf.copy() + copied = True + pdf[field.name] = s else: for column, series in pdf.iteritems(): - s = _check_series_convert_timestamps_tz_local(pdf[column], timezone) - if not copied and s is not pdf[column]: - # Copy once if the series is modified to prevent the original Pandas - # DataFrame from being updated - pdf = pdf.copy() - copied = True - pdf[column] = s + s = _check_series_convert_timestamps_tz_local(series, timezone) + if s is not series: + if not copied: + # Copy once if the series is modified to prevent the original + # Pandas DataFrame from being updated + pdf = pdf.copy() + copied = True + pdf[column] = s # Convert pandas.DataFrame to list of numpy records np_records = pdf.to_records(index=False) From e5998372487af20114e160264a594957344ff433 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 10 Jan 2018 14:55:24 +0900 Subject: [PATCH 055/774] [SPARK-23009][PYTHON] Fix for non-str col names to createDataFrame from Pandas ## What changes were proposed in this pull request? This the case when calling `SparkSession.createDataFrame` using a Pandas DataFrame that has non-str column labels. The column name conversion logic to handle non-string or unicode in python2 is: ``` if column is not any type of string: name = str(column) else if column is unicode in Python 2: name = column.encode('utf-8') ``` ## How was this patch tested? Added a new test with a Pandas DataFrame that has int column labels Author: Bryan Cutler Closes #20210 from BryanCutler/python-createDataFrame-int-col-error-SPARK-23009. --- python/pyspark/sql/session.py | 4 +++- python/pyspark/sql/tests.py | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 3e4574729a631..604021c1f45cc 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -648,7 +648,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr # If no schema supplied by user then get the names of columns only if schema is None: - schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in data.columns] + schema = [str(x) if not isinstance(x, basestring) else + (x.encode('utf-8') if not isinstance(x, str) else x) + for x in data.columns] if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \ and len(data) > 0: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 13576ff57001b..80a94a91a87b3 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3532,6 +3532,15 @@ def test_toPandas_with_array_type(self): self.assertTrue(expected[r][e] == result_arrow[r][e] and result[r][e] == result_arrow[r][e]) + def test_createDataFrame_with_int_col_names(self): + import numpy as np + import pandas as pd + pdf = pd.DataFrame(np.random.rand(4, 2)) + df, df_arrow = self._createDataFrame_toggle(pdf) + pdf_col_names = [str(c) for c in pdf.columns] + self.assertEqual(pdf_col_names, df.columns) + self.assertEqual(pdf_col_names, df_arrow.columns) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class PandasUDFTests(ReusedSQLTestCase): From edf0a48c2ec696b92ed6a96dcee6eeb1a046b20b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 10 Jan 2018 15:01:11 +0800 Subject: [PATCH 056/774] [SPARK-22982] Remove unsafe asynchronous close() call from FileDownloadChannel ## What changes were proposed in this pull request? This patch fixes a severe asynchronous IO bug in Spark's Netty-based file transfer code. At a high-level, the problem is that an unsafe asynchronous `close()` of a pipe's source channel creates a race condition where file transfer code closes a file descriptor then attempts to read from it. If the closed file descriptor's number has been reused by an `open()` call then this invalid read may cause unrelated file operations to return incorrect results. **One manifestation of this problem is incorrect query results.** For a high-level overview of how file download works, take a look at the control flow in `NettyRpcEnv.openChannel()`: this code creates a pipe to buffer results, then submits an asynchronous stream request to a lower-level TransportClient. The callback passes received data to the sink end of the pipe. The source end of the pipe is passed back to the caller of `openChannel()`. Thus `openChannel()` returns immediately and callers interact with the returned pipe source channel. Because the underlying stream request is asynchronous, errors may occur after `openChannel()` has returned and after that method's caller has started to `read()` from the returned channel. For example, if a client requests an invalid stream from a remote server then the "stream does not exist" error may not be received from the remote server until after `openChannel()` has returned. In order to be able to propagate the "stream does not exist" error to the file-fetching application thread, this code wraps the pipe's source channel in a special `FileDownloadChannel` which adds an `setError(t: Throwable)` method, then calls this `setError()` method in the FileDownloadCallback's `onFailure` method. It is possible for `FileDownloadChannel`'s `read()` and `setError()` methods to be called concurrently from different threads: the `setError()` method is called from within the Netty RPC system's stream callback handlers, while the `read()` methods are called from higher-level application code performing remote stream reads. The problem lies in `setError()`: the existing code closed the wrapped pipe source channel. Because `read()` and `setError()` occur in different threads, this means it is possible for one thread to be calling `source.read()` while another asynchronously calls `source.close()`. Java's IO libraries do not guarantee that this will be safe and, in fact, it's possible for these operations to interleave in such a way that a lower-level `read()` system call occurs right after a `close()` call. In the best-case, this fails as a read of a closed file descriptor; in the worst-case, the file descriptor number has been re-used by an intervening `open()` operation and the read corrupts the result of an unrelated file IO operation being performed by a different thread. The solution here is to remove the `stream.close()` call in `onError()`: the thread that is performing the `read()` calls is responsible for closing the stream in a `finally` block, so there's no need to close it here. If that thread is blocked in a `read()` then it will become unblocked when the sink end of the pipe is closed in `FileDownloadCallback.onFailure()`. After making this change, we also need to refine the `read()` method to always check for a `setError()` result, even if the underlying channel `read()` call has succeeded. This patch also makes a slight cleanup to a dodgy-looking `catch e: Exception` block to use a safer `try-finally` error handling idiom. This bug was introduced in SPARK-11956 / #9941 and is present in Spark 1.6.0+. ## How was this patch tested? This fix was tested manually against a workload which non-deterministically hit this bug. Author: Josh Rosen Closes #20179 from JoshRosen/SPARK-22982-fix-unsafe-async-io-in-file-download-channel. --- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 37 +++++++++++-------- .../shuffle/IndexShuffleBlockResolver.scala | 21 +++++++++-- 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index f951591e02a5c..a2936d6ad539c 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -332,16 +332,14 @@ private[netty] class NettyRpcEnv( val pipe = Pipe.open() val source = new FileDownloadChannel(pipe.source()) - try { + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { val client = downloadClient(parsedUri.getHost(), parsedUri.getPort()) val callback = new FileDownloadCallback(pipe.sink(), source, client) client.stream(parsedUri.getPath(), callback) - } catch { - case e: Exception => - pipe.sink().close() - source.close() - throw e - } + })(catchBlock = { + pipe.sink().close() + source.close() + }) source } @@ -370,24 +368,33 @@ private[netty] class NettyRpcEnv( fileDownloadFactory.createClient(host, port) } - private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel { + private class FileDownloadChannel(source: Pipe.SourceChannel) extends ReadableByteChannel { @volatile private var error: Throwable = _ def setError(e: Throwable): Unit = { + // This setError callback is invoked by internal RPC threads in order to propagate remote + // exceptions to application-level threads which are reading from this channel. When an + // RPC error occurs, the RPC system will call setError() and then will close the + // Pipe.SinkChannel corresponding to the other end of the `source` pipe. Closing of the pipe + // sink will cause `source.read()` operations to return EOF, unblocking the application-level + // reading thread. Thus there is no need to actually call `source.close()` here in the + // onError() callback and, in fact, calling it here would be dangerous because the close() + // would be asynchronous with respect to the read() call and could trigger race-conditions + // that lead to data corruption. See the PR for SPARK-22982 for more details on this topic. error = e - source.close() } override def read(dst: ByteBuffer): Int = { Try(source.read(dst)) match { + // See the documentation above in setError(): if an RPC error has occurred then setError() + // will be called to propagate the RPC error and then `source`'s corresponding + // Pipe.SinkChannel will be closed, unblocking this read. In that case, we want to propagate + // the remote RPC exception (and not any exceptions triggered by the pipe close, such as + // ChannelClosedException), hence this `error != null` check: + case _ if error != null => throw error case Success(bytesRead) => bytesRead - case Failure(readErr) => - if (error != null) { - throw error - } else { - throw readErr - } + case Failure(readErr) => throw readErr } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 15540485170d0..266ee42e39cca 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -18,8 +18,8 @@ package org.apache.spark.shuffle import java.io._ - -import com.google.common.io.ByteStreams +import java.nio.channels.Channels +import java.nio.file.Files import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.internal.Logging @@ -196,11 +196,24 @@ private[spark] class IndexShuffleBlockResolver( // find out the consolidated file, then the offset within that from our index val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) - val in = new DataInputStream(new FileInputStream(indexFile)) + // SPARK-22982: if this FileInputStream's position is seeked forward by another piece of code + // which is incorrectly using our file descriptor then this code will fetch the wrong offsets + // (which may cause a reducer to be sent a different reducer's data). The explicit position + // checks added here were a useful debugging aid during SPARK-22982 and may help prevent this + // class of issue from re-occurring in the future which is why they are left here even though + // SPARK-22982 is fixed. + val channel = Files.newByteChannel(indexFile.toPath) + channel.position(blockId.reduceId * 8) + val in = new DataInputStream(Channels.newInputStream(channel)) try { - ByteStreams.skipFully(in, blockId.reduceId * 8) val offset = in.readLong() val nextOffset = in.readLong() + val actualPosition = channel.position() + val expectedPosition = blockId.reduceId * 8 + 16 + if (actualPosition != expectedPosition) { + throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " + + s"expected $expectedPosition but actual position was $actualPosition.") + } new FileSegmentManagedBuffer( transportConf, getDataFile(blockId.shuffleId, blockId.mapId), From eaac60a1e20e29084b7151ffca964cfaa5ba99d1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 10 Jan 2018 15:16:27 +0800 Subject: [PATCH 057/774] [SPARK-16060][SQL][FOLLOW-UP] add a wrapper solution for vectorized orc reader ## What changes were proposed in this pull request? This is mostly from https://github.com/apache/spark/pull/13775 The wrapper solution is pretty good for string/binary type, as the ORC column vector doesn't keep bytes in a continuous memory region, and has a significant overhead when copying the data to Spark columnar batch. For other cases, the wrapper solution is almost same with the current solution. I think we can treat the wrapper solution as a baseline and keep improving the writing to Spark solution. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #20205 from cloud-fan/orc. --- .../apache/spark/sql/internal/SQLConf.scala | 7 + .../datasources/orc/OrcColumnVector.java | 251 ++++++++++++++++++ .../orc/OrcColumnarBatchReader.java | 106 ++++++-- .../datasources/orc/OrcFileFormat.scala | 6 +- .../spark/sql/hive/orc/OrcReadBenchmark.scala | 236 ++++++++++------ 5 files changed, 490 insertions(+), 116 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 74949db883f7a..36e802a9faa6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -391,6 +391,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ORC_COPY_BATCH_TO_SPARK = buildConf("spark.sql.orc.copyBatchToSpark") + .doc("Whether or not to copy the ORC columnar batch to Spark columnar batch in the " + + "vectorized ORC reader.") + .internal() + .booleanConf + .createWithDefault(false) + val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java new file mode 100644 index 0000000000000..f94c55d860304 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import java.math.BigDecimal; + +import org.apache.orc.storage.ql.exec.vector.*; + +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column vector class wrapping Hive's ColumnVector. Because Spark ColumnarBatch only accepts + * Spark's vectorized.ColumnVector, this column vector is used to adapt Hive ColumnVector with + * Spark ColumnarVector. + */ +public class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVector { + private ColumnVector baseData; + private LongColumnVector longData; + private DoubleColumnVector doubleData; + private BytesColumnVector bytesData; + private DecimalColumnVector decimalData; + private TimestampColumnVector timestampData; + final private boolean isTimestamp; + + private int batchSize; + + OrcColumnVector(DataType type, ColumnVector vector) { + super(type); + + if (type instanceof TimestampType) { + isTimestamp = true; + } else { + isTimestamp = false; + } + + baseData = vector; + if (vector instanceof LongColumnVector) { + longData = (LongColumnVector) vector; + } else if (vector instanceof DoubleColumnVector) { + doubleData = (DoubleColumnVector) vector; + } else if (vector instanceof BytesColumnVector) { + bytesData = (BytesColumnVector) vector; + } else if (vector instanceof DecimalColumnVector) { + decimalData = (DecimalColumnVector) vector; + } else if (vector instanceof TimestampColumnVector) { + timestampData = (TimestampColumnVector) vector; + } else { + throw new UnsupportedOperationException(); + } + } + + public void setBatchSize(int batchSize) { + this.batchSize = batchSize; + } + + @Override + public void close() { + + } + + @Override + public int numNulls() { + if (baseData.isRepeating) { + if (baseData.isNull[0]) { + return batchSize; + } else { + return 0; + } + } else if (baseData.noNulls) { + return 0; + } else { + int count = 0; + for (int i = 0; i < batchSize; i++) { + if (baseData.isNull[i]) count++; + } + return count; + } + } + + /* A helper method to get the row index in a column. */ + private int getRowIndex(int rowId) { + return baseData.isRepeating ? 0 : rowId; + } + + @Override + public boolean isNullAt(int rowId) { + return baseData.isNull[getRowIndex(rowId)]; + } + + @Override + public boolean getBoolean(int rowId) { + return longData.vector[getRowIndex(rowId)] == 1; + } + + @Override + public boolean[] getBooleans(int rowId, int count) { + boolean[] res = new boolean[count]; + for (int i = 0; i < count; i++) { + res[i] = getBoolean(rowId + i); + } + return res; + } + + @Override + public byte getByte(int rowId) { + return (byte) longData.vector[getRowIndex(rowId)]; + } + + @Override + public byte[] getBytes(int rowId, int count) { + byte[] res = new byte[count]; + for (int i = 0; i < count; i++) { + res[i] = getByte(rowId + i); + } + return res; + } + + @Override + public short getShort(int rowId) { + return (short) longData.vector[getRowIndex(rowId)]; + } + + @Override + public short[] getShorts(int rowId, int count) { + short[] res = new short[count]; + for (int i = 0; i < count; i++) { + res[i] = getShort(rowId + i); + } + return res; + } + + @Override + public int getInt(int rowId) { + return (int) longData.vector[getRowIndex(rowId)]; + } + + @Override + public int[] getInts(int rowId, int count) { + int[] res = new int[count]; + for (int i = 0; i < count; i++) { + res[i] = getInt(rowId + i); + } + return res; + } + + @Override + public long getLong(int rowId) { + int index = getRowIndex(rowId); + if (isTimestamp) { + return timestampData.time[index] * 1000 + timestampData.nanos[index] / 1000; + } else { + return longData.vector[index]; + } + } + + @Override + public long[] getLongs(int rowId, int count) { + long[] res = new long[count]; + for (int i = 0; i < count; i++) { + res[i] = getLong(rowId + i); + } + return res; + } + + @Override + public float getFloat(int rowId) { + return (float) doubleData.vector[getRowIndex(rowId)]; + } + + @Override + public float[] getFloats(int rowId, int count) { + float[] res = new float[count]; + for (int i = 0; i < count; i++) { + res[i] = getFloat(rowId + i); + } + return res; + } + + @Override + public double getDouble(int rowId) { + return doubleData.vector[getRowIndex(rowId)]; + } + + @Override + public double[] getDoubles(int rowId, int count) { + double[] res = new double[count]; + for (int i = 0; i < count; i++) { + res[i] = getDouble(rowId + i); + } + return res; + } + + @Override + public int getArrayLength(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public int getArrayOffset(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + BigDecimal data = decimalData.vector[getRowIndex(rowId)].getHiveDecimal().bigDecimalValue(); + return Decimal.apply(data, precision, scale); + } + + @Override + public UTF8String getUTF8String(int rowId) { + int index = getRowIndex(rowId); + BytesColumnVector col = bytesData; + return UTF8String.fromBytes(col.vector[index], col.start[index], col.length[index]); + } + + @Override + public byte[] getBinary(int rowId) { + int index = getRowIndex(rowId); + byte[] binary = new byte[bytesData.length[index]]; + System.arraycopy(bytesData.vector[index], bytesData.start[index], binary, 0, binary.length); + return binary; + } + + @Override + public org.apache.spark.sql.vectorized.ColumnVector arrayData() { + throw new UnsupportedOperationException(); + } + + @Override + public org.apache.spark.sql.vectorized.ColumnVector getChildColumn(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index 5c28d0e6e507a..36fdf2bdf84d2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -51,13 +51,13 @@ public class OrcColumnarBatchReader extends RecordReader { /** - * The default size of batch. We use this value for both ORC and Spark consistently - * because they have different default values like the following. + * The default size of batch. We use this value for ORC reader to make it consistent with Spark's + * columnar batch, because their default batch sizes are different like the following: * * - ORC's VectorizedRowBatch.DEFAULT_SIZE = 1024 * - Spark's ColumnarBatch.DEFAULT_BATCH_SIZE = 4 * 1024 */ - public static final int DEFAULT_SIZE = 4 * 1024; + private static final int DEFAULT_SIZE = 4 * 1024; // ORC File Reader private Reader reader; @@ -82,13 +82,18 @@ public class OrcColumnarBatchReader extends RecordReader { // Writable column vectors of the result columnar batch. private WritableColumnVector[] columnVectors; - /** - * The memory mode of the columnarBatch - */ + // The wrapped ORC column vectors. It should be null if `copyToSpark` is true. + private org.apache.spark.sql.vectorized.ColumnVector[] orcVectorWrappers; + + // The memory mode of the columnarBatch private final MemoryMode MEMORY_MODE; - public OrcColumnarBatchReader(boolean useOffHeap) { + // Whether or not to copy the ORC columnar batch to Spark columnar batch. + private final boolean copyToSpark; + + public OrcColumnarBatchReader(boolean useOffHeap, boolean copyToSpark) { MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; + this.copyToSpark = copyToSpark; } @@ -167,27 +172,61 @@ public void initBatch( } int capacity = DEFAULT_SIZE; - if (MEMORY_MODE == MemoryMode.OFF_HEAP) { - columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); - } else { - columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); - } - columnarBatch = new ColumnarBatch(resultSchema, columnVectors, capacity); - if (partitionValues.numFields() > 0) { - int partitionIdx = requiredFields.length; - for (int i = 0; i < partitionValues.numFields(); i++) { - ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i); - columnVectors[i + partitionIdx].setIsConstant(); + if (copyToSpark) { + if (MEMORY_MODE == MemoryMode.OFF_HEAP) { + columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); + } else { + columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); } - } - // Initialize the missing columns once. - for (int i = 0; i < requiredFields.length; i++) { - if (requestedColIds[i] == -1) { - columnVectors[i].putNulls(0, columnarBatch.capacity()); - columnVectors[i].setIsConstant(); + // Initialize the missing columns once. + for (int i = 0; i < requiredFields.length; i++) { + if (requestedColIds[i] == -1) { + columnVectors[i].putNulls(0, capacity); + columnVectors[i].setIsConstant(); + } + } + + if (partitionValues.numFields() > 0) { + int partitionIdx = requiredFields.length; + for (int i = 0; i < partitionValues.numFields(); i++) { + ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i); + columnVectors[i + partitionIdx].setIsConstant(); + } + } + + columnarBatch = new ColumnarBatch(resultSchema, columnVectors, capacity); + } else { + // Just wrap the ORC column vector instead of copying it to Spark column vector. + orcVectorWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()]; + + for (int i = 0; i < requiredFields.length; i++) { + DataType dt = requiredFields[i].dataType(); + int colId = requestedColIds[i]; + // Initialize the missing columns once. + if (colId == -1) { + OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); + missingCol.putNulls(0, capacity); + missingCol.setIsConstant(); + orcVectorWrappers[i] = missingCol; + } else { + orcVectorWrappers[i] = new OrcColumnVector(dt, batch.cols[colId]); + } } + + if (partitionValues.numFields() > 0) { + int partitionIdx = requiredFields.length; + for (int i = 0; i < partitionValues.numFields(); i++) { + DataType dt = partitionSchema.fields()[i].dataType(); + OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt); + ColumnVectorUtils.populate(partitionCol, partitionValues, i); + partitionCol.setIsConstant(); + orcVectorWrappers[partitionIdx + i] = partitionCol; + } + } + + columnarBatch = new ColumnarBatch(resultSchema, orcVectorWrappers, capacity); } } @@ -196,17 +235,26 @@ public void initBatch( * by copying from ORC VectorizedRowBatch columns to Spark ColumnarBatch columns. */ private boolean nextBatch() throws IOException { - for (WritableColumnVector vector : columnVectors) { - vector.reset(); - } - columnarBatch.setNumRows(0); - recordReader.nextBatch(batch); int batchSize = batch.size; if (batchSize == 0) { return false; } columnarBatch.setNumRows(batchSize); + + if (!copyToSpark) { + for (int i = 0; i < requiredFields.length; i++) { + if (requestedColIds[i] != -1) { + ((OrcColumnVector) orcVectorWrappers[i]).setBatchSize(batchSize); + } + } + return true; + } + + for (WritableColumnVector vector : columnVectors) { + vector.reset(); + } + for (int i = 0; i < requiredFields.length; i++) { StructField field = requiredFields[i]; WritableColumnVector toColumn = columnVectors[i]; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index b8bacfa1838ae..2dd314d165348 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration @@ -150,6 +151,7 @@ class OrcFileFormat val sqlConf = sparkSession.sessionState.conf val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled val enableVectorizedReader = supportBatch(sparkSession, resultSchema) + val copyToSpark = sparkSession.sessionState.conf.getConf(SQLConf.ORC_COPY_BATCH_TO_SPARK) val broadcastedConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) @@ -183,8 +185,8 @@ class OrcFileFormat val taskContext = Option(TaskContext.get()) if (enableVectorizedReader) { - val batchReader = - new OrcColumnarBatchReader(enableOffHeapColumnVector && taskContext.isDefined) + val batchReader = new OrcColumnarBatchReader( + enableOffHeapColumnVector && taskContext.isDefined, copyToSpark) batchReader.initialize(fileSplit, taskAttemptContext) batchReader.initBatch( reader.getSchema, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala index 37ed846acd1eb..bf6efa7c4c08c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala @@ -86,7 +86,7 @@ object OrcReadBenchmark { } def numericScanBenchmark(values: Int, dataType: DataType): Unit = { - val sqlBenchmark = new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values) + val benchmark = new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values) withTempPath { dir => withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { @@ -95,61 +95,73 @@ object OrcReadBenchmark { prepareTable(dir, spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM t1")) - sqlBenchmark.addCase("Native ORC MR") { _ => + benchmark.addCase("Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() } } - sqlBenchmark.addCase("Native ORC Vectorized") { _ => + benchmark.addCase("Native ORC Vectorized") { _ => spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() } - sqlBenchmark.addCase("Hive built-in ORC") { _ => + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Hive built-in ORC") { _ => spark.sql("SELECT sum(id) FROM hiveOrcTable").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1192 / 1221 13.2 75.8 1.0X - Native ORC Vectorized 161 / 170 97.5 10.3 7.4X - Hive built-in ORC 1399 / 1413 11.2 89.0 0.9X + Native ORC MR 1135 / 1171 13.9 72.2 1.0X + Native ORC Vectorized 152 / 163 103.4 9.7 7.5X + Native ORC Vectorized with copy 149 / 162 105.4 9.5 7.6X + Hive built-in ORC 1380 / 1384 11.4 87.7 0.8X SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1287 / 1333 12.2 81.8 1.0X - Native ORC Vectorized 164 / 172 95.6 10.5 7.8X - Hive built-in ORC 1629 / 1650 9.7 103.6 0.8X + Native ORC MR 1182 / 1244 13.3 75.2 1.0X + Native ORC Vectorized 145 / 156 108.7 9.2 8.2X + Native ORC Vectorized with copy 148 / 158 106.4 9.4 8.0X + Hive built-in ORC 1591 / 1636 9.9 101.2 0.7X SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1304 / 1388 12.1 82.9 1.0X - Native ORC Vectorized 227 / 240 69.3 14.4 5.7X - Hive built-in ORC 1866 / 1867 8.4 118.6 0.7X + Native ORC MR 1271 / 1271 12.4 80.8 1.0X + Native ORC Vectorized 206 / 212 76.3 13.1 6.2X + Native ORC Vectorized with copy 200 / 213 78.8 12.7 6.4X + Hive built-in ORC 1776 / 1787 8.9 112.9 0.7X SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1331 / 1357 11.8 84.6 1.0X - Native ORC Vectorized 289 / 297 54.4 18.4 4.6X - Hive built-in ORC 1922 / 1929 8.2 122.2 0.7X + Native ORC MR 1344 / 1355 11.7 85.4 1.0X + Native ORC Vectorized 258 / 268 61.0 16.4 5.2X + Native ORC Vectorized with copy 252 / 257 62.4 16.0 5.3X + Hive built-in ORC 1818 / 1823 8.7 115.6 0.7X SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1410 / 1428 11.2 89.7 1.0X - Native ORC Vectorized 328 / 335 48.0 20.8 4.3X - Hive built-in ORC 1929 / 2012 8.2 122.6 0.7X + Native ORC MR 1333 / 1352 11.8 84.8 1.0X + Native ORC Vectorized 310 / 324 50.7 19.7 4.3X + Native ORC Vectorized with copy 312 / 320 50.4 19.9 4.3X + Hive built-in ORC 1904 / 1918 8.3 121.0 0.7X SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1467 / 1485 10.7 93.3 1.0X - Native ORC Vectorized 402 / 411 39.1 25.6 3.6X - Hive built-in ORC 2023 / 2042 7.8 128.6 0.7X + Native ORC MR 1408 / 1585 11.2 89.5 1.0X + Native ORC Vectorized 359 / 368 43.8 22.8 3.9X + Native ORC Vectorized with copy 364 / 371 43.2 23.2 3.9X + Hive built-in ORC 1881 / 1954 8.4 119.6 0.7X */ - sqlBenchmark.run() + benchmark.run() } } } @@ -176,19 +188,26 @@ object OrcReadBenchmark { spark.sql("SELECT sum(c1), sum(length(c2)) FROM nativeOrcTable").collect() } + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(c1), sum(length(c2)) FROM nativeOrcTable").collect() + } + } + benchmark.addCase("Hive built-in ORC") { _ => spark.sql("SELECT sum(c1), sum(length(c2)) FROM hiveOrcTable").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 2729 / 2744 3.8 260.2 1.0X - Native ORC Vectorized 1318 / 1344 8.0 125.7 2.1X - Hive built-in ORC 3731 / 3782 2.8 355.8 0.7X + Native ORC MR 2566 / 2592 4.1 244.7 1.0X + Native ORC Vectorized 1098 / 1113 9.6 104.7 2.3X + Native ORC Vectorized with copy 1527 / 1593 6.9 145.6 1.7X + Hive built-in ORC 3561 / 3705 2.9 339.6 0.7X */ benchmark.run() } @@ -205,63 +224,84 @@ object OrcReadBenchmark { prepareTable(dir, spark.sql("SELECT value % 2 AS p, value AS id FROM t1"), Some("p")) - benchmark.addCase("Read data column - Native ORC MR") { _ => + benchmark.addCase("Data column - Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() } } - benchmark.addCase("Read data column - Native ORC Vectorized") { _ => + benchmark.addCase("Data column - Native ORC Vectorized") { _ => spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() } - benchmark.addCase("Read data column - Hive built-in ORC") { _ => + benchmark.addCase("Data column - Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Data column - Hive built-in ORC") { _ => spark.sql("SELECT sum(id) FROM hiveOrcTable").collect() } - benchmark.addCase("Read partition column - Native ORC MR") { _ => + benchmark.addCase("Partition column - Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() } } - benchmark.addCase("Read partition column - Native ORC Vectorized") { _ => + benchmark.addCase("Partition column - Native ORC Vectorized") { _ => spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() } - benchmark.addCase("Read partition column - Hive built-in ORC") { _ => + benchmark.addCase("Partition column - Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Partition column - Hive built-in ORC") { _ => spark.sql("SELECT sum(p) FROM hiveOrcTable").collect() } - benchmark.addCase("Read both columns - Native ORC MR") { _ => + benchmark.addCase("Both columns - Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() } } - benchmark.addCase("Read both columns - Native ORC Vectorized") { _ => + benchmark.addCase("Both columns - Native ORC Vectorized") { _ => spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() } - benchmark.addCase("Read both columns - Hive built-in ORC") { _ => + benchmark.addCase("Both column - Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Both columns - Hive built-in ORC") { _ => spark.sql("SELECT sum(p), sum(id) FROM hiveOrcTable").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Read data column - Native ORC MR 1531 / 1536 10.3 97.4 1.0X - Read data column - Native ORC Vectorized 295 / 298 53.3 18.8 5.2X - Read data column - Hive built-in ORC 2125 / 2126 7.4 135.1 0.7X - Read partition column - Native ORC MR 1049 / 1062 15.0 66.7 1.5X - Read partition column - Native ORC Vectorized 54 / 57 290.1 3.4 28.2X - Read partition column - Hive built-in ORC 1282 / 1291 12.3 81.5 1.2X - Read both columns - Native ORC MR 1594 / 1598 9.9 101.3 1.0X - Read both columns - Native ORC Vectorized 332 / 336 47.4 21.1 4.6X - Read both columns - Hive built-in ORC 2145 / 2187 7.3 136.4 0.7X + Data only - Native ORC MR 1447 / 1457 10.9 92.0 1.0X + Data only - Native ORC Vectorized 256 / 266 61.4 16.3 5.6X + Data only - Native ORC Vectorized with copy 263 / 273 59.8 16.7 5.5X + Data only - Hive built-in ORC 1960 / 1988 8.0 124.6 0.7X + Partition only - Native ORC MR 1039 / 1043 15.1 66.0 1.4X + Partition only - Native ORC Vectorized 48 / 53 326.6 3.1 30.1X + Partition only - Native ORC Vectorized with copy 48 / 53 328.4 3.0 30.2X + Partition only - Hive built-in ORC 1234 / 1242 12.7 78.4 1.2X + Both columns - Native ORC MR 1465 / 1475 10.7 93.1 1.0X + Both columns - Native ORC Vectorized 292 / 301 53.9 18.6 5.0X + Both column - Native ORC Vectorized with copy 348 / 354 45.1 22.2 4.2X + Both columns - Hive built-in ORC 2051 / 2060 7.7 130.4 0.7X */ benchmark.run() } @@ -287,19 +327,26 @@ object OrcReadBenchmark { spark.sql("SELECT sum(length(c1)) FROM nativeOrcTable").collect() } + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(length(c1)) FROM nativeOrcTable").collect() + } + } + benchmark.addCase("Hive built-in ORC") { _ => spark.sql("SELECT sum(length(c1)) FROM hiveOrcTable").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1325 / 1328 7.9 126.4 1.0X - Native ORC Vectorized 320 / 330 32.8 30.5 4.1X - Hive built-in ORC 1971 / 1972 5.3 188.0 0.7X + Native ORC MR 1271 / 1278 8.3 121.2 1.0X + Native ORC Vectorized 200 / 212 52.4 19.1 6.4X + Native ORC Vectorized with copy 342 / 347 30.7 32.6 3.7X + Hive built-in ORC 1874 / 2105 5.6 178.7 0.7X */ benchmark.run() } @@ -331,32 +378,42 @@ object OrcReadBenchmark { "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() } + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT SUM(LENGTH(c2)) FROM nativeOrcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + } + benchmark.addCase("Hive built-in ORC") { _ => spark.sql("SELECT SUM(LENGTH(c2)) FROM hiveOrcTable " + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 2553 / 2554 4.1 243.4 1.0X - Native ORC Vectorized 953 / 954 11.0 90.9 2.7X - Hive built-in ORC 3875 / 3898 2.7 369.6 0.7X + Native ORC MR 2394 / 2886 4.4 228.3 1.0X + Native ORC Vectorized 699 / 729 15.0 66.7 3.4X + Native ORC Vectorized with copy 959 / 1025 10.9 91.5 2.5X + Hive built-in ORC 3899 / 3901 2.7 371.9 0.6X String with Nulls Scan (0.5%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 2389 / 2408 4.4 227.8 1.0X - Native ORC Vectorized 1208 / 1209 8.7 115.2 2.0X - Hive built-in ORC 2940 / 2952 3.6 280.4 0.8X + Native ORC MR 2234 / 2255 4.7 213.1 1.0X + Native ORC Vectorized 854 / 869 12.3 81.4 2.6X + Native ORC Vectorized with copy 1099 / 1128 9.5 104.8 2.0X + Hive built-in ORC 2767 / 2793 3.8 263.9 0.8X String with Nulls Scan (0.95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1295 / 1311 8.1 123.5 1.0X - Native ORC Vectorized 449 / 457 23.4 42.8 2.9X - Hive built-in ORC 1649 / 1660 6.4 157.3 0.8X + Native ORC MR 1166 / 1202 9.0 111.2 1.0X + Native ORC Vectorized 338 / 345 31.1 32.2 3.5X + Native ORC Vectorized with copy 418 / 428 25.1 39.9 2.8X + Hive built-in ORC 1730 / 1761 6.1 164.9 0.7X */ benchmark.run() } @@ -364,7 +421,7 @@ object OrcReadBenchmark { } def columnsBenchmark(values: Int, width: Int): Unit = { - val sqlBenchmark = new Benchmark(s"SQL Single Column Scan from $width columns", values) + val benchmark = new Benchmark(s"Single Column Scan from $width columns", values) withTempPath { dir => withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { @@ -376,43 +433,52 @@ object OrcReadBenchmark { prepareTable(dir, spark.sql("SELECT * FROM t1")) - sqlBenchmark.addCase("Native ORC MR") { _ => + benchmark.addCase("Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() } } - sqlBenchmark.addCase("Native ORC Vectorized") { _ => + benchmark.addCase("Native ORC Vectorized") { _ => spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() } - sqlBenchmark.addCase("Hive built-in ORC") { _ => + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Hive built-in ORC") { _ => spark.sql(s"SELECT sum(c$middle) FROM hiveOrcTable").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - SQL Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1103 / 1124 1.0 1052.0 1.0X - Native ORC Vectorized 92 / 100 11.4 87.9 12.0X - Hive built-in ORC 383 / 390 2.7 365.4 2.9X + Native ORC MR 1050 / 1053 1.0 1001.1 1.0X + Native ORC Vectorized 95 / 101 11.0 90.9 11.0X + Native ORC Vectorized with copy 95 / 102 11.0 90.9 11.0X + Hive built-in ORC 348 / 358 3.0 331.8 3.0X - SQL Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 2245 / 2250 0.5 2141.0 1.0X - Native ORC Vectorized 157 / 165 6.7 150.2 14.3X - Hive built-in ORC 587 / 593 1.8 559.4 3.8X + Native ORC MR 2099 / 2108 0.5 2002.1 1.0X + Native ORC Vectorized 179 / 187 5.8 171.1 11.7X + Native ORC Vectorized with copy 176 / 188 6.0 167.6 11.9X + Hive built-in ORC 562 / 581 1.9 535.9 3.7X - SQL Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 3343 / 3350 0.3 3188.3 1.0X - Native ORC Vectorized 265 / 280 3.9 253.2 12.6X - Hive built-in ORC 828 / 842 1.3 789.8 4.0X + Native ORC MR 3221 / 3246 0.3 3071.4 1.0X + Native ORC Vectorized 312 / 322 3.4 298.0 10.3X + Native ORC Vectorized with copy 306 / 320 3.4 291.6 10.5X + Hive built-in ORC 815 / 824 1.3 777.3 4.0X */ - sqlBenchmark.run() + benchmark.run() } } } From 70bcc9d5ae33d6669bb5c97db29087ccead770fb Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 9 Jan 2018 23:32:47 -0800 Subject: [PATCH 058/774] [SPARK-22993][ML] Clarify HasCheckpointInterval param doc ## What changes were proposed in this pull request? Add a note to the `HasCheckpointInterval` parameter doc that clarifies that this setting is ignored when no checkpoint directory has been set on the spark context. ## How was this patch tested? No tests necessary, just a doc update. Author: sethah Closes #20188 from sethah/als_checkpoint_doc. --- R/pkg/R/mllib_recommendation.R | 2 ++ R/pkg/R/mllib_tree.R | 6 ++++++ .../apache/spark/ml/param/shared/SharedParamsCodeGen.scala | 4 +++- .../org/apache/spark/ml/param/shared/sharedParams.scala | 4 ++-- python/pyspark/ml/param/_shared_params_code_gen.py | 5 +++-- python/pyspark/ml/param/shared.py | 4 ++-- 6 files changed, 18 insertions(+), 7 deletions(-) diff --git a/R/pkg/R/mllib_recommendation.R b/R/pkg/R/mllib_recommendation.R index fa794249085d7..5441c4a4022a9 100644 --- a/R/pkg/R/mllib_recommendation.R +++ b/R/pkg/R/mllib_recommendation.R @@ -48,6 +48,8 @@ setClass("ALSModel", representation(jobj = "jobj")) #' @param numUserBlocks number of user blocks used to parallelize computation (> 0). #' @param numItemBlocks number of item blocks used to parallelize computation (> 0). #' @param checkpointInterval number of checkpoint intervals (>= 1) or disable checkpoint (-1). +#' Note: this setting will be ignored if the checkpoint directory is not +#' set. #' @param ... additional argument(s) passed to the method. #' @return \code{spark.als} returns a fitted ALS model. #' @rdname spark.als diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 89a58bf0aadae..4e5ddf22ee16d 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -161,6 +161,8 @@ print.summary.decisionTree <- function(x) { #' >= 1. #' @param minInfoGain Minimum information gain for a split to be considered at a tree node. #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' Note: this setting will be ignored if the checkpoint directory is not +#' set. #' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching @@ -382,6 +384,8 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' @param minInstancesPerNode Minimum number of instances each child must have after split. #' @param minInfoGain Minimum information gain for a split to be considered at a tree node. #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' Note: this setting will be ignored if the checkpoint directory is not +#' set. #' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching @@ -595,6 +599,8 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path #' @param minInstancesPerNode Minimum number of instances each child must have after split. #' @param minInfoGain Minimum information gain for a split to be considered at a tree node. #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' Note: this setting will be ignored if the checkpoint directory is not +#' set. #' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index a5d57a15317e6..6ad44af9ef7eb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -63,7 +63,9 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Array[String]]("outputCols", "output column names"), ParamDesc[Int]("checkpointInterval", "set checkpoint interval (>= 1) or " + "disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed " + - "every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"), + "every 10 iterations. Note: this setting will be ignored if the checkpoint directory " + + "is not set in the SparkContext", + isValid = "(interval: Int) => interval == -1 || interval >= 1"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + "will filter out rows with bad values), or error (which will throw an error). More " + diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 13425dacc9f18..be8b2f273164b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -282,10 +282,10 @@ trait HasOutputCols extends Params { trait HasCheckpointInterval extends Params { /** - * Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. + * Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext. * @group param */ - final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations", (interval: Int) => interval == -1 || interval >= 1) + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext", (interval: Int) => interval == -1 || interval >= 1) /** @group getParam */ final def getCheckpointInterval: Int = $(checkpointInterval) diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index d55d209d09398..1d0f60acc6983 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -121,8 +121,9 @@ def get$Name(self): ("outputCol", "output column name.", "self.uid + '__output'", "TypeConverters.toString"), ("numFeatures", "number of features.", None, "TypeConverters.toInt"), ("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " + - "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None, - "TypeConverters.toInt"), + "E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: " + + "this setting will be ignored if the checkpoint directory is not set in the SparkContext.", + None, "TypeConverters.toInt"), ("seed", "random seed.", "hash(type(self).__name__)", "TypeConverters.toInt"), ("tol", "the convergence tolerance for iterative algorithms (>= 0).", None, "TypeConverters.toFloat"), diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index e5c5ddfba6c1f..813f7a59f3fd1 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -281,10 +281,10 @@ def getNumFeatures(self): class HasCheckpointInterval(Params): """ - Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. + Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext. """ - checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", typeConverter=TypeConverters.toInt) + checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.", typeConverter=TypeConverters.toInt) def __init__(self): super(HasCheckpointInterval, self).__init__() From f340b6b3066033d40b7e163fd5fb68e9820adfb1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 10 Jan 2018 00:45:47 -0800 Subject: [PATCH 059/774] [SPARK-22997] Add additional defenses against use of freed MemoryBlocks ## What changes were proposed in this pull request? This patch modifies Spark's `MemoryAllocator` implementations so that `free(MemoryBlock)` mutates the passed block to clear pointers (in the off-heap case) or null out references to backing `long[]` arrays (in the on-heap case). The goal of this change is to add an extra layer of defense against use-after-free bugs because currently it's hard to detect corruption caused by blind writes to freed memory blocks. ## How was this patch tested? New unit tests in `PlatformSuite`, including new tests for existing functionality because we did not have sufficient mutation coverage of the on-heap memory allocator's pooling logic. Author: Josh Rosen Closes #20191 from JoshRosen/SPARK-22997-add-defenses-against-use-after-free-bugs-in-memory-allocator. --- .../unsafe/memory/HeapMemoryAllocator.java | 35 +++++++++---- .../spark/unsafe/memory/MemoryBlock.java | 21 +++++++- .../unsafe/memory/UnsafeMemoryAllocator.java | 11 ++++ .../spark/unsafe/PlatformUtilSuite.java | 50 ++++++++++++++++++- .../spark/memory/TaskMemoryManager.java | 13 ++++- .../spark/memory/TaskMemoryManagerSuite.java | 29 +++++++++++ 6 files changed, 146 insertions(+), 13 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index cc9cc429643ad..3acfe3696cb1e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -31,8 +31,7 @@ public class HeapMemoryAllocator implements MemoryAllocator { @GuardedBy("this") - private final Map>> bufferPoolsBySize = - new HashMap<>(); + private final Map>> bufferPoolsBySize = new HashMap<>(); private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024; @@ -49,13 +48,14 @@ private boolean shouldPool(long size) { public MemoryBlock allocate(long size) throws OutOfMemoryError { if (shouldPool(size)) { synchronized (this) { - final LinkedList> pool = bufferPoolsBySize.get(size); + final LinkedList> pool = bufferPoolsBySize.get(size); if (pool != null) { while (!pool.isEmpty()) { - final WeakReference blockReference = pool.pop(); - final MemoryBlock memory = blockReference.get(); - if (memory != null) { - assert (memory.size() == size); + final WeakReference arrayReference = pool.pop(); + final long[] array = arrayReference.get(); + if (array != null) { + assert (array.length * 8L >= size); + MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -76,18 +76,35 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { @Override public void free(MemoryBlock memory) { + assert (memory.obj != null) : + "baseObject was null; are you trying to use the on-heap allocator to free off-heap memory?"; + assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + "page has already been freed"; + assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) + || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator free()"; + final long size = memory.size(); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); } + + // Mark the page as freed (so we can detect double-frees). + memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; + + // As an additional layer of defense against use-after-free bugs, we mutate the + // MemoryBlock to null out its reference to the long[] array. + long[] array = (long[]) memory.obj; + memory.setObjAndOffset(null, 0); + if (shouldPool(size)) { synchronized (this) { - LinkedList> pool = bufferPoolsBySize.get(size); + LinkedList> pool = bufferPoolsBySize.get(size); if (pool == null) { pool = new LinkedList<>(); bufferPoolsBySize.put(size, pool); } - pool.add(new WeakReference<>(memory)); + pool.add(new WeakReference<>(array)); } } else { // Do nothing diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index cd1d378bc1470..c333857358d30 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -26,6 +26,25 @@ */ public class MemoryBlock extends MemoryLocation { + /** Special `pageNumber` value for pages which were not allocated by TaskMemoryManagers */ + public static final int NO_PAGE_NUMBER = -1; + + /** + * Special `pageNumber` value for marking pages that have been freed in the TaskMemoryManager. + * We set `pageNumber` to this value in TaskMemoryManager.freePage() so that MemoryAllocator + * can detect if pages which were allocated by TaskMemoryManager have been freed in the TMM + * before being passed to MemoryAllocator.free() (it is an error to allocate a page in + * TaskMemoryManager and then directly free it in a MemoryAllocator without going through + * the TMM freePage() call). + */ + public static final int FREED_IN_TMM_PAGE_NUMBER = -2; + + /** + * Special `pageNumber` value for pages that have been freed by the MemoryAllocator. This allows + * us to detect double-frees. + */ + public static final int FREED_IN_ALLOCATOR_PAGE_NUMBER = -3; + private final long length; /** @@ -33,7 +52,7 @@ public class MemoryBlock extends MemoryLocation { * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager, * which lives in a different package. */ - public int pageNumber = -1; + public int pageNumber = NO_PAGE_NUMBER; public MemoryBlock(@Nullable Object obj, long offset, long length) { super(obj, offset); diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java index 55bcdf1ed7b06..4368fb615ba1e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -38,9 +38,20 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { public void free(MemoryBlock memory) { assert (memory.obj == null) : "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; + assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + "page has already been freed"; + assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) + || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + "TMM-allocated pages must be freed via TMM.freePage(), not directly in allocator free()"; + if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); } Platform.freeMemory(memory.offset); + // As an additional layer of defense against use-after-free bugs, we mutate the + // MemoryBlock to reset its pointer. + memory.offset = 0; + // Mark the page as freed (so we can detect double-frees). + memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 4b141339ec816..62854837b05ed 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -62,6 +62,52 @@ public void overlappingCopyMemory() { } } + @Test + public void onHeapMemoryAllocatorPoolingReUsesLongArrays() { + MemoryBlock block1 = MemoryAllocator.HEAP.allocate(1024 * 1024); + Object baseObject1 = block1.getBaseObject(); + MemoryAllocator.HEAP.free(block1); + MemoryBlock block2 = MemoryAllocator.HEAP.allocate(1024 * 1024); + Object baseObject2 = block2.getBaseObject(); + Assert.assertSame(baseObject1, baseObject2); + MemoryAllocator.HEAP.free(block2); + } + + @Test + public void freeingOnHeapMemoryBlockResetsBaseObjectAndOffset() { + MemoryBlock block = MemoryAllocator.HEAP.allocate(1024); + Assert.assertNotNull(block.getBaseObject()); + MemoryAllocator.HEAP.free(block); + Assert.assertNull(block.getBaseObject()); + Assert.assertEquals(0, block.getBaseOffset()); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber); + } + + @Test + public void freeingOffHeapMemoryBlockResetsOffset() { + MemoryBlock block = MemoryAllocator.UNSAFE.allocate(1024); + Assert.assertNull(block.getBaseObject()); + Assert.assertNotEquals(0, block.getBaseOffset()); + MemoryAllocator.UNSAFE.free(block); + Assert.assertNull(block.getBaseObject()); + Assert.assertEquals(0, block.getBaseOffset()); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber); + } + + @Test(expected = AssertionError.class) + public void onHeapMemoryAllocatorThrowsAssertionErrorOnDoubleFree() { + MemoryBlock block = MemoryAllocator.HEAP.allocate(1024); + MemoryAllocator.HEAP.free(block); + MemoryAllocator.HEAP.free(block); + } + + @Test(expected = AssertionError.class) + public void offHeapMemoryAllocatorThrowsAssertionErrorOnDoubleFree() { + MemoryBlock block = MemoryAllocator.UNSAFE.allocate(1024); + MemoryAllocator.UNSAFE.free(block); + MemoryAllocator.UNSAFE.free(block); + } + @Test public void memoryDebugFillEnabledInTest() { Assert.assertTrue(MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED); @@ -71,9 +117,11 @@ public void memoryDebugFillEnabledInTest() { MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); MemoryBlock onheap1 = MemoryAllocator.HEAP.allocate(1024 * 1024); + Object onheap1BaseObject = onheap1.getBaseObject(); + long onheap1BaseOffset = onheap1.getBaseOffset(); MemoryAllocator.HEAP.free(onheap1); Assert.assertEquals( - Platform.getByte(onheap1.getBaseObject(), onheap1.getBaseOffset()), + Platform.getByte(onheap1BaseObject, onheap1BaseOffset), MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); MemoryBlock onheap2 = MemoryAllocator.HEAP.allocate(1024 * 1024); Assert.assertEquals( diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index e8d3730daa7a4..632d718062212 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -321,8 +321,12 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}. */ public void freePage(MemoryBlock page, MemoryConsumer consumer) { - assert (page.pageNumber != -1) : + assert (page.pageNumber != MemoryBlock.NO_PAGE_NUMBER) : "Called freePage() on memory that wasn't allocated with allocatePage()"; + assert (page.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + "Called freePage() on a memory block that has already been freed"; + assert (page.pageNumber != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) : + "Called freePage() on a memory block that has already been freed"; assert(allocatedPages.get(page.pageNumber)); pageTable[page.pageNumber] = null; synchronized (this) { @@ -332,6 +336,10 @@ public void freePage(MemoryBlock page, MemoryConsumer consumer) { logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); } long pageSize = page.size(); + // Clear the page number before passing the block to the MemoryAllocator's free(). + // Doing this allows the MemoryAllocator to detect when a TaskMemoryManager-managed + // page has been inappropriately directly freed without calling TMM.freePage(). + page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; memoryManager.tungstenMemoryAllocator().free(page); releaseExecutionMemory(pageSize, consumer); } @@ -358,7 +366,7 @@ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { @VisibleForTesting public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) { - assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; + assert (pageNumber >= 0) : "encodePageNumberAndOffset called with invalid page"; return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS); } @@ -424,6 +432,7 @@ public long cleanUpAllAllocatedMemory() { for (MemoryBlock page : pageTable) { if (page != null) { logger.debug("unreleased page: " + page + " in task " + taskAttemptId); + page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; memoryManager.tungstenMemoryAllocator().free(page); } } diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index 46b0516e36141..a0664b30d6cc2 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -21,6 +21,7 @@ import org.junit.Test; import org.apache.spark.SparkConf; +import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; public class TaskMemoryManagerSuite { @@ -68,6 +69,34 @@ public void encodePageNumberAndOffsetOnHeap() { Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress)); } + @Test + public void freeingPageSetsPageNumberToSpecialConstant() { + final TaskMemoryManager manager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); + final MemoryBlock dataPage = manager.allocatePage(256, c); + c.freePage(dataPage); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.pageNumber); + } + + @Test(expected = AssertionError.class) + public void freeingPageDirectlyInAllocatorTriggersAssertionError() { + final TaskMemoryManager manager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); + final MemoryBlock dataPage = manager.allocatePage(256, c); + MemoryAllocator.HEAP.free(dataPage); + } + + @Test(expected = AssertionError.class) + public void callingFreePageOnDirectlyAllocatedPageTriggersAssertionError() { + final TaskMemoryManager manager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); + final MemoryBlock dataPage = MemoryAllocator.HEAP.allocate(256); + manager.freePage(dataPage, c); + } + @Test public void cooperativeSpilling() { final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); From 344e3aab87178e45957333479a07e07f202ca1fd Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Wed, 10 Jan 2018 09:44:30 -0800 Subject: [PATCH 060/774] [SPARK-23019][CORE] Wait until SparkContext.stop() finished in SparkLauncherSuite ## What changes were proposed in this pull request? In current code ,the function `waitFor` call https://github.com/apache/spark/blob/cfcd746689c2b84824745fa6d327ffb584c7a17d/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java#L155 only wait until DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet. https://github.com/apache/spark/blob/1c9f95cb771ac78775a77edd1abfeb2d8ae2a124/core/src/main/scala/org/apache/spark/SparkContext.scala#L1924 Thus, in the Jenkins test https://amplab.cs.berkeley.edu/jenkins/job/spark-branch-2.3-test-maven-hadoop-2.6/ , `JdbcRDDSuite` failed because the previous test `SparkLauncherSuite` exit before SparkContext.stop() is finished. To repo: ``` $ build/sbt > project core > testOnly *SparkLauncherSuite *JavaJdbcRDDSuite ``` To Fix: Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM in SparkLauncherSuite. Can' come up with any better solution for now. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #20221 from gengliangwang/SPARK-23019. --- .../java/org/apache/spark/launcher/SparkLauncherSuite.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index c2261c204cd45..9d2f563b2e367 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Map; import java.util.Properties; +import java.util.concurrent.TimeUnit; import org.junit.Test; import static org.junit.Assert.*; @@ -133,6 +134,10 @@ public void testInProcessLauncher() throws Exception { p.put(e.getKey(), e.getValue()); } System.setProperties(p); + // Here DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet. + // Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM. + // See SPARK-23019 and SparkContext.stop() for details. + TimeUnit.MILLISECONDS.sleep(500); } } From 9b33dfc408de986f4203bb0ac0c3f5c56effd69d Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Wed, 10 Jan 2018 14:25:04 -0800 Subject: [PATCH 061/774] [SPARK-22951][SQL] fix aggregation after dropDuplicates on empty data frames ## What changes were proposed in this pull request? (courtesy of liancheng) Spark SQL supports both global aggregation and grouping aggregation. Global aggregation always return a single row with the initial aggregation state as the output, even there are zero input rows. Spark implements this by simply checking the number of grouping keys and treats an aggregation as a global aggregation if it has zero grouping keys. However, this simple principle drops the ball in the following case: ```scala spark.emptyDataFrame.dropDuplicates().agg(count($"*") as "c").show() // +---+ // | c | // +---+ // | 1 | // +---+ ``` The reason is that: 1. `df.dropDuplicates()` is roughly translated into something equivalent to: ```scala val allColumns = df.columns.map { col } df.groupBy(allColumns: _*).agg(allColumns.head, allColumns.tail: _*) ``` This translation is implemented in the rule `ReplaceDeduplicateWithAggregate`. 2. `spark.emptyDataFrame` contains zero columns and zero rows. Therefore, rule `ReplaceDeduplicateWithAggregate` makes a confusing transformation roughly equivalent to the following one: ```scala spark.emptyDataFrame.dropDuplicates() => spark.emptyDataFrame.groupBy().agg(Map.empty[String, String]) ``` The above transformation is confusing because the resulting aggregate operator contains no grouping keys (because `emptyDataFrame` contains no columns), and gets recognized as a global aggregation. As a result, Spark SQL allocates a single row filled by the initial aggregation state and uses it as the output, and returns a wrong result. To fix this issue, this PR tweaks `ReplaceDeduplicateWithAggregate` by appending a literal `1` to the grouping key list of the resulting `Aggregate` operator when the input plan contains zero output columns. In this way, `spark.emptyDataFrame.dropDuplicates()` is now translated into a grouping aggregation, roughly depicted as: ```scala spark.emptyDataFrame.dropDuplicates() => spark.emptyDataFrame.groupBy(lit(1)).agg(Map.empty[String, String]) ``` Which is now properly treated as a grouping aggregation and returns the correct answer. ## How was this patch tested? New unit tests added Author: Feng Liu Closes #20174 from liufengdb/fix-duplicate. --- .../sql/catalyst/optimizer/Optimizer.scala | 8 ++++++- .../optimizer/ReplaceOperatorSuite.scala | 10 +++++++- .../spark/sql/DataFrameAggregateSuite.scala | 24 +++++++++++++++++-- 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index df0af8264a329..c794ba8619322 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1222,7 +1222,13 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] { Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId) } } - Aggregate(keys, aggCols, child) + // SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping + // aggregations by checking the number of grouping keys. The key difference here is that a + // global aggregation always returns at least one row even if there are no input rows. Here + // we append a literal when the grouping key list is empty so that the result aggregate + // operator is properly treated as a grouping aggregation. + val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys + Aggregate(nonemptyKeys, aggCols, child) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index 0fa1aaeb9e164..e9701ffd2c54b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Not} +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, Not} import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ @@ -198,6 +198,14 @@ class ReplaceOperatorSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("add one grouping key if necessary when replace Deduplicate with Aggregate") { + val input = LocalRelation() + val query = Deduplicate(Seq.empty, input) // dropDuplicates() + val optimized = Optimize.execute(query.analyze) + val correctAnswer = Aggregate(Seq(Literal(1)), input.output, input) + comparePlans(optimized, correctAnswer) + } + test("don't replace streaming Deduplicate") { val input = LocalRelation(Seq('a.int, 'b.int), isStreaming = true) val attrA = input.output(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 06848e4d2b297..e7776e36702ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import scala.util.Random +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -27,7 +29,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.DecimalData -import org.apache.spark.sql.types.{Decimal, DecimalType} +import org.apache.spark.sql.types.DecimalType case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) @@ -456,7 +458,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("null moments") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") - checkAnswer( emptyTableData.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), Row(null, null, null, null, null)) @@ -666,4 +667,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { assert(exchangePlans.length == 1) } } + + Seq(true, false).foreach { codegen => + test("SPARK-22951: dropDuplicates on empty dataFrames should produce correct aggregate " + + s"results when codegen is enabled: $codegen") { + withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, codegen.toString)) { + // explicit global aggregations + val emptyAgg = Map.empty[String, String] + checkAnswer(spark.emptyDataFrame.agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.groupBy().agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.groupBy().agg(count("*")), Seq(Row(0))) + checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(count("*")), Seq(Row(0))) + + // global aggregation is converted to grouping aggregation: + assert(spark.emptyDataFrame.dropDuplicates().count() == 0) + } + } + } } From a6647ffbf7a312a3e119a9beef90880cc915aa60 Mon Sep 17 00:00:00 2001 From: Mingjie Tang Date: Thu, 11 Jan 2018 11:51:03 +0800 Subject: [PATCH 062/774] [SPARK-22587] Spark job fails if fs.defaultFS and application jar are different url ## What changes were proposed in this pull request? Two filesystems comparing does not consider the authority of URI. This is specific for WASB file storage system, where userInfo is honored to differentiate filesystems. For example: wasbs://user1xyz.net, wasbs://user2xyz.net would consider as two filesystem. Therefore, we have to add the authority to compare two filesystem, and two filesystem with different authority can not be the same FS. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Mingjie Tang Closes #19885 from merlintang/EAR-7377. --- .../org/apache/spark/deploy/yarn/Client.scala | 24 +++++++++++--- .../spark/deploy/yarn/ClientSuite.scala | 33 +++++++++++++++++++ 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 15328d08b3b5c..8cd3cd9746a3a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1421,15 +1421,20 @@ private object Client extends Logging { } /** - * Return whether the two file systems are the same. + * Return whether two URI represent file system are the same */ - private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = { - val srcUri = srcFs.getUri() - val dstUri = destFs.getUri() + private[spark] def compareUri(srcUri: URI, dstUri: URI): Boolean = { + if (srcUri.getScheme() == null || srcUri.getScheme() != dstUri.getScheme()) { return false } + val srcAuthority = srcUri.getAuthority() + val dstAuthority = dstUri.getAuthority() + if (srcAuthority != null && !srcAuthority.equalsIgnoreCase(dstAuthority)) { + return false + } + var srcHost = srcUri.getHost() var dstHost = dstUri.getHost() @@ -1447,6 +1452,17 @@ private object Client extends Logging { } Objects.equal(srcHost, dstHost) && srcUri.getPort() == dstUri.getPort() + + } + + /** + * Return whether the two file systems are the same. + */ + protected def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = { + val srcUri = srcFs.getUri() + val dstUri = destFs.getUri() + + compareUri(srcUri, dstUri) } /** diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 9d5f5eb621118..7fa597167f3f0 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -357,6 +357,39 @@ class ClientSuite extends SparkFunSuite with Matchers { sparkConf.get(SECONDARY_JARS) should be (Some(Seq(new File(jar2.toURI).getName))) } + private val matching = Seq( + ("files URI match test1", "file:///file1", "file:///file2"), + ("files URI match test2", "file:///c:file1", "file://c:file2"), + ("files URI match test3", "file://host/file1", "file://host/file2"), + ("wasb URI match test", "wasb://bucket1@user", "wasb://bucket1@user/"), + ("hdfs URI match test", "hdfs:/path1", "hdfs:/path1") + ) + + matching.foreach { t => + test(t._1) { + assert(Client.compareUri(new URI(t._2), new URI(t._3)), + s"No match between ${t._2} and ${t._3}") + } + } + + private val unmatching = Seq( + ("files URI unmatch test1", "file:///file1", "file://host/file2"), + ("files URI unmatch test2", "file://host/file1", "file:///file2"), + ("files URI unmatch test3", "file://host/file1", "file://host2/file2"), + ("wasb URI unmatch test1", "wasb://bucket1@user", "wasb://bucket2@user/"), + ("wasb URI unmatch test2", "wasb://bucket1@user", "wasb://bucket1@user2/"), + ("s3 URI unmatch test", "s3a://user@pass:bucket1/", "s3a://user2@pass2:bucket1/"), + ("hdfs URI unmatch test1", "hdfs://namenode1/path1", "hdfs://namenode1:8080/path2"), + ("hdfs URI unmatch test2", "hdfs://namenode1:8020/path1", "hdfs://namenode1:8080/path2") + ) + + unmatching.foreach { t => + test(t._1) { + assert(!Client.compareUri(new URI(t._2), new URI(t._3)), + s"match between ${t._2} and ${t._3}") + } + } + object Fixtures { val knownDefYarnAppCP: Seq[String] = From 87c98de8b23f0e978958fc83677fdc4c339b7e6a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 11 Jan 2018 18:17:34 +0800 Subject: [PATCH 063/774] [SPARK-23001][SQL] Fix NullPointerException when DESC a database with NULL description ## What changes were proposed in this pull request? When users' DB description is NULL, users might hit `NullPointerException`. This PR is to fix the issue. ## How was this patch tested? Added test cases Author: gatorsmile Closes #20215 from gatorsmile/SPARK-23001. --- .../apache/spark/sql/hive/client/HiveClientImpl.scala | 2 +- .../apache/spark/sql/hive/HiveExternalCatalogSuite.scala | 6 ++++++ .../org/apache/spark/sql/hive/client/VersionsSuite.scala | 9 +++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 102f40bacc985..4b923f5235a90 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -330,7 +330,7 @@ private[hive] class HiveClientImpl( Option(client.getDatabase(dbName)).map { d => CatalogDatabase( name = d.getName, - description = d.getDescription, + description = Option(d.getDescription).getOrElse(""), locationUri = CatalogUtils.stringToURI(d.getLocationUri), properties = Option(d.getParameters).map(_.asScala.toMap).orNull) }.getOrElse(throw new NoSuchDatabaseException(dbName)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index 2e35fdeba464d..0a522b6a11c80 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -107,4 +107,10 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { .filter(_.contains("Num Buckets")).head assert(bucketString.contains("10")) } + + test("SPARK-23001: NullPointerException when running desc database") { + val catalog = newBasicCatalog() + catalog.createDatabase(newDb("dbWithNullDesc").copy(description = null), ignoreIfExists = false) + assert(catalog.getDatabase("dbWithNullDesc").description == "") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 94473a08dd317..ff90e9dda5f7c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -163,6 +163,15 @@ class VersionsSuite extends SparkFunSuite with Logging { client.createDatabase(tempDB, ignoreIfExists = true) } + test(s"$version: createDatabase with null description") { + withTempDir { tmpDir => + val dbWithNullDesc = + CatalogDatabase("dbWithNullDesc", description = null, tmpDir.toURI, Map()) + client.createDatabase(dbWithNullDesc, ignoreIfExists = true) + assert(client.getDatabase("dbWithNullDesc").description == "") + } + } + test(s"$version: setCurrentDatabase") { client.setCurrentDatabase("default") } From 1c70da3bfbb4016e394de2c73eb0db7cdd9a6968 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 11 Jan 2018 19:41:48 +0800 Subject: [PATCH 064/774] [SPARK-20657][CORE] Speed up rendering of the stages page. There are two main changes to speed up rendering of the tasks list when rendering the stage page. The first one makes the code only load the tasks being shown in the current page of the tasks table, and information related to only those tasks. One side-effect of this change is that the graph that shows task-related events now only shows events for the tasks in the current page, instead of the previously hardcoded limit of "events for the first 1000 tasks". That ends up helping with readability, though. To make sorting efficient when using a disk store, the task wrapper was extended to include many new indices, one for each of the sortable columns in the UI, and metrics for which quantiles are calculated. The second changes the way metric quantiles are calculated for stages. Instead of using the "Distribution" class to process data for all task metrics, which requires scanning all tasks of a stage, the code now uses the KVStore "skip()" functionality to only read tasks that contain interesting information for the quantiles that are desired. This is still not cheap; because there are many metrics that the UI and API track, the code needs to scan the index for each metric to gather the information. Savings come mainly from skipping deserialization when using the disk store, but the in-memory code also seems to be faster than before (most probably because of other changes in this patch). To make subsequent calls faster, some quantiles are cached in the status store. This makes UIs much faster after the first time a stage has been loaded. With the above changes, a lot of code in the UI layer could be simplified. Author: Marcelo Vanzin Closes #20013 from vanzin/SPARK-20657. --- .../apache/spark/util/kvstore/LevelDB.java | 1 + .../spark/status/AppStatusListener.scala | 57 +- .../apache/spark/status/AppStatusStore.scala | 389 +++++--- .../apache/spark/status/AppStatusUtils.scala | 68 ++ .../org/apache/spark/status/LiveEntity.scala | 344 ++++--- .../spark/status/api/v1/StagesResource.scala | 3 +- .../org/apache/spark/status/api/v1/api.scala | 3 + .../org/apache/spark/status/storeTypes.scala | 327 ++++++- .../apache/spark/ui/jobs/ExecutorTable.scala | 4 +- .../org/apache/spark/ui/jobs/JobPage.scala | 2 +- .../org/apache/spark/ui/jobs/StagePage.scala | 919 ++++++------------ ...mmary_w__custom_quantiles_expectation.json | 3 + ...sk_summary_w_shuffle_read_expectation.json | 3 + ...k_summary_w_shuffle_write_expectation.json | 3 + .../spark/status/AppStatusListenerSuite.scala | 105 +- .../spark/status/AppStatusStoreSuite.scala | 104 ++ .../org/apache/spark/ui/StagePageSuite.scala | 10 +- scalastyle-config.xml | 2 +- 18 files changed, 1361 insertions(+), 986 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala create mode 100644 core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java index 4f9e10ca20066..0e491efac9181 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java @@ -83,6 +83,7 @@ public LevelDB(File path, KVStoreSerializer serializer) throws Exception { if (versionData != null) { long version = serializer.deserializeLong(versionData); if (version != STORE_VERSION) { + close(); throw new UnsupportedStoreVersionException(); } } else { diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 88b75ddd5993a..b4edcf23abc09 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -377,6 +377,10 @@ private[spark] class AppStatusListener( Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => stage.activeTasks += 1 stage.firstLaunchTime = math.min(stage.firstLaunchTime, event.taskInfo.launchTime) + + val locality = event.taskInfo.taskLocality.toString() + val count = stage.localitySummary.getOrElse(locality, 0L) + 1L + stage.localitySummary = stage.localitySummary ++ Map(locality -> count) maybeUpdate(stage, now) stage.jobs.foreach { job => @@ -433,7 +437,7 @@ private[spark] class AppStatusListener( } task.errorMessage = errorMessage val delta = task.updateMetrics(event.taskMetrics) - update(task, now) + update(task, now, last = true) delta }.orNull @@ -450,7 +454,7 @@ private[spark] class AppStatusListener( Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => if (metricsDelta != null) { - stage.metrics.update(metricsDelta) + stage.metrics = LiveEntityHelpers.addMetrics(stage.metrics, metricsDelta) } stage.activeTasks -= 1 stage.completedTasks += completedDelta @@ -486,7 +490,7 @@ private[spark] class AppStatusListener( esummary.failedTasks += failedDelta esummary.killedTasks += killedDelta if (metricsDelta != null) { - esummary.metrics.update(metricsDelta) + esummary.metrics = LiveEntityHelpers.addMetrics(esummary.metrics, metricsDelta) } maybeUpdate(esummary, now) @@ -604,11 +608,11 @@ private[spark] class AppStatusListener( maybeUpdate(task, now) Option(liveStages.get((sid, sAttempt))).foreach { stage => - stage.metrics.update(delta) + stage.metrics = LiveEntityHelpers.addMetrics(stage.metrics, delta) maybeUpdate(stage, now) val esummary = stage.executorSummary(event.execId) - esummary.metrics.update(delta) + esummary.metrics = LiveEntityHelpers.addMetrics(esummary.metrics, delta) maybeUpdate(esummary, now) } } @@ -690,7 +694,7 @@ private[spark] class AppStatusListener( // can update the executor information too. liveRDDs.get(block.rddId).foreach { rdd => if (updatedStorageLevel.isDefined) { - rdd.storageLevel = updatedStorageLevel.get + rdd.setStorageLevel(updatedStorageLevel.get) } val partition = rdd.partition(block.name) @@ -814,7 +818,7 @@ private[spark] class AppStatusListener( /** Update a live entity only if it hasn't been updated in the last configured period. */ private def maybeUpdate(entity: LiveEntity, now: Long): Unit = { - if (liveUpdatePeriodNs >= 0 && now - entity.lastWriteTime > liveUpdatePeriodNs) { + if (live && liveUpdatePeriodNs >= 0 && now - entity.lastWriteTime > liveUpdatePeriodNs) { update(entity, now) } } @@ -865,7 +869,7 @@ private[spark] class AppStatusListener( } stages.foreach { s => - val key = s.id + val key = Array(s.info.stageId, s.info.attemptId) kvstore.delete(s.getClass(), key) val execSummaries = kvstore.view(classOf[ExecutorStageSummaryWrapper]) @@ -885,15 +889,15 @@ private[spark] class AppStatusListener( .asScala tasks.foreach { t => - kvstore.delete(t.getClass(), t.info.taskId) + kvstore.delete(t.getClass(), t.taskId) } // Check whether there are remaining attempts for the same stage. If there aren't, then // also delete the RDD graph data. val remainingAttempts = kvstore.view(classOf[StageDataWrapper]) .index("stageId") - .first(s.stageId) - .last(s.stageId) + .first(s.info.stageId) + .last(s.info.stageId) .closeableIterator() val hasMoreAttempts = try { @@ -905,8 +909,10 @@ private[spark] class AppStatusListener( } if (!hasMoreAttempts) { - kvstore.delete(classOf[RDDOperationGraphWrapper], s.stageId) + kvstore.delete(classOf[RDDOperationGraphWrapper], s.info.stageId) } + + cleanupCachedQuantiles(key) } } @@ -919,9 +925,9 @@ private[spark] class AppStatusListener( // Try to delete finished tasks only. val toDelete = KVUtils.viewToSeq(view, countToDelete) { t => - !live || t.info.status != TaskState.RUNNING.toString() + !live || t.status != TaskState.RUNNING.toString() } - toDelete.foreach { t => kvstore.delete(t.getClass(), t.info.taskId) } + toDelete.foreach { t => kvstore.delete(t.getClass(), t.taskId) } stage.savedTasks.addAndGet(-toDelete.size) // If there are more running tasks than the configured limit, delete running tasks. This @@ -930,13 +936,34 @@ private[spark] class AppStatusListener( val remaining = countToDelete - toDelete.size if (remaining > 0) { val runningTasksToDelete = view.max(remaining).iterator().asScala.toList - runningTasksToDelete.foreach { t => kvstore.delete(t.getClass(), t.info.taskId) } + runningTasksToDelete.foreach { t => kvstore.delete(t.getClass(), t.taskId) } stage.savedTasks.addAndGet(-remaining) } + + // On live applications, cleanup any cached quantiles for the stage. This makes sure that + // quantiles will be recalculated after tasks are replaced with newer ones. + // + // This is not needed in the SHS since caching only happens after the event logs are + // completely processed. + if (live) { + cleanupCachedQuantiles(stageKey) + } } stage.cleaning = false } + private def cleanupCachedQuantiles(stageKey: Array[Int]): Unit = { + val cachedQuantiles = kvstore.view(classOf[CachedQuantile]) + .index("stage") + .first(stageKey) + .last(stageKey) + .asScala + .toList + cachedQuantiles.foreach { q => + kvstore.delete(q.getClass(), q.id) + } + } + /** * Remove at least (retainedSize / 10) items to reduce friction. Because tracking may be done * asynchronously, this method may return 0 in case enough items have been deleted already. diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 5a942f5284018..efc28538a33db 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.{JobExecutionStatus, SparkConf} import org.apache.spark.status.api.v1 import org.apache.spark.ui.scope._ -import org.apache.spark.util.Distribution +import org.apache.spark.util.{Distribution, Utils} import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} /** @@ -98,7 +98,11 @@ private[spark] class AppStatusStore( val it = store.view(classOf[StageDataWrapper]).index("stageId").reverse().first(stageId) .closeableIterator() try { - it.next().info + if (it.hasNext()) { + it.next().info + } else { + throw new NoSuchElementException(s"No stage with id $stageId") + } } finally { it.close() } @@ -110,107 +114,238 @@ private[spark] class AppStatusStore( if (details) stageWithDetails(stage) else stage } + def taskCount(stageId: Int, stageAttemptId: Int): Long = { + store.count(classOf[TaskDataWrapper], "stage", Array(stageId, stageAttemptId)) + } + + def localitySummary(stageId: Int, stageAttemptId: Int): Map[String, Long] = { + store.read(classOf[StageDataWrapper], Array(stageId, stageAttemptId)).locality + } + + /** + * Calculates a summary of the task metrics for the given stage attempt, returning the + * requested quantiles for the recorded metrics. + * + * This method can be expensive if the requested quantiles are not cached; the method + * will only cache certain quantiles (every 0.05 step), so it's recommended to stick to + * those to avoid expensive scans of all task data. + */ def taskSummary( stageId: Int, stageAttemptId: Int, - quantiles: Array[Double]): v1.TaskMetricDistributions = { - - val stage = Array(stageId, stageAttemptId) - - val rawMetrics = store.view(classOf[TaskDataWrapper]) - .index("stage") - .first(stage) - .last(stage) - .asScala - .flatMap(_.info.taskMetrics) - .toList - .view - - def metricQuantiles(f: v1.TaskMetrics => Double): IndexedSeq[Double] = - Distribution(rawMetrics.map { d => f(d) }).get.getQuantiles(quantiles) - - // We need to do a lot of similar munging to nested metrics here. For each one, - // we want (a) extract the values for nested metrics (b) make a distribution for each metric - // (c) shove the distribution into the right field in our return type and (d) only return - // a result if the option is defined for any of the tasks. MetricHelper is a little util - // to make it a little easier to deal w/ all of the nested options. Mostly it lets us just - // implement one "build" method, which just builds the quantiles for each field. - - val inputMetrics = - new MetricHelper[v1.InputMetrics, v1.InputMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw: v1.TaskMetrics): v1.InputMetrics = raw.inputMetrics - - def build: v1.InputMetricDistributions = new v1.InputMetricDistributions( - bytesRead = submetricQuantiles(_.bytesRead), - recordsRead = submetricQuantiles(_.recordsRead) - ) - }.build - - val outputMetrics = - new MetricHelper[v1.OutputMetrics, v1.OutputMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw: v1.TaskMetrics): v1.OutputMetrics = raw.outputMetrics - - def build: v1.OutputMetricDistributions = new v1.OutputMetricDistributions( - bytesWritten = submetricQuantiles(_.bytesWritten), - recordsWritten = submetricQuantiles(_.recordsWritten) - ) - }.build - - val shuffleReadMetrics = - new MetricHelper[v1.ShuffleReadMetrics, v1.ShuffleReadMetricDistributions](rawMetrics, - quantiles) { - def getSubmetrics(raw: v1.TaskMetrics): v1.ShuffleReadMetrics = - raw.shuffleReadMetrics - - def build: v1.ShuffleReadMetricDistributions = new v1.ShuffleReadMetricDistributions( - readBytes = submetricQuantiles { s => s.localBytesRead + s.remoteBytesRead }, - readRecords = submetricQuantiles(_.recordsRead), - remoteBytesRead = submetricQuantiles(_.remoteBytesRead), - remoteBytesReadToDisk = submetricQuantiles(_.remoteBytesReadToDisk), - remoteBlocksFetched = submetricQuantiles(_.remoteBlocksFetched), - localBlocksFetched = submetricQuantiles(_.localBlocksFetched), - totalBlocksFetched = submetricQuantiles { s => - s.localBlocksFetched + s.remoteBlocksFetched - }, - fetchWaitTime = submetricQuantiles(_.fetchWaitTime) - ) - }.build - - val shuffleWriteMetrics = - new MetricHelper[v1.ShuffleWriteMetrics, v1.ShuffleWriteMetricDistributions](rawMetrics, - quantiles) { - def getSubmetrics(raw: v1.TaskMetrics): v1.ShuffleWriteMetrics = - raw.shuffleWriteMetrics - - def build: v1.ShuffleWriteMetricDistributions = new v1.ShuffleWriteMetricDistributions( - writeBytes = submetricQuantiles(_.bytesWritten), - writeRecords = submetricQuantiles(_.recordsWritten), - writeTime = submetricQuantiles(_.writeTime) - ) - }.build - - new v1.TaskMetricDistributions( + unsortedQuantiles: Array[Double]): Option[v1.TaskMetricDistributions] = { + val stageKey = Array(stageId, stageAttemptId) + val quantiles = unsortedQuantiles.sorted + + // We don't know how many tasks remain in the store that actually have metrics. So scan one + // metric and count how many valid tasks there are. Use skip() instead of next() since it's + // cheaper for disk stores (avoids deserialization). + val count = { + Utils.tryWithResource( + store.view(classOf[TaskDataWrapper]) + .parent(stageKey) + .index(TaskIndexNames.EXEC_RUN_TIME) + .first(0L) + .closeableIterator() + ) { it => + var _count = 0L + while (it.hasNext()) { + _count += 1 + it.skip(1) + } + _count + } + } + + if (count <= 0) { + return None + } + + // Find out which quantiles are already cached. The data in the store must match the expected + // task count to be considered, otherwise it will be re-scanned and overwritten. + val cachedQuantiles = quantiles.filter(shouldCacheQuantile).flatMap { q => + val qkey = Array(stageId, stageAttemptId, quantileToString(q)) + asOption(store.read(classOf[CachedQuantile], qkey)).filter(_.taskCount == count) + } + + // If there are no missing quantiles, return the data. Otherwise, just compute everything + // to make the code simpler. + if (cachedQuantiles.size == quantiles.size) { + def toValues(fn: CachedQuantile => Double): IndexedSeq[Double] = cachedQuantiles.map(fn) + + val distributions = new v1.TaskMetricDistributions( + quantiles = quantiles, + executorDeserializeTime = toValues(_.executorDeserializeTime), + executorDeserializeCpuTime = toValues(_.executorDeserializeCpuTime), + executorRunTime = toValues(_.executorRunTime), + executorCpuTime = toValues(_.executorCpuTime), + resultSize = toValues(_.resultSize), + jvmGcTime = toValues(_.jvmGcTime), + resultSerializationTime = toValues(_.resultSerializationTime), + gettingResultTime = toValues(_.gettingResultTime), + schedulerDelay = toValues(_.schedulerDelay), + peakExecutionMemory = toValues(_.peakExecutionMemory), + memoryBytesSpilled = toValues(_.memoryBytesSpilled), + diskBytesSpilled = toValues(_.diskBytesSpilled), + inputMetrics = new v1.InputMetricDistributions( + toValues(_.bytesRead), + toValues(_.recordsRead)), + outputMetrics = new v1.OutputMetricDistributions( + toValues(_.bytesWritten), + toValues(_.recordsWritten)), + shuffleReadMetrics = new v1.ShuffleReadMetricDistributions( + toValues(_.shuffleReadBytes), + toValues(_.shuffleRecordsRead), + toValues(_.shuffleRemoteBlocksFetched), + toValues(_.shuffleLocalBlocksFetched), + toValues(_.shuffleFetchWaitTime), + toValues(_.shuffleRemoteBytesRead), + toValues(_.shuffleRemoteBytesReadToDisk), + toValues(_.shuffleTotalBlocksFetched)), + shuffleWriteMetrics = new v1.ShuffleWriteMetricDistributions( + toValues(_.shuffleWriteBytes), + toValues(_.shuffleWriteRecords), + toValues(_.shuffleWriteTime))) + + return Some(distributions) + } + + // Compute quantiles by scanning the tasks in the store. This is not really stable for live + // stages (e.g. the number of recorded tasks may change while this code is running), but should + // stabilize once the stage finishes. It's also slow, especially with disk stores. + val indices = quantiles.map { q => math.min((q * count).toLong, count - 1) } + + def scanTasks(index: String)(fn: TaskDataWrapper => Long): IndexedSeq[Double] = { + Utils.tryWithResource( + store.view(classOf[TaskDataWrapper]) + .parent(stageKey) + .index(index) + .first(0L) + .closeableIterator() + ) { it => + var last = Double.NaN + var currentIdx = -1L + indices.map { idx => + if (idx == currentIdx) { + last + } else { + val diff = idx - currentIdx + currentIdx = idx + if (it.skip(diff - 1)) { + last = fn(it.next()).toDouble + last + } else { + Double.NaN + } + } + }.toIndexedSeq + } + } + + val computedQuantiles = new v1.TaskMetricDistributions( quantiles = quantiles, - executorDeserializeTime = metricQuantiles(_.executorDeserializeTime), - executorDeserializeCpuTime = metricQuantiles(_.executorDeserializeCpuTime), - executorRunTime = metricQuantiles(_.executorRunTime), - executorCpuTime = metricQuantiles(_.executorCpuTime), - resultSize = metricQuantiles(_.resultSize), - jvmGcTime = metricQuantiles(_.jvmGcTime), - resultSerializationTime = metricQuantiles(_.resultSerializationTime), - memoryBytesSpilled = metricQuantiles(_.memoryBytesSpilled), - diskBytesSpilled = metricQuantiles(_.diskBytesSpilled), - inputMetrics = inputMetrics, - outputMetrics = outputMetrics, - shuffleReadMetrics = shuffleReadMetrics, - shuffleWriteMetrics = shuffleWriteMetrics - ) + executorDeserializeTime = scanTasks(TaskIndexNames.DESER_TIME) { t => + t.executorDeserializeTime + }, + executorDeserializeCpuTime = scanTasks(TaskIndexNames.DESER_CPU_TIME) { t => + t.executorDeserializeCpuTime + }, + executorRunTime = scanTasks(TaskIndexNames.EXEC_RUN_TIME) { t => t.executorRunTime }, + executorCpuTime = scanTasks(TaskIndexNames.EXEC_CPU_TIME) { t => t.executorCpuTime }, + resultSize = scanTasks(TaskIndexNames.RESULT_SIZE) { t => t.resultSize }, + jvmGcTime = scanTasks(TaskIndexNames.GC_TIME) { t => t.jvmGcTime }, + resultSerializationTime = scanTasks(TaskIndexNames.SER_TIME) { t => + t.resultSerializationTime + }, + gettingResultTime = scanTasks(TaskIndexNames.GETTING_RESULT_TIME) { t => + t.gettingResultTime + }, + schedulerDelay = scanTasks(TaskIndexNames.SCHEDULER_DELAY) { t => t.schedulerDelay }, + peakExecutionMemory = scanTasks(TaskIndexNames.PEAK_MEM) { t => t.peakExecutionMemory }, + memoryBytesSpilled = scanTasks(TaskIndexNames.MEM_SPILL) { t => t.memoryBytesSpilled }, + diskBytesSpilled = scanTasks(TaskIndexNames.DISK_SPILL) { t => t.diskBytesSpilled }, + inputMetrics = new v1.InputMetricDistributions( + scanTasks(TaskIndexNames.INPUT_SIZE) { t => t.inputBytesRead }, + scanTasks(TaskIndexNames.INPUT_RECORDS) { t => t.inputRecordsRead }), + outputMetrics = new v1.OutputMetricDistributions( + scanTasks(TaskIndexNames.OUTPUT_SIZE) { t => t.outputBytesWritten }, + scanTasks(TaskIndexNames.OUTPUT_RECORDS) { t => t.outputRecordsWritten }), + shuffleReadMetrics = new v1.ShuffleReadMetricDistributions( + scanTasks(TaskIndexNames.SHUFFLE_TOTAL_READS) { m => + m.shuffleLocalBytesRead + m.shuffleRemoteBytesRead + }, + scanTasks(TaskIndexNames.SHUFFLE_READ_RECORDS) { t => t.shuffleRecordsRead }, + scanTasks(TaskIndexNames.SHUFFLE_REMOTE_BLOCKS) { t => t.shuffleRemoteBlocksFetched }, + scanTasks(TaskIndexNames.SHUFFLE_LOCAL_BLOCKS) { t => t.shuffleLocalBlocksFetched }, + scanTasks(TaskIndexNames.SHUFFLE_READ_TIME) { t => t.shuffleFetchWaitTime }, + scanTasks(TaskIndexNames.SHUFFLE_REMOTE_READS) { t => t.shuffleRemoteBytesRead }, + scanTasks(TaskIndexNames.SHUFFLE_REMOTE_READS_TO_DISK) { t => + t.shuffleRemoteBytesReadToDisk + }, + scanTasks(TaskIndexNames.SHUFFLE_TOTAL_BLOCKS) { m => + m.shuffleLocalBlocksFetched + m.shuffleRemoteBlocksFetched + }), + shuffleWriteMetrics = new v1.ShuffleWriteMetricDistributions( + scanTasks(TaskIndexNames.SHUFFLE_WRITE_SIZE) { t => t.shuffleBytesWritten }, + scanTasks(TaskIndexNames.SHUFFLE_WRITE_RECORDS) { t => t.shuffleRecordsWritten }, + scanTasks(TaskIndexNames.SHUFFLE_WRITE_TIME) { t => t.shuffleWriteTime })) + + // Go through the computed quantiles and cache the values that match the caching criteria. + computedQuantiles.quantiles.zipWithIndex + .filter { case (q, _) => quantiles.contains(q) && shouldCacheQuantile(q) } + .foreach { case (q, idx) => + val cached = new CachedQuantile(stageId, stageAttemptId, quantileToString(q), count, + executorDeserializeTime = computedQuantiles.executorDeserializeTime(idx), + executorDeserializeCpuTime = computedQuantiles.executorDeserializeCpuTime(idx), + executorRunTime = computedQuantiles.executorRunTime(idx), + executorCpuTime = computedQuantiles.executorCpuTime(idx), + resultSize = computedQuantiles.resultSize(idx), + jvmGcTime = computedQuantiles.jvmGcTime(idx), + resultSerializationTime = computedQuantiles.resultSerializationTime(idx), + gettingResultTime = computedQuantiles.gettingResultTime(idx), + schedulerDelay = computedQuantiles.schedulerDelay(idx), + peakExecutionMemory = computedQuantiles.peakExecutionMemory(idx), + memoryBytesSpilled = computedQuantiles.memoryBytesSpilled(idx), + diskBytesSpilled = computedQuantiles.diskBytesSpilled(idx), + + bytesRead = computedQuantiles.inputMetrics.bytesRead(idx), + recordsRead = computedQuantiles.inputMetrics.recordsRead(idx), + + bytesWritten = computedQuantiles.outputMetrics.bytesWritten(idx), + recordsWritten = computedQuantiles.outputMetrics.recordsWritten(idx), + + shuffleReadBytes = computedQuantiles.shuffleReadMetrics.readBytes(idx), + shuffleRecordsRead = computedQuantiles.shuffleReadMetrics.readRecords(idx), + shuffleRemoteBlocksFetched = + computedQuantiles.shuffleReadMetrics.remoteBlocksFetched(idx), + shuffleLocalBlocksFetched = computedQuantiles.shuffleReadMetrics.localBlocksFetched(idx), + shuffleFetchWaitTime = computedQuantiles.shuffleReadMetrics.fetchWaitTime(idx), + shuffleRemoteBytesRead = computedQuantiles.shuffleReadMetrics.remoteBytesRead(idx), + shuffleRemoteBytesReadToDisk = + computedQuantiles.shuffleReadMetrics.remoteBytesReadToDisk(idx), + shuffleTotalBlocksFetched = computedQuantiles.shuffleReadMetrics.totalBlocksFetched(idx), + + shuffleWriteBytes = computedQuantiles.shuffleWriteMetrics.writeBytes(idx), + shuffleWriteRecords = computedQuantiles.shuffleWriteMetrics.writeRecords(idx), + shuffleWriteTime = computedQuantiles.shuffleWriteMetrics.writeTime(idx)) + store.write(cached) + } + + Some(computedQuantiles) } + /** + * Whether to cache information about a specific metric quantile. We cache quantiles at every 0.05 + * step, which covers the default values used both in the API and in the stages page. + */ + private def shouldCacheQuantile(q: Double): Boolean = (math.round(q * 100) % 5) == 0 + + private def quantileToString(q: Double): String = math.round(q * 100).toString + def taskList(stageId: Int, stageAttemptId: Int, maxTasks: Int): Seq[v1.TaskData] = { val stageKey = Array(stageId, stageAttemptId) store.view(classOf[TaskDataWrapper]).index("stage").first(stageKey).last(stageKey).reverse() - .max(maxTasks).asScala.map(_.info).toSeq.reverse + .max(maxTasks).asScala.map(_.toApi).toSeq.reverse } def taskList( @@ -219,18 +354,43 @@ private[spark] class AppStatusStore( offset: Int, length: Int, sortBy: v1.TaskSorting): Seq[v1.TaskData] = { + val (indexName, ascending) = sortBy match { + case v1.TaskSorting.ID => + (None, true) + case v1.TaskSorting.INCREASING_RUNTIME => + (Some(TaskIndexNames.EXEC_RUN_TIME), true) + case v1.TaskSorting.DECREASING_RUNTIME => + (Some(TaskIndexNames.EXEC_RUN_TIME), false) + } + taskList(stageId, stageAttemptId, offset, length, indexName, ascending) + } + + def taskList( + stageId: Int, + stageAttemptId: Int, + offset: Int, + length: Int, + sortBy: Option[String], + ascending: Boolean): Seq[v1.TaskData] = { val stageKey = Array(stageId, stageAttemptId) val base = store.view(classOf[TaskDataWrapper]) val indexed = sortBy match { - case v1.TaskSorting.ID => + case Some(index) => + base.index(index).parent(stageKey) + + case _ => + // Sort by ID, which is the "stage" index. base.index("stage").first(stageKey).last(stageKey) - case v1.TaskSorting.INCREASING_RUNTIME => - base.index("runtime").first(stageKey ++ Array(-1L)).last(stageKey ++ Array(Long.MaxValue)) - case v1.TaskSorting.DECREASING_RUNTIME => - base.index("runtime").first(stageKey ++ Array(Long.MaxValue)).last(stageKey ++ Array(-1L)) - .reverse() } - indexed.skip(offset).max(length).asScala.map(_.info).toSeq + + val ordered = if (ascending) indexed else indexed.reverse() + ordered.skip(offset).max(length).asScala.map(_.toApi).toSeq + } + + def executorSummary(stageId: Int, attemptId: Int): Map[String, v1.ExecutorStageSummary] = { + val stageKey = Array(stageId, attemptId) + store.view(classOf[ExecutorStageSummaryWrapper]).index("stage").first(stageKey).last(stageKey) + .asScala.map { exec => (exec.executorId -> exec.info) }.toMap } def rddList(cachedOnly: Boolean = true): Seq[v1.RDDStorageInfo] = { @@ -256,12 +416,6 @@ private[spark] class AppStatusStore( .map { t => (t.taskId, t) } .toMap - val stageKey = Array(stage.stageId, stage.attemptId) - val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage").first(stageKey) - .last(stageKey).closeableIterator().asScala - .map { exec => (exec.executorId -> exec.info) } - .toMap - new v1.StageData( stage.status, stage.stageId, @@ -295,7 +449,7 @@ private[spark] class AppStatusStore( stage.rddIds, stage.accumulatorUpdates, Some(tasks), - Some(execs), + Some(executorSummary(stage.stageId, stage.attemptId)), stage.killedTasksSummary) } @@ -352,22 +506,3 @@ private[spark] object AppStatusStore { } } - -/** - * Helper for getting distributions from nested metric types. - */ -private abstract class MetricHelper[I, O]( - rawMetrics: Seq[v1.TaskMetrics], - quantiles: Array[Double]) { - - def getSubmetrics(raw: v1.TaskMetrics): I - - def build: O - - val data: Seq[I] = rawMetrics.map(getSubmetrics) - - /** applies the given function to all input metrics, and returns the quantiles */ - def submetricQuantiles(f: I => Double): IndexedSeq[Double] = { - Distribution(data.map { d => f(d) }).get.getQuantiles(quantiles) - } -} diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala b/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala new file mode 100644 index 0000000000000..341bd4e0cd016 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import org.apache.spark.status.api.v1.{TaskData, TaskMetrics} + +private[spark] object AppStatusUtils { + + def schedulerDelay(task: TaskData): Long = { + if (task.taskMetrics.isDefined && task.duration.isDefined) { + val m = task.taskMetrics.get + schedulerDelay(task.launchTime.getTime(), fetchStart(task), task.duration.get, + m.executorDeserializeTime, m.resultSerializationTime, m.executorRunTime) + } else { + 0L + } + } + + def gettingResultTime(task: TaskData): Long = { + gettingResultTime(task.launchTime.getTime(), fetchStart(task), task.duration.getOrElse(-1L)) + } + + def schedulerDelay( + launchTime: Long, + fetchStart: Long, + duration: Long, + deserializeTime: Long, + serializeTime: Long, + runTime: Long): Long = { + math.max(0, duration - runTime - deserializeTime - serializeTime - + gettingResultTime(launchTime, fetchStart, duration)) + } + + def gettingResultTime(launchTime: Long, fetchStart: Long, duration: Long): Long = { + if (fetchStart > 0) { + if (duration > 0) { + launchTime + duration - fetchStart + } else { + System.currentTimeMillis() - fetchStart + } + } else { + 0L + } + } + + private def fetchStart(task: TaskData): Long = { + if (task.resultFetchStart.isDefined) { + task.resultFetchStart.get.getTime() + } else { + -1 + } + } +} diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 305c2fafa6aac..4295e664e131c 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -22,6 +22,8 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.HashMap +import com.google.common.collect.Interners + import org.apache.spark.JobExecutionStatus import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, StageInfo, TaskInfo} @@ -119,7 +121,9 @@ private class LiveTask( import LiveEntityHelpers._ - private var recordedMetrics: v1.TaskMetrics = null + // The task metrics use a special value when no metrics have been reported. The special value is + // checked when calculating indexed values when writing to the store (see [[TaskDataWrapper]]). + private var metrics: v1.TaskMetrics = createMetrics(default = -1L) var errorMessage: Option[String] = None @@ -129,8 +133,8 @@ private class LiveTask( */ def updateMetrics(metrics: TaskMetrics): v1.TaskMetrics = { if (metrics != null) { - val old = recordedMetrics - recordedMetrics = new v1.TaskMetrics( + val old = this.metrics + val newMetrics = createMetrics( metrics.executorDeserializeTime, metrics.executorDeserializeCpuTime, metrics.executorRunTime, @@ -141,73 +145,35 @@ private class LiveTask( metrics.memoryBytesSpilled, metrics.diskBytesSpilled, metrics.peakExecutionMemory, - new v1.InputMetrics( - metrics.inputMetrics.bytesRead, - metrics.inputMetrics.recordsRead), - new v1.OutputMetrics( - metrics.outputMetrics.bytesWritten, - metrics.outputMetrics.recordsWritten), - new v1.ShuffleReadMetrics( - metrics.shuffleReadMetrics.remoteBlocksFetched, - metrics.shuffleReadMetrics.localBlocksFetched, - metrics.shuffleReadMetrics.fetchWaitTime, - metrics.shuffleReadMetrics.remoteBytesRead, - metrics.shuffleReadMetrics.remoteBytesReadToDisk, - metrics.shuffleReadMetrics.localBytesRead, - metrics.shuffleReadMetrics.recordsRead), - new v1.ShuffleWriteMetrics( - metrics.shuffleWriteMetrics.bytesWritten, - metrics.shuffleWriteMetrics.writeTime, - metrics.shuffleWriteMetrics.recordsWritten)) - if (old != null) calculateMetricsDelta(recordedMetrics, old) else recordedMetrics + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead, + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten, + metrics.shuffleReadMetrics.remoteBlocksFetched, + metrics.shuffleReadMetrics.localBlocksFetched, + metrics.shuffleReadMetrics.fetchWaitTime, + metrics.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.remoteBytesReadToDisk, + metrics.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead, + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.writeTime, + metrics.shuffleWriteMetrics.recordsWritten) + + this.metrics = newMetrics + + // Only calculate the delta if the old metrics contain valid information, otherwise + // the new metrics are the delta. + if (old.executorDeserializeTime >= 0L) { + subtractMetrics(newMetrics, old) + } else { + newMetrics + } } else { null } } - /** - * Return a new TaskMetrics object containing the delta of the various fields of the given - * metrics objects. This is currently targeted at updating stage data, so it does not - * necessarily calculate deltas for all the fields. - */ - private def calculateMetricsDelta( - metrics: v1.TaskMetrics, - old: v1.TaskMetrics): v1.TaskMetrics = { - val shuffleWriteDelta = new v1.ShuffleWriteMetrics( - metrics.shuffleWriteMetrics.bytesWritten - old.shuffleWriteMetrics.bytesWritten, - 0L, - metrics.shuffleWriteMetrics.recordsWritten - old.shuffleWriteMetrics.recordsWritten) - - val shuffleReadDelta = new v1.ShuffleReadMetrics( - 0L, 0L, 0L, - metrics.shuffleReadMetrics.remoteBytesRead - old.shuffleReadMetrics.remoteBytesRead, - metrics.shuffleReadMetrics.remoteBytesReadToDisk - - old.shuffleReadMetrics.remoteBytesReadToDisk, - metrics.shuffleReadMetrics.localBytesRead - old.shuffleReadMetrics.localBytesRead, - metrics.shuffleReadMetrics.recordsRead - old.shuffleReadMetrics.recordsRead) - - val inputDelta = new v1.InputMetrics( - metrics.inputMetrics.bytesRead - old.inputMetrics.bytesRead, - metrics.inputMetrics.recordsRead - old.inputMetrics.recordsRead) - - val outputDelta = new v1.OutputMetrics( - metrics.outputMetrics.bytesWritten - old.outputMetrics.bytesWritten, - metrics.outputMetrics.recordsWritten - old.outputMetrics.recordsWritten) - - new v1.TaskMetrics( - 0L, 0L, - metrics.executorRunTime - old.executorRunTime, - metrics.executorCpuTime - old.executorCpuTime, - 0L, 0L, 0L, - metrics.memoryBytesSpilled - old.memoryBytesSpilled, - metrics.diskBytesSpilled - old.diskBytesSpilled, - 0L, - inputDelta, - outputDelta, - shuffleReadDelta, - shuffleWriteDelta) - } - override protected def doUpdate(): Any = { val duration = if (info.finished) { info.duration @@ -215,22 +181,48 @@ private class LiveTask( info.timeRunning(lastUpdateTime.getOrElse(System.currentTimeMillis())) } - val task = new v1.TaskData( + new TaskDataWrapper( info.taskId, info.index, info.attemptNumber, - new Date(info.launchTime), - if (info.gettingResult) Some(new Date(info.gettingResultTime)) else None, - Some(duration), - info.executorId, - info.host, - info.status, - info.taskLocality.toString(), + info.launchTime, + if (info.gettingResult) info.gettingResultTime else -1L, + duration, + weakIntern(info.executorId), + weakIntern(info.host), + weakIntern(info.status), + weakIntern(info.taskLocality.toString()), info.speculative, newAccumulatorInfos(info.accumulables), errorMessage, - Option(recordedMetrics)) - new TaskDataWrapper(task, stageId, stageAttemptId) + + metrics.executorDeserializeTime, + metrics.executorDeserializeCpuTime, + metrics.executorRunTime, + metrics.executorCpuTime, + metrics.resultSize, + metrics.jvmGcTime, + metrics.resultSerializationTime, + metrics.memoryBytesSpilled, + metrics.diskBytesSpilled, + metrics.peakExecutionMemory, + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead, + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten, + metrics.shuffleReadMetrics.remoteBlocksFetched, + metrics.shuffleReadMetrics.localBlocksFetched, + metrics.shuffleReadMetrics.fetchWaitTime, + metrics.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.remoteBytesReadToDisk, + metrics.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead, + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.writeTime, + metrics.shuffleWriteMetrics.recordsWritten, + + stageId, + stageAttemptId) } } @@ -313,50 +305,19 @@ private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveE } -/** Metrics tracked per stage (both total and per executor). */ -private class MetricsTracker { - var executorRunTime = 0L - var executorCpuTime = 0L - var inputBytes = 0L - var inputRecords = 0L - var outputBytes = 0L - var outputRecords = 0L - var shuffleReadBytes = 0L - var shuffleReadRecords = 0L - var shuffleWriteBytes = 0L - var shuffleWriteRecords = 0L - var memoryBytesSpilled = 0L - var diskBytesSpilled = 0L - - def update(delta: v1.TaskMetrics): Unit = { - executorRunTime += delta.executorRunTime - executorCpuTime += delta.executorCpuTime - inputBytes += delta.inputMetrics.bytesRead - inputRecords += delta.inputMetrics.recordsRead - outputBytes += delta.outputMetrics.bytesWritten - outputRecords += delta.outputMetrics.recordsWritten - shuffleReadBytes += delta.shuffleReadMetrics.localBytesRead + - delta.shuffleReadMetrics.remoteBytesRead - shuffleReadRecords += delta.shuffleReadMetrics.recordsRead - shuffleWriteBytes += delta.shuffleWriteMetrics.bytesWritten - shuffleWriteRecords += delta.shuffleWriteMetrics.recordsWritten - memoryBytesSpilled += delta.memoryBytesSpilled - diskBytesSpilled += delta.diskBytesSpilled - } - -} - private class LiveExecutorStageSummary( stageId: Int, attemptId: Int, executorId: String) extends LiveEntity { + import LiveEntityHelpers._ + var taskTime = 0L var succeededTasks = 0 var failedTasks = 0 var killedTasks = 0 - val metrics = new MetricsTracker() + var metrics = createMetrics(default = 0L) override protected def doUpdate(): Any = { val info = new v1.ExecutorStageSummary( @@ -364,14 +325,14 @@ private class LiveExecutorStageSummary( failedTasks, succeededTasks, killedTasks, - metrics.inputBytes, - metrics.inputRecords, - metrics.outputBytes, - metrics.outputRecords, - metrics.shuffleReadBytes, - metrics.shuffleReadRecords, - metrics.shuffleWriteBytes, - metrics.shuffleWriteRecords, + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead, + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten, + metrics.shuffleReadMetrics.remoteBytesRead + metrics.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead, + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.recordsWritten, metrics.memoryBytesSpilled, metrics.diskBytesSpilled) new ExecutorStageSummaryWrapper(stageId, attemptId, executorId, info) @@ -402,7 +363,9 @@ private class LiveStage extends LiveEntity { var firstLaunchTime = Long.MaxValue - val metrics = new MetricsTracker() + var localitySummary: Map[String, Long] = Map() + + var metrics = createMetrics(default = 0L) val executorSummaries = new HashMap[String, LiveExecutorStageSummary]() @@ -435,14 +398,14 @@ private class LiveStage extends LiveEntity { info.completionTime.map(new Date(_)), info.failureReason, - metrics.inputBytes, - metrics.inputRecords, - metrics.outputBytes, - metrics.outputRecords, - metrics.shuffleReadBytes, - metrics.shuffleReadRecords, - metrics.shuffleWriteBytes, - metrics.shuffleWriteRecords, + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead, + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten, + metrics.shuffleReadMetrics.localBytesRead + metrics.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.recordsRead, + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.recordsWritten, metrics.memoryBytesSpilled, metrics.diskBytesSpilled, @@ -459,13 +422,15 @@ private class LiveStage extends LiveEntity { } override protected def doUpdate(): Any = { - new StageDataWrapper(toApi(), jobIds) + new StageDataWrapper(toApi(), jobIds, localitySummary) } } private class LiveRDDPartition(val blockName: String) { + import LiveEntityHelpers._ + // Pointers used by RDDPartitionSeq. @volatile var prev: LiveRDDPartition = null @volatile var next: LiveRDDPartition = null @@ -485,7 +450,7 @@ private class LiveRDDPartition(val blockName: String) { diskUsed: Long): Unit = { value = new v1.RDDPartitionInfo( blockName, - storageLevel, + weakIntern(storageLevel), memoryUsed, diskUsed, executors) @@ -495,6 +460,8 @@ private class LiveRDDPartition(val blockName: String) { private class LiveRDDDistribution(exec: LiveExecutor) { + import LiveEntityHelpers._ + val executorId = exec.executorId var memoryUsed = 0L var diskUsed = 0L @@ -508,7 +475,7 @@ private class LiveRDDDistribution(exec: LiveExecutor) { def toApi(): v1.RDDDataDistribution = { if (lastUpdate == null) { lastUpdate = new v1.RDDDataDistribution( - exec.hostPort, + weakIntern(exec.hostPort), memoryUsed, exec.maxMemory - exec.memoryUsed, diskUsed, @@ -524,7 +491,9 @@ private class LiveRDDDistribution(exec: LiveExecutor) { private class LiveRDD(val info: RDDInfo) extends LiveEntity { - var storageLevel: String = info.storageLevel.description + import LiveEntityHelpers._ + + var storageLevel: String = weakIntern(info.storageLevel.description) var memoryUsed = 0L var diskUsed = 0L @@ -533,6 +502,10 @@ private class LiveRDD(val info: RDDInfo) extends LiveEntity { private val distributions = new HashMap[String, LiveRDDDistribution]() + def setStorageLevel(level: String): Unit = { + this.storageLevel = weakIntern(level) + } + def partition(blockName: String): LiveRDDPartition = { partitions.getOrElseUpdate(blockName, { val part = new LiveRDDPartition(blockName) @@ -593,6 +566,9 @@ private class SchedulerPool(name: String) extends LiveEntity { private object LiveEntityHelpers { + private val stringInterner = Interners.newWeakInterner[String]() + + def newAccumulatorInfos(accums: Iterable[AccumulableInfo]): Seq[v1.AccumulableInfo] = { accums .filter { acc => @@ -604,13 +580,119 @@ private object LiveEntityHelpers { .map { acc => new v1.AccumulableInfo( acc.id, - acc.name.orNull, + acc.name.map(weakIntern).orNull, acc.update.map(_.toString()), acc.value.map(_.toString()).orNull) } .toSeq } + /** String interning to reduce the memory usage. */ + def weakIntern(s: String): String = { + stringInterner.intern(s) + } + + // scalastyle:off argcount + def createMetrics( + executorDeserializeTime: Long, + executorDeserializeCpuTime: Long, + executorRunTime: Long, + executorCpuTime: Long, + resultSize: Long, + jvmGcTime: Long, + resultSerializationTime: Long, + memoryBytesSpilled: Long, + diskBytesSpilled: Long, + peakExecutionMemory: Long, + inputBytesRead: Long, + inputRecordsRead: Long, + outputBytesWritten: Long, + outputRecordsWritten: Long, + shuffleRemoteBlocksFetched: Long, + shuffleLocalBlocksFetched: Long, + shuffleFetchWaitTime: Long, + shuffleRemoteBytesRead: Long, + shuffleRemoteBytesReadToDisk: Long, + shuffleLocalBytesRead: Long, + shuffleRecordsRead: Long, + shuffleBytesWritten: Long, + shuffleWriteTime: Long, + shuffleRecordsWritten: Long): v1.TaskMetrics = { + new v1.TaskMetrics( + executorDeserializeTime, + executorDeserializeCpuTime, + executorRunTime, + executorCpuTime, + resultSize, + jvmGcTime, + resultSerializationTime, + memoryBytesSpilled, + diskBytesSpilled, + peakExecutionMemory, + new v1.InputMetrics( + inputBytesRead, + inputRecordsRead), + new v1.OutputMetrics( + outputBytesWritten, + outputRecordsWritten), + new v1.ShuffleReadMetrics( + shuffleRemoteBlocksFetched, + shuffleLocalBlocksFetched, + shuffleFetchWaitTime, + shuffleRemoteBytesRead, + shuffleRemoteBytesReadToDisk, + shuffleLocalBytesRead, + shuffleRecordsRead), + new v1.ShuffleWriteMetrics( + shuffleBytesWritten, + shuffleWriteTime, + shuffleRecordsWritten)) + } + // scalastyle:on argcount + + def createMetrics(default: Long): v1.TaskMetrics = { + createMetrics(default, default, default, default, default, default, default, default, + default, default, default, default, default, default, default, default, + default, default, default, default, default, default, default, default) + } + + /** Add m2 values to m1. */ + def addMetrics(m1: v1.TaskMetrics, m2: v1.TaskMetrics): v1.TaskMetrics = addMetrics(m1, m2, 1) + + /** Subtract m2 values from m1. */ + def subtractMetrics(m1: v1.TaskMetrics, m2: v1.TaskMetrics): v1.TaskMetrics = { + addMetrics(m1, m2, -1) + } + + private def addMetrics(m1: v1.TaskMetrics, m2: v1.TaskMetrics, mult: Int): v1.TaskMetrics = { + createMetrics( + m1.executorDeserializeTime + m2.executorDeserializeTime * mult, + m1.executorDeserializeCpuTime + m2.executorDeserializeCpuTime * mult, + m1.executorRunTime + m2.executorRunTime * mult, + m1.executorCpuTime + m2.executorCpuTime * mult, + m1.resultSize + m2.resultSize * mult, + m1.jvmGcTime + m2.jvmGcTime * mult, + m1.resultSerializationTime + m2.resultSerializationTime * mult, + m1.memoryBytesSpilled + m2.memoryBytesSpilled * mult, + m1.diskBytesSpilled + m2.diskBytesSpilled * mult, + m1.peakExecutionMemory + m2.peakExecutionMemory * mult, + m1.inputMetrics.bytesRead + m2.inputMetrics.bytesRead * mult, + m1.inputMetrics.recordsRead + m2.inputMetrics.recordsRead * mult, + m1.outputMetrics.bytesWritten + m2.outputMetrics.bytesWritten * mult, + m1.outputMetrics.recordsWritten + m2.outputMetrics.recordsWritten * mult, + m1.shuffleReadMetrics.remoteBlocksFetched + m2.shuffleReadMetrics.remoteBlocksFetched * mult, + m1.shuffleReadMetrics.localBlocksFetched + m2.shuffleReadMetrics.localBlocksFetched * mult, + m1.shuffleReadMetrics.fetchWaitTime + m2.shuffleReadMetrics.fetchWaitTime * mult, + m1.shuffleReadMetrics.remoteBytesRead + m2.shuffleReadMetrics.remoteBytesRead * mult, + m1.shuffleReadMetrics.remoteBytesReadToDisk + + m2.shuffleReadMetrics.remoteBytesReadToDisk * mult, + m1.shuffleReadMetrics.localBytesRead + m2.shuffleReadMetrics.localBytesRead * mult, + m1.shuffleReadMetrics.recordsRead + m2.shuffleReadMetrics.recordsRead * mult, + m1.shuffleWriteMetrics.bytesWritten + m2.shuffleWriteMetrics.bytesWritten * mult, + m1.shuffleWriteMetrics.writeTime + m2.shuffleWriteMetrics.writeTime * mult, + m1.shuffleWriteMetrics.recordsWritten + m2.shuffleWriteMetrics.recordsWritten * mult) + } + } /** diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala index 3b879545b3d2e..96249e4bfd5fa 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala @@ -87,7 +87,8 @@ private[v1] class StagesResource extends BaseAppResource { } } - ui.store.taskSummary(stageId, stageAttemptId, quantiles) + ui.store.taskSummary(stageId, stageAttemptId, quantiles).getOrElse( + throw new NotFoundException(s"No tasks reported metrics for $stageId / $stageAttemptId yet.")) } @GET diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 45eaf935fb083..7d8e4de3c8efb 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -261,6 +261,9 @@ class TaskMetricDistributions private[spark]( val resultSize: IndexedSeq[Double], val jvmGcTime: IndexedSeq[Double], val resultSerializationTime: IndexedSeq[Double], + val gettingResultTime: IndexedSeq[Double], + val schedulerDelay: IndexedSeq[Double], + val peakExecutionMemory: IndexedSeq[Double], val memoryBytesSpilled: IndexedSeq[Double], val diskBytesSpilled: IndexedSeq[Double], diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index 1cfd30df49091..c9cb996a55fcc 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -17,9 +17,11 @@ package org.apache.spark.status -import java.lang.{Integer => JInteger, Long => JLong} +import java.lang.{Long => JLong} +import java.util.Date import com.fasterxml.jackson.annotation.JsonIgnore +import com.fasterxml.jackson.databind.annotation.JsonDeserialize import org.apache.spark.status.KVUtils._ import org.apache.spark.status.api.v1._ @@ -49,10 +51,10 @@ private[spark] class ApplicationEnvironmentInfoWrapper(val info: ApplicationEnvi private[spark] class ExecutorSummaryWrapper(val info: ExecutorSummary) { @JsonIgnore @KVIndex - private[this] val id: String = info.id + private def id: String = info.id @JsonIgnore @KVIndex("active") - private[this] val active: Boolean = info.isActive + private def active: Boolean = info.isActive @JsonIgnore @KVIndex("host") val host: String = info.hostPort.split(":")(0) @@ -69,51 +71,271 @@ private[spark] class JobDataWrapper( val skippedStages: Set[Int]) { @JsonIgnore @KVIndex - private[this] val id: Int = info.jobId + private def id: Int = info.jobId } private[spark] class StageDataWrapper( val info: StageData, - val jobIds: Set[Int]) { + val jobIds: Set[Int], + @JsonDeserialize(contentAs = classOf[JLong]) + val locality: Map[String, Long]) { @JsonIgnore @KVIndex - def id: Array[Int] = Array(info.stageId, info.attemptId) + private[this] val id: Array[Int] = Array(info.stageId, info.attemptId) @JsonIgnore @KVIndex("stageId") - def stageId: Int = info.stageId + private def stageId: Int = info.stageId + @JsonIgnore @KVIndex("active") + private def active: Boolean = info.status == StageStatus.ACTIVE + +} + +/** + * Tasks have a lot of indices that are used in a few different places. This object keeps logical + * names for these indices, mapped to short strings to save space when using a disk store. + */ +private[spark] object TaskIndexNames { + final val ACCUMULATORS = "acc" + final val ATTEMPT = "att" + final val DESER_CPU_TIME = "dct" + final val DESER_TIME = "des" + final val DISK_SPILL = "dbs" + final val DURATION = "dur" + final val ERROR = "err" + final val EXECUTOR = "exe" + final val EXEC_CPU_TIME = "ect" + final val EXEC_RUN_TIME = "ert" + final val GC_TIME = "gc" + final val GETTING_RESULT_TIME = "grt" + final val INPUT_RECORDS = "ir" + final val INPUT_SIZE = "is" + final val LAUNCH_TIME = "lt" + final val LOCALITY = "loc" + final val MEM_SPILL = "mbs" + final val OUTPUT_RECORDS = "or" + final val OUTPUT_SIZE = "os" + final val PEAK_MEM = "pem" + final val RESULT_SIZE = "rs" + final val SCHEDULER_DELAY = "dly" + final val SER_TIME = "rst" + final val SHUFFLE_LOCAL_BLOCKS = "slbl" + final val SHUFFLE_READ_RECORDS = "srr" + final val SHUFFLE_READ_TIME = "srt" + final val SHUFFLE_REMOTE_BLOCKS = "srbl" + final val SHUFFLE_REMOTE_READS = "srby" + final val SHUFFLE_REMOTE_READS_TO_DISK = "srbd" + final val SHUFFLE_TOTAL_READS = "stby" + final val SHUFFLE_TOTAL_BLOCKS = "stbl" + final val SHUFFLE_WRITE_RECORDS = "swr" + final val SHUFFLE_WRITE_SIZE = "sws" + final val SHUFFLE_WRITE_TIME = "swt" + final val STAGE = "stage" + final val STATUS = "sta" + final val TASK_INDEX = "idx" } /** - * The task information is always indexed with the stage ID, since that is how the UI and API - * consume it. That means every indexed value has the stage ID and attempt ID included, aside - * from the actual data being indexed. + * Unlike other data types, the task data wrapper does not keep a reference to the API's TaskData. + * That is to save memory, since for large applications there can be a large number of these + * elements (by default up to 100,000 per stage), and every bit of wasted memory adds up. + * + * It also contains many secondary indices, which are used to sort data efficiently in the UI at the + * expense of storage space (and slower write times). */ private[spark] class TaskDataWrapper( - val info: TaskData, + // Storing this as an object actually saves memory; it's also used as the key in the in-memory + // store, so in that case you'd save the extra copy of the value here. + @KVIndexParam + val taskId: JLong, + @KVIndexParam(value = TaskIndexNames.TASK_INDEX, parent = TaskIndexNames.STAGE) + val index: Int, + @KVIndexParam(value = TaskIndexNames.ATTEMPT, parent = TaskIndexNames.STAGE) + val attempt: Int, + @KVIndexParam(value = TaskIndexNames.LAUNCH_TIME, parent = TaskIndexNames.STAGE) + val launchTime: Long, + val resultFetchStart: Long, + @KVIndexParam(value = TaskIndexNames.DURATION, parent = TaskIndexNames.STAGE) + val duration: Long, + @KVIndexParam(value = TaskIndexNames.EXECUTOR, parent = TaskIndexNames.STAGE) + val executorId: String, + val host: String, + @KVIndexParam(value = TaskIndexNames.STATUS, parent = TaskIndexNames.STAGE) + val status: String, + @KVIndexParam(value = TaskIndexNames.LOCALITY, parent = TaskIndexNames.STAGE) + val taskLocality: String, + val speculative: Boolean, + val accumulatorUpdates: Seq[AccumulableInfo], + val errorMessage: Option[String], + + // The following is an exploded view of a TaskMetrics API object. This saves 5 objects + // (= 80 bytes of Java object overhead) per instance of this wrapper. If the first value + // (executorDeserializeTime) is -1L, it means the metrics for this task have not been + // recorded. + @KVIndexParam(value = TaskIndexNames.DESER_TIME, parent = TaskIndexNames.STAGE) + val executorDeserializeTime: Long, + @KVIndexParam(value = TaskIndexNames.DESER_CPU_TIME, parent = TaskIndexNames.STAGE) + val executorDeserializeCpuTime: Long, + @KVIndexParam(value = TaskIndexNames.EXEC_RUN_TIME, parent = TaskIndexNames.STAGE) + val executorRunTime: Long, + @KVIndexParam(value = TaskIndexNames.EXEC_CPU_TIME, parent = TaskIndexNames.STAGE) + val executorCpuTime: Long, + @KVIndexParam(value = TaskIndexNames.RESULT_SIZE, parent = TaskIndexNames.STAGE) + val resultSize: Long, + @KVIndexParam(value = TaskIndexNames.GC_TIME, parent = TaskIndexNames.STAGE) + val jvmGcTime: Long, + @KVIndexParam(value = TaskIndexNames.SER_TIME, parent = TaskIndexNames.STAGE) + val resultSerializationTime: Long, + @KVIndexParam(value = TaskIndexNames.MEM_SPILL, parent = TaskIndexNames.STAGE) + val memoryBytesSpilled: Long, + @KVIndexParam(value = TaskIndexNames.DISK_SPILL, parent = TaskIndexNames.STAGE) + val diskBytesSpilled: Long, + @KVIndexParam(value = TaskIndexNames.PEAK_MEM, parent = TaskIndexNames.STAGE) + val peakExecutionMemory: Long, + @KVIndexParam(value = TaskIndexNames.INPUT_SIZE, parent = TaskIndexNames.STAGE) + val inputBytesRead: Long, + @KVIndexParam(value = TaskIndexNames.INPUT_RECORDS, parent = TaskIndexNames.STAGE) + val inputRecordsRead: Long, + @KVIndexParam(value = TaskIndexNames.OUTPUT_SIZE, parent = TaskIndexNames.STAGE) + val outputBytesWritten: Long, + @KVIndexParam(value = TaskIndexNames.OUTPUT_RECORDS, parent = TaskIndexNames.STAGE) + val outputRecordsWritten: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_REMOTE_BLOCKS, parent = TaskIndexNames.STAGE) + val shuffleRemoteBlocksFetched: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_LOCAL_BLOCKS, parent = TaskIndexNames.STAGE) + val shuffleLocalBlocksFetched: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_READ_TIME, parent = TaskIndexNames.STAGE) + val shuffleFetchWaitTime: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_REMOTE_READS, parent = TaskIndexNames.STAGE) + val shuffleRemoteBytesRead: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_REMOTE_READS_TO_DISK, + parent = TaskIndexNames.STAGE) + val shuffleRemoteBytesReadToDisk: Long, + val shuffleLocalBytesRead: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_READ_RECORDS, parent = TaskIndexNames.STAGE) + val shuffleRecordsRead: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_WRITE_SIZE, parent = TaskIndexNames.STAGE) + val shuffleBytesWritten: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_WRITE_TIME, parent = TaskIndexNames.STAGE) + val shuffleWriteTime: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_WRITE_RECORDS, parent = TaskIndexNames.STAGE) + val shuffleRecordsWritten: Long, + val stageId: Int, val stageAttemptId: Int) { - @JsonIgnore @KVIndex - def id: Long = info.taskId + def hasMetrics: Boolean = executorDeserializeTime >= 0 + + def toApi: TaskData = { + val metrics = if (hasMetrics) { + Some(new TaskMetrics( + executorDeserializeTime, + executorDeserializeCpuTime, + executorRunTime, + executorCpuTime, + resultSize, + jvmGcTime, + resultSerializationTime, + memoryBytesSpilled, + diskBytesSpilled, + peakExecutionMemory, + new InputMetrics( + inputBytesRead, + inputRecordsRead), + new OutputMetrics( + outputBytesWritten, + outputRecordsWritten), + new ShuffleReadMetrics( + shuffleRemoteBlocksFetched, + shuffleLocalBlocksFetched, + shuffleFetchWaitTime, + shuffleRemoteBytesRead, + shuffleRemoteBytesReadToDisk, + shuffleLocalBytesRead, + shuffleRecordsRead), + new ShuffleWriteMetrics( + shuffleBytesWritten, + shuffleWriteTime, + shuffleRecordsWritten))) + } else { + None + } - @JsonIgnore @KVIndex("stage") - def stage: Array[Int] = Array(stageId, stageAttemptId) + new TaskData( + taskId, + index, + attempt, + new Date(launchTime), + if (resultFetchStart > 0L) Some(new Date(resultFetchStart)) else None, + if (duration > 0L) Some(duration) else None, + executorId, + host, + status, + taskLocality, + speculative, + accumulatorUpdates, + errorMessage, + metrics) + } + + @JsonIgnore @KVIndex(TaskIndexNames.STAGE) + private def stage: Array[Int] = Array(stageId, stageAttemptId) - @JsonIgnore @KVIndex("runtime") - def runtime: Array[AnyRef] = { - val _runtime = info.taskMetrics.map(_.executorRunTime).getOrElse(-1L) - Array(stageId: JInteger, stageAttemptId: JInteger, _runtime: JLong) + @JsonIgnore @KVIndex(value = TaskIndexNames.SCHEDULER_DELAY, parent = TaskIndexNames.STAGE) + def schedulerDelay: Long = { + if (hasMetrics) { + AppStatusUtils.schedulerDelay(launchTime, resultFetchStart, duration, executorDeserializeTime, + resultSerializationTime, executorRunTime) + } else { + -1L + } } - @JsonIgnore @KVIndex("startTime") - def startTime: Array[AnyRef] = { - Array(stageId: JInteger, stageAttemptId: JInteger, info.launchTime.getTime(): JLong) + @JsonIgnore @KVIndex(value = TaskIndexNames.GETTING_RESULT_TIME, parent = TaskIndexNames.STAGE) + def gettingResultTime: Long = { + if (hasMetrics) { + AppStatusUtils.gettingResultTime(launchTime, resultFetchStart, duration) + } else { + -1L + } } - @JsonIgnore @KVIndex("active") - def active: Boolean = info.duration.isEmpty + /** + * Sorting by accumulators is a little weird, and the previous behavior would generate + * insanely long keys in the index. So this implementation just considers the first + * accumulator and its String representation. + */ + @JsonIgnore @KVIndex(value = TaskIndexNames.ACCUMULATORS, parent = TaskIndexNames.STAGE) + private def accumulators: String = { + if (accumulatorUpdates.nonEmpty) { + val acc = accumulatorUpdates.head + s"${acc.name}:${acc.value}" + } else { + "" + } + } + + @JsonIgnore @KVIndex(value = TaskIndexNames.SHUFFLE_TOTAL_READS, parent = TaskIndexNames.STAGE) + private def shuffleTotalReads: Long = { + if (hasMetrics) { + shuffleLocalBytesRead + shuffleRemoteBytesRead + } else { + -1L + } + } + + @JsonIgnore @KVIndex(value = TaskIndexNames.SHUFFLE_TOTAL_BLOCKS, parent = TaskIndexNames.STAGE) + private def shuffleTotalBlocks: Long = { + if (hasMetrics) { + shuffleLocalBlocksFetched + shuffleRemoteBlocksFetched + } else { + -1L + } + } + + @JsonIgnore @KVIndex(value = TaskIndexNames.ERROR, parent = TaskIndexNames.STAGE) + private def error: String = if (errorMessage.isDefined) errorMessage.get else "" } @@ -134,10 +356,13 @@ private[spark] class ExecutorStageSummaryWrapper( val info: ExecutorStageSummary) { @JsonIgnore @KVIndex - val id: Array[Any] = Array(stageId, stageAttemptId, executorId) + private val _id: Array[Any] = Array(stageId, stageAttemptId, executorId) @JsonIgnore @KVIndex("stage") - private[this] val stage: Array[Int] = Array(stageId, stageAttemptId) + private def stage: Array[Int] = Array(stageId, stageAttemptId) + + @JsonIgnore + def id: Array[Any] = _id } @@ -203,3 +428,53 @@ private[spark] class AppSummary( def id: String = classOf[AppSummary].getName() } + +/** + * A cached view of a specific quantile for one stage attempt's metrics. + */ +private[spark] class CachedQuantile( + val stageId: Int, + val stageAttemptId: Int, + val quantile: String, + val taskCount: Long, + + // The following fields are an exploded view of a single entry for TaskMetricDistributions. + val executorDeserializeTime: Double, + val executorDeserializeCpuTime: Double, + val executorRunTime: Double, + val executorCpuTime: Double, + val resultSize: Double, + val jvmGcTime: Double, + val resultSerializationTime: Double, + val gettingResultTime: Double, + val schedulerDelay: Double, + val peakExecutionMemory: Double, + val memoryBytesSpilled: Double, + val diskBytesSpilled: Double, + + val bytesRead: Double, + val recordsRead: Double, + + val bytesWritten: Double, + val recordsWritten: Double, + + val shuffleReadBytes: Double, + val shuffleRecordsRead: Double, + val shuffleRemoteBlocksFetched: Double, + val shuffleLocalBlocksFetched: Double, + val shuffleFetchWaitTime: Double, + val shuffleRemoteBytesRead: Double, + val shuffleRemoteBytesReadToDisk: Double, + val shuffleTotalBlocksFetched: Double, + + val shuffleWriteBytes: Double, + val shuffleWriteRecords: Double, + val shuffleWriteTime: Double) { + + @KVIndex @JsonIgnore + def id: Array[Any] = Array(stageId, stageAttemptId, quantile) + + @KVIndex("stage") @JsonIgnore + def stage: Array[Int] = Array(stageId, stageAttemptId) + +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 41d42b52430a5..95c12b1e73653 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -87,7 +87,9 @@ private[ui] class ExecutorTable(stage: StageData, store: AppStatusStore) { } private def createExecutorTable(stage: StageData) : Seq[Node] = { - stage.executorSummary.getOrElse(Map.empty).toSeq.sortBy(_._1).map { case (k, v) => + val executorSummary = store.executorSummary(stage.stageId, stage.attemptId) + + executorSummary.toSeq.sortBy(_._1).map { case (k, v) => val executor = store.asOption(store.executorSummary(k)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 740f12e7d13d4..bf59152c8c0cd 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -201,7 +201,7 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP val stages = jobData.stageIds.map { stageId => // This could be empty if the listener hasn't received information about the // stage or if the stage information has been garbage collected - store.stageData(stageId).lastOption.getOrElse { + store.asOption(store.lastStageAttempt(stageId)).getOrElse { new v1.StageData( v1.StageStatus.PENDING, stageId, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 11a6a34344976..7c6e06cf183ba 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -19,6 +19,7 @@ package org.apache.spark.ui.jobs import java.net.URLEncoder import java.util.Date +import java.util.concurrent.TimeUnit import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{HashMap, HashSet} @@ -29,15 +30,14 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.SparkConf import org.apache.spark.internal.config._ import org.apache.spark.scheduler.TaskLocality -import org.apache.spark.status.AppStatusStore +import org.apache.spark.status._ import org.apache.spark.status.api.v1._ import org.apache.spark.ui._ -import org.apache.spark.util.{Distribution, Utils} +import org.apache.spark.util.Utils /** Page showing statistics and task list for a given stage */ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends WebUIPage("stage") { import ApiHelper._ - import StagePage._ private val TIMELINE_LEGEND = {
@@ -67,17 +67,17 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We // if we find that it's okay. private val MAX_TIMELINE_TASKS = parent.conf.getInt("spark.ui.timeline.tasks.maximum", 1000) - private def getLocalitySummaryString(stageData: StageData, taskList: Seq[TaskData]): String = { - val localities = taskList.map(_.taskLocality) - val localityCounts = localities.groupBy(identity).mapValues(_.size) + private def getLocalitySummaryString(localitySummary: Map[String, Long]): String = { val names = Map( TaskLocality.PROCESS_LOCAL.toString() -> "Process local", TaskLocality.NODE_LOCAL.toString() -> "Node local", TaskLocality.RACK_LOCAL.toString() -> "Rack local", TaskLocality.ANY.toString() -> "Any") - val localityNamesAndCounts = localityCounts.toSeq.map { case (locality, count) => - s"${names(locality)}: $count" - } + val localityNamesAndCounts = names.flatMap { case (key, name) => + localitySummary.get(key).map { count => + s"$name: $count" + } + }.toSeq localityNamesAndCounts.sorted.mkString("; ") } @@ -108,7 +108,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val stageHeader = s"Details for Stage $stageId (Attempt $stageAttemptId)" val stageData = parent.store - .asOption(parent.store.stageAttempt(stageId, stageAttemptId, details = true)) + .asOption(parent.store.stageAttempt(stageId, stageAttemptId, details = false)) .getOrElse { val content =
@@ -117,8 +117,11 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We return UIUtils.headerSparkPage(stageHeader, content, parent) } - val tasks = stageData.tasks.getOrElse(Map.empty).values.toSeq - if (tasks.isEmpty) { + val localitySummary = store.localitySummary(stageData.stageId, stageData.attemptId) + + val totalTasks = stageData.numActiveTasks + stageData.numCompleteTasks + + stageData.numFailedTasks + stageData.numKilledTasks + if (totalTasks == 0) { val content =

Summary Metrics

No tasks have started yet @@ -127,18 +130,14 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We return UIUtils.headerSparkPage(stageHeader, content, parent) } + val storedTasks = store.taskCount(stageData.stageId, stageData.attemptId) val numCompleted = stageData.numCompleteTasks - val totalTasks = stageData.numActiveTasks + stageData.numCompleteTasks + - stageData.numFailedTasks + stageData.numKilledTasks - val totalTasksNumStr = if (totalTasks == tasks.size) { + val totalTasksNumStr = if (totalTasks == storedTasks) { s"$totalTasks" } else { - s"$totalTasks, showing ${tasks.size}" + s"$totalTasks, showing ${storedTasks}" } - val externalAccumulables = stageData.accumulatorUpdates - val hasAccumulators = externalAccumulables.size > 0 - val summary =
    @@ -148,7 +147,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
  • Locality Level Summary: - {getLocalitySummaryString(stageData, tasks)} + {getLocalitySummaryString(localitySummary)}
  • {if (hasInput(stageData)) {
  • @@ -266,7 +265,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val accumulableTable = UIUtils.listingTable( accumulableHeaders, accumulableRow, - externalAccumulables.toSeq) + stageData.accumulatorUpdates.toSeq) val page: Int = { // If the user has changed to a larger page size, then go to page 1 in order to avoid @@ -280,16 +279,9 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val currentTime = System.currentTimeMillis() val (taskTable, taskTableHTML) = try { val _taskTable = new TaskPagedTable( - parent.conf, + stageData, UIUtils.prependBaseUri(parent.basePath) + s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", - tasks, - hasAccumulators, - hasInput(stageData), - hasOutput(stageData), - hasShuffleRead(stageData), - hasShuffleWrite(stageData), - hasBytesSpilled(stageData), currentTime, pageSize = taskPageSize, sortColumn = taskSortColumn, @@ -320,217 +312,155 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We | } |}); """.stripMargin - } + } } - val taskIdsInPage = if (taskTable == null) Set.empty[Long] - else taskTable.dataSource.slicedTaskIds + val metricsSummary = store.taskSummary(stageData.stageId, stageData.attemptId, + Array(0, 0.25, 0.5, 0.75, 1.0)) - // Excludes tasks which failed and have incomplete metrics - val validTasks = tasks.filter(t => t.status == "SUCCESS" && t.taskMetrics.isDefined) - - val summaryTable: Option[Seq[Node]] = - if (validTasks.size == 0) { - None - } else { - def getDistributionQuantiles(data: Seq[Double]): IndexedSeq[Double] = { - Distribution(data).get.getQuantiles() - } - def getFormattedTimeQuantiles(times: Seq[Double]): Seq[Node] = { - getDistributionQuantiles(times).map { millis => - {UIUtils.formatDuration(millis.toLong)} - } - } - def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = { - getDistributionQuantiles(data).map(d => {Utils.bytesToString(d.toLong)}) + val summaryTable = metricsSummary.map { metrics => + def timeQuantiles(data: IndexedSeq[Double]): Seq[Node] = { + data.map { millis => + {UIUtils.formatDuration(millis.toLong)} } + } - val deserializationTimes = validTasks.map { task => - task.taskMetrics.get.executorDeserializeTime.toDouble - } - val deserializationQuantiles = - - - Task Deserialization Time - - +: getFormattedTimeQuantiles(deserializationTimes) - - val serviceTimes = validTasks.map(_.taskMetrics.get.executorRunTime.toDouble) - val serviceQuantiles = Duration +: getFormattedTimeQuantiles(serviceTimes) - - val gcTimes = validTasks.map(_.taskMetrics.get.jvmGcTime.toDouble) - val gcQuantiles = - - GC Time - - +: getFormattedTimeQuantiles(gcTimes) - - val serializationTimes = validTasks.map(_.taskMetrics.get.resultSerializationTime.toDouble) - val serializationQuantiles = - - - Result Serialization Time - - +: getFormattedTimeQuantiles(serializationTimes) - - val gettingResultTimes = validTasks.map(getGettingResultTime(_, currentTime).toDouble) - val gettingResultQuantiles = - - - Getting Result Time - - +: - getFormattedTimeQuantiles(gettingResultTimes) - - val peakExecutionMemory = validTasks.map(_.taskMetrics.get.peakExecutionMemory.toDouble) - val peakExecutionMemoryQuantiles = { - - - Peak Execution Memory - - +: getFormattedSizeQuantiles(peakExecutionMemory) + def sizeQuantiles(data: IndexedSeq[Double]): Seq[Node] = { + data.map { size => + {Utils.bytesToString(size.toLong)} } + } - // The scheduler delay includes the network delay to send the task to the worker - // machine and to send back the result (but not the time to fetch the task result, - // if it needed to be fetched from the block manager on the worker). - val schedulerDelays = validTasks.map { task => - getSchedulerDelay(task, task.taskMetrics.get, currentTime).toDouble - } - val schedulerDelayTitle = Scheduler Delay - val schedulerDelayQuantiles = schedulerDelayTitle +: - getFormattedTimeQuantiles(schedulerDelays) - def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double]) - : Seq[Elem] = { - val recordDist = getDistributionQuantiles(records).iterator - getDistributionQuantiles(data).map(d => - {s"${Utils.bytesToString(d.toLong)} / ${recordDist.next().toLong}"} - ) + def sizeQuantilesWithRecords( + data: IndexedSeq[Double], + records: IndexedSeq[Double]) : Seq[Node] = { + data.zip(records).map { case (d, r) => + {s"${Utils.bytesToString(d.toLong)} / ${r.toLong}"} } + } - val inputSizes = validTasks.map(_.taskMetrics.get.inputMetrics.bytesRead.toDouble) - val inputRecords = validTasks.map(_.taskMetrics.get.inputMetrics.recordsRead.toDouble) - val inputQuantiles = Input Size / Records +: - getFormattedSizeQuantilesWithRecords(inputSizes, inputRecords) + def titleCell(title: String, tooltip: String): Seq[Node] = { + + + {title} + + + } - val outputSizes = validTasks.map(_.taskMetrics.get.outputMetrics.bytesWritten.toDouble) - val outputRecords = validTasks.map(_.taskMetrics.get.outputMetrics.recordsWritten.toDouble) - val outputQuantiles = Output Size / Records +: - getFormattedSizeQuantilesWithRecords(outputSizes, outputRecords) + def simpleTitleCell(title: String): Seq[Node] = {title} - val shuffleReadBlockedTimes = validTasks.map { task => - task.taskMetrics.get.shuffleReadMetrics.fetchWaitTime.toDouble - } - val shuffleReadBlockedQuantiles = - - - Shuffle Read Blocked Time - - +: - getFormattedTimeQuantiles(shuffleReadBlockedTimes) - - val shuffleReadTotalSizes = validTasks.map { task => - totalBytesRead(task.taskMetrics.get.shuffleReadMetrics).toDouble - } - val shuffleReadTotalRecords = validTasks.map { task => - task.taskMetrics.get.shuffleReadMetrics.recordsRead.toDouble - } - val shuffleReadTotalQuantiles = - - - Shuffle Read Size / Records - - +: - getFormattedSizeQuantilesWithRecords(shuffleReadTotalSizes, shuffleReadTotalRecords) - - val shuffleReadRemoteSizes = validTasks.map { task => - task.taskMetrics.get.shuffleReadMetrics.remoteBytesRead.toDouble - } - val shuffleReadRemoteQuantiles = - - - Shuffle Remote Reads - - +: - getFormattedSizeQuantiles(shuffleReadRemoteSizes) - - val shuffleWriteSizes = validTasks.map { task => - task.taskMetrics.get.shuffleWriteMetrics.bytesWritten.toDouble - } + val deserializationQuantiles = titleCell("Task Deserialization Time", + ToolTips.TASK_DESERIALIZATION_TIME) ++ timeQuantiles(metrics.executorDeserializeTime) - val shuffleWriteRecords = validTasks.map { task => - task.taskMetrics.get.shuffleWriteMetrics.recordsWritten.toDouble - } + val serviceQuantiles = simpleTitleCell("Duration") ++ timeQuantiles(metrics.executorRunTime) - val shuffleWriteQuantiles = Shuffle Write Size / Records +: - getFormattedSizeQuantilesWithRecords(shuffleWriteSizes, shuffleWriteRecords) + val gcQuantiles = titleCell("GC Time", ToolTips.GC_TIME) ++ timeQuantiles(metrics.jvmGcTime) - val memoryBytesSpilledSizes = validTasks.map(_.taskMetrics.get.memoryBytesSpilled.toDouble) - val memoryBytesSpilledQuantiles = Shuffle spill (memory) +: - getFormattedSizeQuantiles(memoryBytesSpilledSizes) + val serializationQuantiles = titleCell("Result Serialization Time", + ToolTips.RESULT_SERIALIZATION_TIME) ++ timeQuantiles(metrics.resultSerializationTime) - val diskBytesSpilledSizes = validTasks.map(_.taskMetrics.get.diskBytesSpilled.toDouble) - val diskBytesSpilledQuantiles = Shuffle spill (disk) +: - getFormattedSizeQuantiles(diskBytesSpilledSizes) + val gettingResultQuantiles = titleCell("Getting Result Time", ToolTips.GETTING_RESULT_TIME) ++ + timeQuantiles(metrics.gettingResultTime) - val listings: Seq[Seq[Node]] = Seq( - {serviceQuantiles}, - {schedulerDelayQuantiles}, - - {deserializationQuantiles} - - {gcQuantiles}, - - {serializationQuantiles} - , - {gettingResultQuantiles}, - - {peakExecutionMemoryQuantiles} - , - if (hasInput(stageData)) {inputQuantiles} else Nil, - if (hasOutput(stageData)) {outputQuantiles} else Nil, - if (hasShuffleRead(stageData)) { - - {shuffleReadBlockedQuantiles} - - {shuffleReadTotalQuantiles} - - {shuffleReadRemoteQuantiles} - - } else { - Nil - }, - if (hasShuffleWrite(stageData)) {shuffleWriteQuantiles} else Nil, - if (hasBytesSpilled(stageData)) {memoryBytesSpilledQuantiles} else Nil, - if (hasBytesSpilled(stageData)) {diskBytesSpilledQuantiles} else Nil) - - val quantileHeaders = Seq("Metric", "Min", "25th percentile", - "Median", "75th percentile", "Max") - // The summary table does not use CSS to stripe rows, which doesn't work with hidden - // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows). - Some(UIUtils.listingTable( - quantileHeaders, - identity[Seq[Node]], - listings, - fixedWidth = true, - id = Some("task-summary-table"), - stripeRowsWithCss = false)) + val peakExecutionMemoryQuantiles = titleCell("Peak Execution Memory", + ToolTips.PEAK_EXECUTION_MEMORY) ++ sizeQuantiles(metrics.peakExecutionMemory) + + // The scheduler delay includes the network delay to send the task to the worker + // machine and to send back the result (but not the time to fetch the task result, + // if it needed to be fetched from the block manager on the worker). + val schedulerDelayQuantiles = titleCell("Scheduler Delay", ToolTips.SCHEDULER_DELAY) ++ + timeQuantiles(metrics.schedulerDelay) + + def inputQuantiles: Seq[Node] = { + simpleTitleCell("Input Size / Records") ++ + sizeQuantilesWithRecords(metrics.inputMetrics.bytesRead, metrics.inputMetrics.recordsRead) + } + + def outputQuantiles: Seq[Node] = { + simpleTitleCell("Output Size / Records") ++ + sizeQuantilesWithRecords(metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten) } + def shuffleReadBlockedQuantiles: Seq[Node] = { + titleCell("Shuffle Read Blocked Time", ToolTips.SHUFFLE_READ_BLOCKED_TIME) ++ + timeQuantiles(metrics.shuffleReadMetrics.fetchWaitTime) + } + + def shuffleReadTotalQuantiles: Seq[Node] = { + titleCell("Shuffle Read Size / Records", ToolTips.SHUFFLE_READ) ++ + sizeQuantilesWithRecords(metrics.shuffleReadMetrics.readBytes, + metrics.shuffleReadMetrics.readRecords) + } + + def shuffleReadRemoteQuantiles: Seq[Node] = { + titleCell("Shuffle Remote Reads", ToolTips.SHUFFLE_READ_REMOTE_SIZE) ++ + sizeQuantiles(metrics.shuffleReadMetrics.remoteBytesRead) + } + + def shuffleWriteQuantiles: Seq[Node] = { + simpleTitleCell("Shuffle Write Size / Records") ++ + sizeQuantilesWithRecords(metrics.shuffleWriteMetrics.writeBytes, + metrics.shuffleWriteMetrics.writeRecords) + } + + def memoryBytesSpilledQuantiles: Seq[Node] = { + simpleTitleCell("Shuffle spill (memory)") ++ sizeQuantiles(metrics.memoryBytesSpilled) + } + + def diskBytesSpilledQuantiles: Seq[Node] = { + simpleTitleCell("Shuffle spill (disk)") ++ sizeQuantiles(metrics.diskBytesSpilled) + } + + val listings: Seq[Seq[Node]] = Seq( + {serviceQuantiles}, + {schedulerDelayQuantiles}, + + {deserializationQuantiles} + + {gcQuantiles}, + + {serializationQuantiles} + , + {gettingResultQuantiles}, + + {peakExecutionMemoryQuantiles} + , + if (hasInput(stageData)) {inputQuantiles} else Nil, + if (hasOutput(stageData)) {outputQuantiles} else Nil, + if (hasShuffleRead(stageData)) { + + {shuffleReadBlockedQuantiles} + + {shuffleReadTotalQuantiles} + + {shuffleReadRemoteQuantiles} + + } else { + Nil + }, + if (hasShuffleWrite(stageData)) {shuffleWriteQuantiles} else Nil, + if (hasBytesSpilled(stageData)) {memoryBytesSpilledQuantiles} else Nil, + if (hasBytesSpilled(stageData)) {diskBytesSpilledQuantiles} else Nil) + + val quantileHeaders = Seq("Metric", "Min", "25th percentile", "Median", "75th percentile", + "Max") + // The summary table does not use CSS to stripe rows, which doesn't work with hidden + // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows). + UIUtils.listingTable( + quantileHeaders, + identity[Seq[Node]], + listings, + fixedWidth = true, + id = Some("task-summary-table"), + stripeRowsWithCss = false) + } + val executorTable = new ExecutorTable(stageData, parent.store) val maybeAccumulableTable: Seq[Node] = - if (hasAccumulators) {

    Accumulators

    ++ accumulableTable } else Seq() + if (hasAccumulators(stageData)) {

    Accumulators

    ++ accumulableTable } else Seq() val aggMetrics = taskIdsInPage.contains(t.taskId) }, + Option(taskTable).map(_.dataSource.tasks).getOrElse(Nil), currentTime) ++

    Summary Metrics for {numCompleted} Completed Tasks

    ++
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++ @@ -593,10 +523,9 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val serializationTimeProportion = toProportion(serializationTime) val deserializationTime = metricsOpt.map(_.executorDeserializeTime).getOrElse(0L) val deserializationTimeProportion = toProportion(deserializationTime) - val gettingResultTime = getGettingResultTime(taskInfo, currentTime) + val gettingResultTime = AppStatusUtils.gettingResultTime(taskInfo) val gettingResultTimeProportion = toProportion(gettingResultTime) - val schedulerDelay = - metricsOpt.map(getSchedulerDelay(taskInfo, _, currentTime)).getOrElse(0L) + val schedulerDelay = AppStatusUtils.schedulerDelay(taskInfo) val schedulerDelayProportion = toProportion(schedulerDelay) val executorOverhead = serializationTime + deserializationTime @@ -708,7 +637,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We { if (MAX_TIMELINE_TASKS < tasks.size) { - This stage has more than the maximum number of tasks that can be shown in the + This page has more than the maximum number of tasks that can be shown in the visualization! Only the most recent {MAX_TIMELINE_TASKS} tasks (of {tasks.size} total) are shown. @@ -733,402 +662,49 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We } -private[ui] object StagePage { - private[ui] def getGettingResultTime(info: TaskData, currentTime: Long): Long = { - info.resultFetchStart match { - case Some(start) => - info.duration match { - case Some(duration) => - info.launchTime.getTime() + duration - start.getTime() - - case _ => - currentTime - start.getTime() - } - - case _ => - 0L - } - } - - private[ui] def getSchedulerDelay( - info: TaskData, - metrics: TaskMetrics, - currentTime: Long): Long = { - info.duration match { - case Some(duration) => - val executorOverhead = metrics.executorDeserializeTime + metrics.resultSerializationTime - math.max( - 0, - duration - metrics.executorRunTime - executorOverhead - - getGettingResultTime(info, currentTime)) - - case _ => - // The task is still running and the metrics like executorRunTime are not available. - 0L - } - } - -} - -private[ui] case class TaskTableRowInputData(inputSortable: Long, inputReadable: String) - -private[ui] case class TaskTableRowOutputData(outputSortable: Long, outputReadable: String) - -private[ui] case class TaskTableRowShuffleReadData( - shuffleReadBlockedTimeSortable: Long, - shuffleReadBlockedTimeReadable: String, - shuffleReadSortable: Long, - shuffleReadReadable: String, - shuffleReadRemoteSortable: Long, - shuffleReadRemoteReadable: String) - -private[ui] case class TaskTableRowShuffleWriteData( - writeTimeSortable: Long, - writeTimeReadable: String, - shuffleWriteSortable: Long, - shuffleWriteReadable: String) - -private[ui] case class TaskTableRowBytesSpilledData( - memoryBytesSpilledSortable: Long, - memoryBytesSpilledReadable: String, - diskBytesSpilledSortable: Long, - diskBytesSpilledReadable: String) - -/** - * Contains all data that needs for sorting and generating HTML. Using this one rather than - * TaskData to avoid creating duplicate contents during sorting the data. - */ -private[ui] class TaskTableRowData( - val index: Int, - val taskId: Long, - val attempt: Int, - val speculative: Boolean, - val status: String, - val taskLocality: String, - val executorId: String, - val host: String, - val launchTime: Long, - val duration: Long, - val formatDuration: String, - val schedulerDelay: Long, - val taskDeserializationTime: Long, - val gcTime: Long, - val serializationTime: Long, - val gettingResultTime: Long, - val peakExecutionMemoryUsed: Long, - val accumulators: Option[String], // HTML - val input: Option[TaskTableRowInputData], - val output: Option[TaskTableRowOutputData], - val shuffleRead: Option[TaskTableRowShuffleReadData], - val shuffleWrite: Option[TaskTableRowShuffleWriteData], - val bytesSpilled: Option[TaskTableRowBytesSpilledData], - val error: String, - val logs: Map[String, String]) - private[ui] class TaskDataSource( - tasks: Seq[TaskData], - hasAccumulators: Boolean, - hasInput: Boolean, - hasOutput: Boolean, - hasShuffleRead: Boolean, - hasShuffleWrite: Boolean, - hasBytesSpilled: Boolean, + stage: StageData, currentTime: Long, pageSize: Int, sortColumn: String, desc: Boolean, - store: AppStatusStore) extends PagedDataSource[TaskTableRowData](pageSize) { - import StagePage._ + store: AppStatusStore) extends PagedDataSource[TaskData](pageSize) { + import ApiHelper._ // Keep an internal cache of executor log maps so that long task lists render faster. private val executorIdToLogs = new HashMap[String, Map[String, String]]() - // Convert TaskData to TaskTableRowData which contains the final contents to show in the table - // so that we can avoid creating duplicate contents during sorting the data - private val data = tasks.map(taskRow).sorted(ordering(sortColumn, desc)) - - private var _slicedTaskIds: Set[Long] = _ + private var _tasksToShow: Seq[TaskData] = null - override def dataSize: Int = data.size + override def dataSize: Int = stage.numCompleteTasks + stage.numFailedTasks + stage.numKilledTasks - override def sliceData(from: Int, to: Int): Seq[TaskTableRowData] = { - val r = data.slice(from, to) - _slicedTaskIds = r.map(_.taskId).toSet - r - } - - def slicedTaskIds: Set[Long] = _slicedTaskIds - - private def taskRow(info: TaskData): TaskTableRowData = { - val metrics = info.taskMetrics - val duration = info.duration.getOrElse(1L) - val formatDuration = info.duration.map(d => UIUtils.formatDuration(d)).getOrElse("") - val schedulerDelay = metrics.map(getSchedulerDelay(info, _, currentTime)).getOrElse(0L) - val gcTime = metrics.map(_.jvmGcTime).getOrElse(0L) - val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) - val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) - val gettingResultTime = getGettingResultTime(info, currentTime) - - val externalAccumulableReadable = info.accumulatorUpdates.map { acc => - StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update}") + override def sliceData(from: Int, to: Int): Seq[TaskData] = { + if (_tasksToShow == null) { + _tasksToShow = store.taskList(stage.stageId, stage.attemptId, from, to - from, + indexName(sortColumn), !desc) } - val peakExecutionMemoryUsed = metrics.map(_.peakExecutionMemory).getOrElse(0L) - - val maybeInput = metrics.map(_.inputMetrics) - val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L) - val inputReadable = maybeInput - .map(m => s"${Utils.bytesToString(m.bytesRead)}") - .getOrElse("") - val inputRecords = maybeInput.map(_.recordsRead.toString).getOrElse("") - - val maybeOutput = metrics.map(_.outputMetrics) - val outputSortable = maybeOutput.map(_.bytesWritten).getOrElse(0L) - val outputReadable = maybeOutput - .map(m => s"${Utils.bytesToString(m.bytesWritten)}") - .getOrElse("") - val outputRecords = maybeOutput.map(_.recordsWritten.toString).getOrElse("") - - val maybeShuffleRead = metrics.map(_.shuffleReadMetrics) - val shuffleReadBlockedTimeSortable = maybeShuffleRead.map(_.fetchWaitTime).getOrElse(0L) - val shuffleReadBlockedTimeReadable = - maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("") - - val totalShuffleBytes = maybeShuffleRead.map(ApiHelper.totalBytesRead) - val shuffleReadSortable = totalShuffleBytes.getOrElse(0L) - val shuffleReadReadable = totalShuffleBytes.map(Utils.bytesToString).getOrElse("") - val shuffleReadRecords = maybeShuffleRead.map(_.recordsRead.toString).getOrElse("") - - val remoteShuffleBytes = maybeShuffleRead.map(_.remoteBytesRead) - val shuffleReadRemoteSortable = remoteShuffleBytes.getOrElse(0L) - val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("") - - val maybeShuffleWrite = metrics.map(_.shuffleWriteMetrics) - val shuffleWriteSortable = maybeShuffleWrite.map(_.bytesWritten).getOrElse(0L) - val shuffleWriteReadable = maybeShuffleWrite - .map(m => s"${Utils.bytesToString(m.bytesWritten)}").getOrElse("") - val shuffleWriteRecords = maybeShuffleWrite - .map(_.recordsWritten.toString).getOrElse("") - - val maybeWriteTime = metrics.map(_.shuffleWriteMetrics.writeTime) - val writeTimeSortable = maybeWriteTime.getOrElse(0L) - val writeTimeReadable = maybeWriteTime.map(t => t / (1000 * 1000)).map { ms => - if (ms == 0) "" else UIUtils.formatDuration(ms) - }.getOrElse("") - - val maybeMemoryBytesSpilled = metrics.map(_.memoryBytesSpilled) - val memoryBytesSpilledSortable = maybeMemoryBytesSpilled.getOrElse(0L) - val memoryBytesSpilledReadable = - maybeMemoryBytesSpilled.map(Utils.bytesToString).getOrElse("") - - val maybeDiskBytesSpilled = metrics.map(_.diskBytesSpilled) - val diskBytesSpilledSortable = maybeDiskBytesSpilled.getOrElse(0L) - val diskBytesSpilledReadable = maybeDiskBytesSpilled.map(Utils.bytesToString).getOrElse("") - - val input = - if (hasInput) { - Some(TaskTableRowInputData(inputSortable, s"$inputReadable / $inputRecords")) - } else { - None - } - - val output = - if (hasOutput) { - Some(TaskTableRowOutputData(outputSortable, s"$outputReadable / $outputRecords")) - } else { - None - } - - val shuffleRead = - if (hasShuffleRead) { - Some(TaskTableRowShuffleReadData( - shuffleReadBlockedTimeSortable, - shuffleReadBlockedTimeReadable, - shuffleReadSortable, - s"$shuffleReadReadable / $shuffleReadRecords", - shuffleReadRemoteSortable, - shuffleReadRemoteReadable - )) - } else { - None - } - - val shuffleWrite = - if (hasShuffleWrite) { - Some(TaskTableRowShuffleWriteData( - writeTimeSortable, - writeTimeReadable, - shuffleWriteSortable, - s"$shuffleWriteReadable / $shuffleWriteRecords" - )) - } else { - None - } - - val bytesSpilled = - if (hasBytesSpilled) { - Some(TaskTableRowBytesSpilledData( - memoryBytesSpilledSortable, - memoryBytesSpilledReadable, - diskBytesSpilledSortable, - diskBytesSpilledReadable - )) - } else { - None - } - - new TaskTableRowData( - info.index, - info.taskId, - info.attempt, - info.speculative, - info.status, - info.taskLocality.toString, - info.executorId, - info.host, - info.launchTime.getTime(), - duration, - formatDuration, - schedulerDelay, - taskDeserializationTime, - gcTime, - serializationTime, - gettingResultTime, - peakExecutionMemoryUsed, - if (hasAccumulators) Some(externalAccumulableReadable.mkString("
    ")) else None, - input, - output, - shuffleRead, - shuffleWrite, - bytesSpilled, - info.errorMessage.getOrElse(""), - executorLogs(info.executorId)) + _tasksToShow } - private def executorLogs(id: String): Map[String, String] = { + def tasks: Seq[TaskData] = _tasksToShow + + def executorLogs(id: String): Map[String, String] = { executorIdToLogs.getOrElseUpdate(id, store.asOption(store.executorSummary(id)).map(_.executorLogs).getOrElse(Map.empty)) } - /** - * Return Ordering according to sortColumn and desc - */ - private def ordering(sortColumn: String, desc: Boolean): Ordering[TaskTableRowData] = { - val ordering: Ordering[TaskTableRowData] = sortColumn match { - case "Index" => Ordering.by(_.index) - case "ID" => Ordering.by(_.taskId) - case "Attempt" => Ordering.by(_.attempt) - case "Status" => Ordering.by(_.status) - case "Locality Level" => Ordering.by(_.taskLocality) - case "Executor ID" => Ordering.by(_.executorId) - case "Host" => Ordering.by(_.host) - case "Launch Time" => Ordering.by(_.launchTime) - case "Duration" => Ordering.by(_.duration) - case "Scheduler Delay" => Ordering.by(_.schedulerDelay) - case "Task Deserialization Time" => Ordering.by(_.taskDeserializationTime) - case "GC Time" => Ordering.by(_.gcTime) - case "Result Serialization Time" => Ordering.by(_.serializationTime) - case "Getting Result Time" => Ordering.by(_.gettingResultTime) - case "Peak Execution Memory" => Ordering.by(_.peakExecutionMemoryUsed) - case "Accumulators" => - if (hasAccumulators) { - Ordering.by(_.accumulators.get) - } else { - throw new IllegalArgumentException( - "Cannot sort by Accumulators because of no accumulators") - } - case "Input Size / Records" => - if (hasInput) { - Ordering.by(_.input.get.inputSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Input Size / Records because of no inputs") - } - case "Output Size / Records" => - if (hasOutput) { - Ordering.by(_.output.get.outputSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Output Size / Records because of no outputs") - } - // ShuffleRead - case "Shuffle Read Blocked Time" => - if (hasShuffleRead) { - Ordering.by(_.shuffleRead.get.shuffleReadBlockedTimeSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Read Blocked Time because of no shuffle reads") - } - case "Shuffle Read Size / Records" => - if (hasShuffleRead) { - Ordering.by(_.shuffleRead.get.shuffleReadSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Read Size / Records because of no shuffle reads") - } - case "Shuffle Remote Reads" => - if (hasShuffleRead) { - Ordering.by(_.shuffleRead.get.shuffleReadRemoteSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Remote Reads because of no shuffle reads") - } - // ShuffleWrite - case "Write Time" => - if (hasShuffleWrite) { - Ordering.by(_.shuffleWrite.get.writeTimeSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Write Time because of no shuffle writes") - } - case "Shuffle Write Size / Records" => - if (hasShuffleWrite) { - Ordering.by(_.shuffleWrite.get.shuffleWriteSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Write Size / Records because of no shuffle writes") - } - // BytesSpilled - case "Shuffle Spill (Memory)" => - if (hasBytesSpilled) { - Ordering.by(_.bytesSpilled.get.memoryBytesSpilledSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Spill (Memory) because of no spills") - } - case "Shuffle Spill (Disk)" => - if (hasBytesSpilled) { - Ordering.by(_.bytesSpilled.get.diskBytesSpilledSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Spill (Disk) because of no spills") - } - case "Errors" => Ordering.by(_.error) - case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") - } - if (desc) { - ordering.reverse - } else { - ordering - } - } - } private[ui] class TaskPagedTable( - conf: SparkConf, + stage: StageData, basePath: String, - data: Seq[TaskData], - hasAccumulators: Boolean, - hasInput: Boolean, - hasOutput: Boolean, - hasShuffleRead: Boolean, - hasShuffleWrite: Boolean, - hasBytesSpilled: Boolean, currentTime: Long, pageSize: Int, sortColumn: String, desc: Boolean, - store: AppStatusStore) extends PagedTable[TaskTableRowData] { + store: AppStatusStore) extends PagedTable[TaskData] { + + import ApiHelper._ override def tableId: String = "task-table" @@ -1142,13 +718,7 @@ private[ui] class TaskPagedTable( override def pageNumberFormField: String = "task.page" override val dataSource: TaskDataSource = new TaskDataSource( - data, - hasAccumulators, - hasInput, - hasOutput, - hasShuffleRead, - hasShuffleWrite, - hasBytesSpilled, + stage, currentTime, pageSize, sortColumn, @@ -1180,22 +750,22 @@ private[ui] class TaskPagedTable( ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME), ("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) ++ - {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ - {if (hasInput) Seq(("Input Size / Records", "")) else Nil} ++ - {if (hasOutput) Seq(("Output Size / Records", "")) else Nil} ++ - {if (hasShuffleRead) { + {if (hasAccumulators(stage)) Seq(("Accumulators", "")) else Nil} ++ + {if (hasInput(stage)) Seq(("Input Size / Records", "")) else Nil} ++ + {if (hasOutput(stage)) Seq(("Output Size / Records", "")) else Nil} ++ + {if (hasShuffleRead(stage)) { Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), ("Shuffle Read Size / Records", ""), ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) } else { Nil }} ++ - {if (hasShuffleWrite) { + {if (hasShuffleWrite(stage)) { Seq(("Write Time", ""), ("Shuffle Write Size / Records", "")) } else { Nil }} ++ - {if (hasBytesSpilled) { + {if (hasBytesSpilled(stage)) { Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) } else { Nil @@ -1237,7 +807,17 @@ private[ui] class TaskPagedTable( {headerRow} } - def row(task: TaskTableRowData): Seq[Node] = { + def row(task: TaskData): Seq[Node] = { + def formatDuration(value: Option[Long], hideZero: Boolean = false): String = { + value.map { v => + if (v > 0 || !hideZero) UIUtils.formatDuration(v) else "" + }.getOrElse("") + } + + def formatBytes(value: Option[Long]): String = { + Utils.bytesToString(value.getOrElse(0L)) + } + {task.index} {task.taskId} @@ -1249,62 +829,98 @@ private[ui] class TaskPagedTable(
    {task.host}
    { - task.logs.map { + dataSource.executorLogs(task.executorId).map { case (logName, logUrl) => } }
    - {UIUtils.formatDate(new Date(task.launchTime))} - {task.formatDuration} + {UIUtils.formatDate(task.launchTime)} + {formatDuration(task.duration)} - {UIUtils.formatDuration(task.schedulerDelay)} + {UIUtils.formatDuration(AppStatusUtils.schedulerDelay(task))} - {UIUtils.formatDuration(task.taskDeserializationTime)} + {formatDuration(task.taskMetrics.map(_.executorDeserializeTime))} - {if (task.gcTime > 0) UIUtils.formatDuration(task.gcTime) else ""} + {formatDuration(task.taskMetrics.map(_.jvmGcTime), hideZero = true)} - {UIUtils.formatDuration(task.serializationTime)} + {formatDuration(task.taskMetrics.map(_.resultSerializationTime))} - {UIUtils.formatDuration(task.gettingResultTime)} + {UIUtils.formatDuration(AppStatusUtils.gettingResultTime(task))} - {Utils.bytesToString(task.peakExecutionMemoryUsed)} + {formatBytes(task.taskMetrics.map(_.peakExecutionMemory))} - {if (task.accumulators.nonEmpty) { - {Unparsed(task.accumulators.get)} + {if (hasAccumulators(stage)) { + accumulatorsInfo(task) }} - {if (task.input.nonEmpty) { - {task.input.get.inputReadable} + {if (hasInput(stage)) { + metricInfo(task) { m => + val bytesRead = Utils.bytesToString(m.inputMetrics.bytesRead) + val records = m.inputMetrics.recordsRead + {bytesRead} / {records} + } }} - {if (task.output.nonEmpty) { - {task.output.get.outputReadable} + {if (hasOutput(stage)) { + metricInfo(task) { m => + val bytesWritten = Utils.bytesToString(m.outputMetrics.bytesWritten) + val records = m.outputMetrics.recordsWritten + {bytesWritten} / {records} + } }} - {if (task.shuffleRead.nonEmpty) { + {if (hasShuffleRead(stage)) { - {task.shuffleRead.get.shuffleReadBlockedTimeReadable} + {formatDuration(task.taskMetrics.map(_.shuffleReadMetrics.fetchWaitTime))} - {task.shuffleRead.get.shuffleReadReadable} + { + metricInfo(task) { m => + val bytesRead = Utils.bytesToString(totalBytesRead(m.shuffleReadMetrics)) + val records = m.shuffleReadMetrics.recordsRead + Unparsed(s"$bytesRead / $records") + } + } - {task.shuffleRead.get.shuffleReadRemoteReadable} + {formatBytes(task.taskMetrics.map(_.shuffleReadMetrics.remoteBytesRead))} }} - {if (task.shuffleWrite.nonEmpty) { - {task.shuffleWrite.get.writeTimeReadable} - {task.shuffleWrite.get.shuffleWriteReadable} + {if (hasShuffleWrite(stage)) { + { + formatDuration( + task.taskMetrics.map { m => + TimeUnit.NANOSECONDS.toMillis(m.shuffleWriteMetrics.writeTime) + }, + hideZero = true) + } + { + metricInfo(task) { m => + val bytesWritten = Utils.bytesToString(m.shuffleWriteMetrics.bytesWritten) + val records = m.shuffleWriteMetrics.recordsWritten + Unparsed(s"$bytesWritten / $records") + } + } }} - {if (task.bytesSpilled.nonEmpty) { - {task.bytesSpilled.get.memoryBytesSpilledReadable} - {task.bytesSpilled.get.diskBytesSpilledReadable} + {if (hasBytesSpilled(stage)) { + {formatBytes(task.taskMetrics.map(_.memoryBytesSpilled))} + {formatBytes(task.taskMetrics.map(_.diskBytesSpilled))} }} - {errorMessageCell(task.error)} + {errorMessageCell(task.errorMessage.getOrElse(""))} } + private def accumulatorsInfo(task: TaskData): Seq[Node] = { + task.accumulatorUpdates.map { acc => + Unparsed(StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update}")) + } + } + + private def metricInfo(task: TaskData)(fn: TaskMetrics => Seq[Node]): Seq[Node] = { + task.taskMetrics.map(fn).getOrElse(Nil) + } + private def errorMessageCell(error: String): Seq[Node] = { val isMultiline = error.indexOf('\n') >= 0 // Display the first line by default @@ -1333,6 +949,36 @@ private[ui] class TaskPagedTable( private object ApiHelper { + + private val COLUMN_TO_INDEX = Map( + "ID" -> null.asInstanceOf[String], + "Index" -> TaskIndexNames.TASK_INDEX, + "Attempt" -> TaskIndexNames.ATTEMPT, + "Status" -> TaskIndexNames.STATUS, + "Locality Level" -> TaskIndexNames.LOCALITY, + "Executor ID / Host" -> TaskIndexNames.EXECUTOR, + "Launch Time" -> TaskIndexNames.LAUNCH_TIME, + "Duration" -> TaskIndexNames.DURATION, + "Scheduler Delay" -> TaskIndexNames.SCHEDULER_DELAY, + "Task Deserialization Time" -> TaskIndexNames.DESER_TIME, + "GC Time" -> TaskIndexNames.GC_TIME, + "Result Serialization Time" -> TaskIndexNames.SER_TIME, + "Getting Result Time" -> TaskIndexNames.GETTING_RESULT_TIME, + "Peak Execution Memory" -> TaskIndexNames.PEAK_MEM, + "Accumulators" -> TaskIndexNames.ACCUMULATORS, + "Input Size / Records" -> TaskIndexNames.INPUT_SIZE, + "Output Size / Records" -> TaskIndexNames.OUTPUT_SIZE, + "Shuffle Read Blocked Time" -> TaskIndexNames.SHUFFLE_READ_TIME, + "Shuffle Read Size / Records" -> TaskIndexNames.SHUFFLE_TOTAL_READS, + "Shuffle Remote Reads" -> TaskIndexNames.SHUFFLE_REMOTE_READS, + "Write Time" -> TaskIndexNames.SHUFFLE_WRITE_TIME, + "Shuffle Write Size / Records" -> TaskIndexNames.SHUFFLE_WRITE_SIZE, + "Shuffle Spill (Memory)" -> TaskIndexNames.MEM_SPILL, + "Shuffle Spill (Disk)" -> TaskIndexNames.DISK_SPILL, + "Errors" -> TaskIndexNames.ERROR) + + def hasAccumulators(stageData: StageData): Boolean = stageData.accumulatorUpdates.size > 0 + def hasInput(stageData: StageData): Boolean = stageData.inputBytes > 0 def hasOutput(stageData: StageData): Boolean = stageData.outputBytes > 0 @@ -1349,4 +995,11 @@ private object ApiHelper { metrics.localBytesRead + metrics.remoteBytesRead } + def indexName(sortColumn: String): Option[String] = { + COLUMN_TO_INDEX.get(sortColumn) match { + case Some(v) => Option(v) + case _ => throw new IllegalArgumentException(s"Invalid sort column: $sortColumn") + } + } + } diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json index f8e27703c0def..5c42ac1d87f4c 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json @@ -7,6 +7,9 @@ "resultSize" : [ 2010.0, 2065.0, 2065.0 ], "jvmGcTime" : [ 0.0, 0.0, 7.0 ], "resultSerializationTime" : [ 0.0, 0.0, 2.0 ], + "gettingResultTime" : [ 0.0, 0.0, 0.0 ], + "schedulerDelay" : [ 2.0, 6.0, 53.0 ], + "peakExecutionMemory" : [ 0.0, 0.0, 0.0 ], "memoryBytesSpilled" : [ 0.0, 0.0, 0.0 ], "diskBytesSpilled" : [ 0.0, 0.0, 0.0 ], "inputMetrics" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json index a28bda16a956e..e6b705989cc97 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json @@ -7,6 +7,9 @@ "resultSize" : [ 1034.0, 1034.0, 1034.0, 1034.0, 1034.0 ], "jvmGcTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "resultSerializationTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "gettingResultTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "schedulerDelay" : [ 4.0, 4.0, 6.0, 7.0, 9.0 ], + "peakExecutionMemory" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "memoryBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "diskBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "inputMetrics" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json index ede3eaed1d1d2..788f28cf7b365 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json @@ -7,6 +7,9 @@ "resultSize" : [ 2010.0, 2065.0, 2065.0, 2065.0, 2065.0 ], "jvmGcTime" : [ 0.0, 0.0, 0.0, 5.0, 7.0 ], "resultSerializationTime" : [ 0.0, 0.0, 0.0, 0.0, 1.0 ], + "gettingResultTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "schedulerDelay" : [ 2.0, 4.0, 6.0, 13.0, 40.0 ], + "peakExecutionMemory" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "memoryBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "diskBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "inputMetrics" : { diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index b8c84e24c2c3f..ca66b6b9db890 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -213,45 +213,42 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { s1Tasks.foreach { task => check[TaskDataWrapper](task.taskId) { wrapper => - assert(wrapper.info.taskId === task.taskId) + assert(wrapper.taskId === task.taskId) assert(wrapper.stageId === stages.head.stageId) - assert(wrapper.stageAttemptId === stages.head.attemptNumber) - assert(Arrays.equals(wrapper.stage, Array(stages.head.stageId, stages.head.attemptNumber))) - - val runtime = Array[AnyRef](stages.head.stageId: JInteger, - stages.head.attemptNumber: JInteger, - -1L: JLong) - assert(Arrays.equals(wrapper.runtime, runtime)) - - assert(wrapper.info.index === task.index) - assert(wrapper.info.attempt === task.attemptNumber) - assert(wrapper.info.launchTime === new Date(task.launchTime)) - assert(wrapper.info.executorId === task.executorId) - assert(wrapper.info.host === task.host) - assert(wrapper.info.status === task.status) - assert(wrapper.info.taskLocality === task.taskLocality.toString()) - assert(wrapper.info.speculative === task.speculative) + assert(wrapper.stageAttemptId === stages.head.attemptId) + assert(wrapper.index === task.index) + assert(wrapper.attempt === task.attemptNumber) + assert(wrapper.launchTime === task.launchTime) + assert(wrapper.executorId === task.executorId) + assert(wrapper.host === task.host) + assert(wrapper.status === task.status) + assert(wrapper.taskLocality === task.taskLocality.toString()) + assert(wrapper.speculative === task.speculative) } } - // Send executor metrics update. Only update one metric to avoid a lot of boilerplate code. - s1Tasks.foreach { task => - val accum = new AccumulableInfo(1L, Some(InternalAccumulator.MEMORY_BYTES_SPILLED), - Some(1L), None, true, false, None) - listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate( - task.executorId, - Seq((task.taskId, stages.head.stageId, stages.head.attemptNumber, Seq(accum))))) - } + // Send two executor metrics update. Only update one metric to avoid a lot of boilerplate code. + // The tasks are distributed among the two executors, so the executor-level metrics should + // hold half of the cummulative value of the metric being updated. + Seq(1L, 2L).foreach { value => + s1Tasks.foreach { task => + val accum = new AccumulableInfo(1L, Some(InternalAccumulator.MEMORY_BYTES_SPILLED), + Some(value), None, true, false, None) + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate( + task.executorId, + Seq((task.taskId, stages.head.stageId, stages.head.attemptNumber, Seq(accum))))) + } - check[StageDataWrapper](key(stages.head)) { stage => - assert(stage.info.memoryBytesSpilled === s1Tasks.size) - } + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.memoryBytesSpilled === s1Tasks.size * value) + } - val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") - .first(key(stages.head)).last(key(stages.head)).asScala.toSeq - assert(execs.size > 0) - execs.foreach { exec => - assert(exec.info.memoryBytesSpilled === s1Tasks.size / 2) + val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") + .first(key(stages.head)).last(key(stages.head)).asScala.toSeq + assert(execs.size > 0) + execs.foreach { exec => + assert(exec.info.memoryBytesSpilled === s1Tasks.size * value / 2) + } } // Fail one of the tasks, re-start it. @@ -278,13 +275,13 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } check[TaskDataWrapper](s1Tasks.head.taskId) { task => - assert(task.info.status === s1Tasks.head.status) - assert(task.info.errorMessage == Some(TaskResultLost.toErrorString)) + assert(task.status === s1Tasks.head.status) + assert(task.errorMessage == Some(TaskResultLost.toErrorString)) } check[TaskDataWrapper](reattempt.taskId) { task => - assert(task.info.index === s1Tasks.head.index) - assert(task.info.attempt === reattempt.attemptNumber) + assert(task.index === s1Tasks.head.index) + assert(task.attempt === reattempt.attemptNumber) } // Kill one task, restart it. @@ -306,8 +303,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } check[TaskDataWrapper](killed.taskId) { task => - assert(task.info.index === killed.index) - assert(task.info.errorMessage === Some("killed")) + assert(task.index === killed.index) + assert(task.errorMessage === Some("killed")) } // Start a new attempt and finish it with TaskCommitDenied, make sure it's handled like a kill. @@ -334,8 +331,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } check[TaskDataWrapper](denied.taskId) { task => - assert(task.info.index === killed.index) - assert(task.info.errorMessage === Some(denyReason.toErrorString)) + assert(task.index === killed.index) + assert(task.errorMessage === Some(denyReason.toErrorString)) } // Start a new attempt. @@ -373,10 +370,10 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { pending.foreach { task => check[TaskDataWrapper](task.taskId) { wrapper => - assert(wrapper.info.errorMessage === None) - assert(wrapper.info.taskMetrics.get.executorCpuTime === 2L) - assert(wrapper.info.taskMetrics.get.executorRunTime === 4L) - assert(wrapper.info.duration === Some(task.duration)) + assert(wrapper.errorMessage === None) + assert(wrapper.executorCpuTime === 2L) + assert(wrapper.executorRunTime === 4L) + assert(wrapper.duration === task.duration) } } @@ -894,6 +891,23 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(store.count(classOf[StageDataWrapper]) === 3) assert(store.count(classOf[RDDOperationGraphWrapper]) === 3) + val dropped = stages.drop(1).head + + // Cache some quantiles by calling AppStatusStore.taskSummary(). For quantiles to be + // calculcated, we need at least one finished task. + time += 1 + val task = createTasks(1, Array("1")).head + listener.onTaskStart(SparkListenerTaskStart(dropped.stageId, dropped.attemptId, task)) + + time += 1 + task.markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(dropped.stageId, dropped.attemptId, + "taskType", Success, task, null)) + + new AppStatusStore(store) + .taskSummary(dropped.stageId, dropped.attemptId, Array(0.25d, 0.50d, 0.75d)) + assert(store.count(classOf[CachedQuantile], "stage", key(dropped)) === 3) + stages.drop(1).foreach { s => time += 1 s.completionTime = Some(time) @@ -905,6 +919,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { intercept[NoSuchElementException] { store.read(classOf[StageDataWrapper], Array(2, 0)) } + assert(store.count(classOf[CachedQuantile], "stage", key(dropped)) === 0) val attempt2 = new StageInfo(3, 1, "stage3", 4, Nil, Nil, "details3") time += 1 diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala new file mode 100644 index 0000000000000..92f90f3d96ddf --- /dev/null +++ b/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import org.apache.spark.SparkFunSuite +import org.apache.spark.status.api.v1.TaskMetricDistributions +import org.apache.spark.util.Distribution +import org.apache.spark.util.kvstore._ + +class AppStatusStoreSuite extends SparkFunSuite { + + private val uiQuantiles = Array(0.0, 0.25, 0.5, 0.75, 1.0) + private val stageId = 1 + private val attemptId = 1 + + test("quantile calculation: 1 task") { + compareQuantiles(1, uiQuantiles) + } + + test("quantile calculation: few tasks") { + compareQuantiles(4, uiQuantiles) + } + + test("quantile calculation: more tasks") { + compareQuantiles(100, uiQuantiles) + } + + test("quantile calculation: lots of tasks") { + compareQuantiles(4096, uiQuantiles) + } + + test("quantile calculation: custom quantiles") { + compareQuantiles(4096, Array(0.01, 0.33, 0.5, 0.42, 0.69, 0.99)) + } + + test("quantile cache") { + val store = new InMemoryStore() + (0 until 4096).foreach { i => store.write(newTaskData(i)) } + + val appStore = new AppStatusStore(store) + + appStore.taskSummary(stageId, attemptId, Array(0.13d)) + intercept[NoSuchElementException] { + store.read(classOf[CachedQuantile], Array(stageId, attemptId, "13")) + } + + appStore.taskSummary(stageId, attemptId, Array(0.25d)) + val d1 = store.read(classOf[CachedQuantile], Array(stageId, attemptId, "25")) + + // Add a new task to force the cached quantile to be evicted, and make sure it's updated. + store.write(newTaskData(4096)) + appStore.taskSummary(stageId, attemptId, Array(0.25d, 0.50d, 0.73d)) + + val d2 = store.read(classOf[CachedQuantile], Array(stageId, attemptId, "25")) + assert(d1.taskCount != d2.taskCount) + + store.read(classOf[CachedQuantile], Array(stageId, attemptId, "50")) + intercept[NoSuchElementException] { + store.read(classOf[CachedQuantile], Array(stageId, attemptId, "73")) + } + + assert(store.count(classOf[CachedQuantile]) === 2) + } + + private def compareQuantiles(count: Int, quantiles: Array[Double]): Unit = { + val store = new InMemoryStore() + val values = (0 until count).map { i => + val task = newTaskData(i) + store.write(task) + i.toDouble + }.toArray + + val summary = new AppStatusStore(store).taskSummary(stageId, attemptId, quantiles).get + val dist = new Distribution(values, 0, values.length).getQuantiles(quantiles.sorted) + + dist.zip(summary.executorRunTime).foreach { case (expected, actual) => + assert(expected === actual) + } + } + + private def newTaskData(i: Int): TaskDataWrapper = { + new TaskDataWrapper( + i, i, i, i, i, i, i.toString, i.toString, i.toString, i.toString, false, Nil, None, + i, i, i, i, i, i, i, i, i, i, + i, i, i, i, i, i, i, i, i, i, + i, i, i, i, stageId, attemptId) + } + +} diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 661d0d48d2f37..0aeddf730cd35 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.config._ import org.apache.spark.ui.jobs.{StagePage, StagesTab} class StagePageSuite extends SparkFunSuite with LocalSparkContext { @@ -35,15 +36,13 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { private val peakExecutionMemory = 10 test("peak execution memory should displayed") { - val conf = new SparkConf(false) - val html = renderStagePage(conf).toString().toLowerCase(Locale.ROOT) + val html = renderStagePage().toString().toLowerCase(Locale.ROOT) val targetString = "peak execution memory" assert(html.contains(targetString)) } test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { - val conf = new SparkConf(false) - val html = renderStagePage(conf).toString().toLowerCase(Locale.ROOT) + val html = renderStagePage().toString().toLowerCase(Locale.ROOT) // verify min/25/50/75/max show task value not cumulative values assert(html.contains(s"$peakExecutionMemory.0 b" * 5)) } @@ -52,7 +51,8 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { * Render a stage page started with the given conf and return the HTML. * This also runs a dummy stage to populate the page with useful content. */ - private def renderStagePage(conf: SparkConf): Seq[Node] = { + private def renderStagePage(): Seq[Node] = { + val conf = new SparkConf(false).set(LIVE_ENTITY_UPDATE_PERIOD, 0L) val statusStore = AppStatusStore.createLiveStore(conf) val listener = statusStore.listener.get diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 7bdd3fac773a3..e2fa5754afaee 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -93,7 +93,7 @@ This file is divided into 3 sections: - + From 0552c36e02434c60dad82024334d291f6008b822 Mon Sep 17 00:00:00 2001 From: wuyi5 Date: Thu, 11 Jan 2018 22:17:15 +0900 Subject: [PATCH 065/774] [SPARK-22967][TESTS] Fix VersionSuite's unit tests by change Windows path into URI path ## What changes were proposed in this pull request? Two unit test will fail due to Windows format path: 1.test(s"$version: read avro file containing decimal") ``` org.apache.hadoop.hive.ql.metadata.HiveException: MetaException(message:java.lang.IllegalArgumentException: Can not create a Path from an empty string); ``` 2.test(s"$version: SPARK-17920: Insert into/overwrite avro table") ``` Unable to infer the schema. The schema specification is required to create the table `default`.`tab2`.; org.apache.spark.sql.AnalysisException: Unable to infer the schema. The schema specification is required to create the table `default`.`tab2`.; ``` This pr fix these two unit test by change Windows path into URI path. ## How was this patch tested? Existed. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: wuyi5 Closes #20199 from Ngone51/SPARK-22967. --- .../org/apache/spark/sql/hive/client/VersionsSuite.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index ff90e9dda5f7c..e64389e56b5a1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -811,7 +811,7 @@ class VersionsSuite extends SparkFunSuite with Logging { test(s"$version: read avro file containing decimal") { val url = Thread.currentThread().getContextClassLoader.getResource("avroDecimal") - val location = new File(url.getFile) + val location = new File(url.getFile).toURI.toString val tableName = "tab1" val avroSchema = @@ -851,6 +851,8 @@ class VersionsSuite extends SparkFunSuite with Logging { } test(s"$version: SPARK-17920: Insert into/overwrite avro table") { + // skipped because it's failed in the condition on Windows + assume(!(Utils.isWindows && version == "0.12")) withTempDir { dir => val avroSchema = """ @@ -875,10 +877,10 @@ class VersionsSuite extends SparkFunSuite with Logging { val writer = new PrintWriter(schemaFile) writer.write(avroSchema) writer.close() - val schemaPath = schemaFile.getCanonicalPath + val schemaPath = schemaFile.toURI.toString val url = Thread.currentThread().getContextClassLoader.getResource("avroDecimal") - val srcLocation = new File(url.getFile).getCanonicalPath + val srcLocation = new File(url.getFile).toURI.toString val destTableName = "tab1" val srcTableName = "tab2" From 76892bcf2c08efd7e9c5b16d377e623d82fe695e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 11 Jan 2018 21:32:36 +0800 Subject: [PATCH 066/774] [SPARK-23000][TEST-HADOOP2.6] Fix Flaky test suite DataSourceWithHiveMetastoreCatalogSuite ## What changes were proposed in this pull request? The Spark 2.3 branch still failed due to the flaky test suite `DataSourceWithHiveMetastoreCatalogSuite `. https://amplab.cs.berkeley.edu/jenkins/job/spark-branch-2.3-test-sbt-hadoop-2.6/ Although https://github.com/apache/spark/pull/20207 is unable to reproduce it in Spark 2.3, it sounds like the current DB of Spark's Catalog is changed based on the following stacktrace. Thus, we just need to reset it. ``` [info] DataSourceWithHiveMetastoreCatalogSuite: 02:40:39.486 ERROR org.apache.hadoop.hive.ql.parse.CalcitePlanner: org.apache.hadoop.hive.ql.parse.SemanticException: Line 1:14 Table not found 't' at org.apache.hadoop.hive.ql.parse.SemanticAnalyzer.getMetaData(SemanticAnalyzer.java:1594) at org.apache.hadoop.hive.ql.parse.SemanticAnalyzer.getMetaData(SemanticAnalyzer.java:1545) at org.apache.hadoop.hive.ql.parse.SemanticAnalyzer.genResolvedParseTree(SemanticAnalyzer.java:10077) at org.apache.hadoop.hive.ql.parse.SemanticAnalyzer.analyzeInternal(SemanticAnalyzer.java:10128) at org.apache.hadoop.hive.ql.parse.CalcitePlanner.analyzeInternal(CalcitePlanner.java:209) at org.apache.hadoop.hive.ql.parse.BaseSemanticAnalyzer.analyze(BaseSemanticAnalyzer.java:227) at org.apache.hadoop.hive.ql.Driver.compile(Driver.java:424) at org.apache.hadoop.hive.ql.Driver.compile(Driver.java:308) at org.apache.hadoop.hive.ql.Driver.compileInternal(Driver.java:1122) at org.apache.hadoop.hive.ql.Driver.runInternal(Driver.java:1170) at org.apache.hadoop.hive.ql.Driver.run(Driver.java:1059) at org.apache.hadoop.hive.ql.Driver.run(Driver.java:1049) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$runHive$1.apply(HiveClientImpl.scala:694) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$runHive$1.apply(HiveClientImpl.scala:683) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$withHiveState$1.apply(HiveClientImpl.scala:272) at org.apache.spark.sql.hive.client.HiveClientImpl.liftedTree1$1(HiveClientImpl.scala:210) at org.apache.spark.sql.hive.client.HiveClientImpl.retryLocked(HiveClientImpl.scala:209) at org.apache.spark.sql.hive.client.HiveClientImpl.withHiveState(HiveClientImpl.scala:255) at org.apache.spark.sql.hive.client.HiveClientImpl.runHive(HiveClientImpl.scala:683) at org.apache.spark.sql.hive.client.HiveClientImpl.runSqlHive(HiveClientImpl.scala:673) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$9$$anonfun$apply$1$$anonfun$apply$mcV$sp$3.apply$mcV$sp(HiveMetastoreCatalogSuite.scala:185) at org.apache.spark.sql.test.SQLTestUtilsBase$class.withTable(SQLTestUtils.scala:273) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.withTable(HiveMetastoreCatalogSuite.scala:139) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$9$$anonfun$apply$1.apply$mcV$sp(HiveMetastoreCatalogSuite.scala:163) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$9$$anonfun$apply$1.apply(HiveMetastoreCatalogSuite.scala:163) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$9$$anonfun$apply$1.apply(HiveMetastoreCatalogSuite.scala:163) at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) at org.scalatest.Transformer.apply(Transformer.scala:22) at org.scalatest.Transformer.apply(Transformer.scala:20) at org.scalatest.FunSuiteLike$$anon$1.apply(FunSuiteLike.scala:186) at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:68) at org.scalatest.FunSuiteLike$class.invokeWithFixture$1(FunSuiteLike.scala:183) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:196) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:196) at org.scalatest.SuperEngine.runTestImpl(Engine.scala:289) at org.scalatest.FunSuiteLike$class.runTest(FunSuiteLike.scala:196) at org.scalatest.FunSuite.runTest(FunSuite.scala:1560) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:229) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:229) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:396) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:384) at scala.collection.immutable.List.foreach(List.scala:381) at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:384) at org.scalatest.SuperEngine.org$scalatest$SuperEngine$$runTestsInBranch(Engine.scala:379) at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:461) at org.scalatest.FunSuiteLike$class.runTests(FunSuiteLike.scala:229) at org.scalatest.FunSuite.runTests(FunSuite.scala:1560) at org.scalatest.Suite$class.run(Suite.scala:1147) at org.scalatest.FunSuite.org$scalatest$FunSuiteLike$$super$run(FunSuite.scala:1560) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:233) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:233) at org.scalatest.SuperEngine.runImpl(Engine.scala:521) at org.scalatest.FunSuiteLike$class.run(FunSuiteLike.scala:233) at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterAll$$super$run(SparkFunSuite.scala:31) at org.scalatest.BeforeAndAfterAll$class.liftedTree1$1(BeforeAndAfterAll.scala:213) at org.scalatest.BeforeAndAfterAll$class.run(BeforeAndAfterAll.scala:210) at org.apache.spark.SparkFunSuite.run(SparkFunSuite.scala:31) at org.scalatest.tools.Framework.org$scalatest$tools$Framework$$runSuite(Framework.scala:314) at org.scalatest.tools.Framework$ScalaTestTask.execute(Framework.scala:480) at sbt.ForkMain$Run$2.call(ForkMain.java:296) at sbt.ForkMain$Run$2.call(ForkMain.java:286) at java.util.concurrent.FutureTask.run(FutureTask.java:266) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ``` ## How was this patch tested? N/A Author: gatorsmile Closes #20218 from gatorsmile/testFixAgain. --- .../org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index cf4ce83124d88..ba9b944e4a055 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -148,6 +148,7 @@ class DataSourceWithHiveMetastoreCatalogSuite override def beforeAll(): Unit = { super.beforeAll() + sparkSession.sessionState.catalog.reset() sparkSession.metadataHive.reset() } From b46e58b74c82dac37b7b92284ea3714919c5a886 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 11 Jan 2018 22:33:42 +0900 Subject: [PATCH 067/774] [SPARK-19732][FOLLOW-UP] Document behavior changes made in na.fill and fillna ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/18164 introduces the behavior changes. We need to document it. ## How was this patch tested? N/A Author: gatorsmile Closes #20234 from gatorsmile/docBehaviorChange. --- docs/sql-programming-guide.md | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 72f79d6909ecc..258c769ff593b 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1788,12 +1788,10 @@ options. Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other combinations of scales and precisions because currently we only infer decimal type like `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type. - In PySpark, now we need Pandas 0.19.2 or upper if you want to use Pandas related functionalities, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc. - In PySpark, the behavior of timestamp values for Pandas related functionalities was changed to respect session timezone. If you want to use the old behavior, you need to set a configuration `spark.sql.execution.pandas.respectSessionTimeZone` to `False`. See [SPARK-22395](https://issues.apache.org/jira/browse/SPARK-22395) for details. - - - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489). - - - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. - - - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. + - In PySpark, `na.fill()` or `fillna` also accepts boolean and replaces nulls with booleans. In prior Spark versions, PySpark just ignores it and returns the original Dataset/DataFrame. + - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489). + - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. + - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. ## Upgrading From Spark SQL 2.1 to 2.2 From 6d230dccf65300651f989392159d84bfaf08f18f Mon Sep 17 00:00:00 2001 From: FanDonglai Date: Thu, 11 Jan 2018 09:06:40 -0600 Subject: [PATCH 068/774] Update PageRank.scala ## What changes were proposed in this pull request? Hi, acording to code below, "if (id == src) (0.0, Double.NegativeInfinity) else (0.0, 0.0)" I think the comment can be wrong ## How was this patch tested? Please review http://spark.apache.org/contributing.html before opening a pull request. Author: FanDonglai Closes #20220 from ddna1021/master. --- .../src/main/scala/org/apache/spark/graphx/lib/PageRank.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index fd7b7f7c1c487..ebd65e8320e5c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -303,7 +303,7 @@ object PageRank extends Logging { val src: VertexId = srcId.getOrElse(-1L) // Initialize the pagerankGraph with each edge attribute - // having weight 1/outDegree and each vertex with attribute 1.0. + // having weight 1/outDegree and each vertex with attribute 0. val pagerankGraph: Graph[(Double, Double), Double] = graph // Associate the degree with each vertex .outerJoinVertices(graph.outDegrees) { From 0b2eefb674151a0af64806728b38d9410da552ec Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 11 Jan 2018 10:37:35 -0800 Subject: [PATCH 069/774] [SPARK-22994][K8S] Use a single image for all Spark containers. This change allows a user to submit a Spark application on kubernetes having to provide a single image, instead of one image for each type of container. The image's entry point now takes an extra argument that identifies the process that is being started. The configuration still allows the user to provide different images for each container type if they so desire. On top of that, the entry point was simplified a bit to share more code; mainly, the same env variable is used to propagate the user-defined classpath to the different containers. Aside from being modified to match the new behavior, the 'build-push-docker-images.sh' script was renamed to 'docker-image-tool.sh' to more closely match its purpose; the old name was a little awkward and now also not entirely correct, since there is a single image. It was also moved to 'bin' since it's not necessarily an admin tool. Docs have been updated to match the new behavior. Tested locally with minikube. Author: Marcelo Vanzin Closes #20192 from vanzin/SPARK-22994. --- .../docker-image-tool.sh | 68 ++++++------- docs/running-on-kubernetes.md | 58 +++++------ .../org/apache/spark/deploy/k8s/Config.scala | 17 ++-- .../apache/spark/deploy/k8s/Constants.scala | 3 +- .../deploy/k8s/InitContainerBootstrap.scala | 1 + .../steps/BasicDriverConfigurationStep.scala | 3 +- .../cluster/k8s/ExecutorPodFactory.scala | 3 +- .../DriverConfigOrchestratorSuite.scala | 12 +-- .../BasicDriverConfigurationStepSuite.scala | 4 +- ...InitContainerConfigOrchestratorSuite.scala | 4 +- .../cluster/k8s/ExecutorPodFactorySuite.scala | 4 +- .../src/main/dockerfiles/driver/Dockerfile | 35 ------- .../src/main/dockerfiles/executor/Dockerfile | 35 ------- .../dockerfiles/init-container/Dockerfile | 24 ----- .../main/dockerfiles/spark-base/entrypoint.sh | 37 ------- .../{spark-base => spark}/Dockerfile | 10 +- .../src/main/dockerfiles/spark/entrypoint.sh | 97 +++++++++++++++++++ 17 files changed, 189 insertions(+), 226 deletions(-) rename sbin/build-push-docker-images.sh => bin/docker-image-tool.sh (63%) delete mode 100644 resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile delete mode 100644 resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile delete mode 100644 resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile delete mode 100755 resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh rename resource-managers/kubernetes/docker/src/main/dockerfiles/{spark-base => spark}/Dockerfile (87%) create mode 100755 resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh diff --git a/sbin/build-push-docker-images.sh b/bin/docker-image-tool.sh similarity index 63% rename from sbin/build-push-docker-images.sh rename to bin/docker-image-tool.sh index b9532597419a5..071406336d1b1 100755 --- a/sbin/build-push-docker-images.sh +++ b/bin/docker-image-tool.sh @@ -24,29 +24,11 @@ function error { exit 1 } -# Detect whether this is a git clone or a Spark distribution and adjust paths -# accordingly. if [ -z "${SPARK_HOME}" ]; then SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" fi . "${SPARK_HOME}/bin/load-spark-env.sh" -if [ -f "$SPARK_HOME/RELEASE" ]; then - IMG_PATH="kubernetes/dockerfiles" - SPARK_JARS="jars" -else - IMG_PATH="resource-managers/kubernetes/docker/src/main/dockerfiles" - SPARK_JARS="assembly/target/scala-$SPARK_SCALA_VERSION/jars" -fi - -if [ ! -d "$IMG_PATH" ]; then - error "Cannot find docker images. This script must be run from a runnable distribution of Apache Spark." -fi - -declare -A path=( [spark-driver]="$IMG_PATH/driver/Dockerfile" \ - [spark-executor]="$IMG_PATH/executor/Dockerfile" \ - [spark-init]="$IMG_PATH/init-container/Dockerfile" ) - function image_ref { local image="$1" local add_repo="${2:-1}" @@ -60,35 +42,49 @@ function image_ref { } function build { - docker build \ - --build-arg "spark_jars=$SPARK_JARS" \ - --build-arg "img_path=$IMG_PATH" \ - -t spark-base \ - -f "$IMG_PATH/spark-base/Dockerfile" . - for image in "${!path[@]}"; do - docker build -t "$(image_ref $image)" -f ${path[$image]} . - done + local BUILD_ARGS + local IMG_PATH + + if [ ! -f "$SPARK_HOME/RELEASE" ]; then + # Set image build arguments accordingly if this is a source repo and not a distribution archive. + IMG_PATH=resource-managers/kubernetes/docker/src/main/dockerfiles + BUILD_ARGS=( + --build-arg + img_path=$IMG_PATH + --build-arg + spark_jars=assembly/target/scala-$SPARK_SCALA_VERSION/jars + ) + else + # Not passed as an argument to docker, but used to validate the Spark directory. + IMG_PATH="kubernetes/dockerfiles" + fi + + if [ ! -d "$IMG_PATH" ]; then + error "Cannot find docker image. This script must be run from a runnable distribution of Apache Spark." + fi + + docker build "${BUILD_ARGS[@]}" \ + -t $(image_ref spark) \ + -f "$IMG_PATH/spark/Dockerfile" . } function push { - for image in "${!path[@]}"; do - docker push "$(image_ref $image)" - done + docker push "$(image_ref spark)" } function usage { cat < -t my-tag build - ./sbin/build-push-docker-images.sh -r -t my-tag push - -Docker files are under the `kubernetes/dockerfiles/` directory and can be customized further before -building using the supplied script, or manually. + ./bin/docker-image-tool.sh -r -t my-tag build + ./bin/docker-image-tool.sh -r -t my-tag push ## Cluster Mode @@ -79,8 +76,7 @@ $ bin/spark-submit \ --name spark-pi \ --class org.apache.spark.examples.SparkPi \ --conf spark.executor.instances=5 \ - --conf spark.kubernetes.driver.container.image= \ - --conf spark.kubernetes.executor.container.image= \ + --conf spark.kubernetes.container.image= \ local:///path/to/examples.jar ``` @@ -126,13 +122,7 @@ Those dependencies can be added to the classpath by referencing them with `local ### Using Remote Dependencies When there are application dependencies hosted in remote locations like HDFS or HTTP servers, the driver and executor pods need a Kubernetes [init-container](https://kubernetes.io/docs/concepts/workloads/pods/init-containers/) for downloading -the dependencies so the driver and executor containers can use them locally. This requires users to specify the container -image for the init-container using the configuration property `spark.kubernetes.initContainer.image`. For example, users -simply add the following option to the `spark-submit` command to specify the init-container image: - -``` ---conf spark.kubernetes.initContainer.image= -``` +the dependencies so the driver and executor containers can use them locally. The init-container handles remote dependencies specified in `spark.jars` (or the `--jars` option of `spark-submit`) and `spark.files` (or the `--files` option of `spark-submit`). It also handles remotely hosted main application resources, e.g., @@ -147,9 +137,7 @@ $ bin/spark-submit \ --jars https://path/to/dependency1.jar,https://path/to/dependency2.jar --files hdfs://host:port/path/to/file1,hdfs://host:port/path/to/file2 --conf spark.executor.instances=5 \ - --conf spark.kubernetes.driver.container.image= \ - --conf spark.kubernetes.executor.container.image= \ - --conf spark.kubernetes.initContainer.image= + --conf spark.kubernetes.container.image= \ https://path/to/examples.jar ``` @@ -322,21 +310,27 @@ specific to Spark on Kubernetes. - spark.kubernetes.driver.container.image + spark.kubernetes.container.image (none) - Container image to use for the driver. - This is usually of the form example.com/repo/spark-driver:v1.0.0. - This configuration is required and must be provided by the user. + Container image to use for the Spark application. + This is usually of the form example.com/repo/spark:v1.0.0. + This configuration is required and must be provided by the user, unless explicit + images are provided for each different container type. + + + + spark.kubernetes.driver.container.image + (value of spark.kubernetes.container.image) + + Custom container image to use for the driver. spark.kubernetes.executor.container.image - (none) + (value of spark.kubernetes.container.image) - Container image to use for the executors. - This is usually of the form example.com/repo/spark-executor:v1.0.0. - This configuration is required and must be provided by the user. + Custom container image to use for executors. @@ -643,9 +637,9 @@ specific to Spark on Kubernetes. spark.kubernetes.initContainer.image - (none) + (value of spark.kubernetes.container.image) - Container image for the init-container of the driver and executors for downloading dependencies. This is usually of the form example.com/repo/spark-init:v1.0.0. This configuration is optional and must be provided by the user if any non-container local dependency is used and must be downloaded remotely. + Custom container image for the init container of both driver and executors. diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index e5d79d9a9d9da..471196ac0e3f6 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -29,17 +29,23 @@ private[spark] object Config extends Logging { .stringConf .createWithDefault("default") + val CONTAINER_IMAGE = + ConfigBuilder("spark.kubernetes.container.image") + .doc("Container image to use for Spark containers. Individual container types " + + "(e.g. driver or executor) can also be configured to use different images if desired, " + + "by setting the container type-specific image name.") + .stringConf + .createOptional + val DRIVER_CONTAINER_IMAGE = ConfigBuilder("spark.kubernetes.driver.container.image") .doc("Container image to use for the driver.") - .stringConf - .createOptional + .fallbackConf(CONTAINER_IMAGE) val EXECUTOR_CONTAINER_IMAGE = ConfigBuilder("spark.kubernetes.executor.container.image") .doc("Container image to use for the executors.") - .stringConf - .createOptional + .fallbackConf(CONTAINER_IMAGE) val CONTAINER_IMAGE_PULL_POLICY = ConfigBuilder("spark.kubernetes.container.image.pullPolicy") @@ -148,8 +154,7 @@ private[spark] object Config extends Logging { val INIT_CONTAINER_IMAGE = ConfigBuilder("spark.kubernetes.initContainer.image") .doc("Image for the driver and executor's init-container for downloading dependencies.") - .stringConf - .createOptional + .fallbackConf(CONTAINER_IMAGE) val INIT_CONTAINER_MOUNT_TIMEOUT = ConfigBuilder("spark.kubernetes.mountDependencies.timeout") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 111cb2a3b75e5..9411956996843 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -60,10 +60,9 @@ private[spark] object Constants { val ENV_APPLICATION_ID = "SPARK_APPLICATION_ID" val ENV_EXECUTOR_ID = "SPARK_EXECUTOR_ID" val ENV_EXECUTOR_POD_IP = "SPARK_EXECUTOR_POD_IP" - val ENV_EXECUTOR_EXTRA_CLASSPATH = "SPARK_EXECUTOR_EXTRA_CLASSPATH" val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH" val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_" - val ENV_SUBMIT_EXTRA_CLASSPATH = "SPARK_SUBMIT_EXTRA_CLASSPATH" + val ENV_CLASSPATH = "SPARK_CLASSPATH" val ENV_DRIVER_MAIN_CLASS = "SPARK_DRIVER_CLASS" val ENV_DRIVER_ARGS = "SPARK_DRIVER_ARGS" val ENV_DRIVER_JAVA_OPTS = "SPARK_DRIVER_JAVA_OPTS" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala index dfeccf9e2bd1c..f6a57dfe00171 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala @@ -77,6 +77,7 @@ private[spark] class InitContainerBootstrap( .withMountPath(INIT_CONTAINER_PROPERTIES_FILE_DIR) .endVolumeMount() .addToVolumeMounts(sharedVolumeMounts: _*) + .addToArgs("init") .addToArgs(INIT_CONTAINER_PROPERTIES_FILE_PATH) .build() diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index eca46b84c6066..164e2e5594778 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -66,7 +66,7 @@ private[spark] class BasicDriverConfigurationStep( override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { val driverExtraClasspathEnv = driverExtraClasspath.map { classPath => new EnvVarBuilder() - .withName(ENV_SUBMIT_EXTRA_CLASSPATH) + .withName(ENV_CLASSPATH) .withValue(classPath) .build() } @@ -133,6 +133,7 @@ private[spark] class BasicDriverConfigurationStep( .addToLimits("memory", driverMemoryLimitQuantity) .addToLimits(maybeCpuLimitQuantity.toMap.asJava) .endResources() + .addToArgs("driver") .build() val baseDriverPod = new PodBuilder(driverSpec.driverPod) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index bcacb3934d36a..141bd2827e7c5 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -128,7 +128,7 @@ private[spark] class ExecutorPodFactory( .build() val executorExtraClasspathEnv = executorExtraClasspath.map { cp => new EnvVarBuilder() - .withName(ENV_EXECUTOR_EXTRA_CLASSPATH) + .withName(ENV_CLASSPATH) .withValue(cp) .build() } @@ -181,6 +181,7 @@ private[spark] class ExecutorPodFactory( .endResources() .addAllToEnv(executorEnv.asJava) .withPorts(requiredPorts.asJava) + .addToArgs("executor") .build() val executorPod = new PodBuilder() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala index f193b1f4d3664..65274c6f50e01 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala @@ -34,8 +34,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { private val SECRET_MOUNT_PATH = "/etc/secrets/driver" test("Base submission steps with a main app resource.") { - val sparkConf = new SparkConf(false) - .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) + val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE) val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") val orchestrator = new DriverConfigOrchestrator( APP_ID, @@ -55,8 +54,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { } test("Base submission steps without a main app resource.") { - val sparkConf = new SparkConf(false) - .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) + val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE) val orchestrator = new DriverConfigOrchestrator( APP_ID, LAUNCH_TIME, @@ -75,8 +73,8 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { test("Submission steps with an init-container.") { val sparkConf = new SparkConf(false) - .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) - .set(INIT_CONTAINER_IMAGE, IC_IMAGE) + .set(CONTAINER_IMAGE, DRIVER_IMAGE) + .set(INIT_CONTAINER_IMAGE.key, IC_IMAGE) .set("spark.jars", "hdfs://localhost:9000/var/apps/jars/jar1.jar") val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") val orchestrator = new DriverConfigOrchestrator( @@ -98,7 +96,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { test("Submission steps with driver secrets to mount") { val sparkConf = new SparkConf(false) - .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) + .set(CONTAINER_IMAGE, DRIVER_IMAGE) .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index 8ee629ac8ddc1..b136f2c02ffba 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -47,7 +47,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { .set(KUBERNETES_DRIVER_LIMIT_CORES, "4") .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "256M") .set(org.apache.spark.internal.config.DRIVER_MEMORY_OVERHEAD, 200L) - .set(DRIVER_CONTAINER_IMAGE, "spark-driver:latest") + .set(CONTAINER_IMAGE, "spark-driver:latest") .set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$CUSTOM_ANNOTATION_KEY", CUSTOM_ANNOTATION_VALUE) .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY1", "customDriverEnv1") .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY2", "customDriverEnv2") @@ -79,7 +79,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { .asScala .map(env => (env.getName, env.getValue)) .toMap - assert(envs(ENV_SUBMIT_EXTRA_CLASSPATH) === "/opt/spark/spark-examples.jar") + assert(envs(ENV_CLASSPATH) === "/opt/spark/spark-examples.jar") assert(envs(ENV_DRIVER_MEMORY) === "256M") assert(envs(ENV_DRIVER_MAIN_CLASS) === MAIN_CLASS) assert(envs(ENV_DRIVER_ARGS) === "arg1 arg2 \"arg 3\"") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala index 20f2e5bc15df3..09b42e4484d86 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala @@ -40,7 +40,7 @@ class InitContainerConfigOrchestratorSuite extends SparkFunSuite { test("including basic configuration step") { val sparkConf = new SparkConf(true) - .set(INIT_CONTAINER_IMAGE, DOCKER_IMAGE) + .set(CONTAINER_IMAGE, DOCKER_IMAGE) .set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$CUSTOM_LABEL_KEY", CUSTOM_LABEL_VALUE) val orchestrator = new InitContainerConfigOrchestrator( @@ -59,7 +59,7 @@ class InitContainerConfigOrchestratorSuite extends SparkFunSuite { test("including step to mount user-specified secrets") { val sparkConf = new SparkConf(false) - .set(INIT_CONTAINER_IMAGE, DOCKER_IMAGE) + .set(CONTAINER_IMAGE, DOCKER_IMAGE) .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 7cfbe54c95390..a3c615be031d2 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -54,7 +54,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef baseConf = new SparkConf() .set(KUBERNETES_DRIVER_POD_NAME, driverPodName) .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) - .set(EXECUTOR_CONTAINER_IMAGE, executorImage) + .set(CONTAINER_IMAGE, executorImage) } test("basic executor pod has reasonable defaults") { @@ -107,7 +107,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef checkEnv(executor, Map("SPARK_JAVA_OPT_0" -> "foo=bar", - "SPARK_EXECUTOR_EXTRA_CLASSPATH" -> "bar=baz", + ENV_CLASSPATH -> "bar=baz", "qux" -> "quux")) checkOwnerReferences(executor, driverPodUid) } diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile deleted file mode 100644 index 45fbcd9cd0deb..0000000000000 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile +++ /dev/null @@ -1,35 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -FROM spark-base - -# Before building the docker image, first build and make a Spark distribution following -# the instructions in http://spark.apache.org/docs/latest/building-spark.html. -# If this docker file is being used in the context of building your images from a Spark -# distribution, the docker build command should be invoked from the top level directory -# of the Spark distribution. E.g.: -# docker build -t spark-driver:latest -f kubernetes/dockerfiles/driver/Dockerfile . - -COPY examples /opt/spark/examples - -CMD SPARK_CLASSPATH="${SPARK_HOME}/jars/*" && \ - env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt && \ - readarray -t SPARK_DRIVER_JAVA_OPTS < /tmp/java_opts.txt && \ - if ! [ -z ${SPARK_MOUNTED_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_MOUNTED_CLASSPATH:$SPARK_CLASSPATH"; fi && \ - if ! [ -z ${SPARK_SUBMIT_EXTRA_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_SUBMIT_EXTRA_CLASSPATH:$SPARK_CLASSPATH"; fi && \ - if ! [ -z ${SPARK_MOUNTED_FILES_DIR+x} ]; then cp -R "$SPARK_MOUNTED_FILES_DIR/." .; fi && \ - ${JAVA_HOME}/bin/java "${SPARK_DRIVER_JAVA_OPTS[@]}" -cp "$SPARK_CLASSPATH" -Xms$SPARK_DRIVER_MEMORY -Xmx$SPARK_DRIVER_MEMORY -Dspark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS $SPARK_DRIVER_CLASS $SPARK_DRIVER_ARGS diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile deleted file mode 100644 index 0f806cf7e148e..0000000000000 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile +++ /dev/null @@ -1,35 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -FROM spark-base - -# Before building the docker image, first build and make a Spark distribution following -# the instructions in http://spark.apache.org/docs/latest/building-spark.html. -# If this docker file is being used in the context of building your images from a Spark -# distribution, the docker build command should be invoked from the top level directory -# of the Spark distribution. E.g.: -# docker build -t spark-executor:latest -f kubernetes/dockerfiles/executor/Dockerfile . - -COPY examples /opt/spark/examples - -CMD SPARK_CLASSPATH="${SPARK_HOME}/jars/*" && \ - env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt && \ - readarray -t SPARK_EXECUTOR_JAVA_OPTS < /tmp/java_opts.txt && \ - if ! [ -z ${SPARK_MOUNTED_CLASSPATH}+x} ]; then SPARK_CLASSPATH="$SPARK_MOUNTED_CLASSPATH:$SPARK_CLASSPATH"; fi && \ - if ! [ -z ${SPARK_EXECUTOR_EXTRA_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_EXECUTOR_EXTRA_CLASSPATH:$SPARK_CLASSPATH"; fi && \ - if ! [ -z ${SPARK_MOUNTED_FILES_DIR+x} ]; then cp -R "$SPARK_MOUNTED_FILES_DIR/." .; fi && \ - ${JAVA_HOME}/bin/java "${SPARK_EXECUTOR_JAVA_OPTS[@]}" -Xms$SPARK_EXECUTOR_MEMORY -Xmx$SPARK_EXECUTOR_MEMORY -cp "$SPARK_CLASSPATH" org.apache.spark.executor.CoarseGrainedExecutorBackend --driver-url $SPARK_DRIVER_URL --executor-id $SPARK_EXECUTOR_ID --cores $SPARK_EXECUTOR_CORES --app-id $SPARK_APPLICATION_ID --hostname $SPARK_EXECUTOR_POD_IP diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile deleted file mode 100644 index 047056ab2633b..0000000000000 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ /dev/null @@ -1,24 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -FROM spark-base - -# If this docker file is being used in the context of building your images from a Spark distribution, the docker build -# command should be invoked from the top level directory of the Spark distribution. E.g.: -# docker build -t spark-init:latest -f kubernetes/dockerfiles/init-container/Dockerfile . - -ENTRYPOINT [ "/opt/entrypoint.sh", "/opt/spark/bin/spark-class", "org.apache.spark.deploy.k8s.SparkPodInitContainer" ] diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh deleted file mode 100755 index 82559889f4beb..0000000000000 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh +++ /dev/null @@ -1,37 +0,0 @@ -#!/bin/bash -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# echo commands to the terminal output -set -ex - -# Check whether there is a passwd entry for the container UID -myuid=$(id -u) -mygid=$(id -g) -uidentry=$(getent passwd $myuid) - -# If there is no passwd entry for the container UID, attempt to create one -if [ -z "$uidentry" ] ; then - if [ -w /etc/passwd ] ; then - echo "$myuid:x:$myuid:$mygid:anonymous uid:$SPARK_HOME:/bin/false" >> /etc/passwd - else - echo "Container ENTRYPOINT failed to add passwd entry for anonymous UID" - fi -fi - -# Execute the container CMD under tini for better hygiene -/sbin/tini -s -- "$@" diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile similarity index 87% rename from resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile rename to resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index da1d6b9e161cc..491b7cf692478 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -17,15 +17,15 @@ FROM openjdk:8-alpine -ARG spark_jars -ARG img_path +ARG spark_jars=jars +ARG img_path=kubernetes/dockerfiles # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. # If this docker file is being used in the context of building your images from a Spark # distribution, the docker build command should be invoked from the top level directory # of the Spark distribution. E.g.: -# docker build -t spark-base:latest -f kubernetes/dockerfiles/spark-base/Dockerfile . +# docker build -t spark:latest -f kubernetes/dockerfiles/spark/Dockerfile . RUN set -ex && \ apk upgrade --no-cache && \ @@ -41,7 +41,9 @@ COPY ${spark_jars} /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin COPY conf /opt/spark/conf -COPY ${img_path}/spark-base/entrypoint.sh /opt/ +COPY ${img_path}/spark/entrypoint.sh /opt/ +COPY examples /opt/spark/examples +COPY data /opt/spark/data ENV SPARK_HOME /opt/spark diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh new file mode 100755 index 0000000000000..0c28c75857871 --- /dev/null +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -0,0 +1,97 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# echo commands to the terminal output +set -ex + +# Check whether there is a passwd entry for the container UID +myuid=$(id -u) +mygid=$(id -g) +uidentry=$(getent passwd $myuid) + +# If there is no passwd entry for the container UID, attempt to create one +if [ -z "$uidentry" ] ; then + if [ -w /etc/passwd ] ; then + echo "$myuid:x:$myuid:$mygid:anonymous uid:$SPARK_HOME:/bin/false" >> /etc/passwd + else + echo "Container ENTRYPOINT failed to add passwd entry for anonymous UID" + fi +fi + +SPARK_K8S_CMD="$1" +if [ -z "$SPARK_K8S_CMD" ]; then + echo "No command to execute has been provided." 1>&2 + exit 1 +fi +shift 1 + +SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" +env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt +readarray -t SPARK_DRIVER_JAVA_OPTS < /tmp/java_opts.txt +if [ -n "$SPARK_MOUNTED_CLASSPATH" ]; then + SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_MOUNTED_CLASSPATH" +fi +if [ -n "$SPARK_MOUNTED_FILES_DIR" ]; then + cp -R "$SPARK_MOUNTED_FILES_DIR/." . +fi + +case "$SPARK_K8S_CMD" in + driver) + CMD=( + ${JAVA_HOME}/bin/java + "${SPARK_DRIVER_JAVA_OPTS[@]}" + -cp "$SPARK_CLASSPATH" + -Xms$SPARK_DRIVER_MEMORY + -Xmx$SPARK_DRIVER_MEMORY + -Dspark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS + $SPARK_DRIVER_CLASS + $SPARK_DRIVER_ARGS + ) + ;; + + executor) + CMD=( + ${JAVA_HOME}/bin/java + "${SPARK_EXECUTOR_JAVA_OPTS[@]}" + -Xms$SPARK_EXECUTOR_MEMORY + -Xmx$SPARK_EXECUTOR_MEMORY + -cp "$SPARK_CLASSPATH" + org.apache.spark.executor.CoarseGrainedExecutorBackend + --driver-url $SPARK_DRIVER_URL + --executor-id $SPARK_EXECUTOR_ID + --cores $SPARK_EXECUTOR_CORES + --app-id $SPARK_APPLICATION_ID + --hostname $SPARK_EXECUTOR_POD_IP + ) + ;; + + init) + CMD=( + "$SPARK_HOME/bin/spark-class" + "org.apache.spark.deploy.k8s.SparkPodInitContainer" + "$@" + ) + ;; + + *) + echo "Unknown command: $SPARK_K8S_CMD" 1>&2 + exit 1 +esac + +# Execute the container CMD under tini for better hygiene +exec /sbin/tini -s -- "${CMD[@]}" From 6f7aaed805070d29dcba32e04ca7a1f581fa54b9 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 11 Jan 2018 10:52:12 -0800 Subject: [PATCH 070/774] [SPARK-22908] Add kafka source and sink for continuous processing. ## What changes were proposed in this pull request? Add kafka source and sink for continuous processing. This involves two small changes to the execution engine: * Bring data reader close() into the normal data reader thread to avoid thread safety issues. * Fix up the semantics of the RECONFIGURING StreamExecution state. State updates are now atomic, and we don't have to deal with swallowing an exception. ## How was this patch tested? new unit tests Author: Jose Torres Closes #20096 from jose-torres/continuous-kafka. --- .../sql/kafka010/KafkaContinuousReader.scala | 232 +++++++++ .../sql/kafka010/KafkaContinuousWriter.scala | 119 +++++ .../sql/kafka010/KafkaOffsetReader.scala | 21 +- .../spark/sql/kafka010/KafkaSource.scala | 17 +- .../sql/kafka010/KafkaSourceOffset.scala | 7 +- .../sql/kafka010/KafkaSourceProvider.scala | 105 +++- .../spark/sql/kafka010/KafkaWriteTask.scala | 71 ++- .../spark/sql/kafka010/KafkaWriter.scala | 5 +- .../kafka010/KafkaContinuousSinkSuite.scala | 474 ++++++++++++++++++ .../kafka010/KafkaContinuousSourceSuite.scala | 96 ++++ .../sql/kafka010/KafkaContinuousTest.scala | 64 +++ .../spark/sql/kafka010/KafkaSourceSuite.scala | 470 +++++++++-------- .../apache/spark/sql/DataFrameReader.scala | 32 +- .../apache/spark/sql/DataFrameWriter.scala | 25 +- .../datasources/v2/WriteToDataSourceV2.scala | 8 +- .../execution/streaming/StreamExecution.scala | 15 +- .../ContinuousDataSourceRDDIter.scala | 3 +- .../continuous/ContinuousExecution.scala | 67 ++- .../continuous/EpochCoordinator.scala | 21 +- .../sql/streaming/DataStreamWriter.scala | 26 +- .../spark/sql/streaming/StreamTest.scala | 36 +- 21 files changed, 1531 insertions(+), 383 deletions(-) create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala new file mode 100644 index 0000000000000..928379544758c --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import org.apache.kafka.clients.consumer.ConsumerRecord +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +/** + * A [[ContinuousReader]] for data from kafka. + * + * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be + * read by per-task consumers generated later. + * @param kafkaParams String params for per-task Kafka consumers. + * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceV2Options]] params which + * are not Kafka consumer params. + * @param metadataPath Path to a directory this reader can use for writing metadata. + * @param initialOffsets The Kafka offsets to start reading data at. + * @param failOnDataLoss Flag indicating whether reading should fail in data loss + * scenarios, where some offsets after the specified initial ones can't be + * properly read. + */ +class KafkaContinuousReader( + offsetReader: KafkaOffsetReader, + kafkaParams: ju.Map[String, Object], + sourceOptions: Map[String, String], + metadataPath: String, + initialOffsets: KafkaOffsetRangeLimit, + failOnDataLoss: Boolean) + extends ContinuousReader with SupportsScanUnsafeRow with Logging { + + private lazy val session = SparkSession.getActiveSession.get + private lazy val sc = session.sparkContext + + // Initialized when creating read tasks. If this diverges from the partitions at the latest + // offsets, we need to reconfigure. + // Exposed outside this object only for unit tests. + private[sql] var knownPartitions: Set[TopicPartition] = _ + + override def readSchema: StructType = KafkaOffsetReader.kafkaSchema + + private var offset: Offset = _ + override def setOffset(start: ju.Optional[Offset]): Unit = { + offset = start.orElse { + val offsets = initialOffsets match { + case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) + case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) + case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) + } + logInfo(s"Initial offsets: $offsets") + offsets + } + } + + override def getStartOffset(): Offset = offset + + override def deserializeOffset(json: String): Offset = { + KafkaSourceOffset(JsonUtils.partitionOffsets(json)) + } + + override def createUnsafeRowReadTasks(): ju.List[ReadTask[UnsafeRow]] = { + import scala.collection.JavaConverters._ + + val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) + + val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet + val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) + val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) + + val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"Some partitions were deleted: $deletedPartitions") + } + + val startOffsets = newPartitionOffsets ++ + oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) + knownPartitions = startOffsets.keySet + + startOffsets.toSeq.map { + case (topicPartition, start) => + KafkaContinuousReadTask( + topicPartition, start, kafkaParams, failOnDataLoss) + .asInstanceOf[ReadTask[UnsafeRow]] + }.asJava + } + + /** Stop this source and free any resources it has allocated. */ + def stop(): Unit = synchronized { + offsetReader.close() + } + + override def commit(end: Offset): Unit = {} + + override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { + val mergedMap = offsets.map { + case KafkaSourcePartitionOffset(p, o) => Map(p -> o) + }.reduce(_ ++ _) + KafkaSourceOffset(mergedMap) + } + + override def needsReconfiguration(): Boolean = { + knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions + } + + override def toString(): String = s"KafkaSource[$offsetReader]" + + /** + * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. + * Otherwise, just log a warning. + */ + private def reportDataLoss(message: String): Unit = { + if (failOnDataLoss) { + throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE") + } else { + logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") + } + } +} + +/** + * A read task for continuous Kafka processing. This will be serialized and transformed into a + * full reader on executors. + * + * @param topicPartition The (topic, partition) pair this task is responsible for. + * @param startOffset The offset to start reading from within the partition. + * @param kafkaParams Kafka consumer params to use. + * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets + * are skipped. + */ +case class KafkaContinuousReadTask( + topicPartition: TopicPartition, + startOffset: Long, + kafkaParams: ju.Map[String, Object], + failOnDataLoss: Boolean) extends ReadTask[UnsafeRow] { + override def createDataReader(): KafkaContinuousDataReader = { + new KafkaContinuousDataReader(topicPartition, startOffset, kafkaParams, failOnDataLoss) + } +} + +/** + * A per-task data reader for continuous Kafka processing. + * + * @param topicPartition The (topic, partition) pair this data reader is responsible for. + * @param startOffset The offset to start reading from within the partition. + * @param kafkaParams Kafka consumer params to use. + * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets + * are skipped. + */ +class KafkaContinuousDataReader( + topicPartition: TopicPartition, + startOffset: Long, + kafkaParams: ju.Map[String, Object], + failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { + private val topic = topicPartition.topic + private val kafkaPartition = topicPartition.partition + private val consumer = CachedKafkaConsumer.createUncached(topic, kafkaPartition, kafkaParams) + + private val sharedRow = new UnsafeRow(7) + private val bufferHolder = new BufferHolder(sharedRow) + private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) + + private var nextKafkaOffset = startOffset + private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _ + + override def next(): Boolean = { + var r: ConsumerRecord[Array[Byte], Array[Byte]] = null + while (r == null) { + r = consumer.get( + nextKafkaOffset, + untilOffset = Long.MaxValue, + pollTimeoutMs = Long.MaxValue, + failOnDataLoss) + } + nextKafkaOffset = r.offset + 1 + currentRecord = r + true + } + + override def get(): UnsafeRow = { + bufferHolder.reset() + + if (currentRecord.key == null) { + rowWriter.setNullAt(0) + } else { + rowWriter.write(0, currentRecord.key) + } + rowWriter.write(1, currentRecord.value) + rowWriter.write(2, UTF8String.fromString(currentRecord.topic)) + rowWriter.write(3, currentRecord.partition) + rowWriter.write(4, currentRecord.offset) + rowWriter.write(5, + DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(currentRecord.timestamp))) + rowWriter.write(6, currentRecord.timestampType.id) + sharedRow.setTotalSize(bufferHolder.totalSize) + sharedRow + } + + override def getOffset(): KafkaSourcePartitionOffset = { + KafkaSourcePartitionOffset(topicPartition, nextKafkaOffset) + } + + override def close(): Unit = { + consumer.close() + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala new file mode 100644 index 0000000000000..9843f469c5b25 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.clients.producer.{Callback, ProducerRecord, RecordMetadata} +import scala.collection.JavaConverters._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{kafkaParamsForProducer, TOPIC_OPTION_KEY} +import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.{BinaryType, StringType, StructType} + +/** + * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we + * don't need to really send one. + */ +case object KafkaWriterCommitMessage extends WriterCommitMessage + +/** + * A [[ContinuousWriter]] for Kafka writing. Responsible for generating the writer factory. + * @param topic The topic this writer is responsible for. If None, topic will be inferred from + * a `topic` field in the incoming data. + * @param producerParams Parameters for Kafka producers in each task. + * @param schema The schema of the input data. + */ +class KafkaContinuousWriter( + topic: Option[String], producerParams: Map[String, String], schema: StructType) + extends ContinuousWriter with SupportsWriteInternalRow { + + validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) + + override def createInternalRowWriterFactory(): KafkaContinuousWriterFactory = + KafkaContinuousWriterFactory(topic, producerParams, schema) + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(messages: Array[WriterCommitMessage]): Unit = {} +} + +/** + * A [[DataWriterFactory]] for Kafka writing. Will be serialized and sent to executors to generate + * the per-task data writers. + * @param topic The topic that should be written to. If None, topic will be inferred from + * a `topic` field in the incoming data. + * @param producerParams Parameters for Kafka producers in each task. + * @param schema The schema of the input data. + */ +case class KafkaContinuousWriterFactory( + topic: Option[String], producerParams: Map[String, String], schema: StructType) + extends DataWriterFactory[InternalRow] { + + override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + new KafkaContinuousDataWriter(topic, producerParams, schema.toAttributes) + } +} + +/** + * A [[DataWriter]] for Kafka writing. One data writer will be created in each partition to + * process incoming rows. + * + * @param targetTopic The topic that this data writer is targeting. If None, topic will be inferred + * from a `topic` field in the incoming data. + * @param producerParams Parameters to use for the Kafka producer. + * @param inputSchema The attributes in the input data. + */ +class KafkaContinuousDataWriter( + targetTopic: Option[String], producerParams: Map[String, String], inputSchema: Seq[Attribute]) + extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] { + import scala.collection.JavaConverters._ + + private lazy val producer = CachedKafkaProducer.getOrCreate( + new java.util.HashMap[String, Object](producerParams.asJava)) + + def write(row: InternalRow): Unit = { + checkForErrors() + sendRow(row, producer) + } + + def commit(): WriterCommitMessage = { + // Send is asynchronous, but we can't commit until all rows are actually in Kafka. + // This requires flushing and then checking that no callbacks produced errors. + // We also check for errors before to fail as soon as possible - the check is cheap. + checkForErrors() + producer.flush() + checkForErrors() + KafkaWriterCommitMessage + } + + def abort(): Unit = {} + + def close(): Unit = { + checkForErrors() + if (producer != null) { + producer.flush() + checkForErrors() + CachedKafkaProducer.close(new java.util.HashMap[String, Object](producerParams.asJava)) + } + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 3e65949a6fd1b..551641cfdbca8 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -117,10 +117,14 @@ private[kafka010] class KafkaOffsetReader( * Resolves the specific offsets based on Kafka seek positions. * This method resolves offset value -1 to the latest and -2 to the * earliest Kafka seek position. + * + * @param partitionOffsets the specific offsets to resolve + * @param reportDataLoss callback to either report or log data loss depending on setting */ def fetchSpecificOffsets( - partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = - runUninterruptibly { + partitionOffsets: Map[TopicPartition, Long], + reportDataLoss: String => Unit): KafkaSourceOffset = { + val fetched = runUninterruptibly { withRetriesWithoutInterrupt { // Poll to get the latest assigned partitions consumer.poll(0) @@ -145,6 +149,19 @@ private[kafka010] class KafkaOffsetReader( } } + partitionOffsets.foreach { + case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && + off != KafkaOffsetRangeLimit.EARLIEST => + if (fetched(tp) != off) { + reportDataLoss( + s"startingOffsets for $tp was $off but consumer reset to ${fetched(tp)}") + } + case _ => + // no real way to check that beginning or end is reasonable + } + KafkaSourceOffset(fetched) + } + /** * Fetch the earliest offsets for the topic partitions that are indicated * in the [[ConsumerStrategy]]. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index e9cff04ba5f2e..27da76068a66f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -130,7 +130,7 @@ private[kafka010] class KafkaSource( val offsets = startingOffsets match { case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchEarliestOffsets()) case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets()) - case SpecificOffsetRangeLimit(p) => fetchAndVerify(p) + case SpecificOffsetRangeLimit(p) => kafkaReader.fetchSpecificOffsets(p, reportDataLoss) } metadataLog.add(0, offsets) logInfo(s"Initial offsets: $offsets") @@ -138,21 +138,6 @@ private[kafka010] class KafkaSource( }.partitionToOffsets } - private def fetchAndVerify(specificOffsets: Map[TopicPartition, Long]) = { - val result = kafkaReader.fetchSpecificOffsets(specificOffsets) - specificOffsets.foreach { - case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && - off != KafkaOffsetRangeLimit.EARLIEST => - if (result(tp) != off) { - reportDataLoss( - s"startingOffsets for $tp was $off but consumer reset to ${result(tp)}") - } - case _ => - // no real way to check that beginning or end is reasonable - } - KafkaSourceOffset(result) - } - private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None override def schema: StructType = KafkaOffsetReader.kafkaSchema diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index b5da415b3097e..c82154cfbad7f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -20,17 +20,22 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} +import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2, PartitionOffset} /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and * their offsets. */ private[kafka010] -case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset { +case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends OffsetV2 { override val json = JsonUtils.partitionOffsets(partitionToOffsets) } +private[kafka010] +case class KafkaSourcePartitionOffset(topicPartition: TopicPartition, partitionOffset: Long) + extends PartitionOffset + /** Companion object of the [[KafkaSourceOffset]] */ private[kafka010] object KafkaSourceOffset { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 3cb4d8cad12cc..3914370a96595 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.{Locale, UUID} +import java.util.{Locale, Optional, UUID} import scala.collection.JavaConverters._ @@ -27,9 +27,12 @@ import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} -import org.apache.spark.sql.execution.streaming.{Sink, Source} +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.execution.streaming.{Offset, Sink, Source} import org.apache.spark.sql.sources._ +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -43,6 +46,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider + with ContinuousWriteSupport + with ContinuousReadSupport with Logging { import KafkaSourceProvider._ @@ -101,6 +106,43 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParams)) } + override def createContinuousReader( + schema: Optional[StructType], + metadataPath: String, + options: DataSourceV2Options): KafkaContinuousReader = { + val parameters = options.asMap().asScala.toMap + validateStreamOptions(parameters) + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + val specifiedKafkaParams = + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + + val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, + STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + + val kafkaOffsetReader = new KafkaOffsetReader( + strategy(caseInsensitiveParams), + kafkaParamsForDriver(specifiedKafkaParams), + parameters, + driverGroupIdPrefix = s"$uniqueGroupId-driver") + + new KafkaContinuousReader( + kafkaOffsetReader, + kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), + parameters, + metadataPath, + startingStreamOffsets, + failOnDataLoss(caseInsensitiveParams)) + } + /** * Returns a new base relation with the given parameters. * @@ -181,26 +223,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = { - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " - + "are serialized with ByteArraySerializer.") - } + override def createContinuousWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[ContinuousWriter] = { + import scala.collection.JavaConverters._ - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) - { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " - + "value are serialized with ByteArraySerializer.") - } - parameters - .keySet - .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } - .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, - ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + val spark = SparkSession.getActiveSession.get + val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim) + // We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable. + val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) + + KafkaWriter.validateQuery( + schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic) + + Optional.of(new KafkaContinuousWriter(topic, producerParams, schema)) } private def strategy(caseInsensitiveParams: Map[String, String]) = @@ -450,4 +488,27 @@ private[kafka010] object KafkaSourceProvider extends Logging { def build(): ju.Map[String, Object] = map } + + private[kafka010] def kafkaParamsForProducer( + parameters: Map[String, String]): Map[String, String] = { + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " + + "are serialized with ByteArraySerializer.") + } + + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) + { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " + + "value are serialized with ByteArraySerializer.") + } + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, + ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index 6fd333e2f43ba..baa60febf661d 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -33,10 +33,8 @@ import org.apache.spark.sql.types.{BinaryType, StringType} private[kafka010] class KafkaWriteTask( producerConfiguration: ju.Map[String, Object], inputSchema: Seq[Attribute], - topic: Option[String]) { + topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) { // used to synchronize with Kafka callbacks - @volatile private var failedWrite: Exception = null - private val projection = createProjection private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _ /** @@ -46,23 +44,7 @@ private[kafka010] class KafkaWriteTask( producer = CachedKafkaProducer.getOrCreate(producerConfiguration) while (iterator.hasNext && failedWrite == null) { val currentRow = iterator.next() - val projectedRow = projection(currentRow) - val topic = projectedRow.getUTF8String(0) - val key = projectedRow.getBinary(1) - val value = projectedRow.getBinary(2) - if (topic == null) { - throw new NullPointerException(s"null topic present in the data. Use the " + - s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") - } - val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) - val callback = new Callback() { - override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { - if (failedWrite == null && e != null) { - failedWrite = e - } - } - } - producer.send(record, callback) + sendRow(currentRow, producer) } } @@ -74,8 +56,49 @@ private[kafka010] class KafkaWriteTask( producer = null } } +} + +private[kafka010] abstract class KafkaRowWriter( + inputSchema: Seq[Attribute], topic: Option[String]) { + + // used to synchronize with Kafka callbacks + @volatile protected var failedWrite: Exception = _ + protected val projection = createProjection + + private val callback = new Callback() { + override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { + if (failedWrite == null && e != null) { + failedWrite = e + } + } + } - private def createProjection: UnsafeProjection = { + /** + * Send the specified row to the producer, with a callback that will save any exception + * to failedWrite. Note that send is asynchronous; subclasses must flush() their producer before + * assuming the row is in Kafka. + */ + protected def sendRow( + row: InternalRow, producer: KafkaProducer[Array[Byte], Array[Byte]]): Unit = { + val projectedRow = projection(row) + val topic = projectedRow.getUTF8String(0) + val key = projectedRow.getBinary(1) + val value = projectedRow.getBinary(2) + if (topic == null) { + throw new NullPointerException(s"null topic present in the data. Use the " + + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") + } + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) + producer.send(record, callback) + } + + protected def checkForErrors(): Unit = { + if (failedWrite != null) { + throw failedWrite + } + } + + private def createProjection = { val topicExpression = topic.map(Literal(_)).orElse { inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME) }.getOrElse { @@ -112,11 +135,5 @@ private[kafka010] class KafkaWriteTask( Seq(topicExpression, Cast(keyExpression, BinaryType), Cast(valueExpression, BinaryType)), inputSchema) } - - private def checkForErrors(): Unit = { - if (failedWrite != null) { - throw failedWrite - } - } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 5e9ae35b3f008..15cd44812cb0c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -43,10 +43,9 @@ private[kafka010] object KafkaWriter extends Logging { override def toString: String = "KafkaWriter" def validateQuery( - queryExecution: QueryExecution, + schema: Seq[Attribute], kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { - val schema = queryExecution.analyzed.output schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( if (topic.isEmpty) { throw new AnalysisException(s"topic option required when no " + @@ -84,7 +83,7 @@ private[kafka010] object KafkaWriter extends Logging { kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { val schema = queryExecution.analyzed.output - validateQuery(queryExecution, kafkaParameters, topic) + validateQuery(schema, kafkaParameters, topic) queryExecution.toRdd.foreachPartition { iter => val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) Utils.tryWithSafeFinally(block = writeTask.execute(iter))( diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala new file mode 100644 index 0000000000000..dfc97b1c38bb5 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -0,0 +1,474 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.Locale +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.common.serialization.ByteArraySerializer +import org.scalatest.time.SpanSugar._ +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.types.{BinaryType, DataType} +import org.apache.spark.util.Utils + +/** + * This is a temporary port of KafkaSinkSuite, since we do not yet have a V2 memory stream. + * Once we have one, this will be changed to a specialization of KafkaSinkSuite and we won't have + * to duplicate all the code. + */ +class KafkaContinuousSinkSuite extends KafkaContinuousTest { + import testImplicits._ + + override val streamingTimeout = 30.seconds + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils( + withBrokerProps = Map("auto.create.topics.enable" -> "false")) + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } + + test("streaming - write to kafka with topic field") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = None, + withOutputMode = Some(OutputMode.Append))( + withSelectExpr = s"'$topic' as topic", "value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } finally { + writer.stop() + } + } + + test("streaming - write w/o topic field, with topic option") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Append()))() + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } finally { + writer.stop() + } + } + + test("streaming - topic field and topic option") { + /* The purpose of this test is to ensure that the topic option + * overrides the topic field. We begin by writing some data that + * includes a topic field and value (e.g., 'foo') along with a topic + * option. Then when we read from the topic specified in the option + * we should see the data i.e., the data was written to the topic + * option, and not to the topic in the data e.g., foo + */ + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Append()))( + withSelectExpr = "'foo' as topic", "CAST(value as STRING) value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .selectExpr("CAST(key AS INT)", "CAST(value AS INT)") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } finally { + writer.stop() + } + } + + test("null topic attribute") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + testUtils.createTopic(topic) + + /* No topic field or topic option */ + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = "CAST(null as STRING) as topic", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getCause.getCause.getMessage + .toLowerCase(Locale.ROOT) + .contains("null topic present in the data.")) + } + + test("streaming - write data with bad schema") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + testUtils.createTopic(topic) + + /* No topic field or topic option */ + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = "value as key", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage + .toLowerCase(Locale.ROOT) + .contains("topic option required when no 'topic' attribute is present")) + + try { + /* No value field */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "value as key" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "required attribute 'value' not found")) + } + + test("streaming - write data with valid schema but wrong types") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + .selectExpr("CAST(value as STRING) value") + val topic = newTopic() + testUtils.createTopic(topic) + + var writer: StreamingQuery = null + var ex: Exception = null + try { + /* topic field wrong type */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"CAST('1' as INT) as topic", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string")) + + try { + /* value field wrong type */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "value attribute type must be a string or binarytype")) + + try { + ex = intercept[StreamingQueryException] { + /* key field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "key attribute type must be a string or binarytype")) + } + + test("streaming - write to non-existing topic") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))() + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + } + throw writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) + } + + test("streaming - exception on config serializer") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + testUtils.sendMessages(inputTopic, Array("0")) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .load() + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.key.serializer" -> "foo"))() + writer.processAllAvailable() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'key.serializer' is not supported")) + } finally { + writer.stop() + } + + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.value.serializer" -> "foo"))() + writer.processAllAvailable() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'value.serializer' is not supported")) + } finally { + writer.stop() + } + } + + test("generic - write big data with small producer buffer") { + /* This test ensures that we understand the semantics of Kafka when + * is comes to blocking on a call to send when the send buffer is full. + * This test will configure the smallest possible producer buffer and + * indicate that we should block when it is full. Thus, no exception should + * be thrown in the case of a full buffer. + */ + val topic = newTopic() + testUtils.createTopic(topic, 1) + val options = new java.util.HashMap[String, String] + options.put("bootstrap.servers", testUtils.brokerAddress) + options.put("buffer.memory", "16384") // min buffer size + options.put("block.on.buffer.full", "true") + options.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + val inputSchema = Seq(AttributeReference("value", BinaryType)()) + val data = new Array[Byte](15000) // large value + val writeTask = new KafkaContinuousDataWriter(Some(topic), options.asScala.toMap, inputSchema) + try { + val fieldTypes: Array[DataType] = Array(BinaryType) + val converter = UnsafeProjection.create(fieldTypes) + val row = new SpecificInternalRow(fieldTypes) + row.update(0, data) + val iter = Seq.fill(1000)(converter.apply(row)).iterator + iter.foreach(writeTask.write(_)) + writeTask.commit() + } finally { + writeTask.close() + } + } + + private def createKafkaReader(topic: String): DataFrame = { + spark.read + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("startingOffsets", "earliest") + .option("endingOffsets", "latest") + .option("subscribe", topic) + .load() + } + + private def createKafkaWriter( + input: DataFrame, + withTopic: Option[String] = None, + withOutputMode: Option[OutputMode] = None, + withOptions: Map[String, String] = Map[String, String]()) + (withSelectExpr: String*): StreamingQuery = { + var stream: DataStreamWriter[Row] = null + val checkpointDir = Utils.createTempDir() + var df = input.toDF() + if (withSelectExpr.length > 0) { + df = df.selectExpr(withSelectExpr: _*) + } + stream = df.writeStream + .format("kafka") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + // We need to reduce blocking time to efficiently test non-existent partition behavior. + .option("kafka.max.block.ms", "1000") + .trigger(Trigger.Continuous(1000)) + .queryName("kafkaStream") + withTopic.foreach(stream.option("topic", _)) + withOutputMode.foreach(stream.outputMode(_)) + withOptions.foreach(opt => stream.option(opt._1, opt._2)) + stream.start() + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala new file mode 100644 index 0000000000000..b3dade414f625 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.Properties +import java.util.concurrent.atomic.AtomicInteger + +import org.scalatest.time.SpanSugar._ +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} + +// Run tests in KafkaSourceSuiteBase in continuous execution mode. +class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest + +class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { + import testImplicits._ + + override val brokerProps = Map("auto.create.topics.enable" -> "false") + + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Execute { query => + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + eventually(timeout(streamingTimeout)) { + assert( + query.lastExecution.logical.collectFirst { + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + }.exists { r => + // Ensure the new topic is present and the old topic is gone. + r.knownPartitions.exists(_.topic == topic2) + }, + s"query never reconfigured to new topic $topic2") + } + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } +} + +class KafkaContinuousSourceStressForDontFailOnDataLossSuite + extends KafkaSourceStressForDontFailOnDataLossSuite { + override protected def startStream(ds: Dataset[Int]) = { + ds.writeStream + .format("memory") + .queryName("memory") + .start() + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala new file mode 100644 index 0000000000000..e713e6695d2bd --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.spark.SparkContext +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.streaming.Trigger +import org.apache.spark.sql.test.TestSparkSession + +// Trait to configure StreamTest for kafka continuous execution tests. +trait KafkaContinuousTest extends KafkaSourceTest { + override val defaultTrigger = Trigger.Continuous(1000) + override val defaultUseV2Sink = true + + // We need more than the default local[2] to be able to schedule all partitions simultaneously. + override protected def createSparkSession = new TestSparkSession( + new SparkContext( + "local[10]", + "continuous-stream-test-sql-context", + sparkConf.set("spark.sql.testkey", "true"))) + + // In addition to setting the partitions in Kafka, we have to wait until the query has + // reconfigured to the new count so the test framework can hook in properly. + override protected def setTopicPartitions( + topic: String, newCount: Int, query: StreamExecution) = { + testUtils.addPartitions(topic, newCount) + eventually(timeout(streamingTimeout)) { + assert( + query.lastExecution.logical.collectFirst { + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + }.exists(_.knownPartitions.size == newCount), + s"query never reconfigured to $newCount partitions") + } + } + + test("ensure continuous stream is being used") { + val query = spark.readStream + .format("rate") + .option("numPartitions", "1") + .option("rowsPerSecond", "1") + .load() + + testStream(query)( + Execute(q => assert(q.isInstanceOf[ContinuousExecution])) + ) + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 2034b9be07f24..d66908f86ccc7 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -34,11 +34,14 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext -import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryWriter import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ -import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} +import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest, Trigger} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} import org.apache.spark.util.Utils @@ -49,9 +52,11 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override val streamingTimeout = 30.seconds + protected val brokerProps = Map[String, Object]() + override def beforeAll(): Unit = { super.beforeAll() - testUtils = new KafkaTestUtils + testUtils = new KafkaTestUtils(brokerProps) testUtils.setup() } @@ -59,18 +64,25 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { if (testUtils != null) { testUtils.teardown() testUtils = null - super.afterAll() } + super.afterAll() } protected def makeSureGetOffsetCalled = AssertOnQuery { q => // Because KafkaSource's initialPartitionOffsets is set lazily, we need to make sure - // its "getOffset" is called before pushing any data. Otherwise, because of the race contion, + // its "getOffset" is called before pushing any data. Otherwise, because of the race condition, // we don't know which data should be fetched when `startingOffsets` is latest. - q.processAllAvailable() + q match { + case c: ContinuousExecution => c.awaitEpoch(0) + case m: MicroBatchExecution => m.processAllAvailable() + } true } + protected def setTopicPartitions(topic: String, newCount: Int, query: StreamExecution) : Unit = { + testUtils.addPartitions(topic, newCount) + } + /** * Add data to Kafka. * @@ -82,7 +94,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { message: String = "", topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData { - override def addData(query: Option[StreamExecution]): (Source, Offset) = { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { if (query.get.isActive) { // Make sure no Spark job is running when deleting a topic query.get.processAllAvailable() @@ -97,16 +109,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2)) } - // Read all topics again in case some topics are delete. - val allTopics = testUtils.getAllTopicsAndPartitionSize().toMap.keys require( query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[KafkaSource] => - source.asInstanceOf[KafkaSource] - } + case StreamingExecutionRelation(source: KafkaSource, _) => source + } ++ (query.get.lastExecution match { + case null => Seq() + case e => e.logical.collect { + case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + } + }) if (sources.isEmpty) { throw new Exception( "Could not find Kafka source in the StreamExecution logical plan to add data to") @@ -137,14 +151,158 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override def toString: String = s"AddKafkaData(topics = $topics, data = $data, message = $message)" } -} + private val topicId = new AtomicInteger(0) + protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" +} -class KafkaSourceSuite extends KafkaSourceTest { +class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { import testImplicits._ - private val topicId = new AtomicInteger(0) + test("(de)serialization of initial offsets") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + + testStream(reader.load)( + makeSureGetOffsetCalled, + StopStream, + StartStream(), + StopStream) + } + + test("maxOffsetsPerTrigger") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) + testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) + testUtils.sendMessages(topic, Array("1"), Some(2)) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 10) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + val clock = new StreamManualClock + + val waitUntilBatchProcessed = AssertOnQuery { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + } + + testStream(mapped)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // 1 from smallest, 1 from middle, 8 from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 + ), + StopStream, + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 + ), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, + 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 + ) + ) + } + + test("input row metrics") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + val mapped = kafka.map(kv => kv._2.toInt + 1) + testStream(mapped)( + StartStream(trigger = ProcessingTime(1)), + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + AssertOnQuery { query => + val recordsRead = query.recentProgress.map(_.numInputRows).sum + recordsRead == 3 + } + ) + } + + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Assert { + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + true + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } testWithUninterruptibleThread( "deserialization of initial offset with Spark 2.1.0") { @@ -237,86 +395,51 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - test("(de)serialization of initial offsets") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 64) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", topic) - - testStream(reader.load)( - makeSureGetOffsetCalled, - StopStream, - StartStream(), - StopStream) - } - - test("maxOffsetsPerTrigger") { + test("KafkaSource with watermark") { + val now = System.currentTimeMillis() val topic = newTopic() - testUtils.createTopic(topic, partitions = 3) - testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) - testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) - testUtils.sendMessages(topic, Array("1"), Some(2)) + testUtils.createTopic(newTopic(), partitions = 1) + testUtils.sendMessages(topic, Array(1).map(_.toString)) - val reader = spark + val kafka = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") - .option("maxOffsetsPerTrigger", 10) + .option("startingOffsets", s"earliest") .option("subscribe", topic) - .option("startingOffsets", "earliest") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) - - val clock = new StreamManualClock + .load() - val waitUntilBatchProcessed = AssertOnQuery { q => - eventually(Timeout(streamingTimeout)) { - if (!q.exception.isDefined) { - assert(clock.isStreamWaitingAt(clock.getTimeMillis())) - } - } - if (q.exception.isDefined) { - throw q.exception.get - } - true - } + val windowedAggregation = kafka + .withWatermark("timestamp", "10 seconds") + .groupBy(window($"timestamp", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start") as 'window, $"count") - testStream(mapped)( - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // 1 from smallest, 1 from middle, 8 from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 - ), - StopStream, - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 - ), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, - 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 - ) - ) + val query = windowedAggregation + .writeStream + .format("memory") + .outputMode("complete") + .queryName("kafkaWatermark") + .start() + query.processAllAvailable() + val rows = spark.table("kafkaWatermark").collect() + assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + val row = rows(0) + // We cannot check the exact window start time as it depands on the time that messages were + // inserted by the producer. So here we just use a low bound to make sure the internal + // conversion works. + assert( + row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, + s"Unexpected results: $row") + assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") + query.stop() } +} + +class KafkaSourceSuiteBase extends KafkaSourceTest { + + import testImplicits._ test("cannot stop Kafka stream") { val topic = newTopic() @@ -328,7 +451,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"topic-.*") + .option("subscribePattern", s"$topic.*") val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") @@ -422,65 +545,6 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - test("subscribing topic by pattern with topic deletions") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-seems" - val topic2 = topicPrefix + "-bad" - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"$topicPrefix-.*") - .option("failOnDataLoss", "false") - - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped = kafka.map(kv => kv._2.toInt + 1) - - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - Assert { - testUtils.deleteTopic(topic) - testUtils.createTopic(topic2, partitions = 5) - true - }, - AddKafkaData(Set(topic2), 4, 5, 6), - CheckAnswer(2, 3, 4, 5, 6, 7) - ) - } - - test("starting offset is latest by default") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("0")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", topic) - - val kafka = reader.load() - .selectExpr("CAST(value AS STRING)") - .as[String] - val mapped = kafka.map(_.toInt) - - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(1, 2, 3) // should not have 0 - ) - } - test("bad source options") { def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = { val ex = intercept[IllegalArgumentException] { @@ -540,34 +604,6 @@ class KafkaSourceSuite extends KafkaSourceTest { testUnsupportedConfig("kafka.auto.offset.reset", "latest") } - test("input row metrics") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val kafka = spark - .readStream - .format("kafka") - .option("subscribe", topic) - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - - val mapped = kafka.map(kv => kv._2.toInt + 1) - testStream(mapped)( - StartStream(trigger = ProcessingTime(1)), - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - AssertOnQuery { query => - val recordsRead = query.recentProgress.map(_.numInputRows).sum - recordsRead == 3 - } - ) - } - test("delete a topic when a Spark job is running") { KafkaSourceSuite.collectedData.clear() @@ -629,8 +665,6 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" - private def assignString(topic: String, partitions: Iterable[Int]): String = { JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p))) } @@ -676,6 +710,10 @@ class KafkaSourceSuite extends KafkaSourceTest { testStream(mapped)( makeSureGetOffsetCalled, + Execute { q => + // wait to reach the last offset in every partition + q.awaitOffset(0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L))) + }, CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22), StopStream, StartStream(), @@ -706,6 +744,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .format("memory") .outputMode("append") .queryName("kafkaColumnTypes") + .trigger(defaultTrigger) .start() query.processAllAvailable() val rows = spark.table("kafkaColumnTypes").collect() @@ -723,47 +762,6 @@ class KafkaSourceSuite extends KafkaSourceTest { query.stop() } - test("KafkaSource with watermark") { - val now = System.currentTimeMillis() - val topic = newTopic() - testUtils.createTopic(newTopic(), partitions = 1) - testUtils.sendMessages(topic, Array(1).map(_.toString)) - - val kafka = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("startingOffsets", s"earliest") - .option("subscribe", topic) - .load() - - val windowedAggregation = kafka - .withWatermark("timestamp", "10 seconds") - .groupBy(window($"timestamp", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"window".getField("start") as 'window, $"count") - - val query = windowedAggregation - .writeStream - .format("memory") - .outputMode("complete") - .queryName("kafkaWatermark") - .start() - query.processAllAvailable() - val rows = spark.table("kafkaWatermark").collect() - assert(rows.length === 1, s"Unexpected results: ${rows.toList}") - val row = rows(0) - // We cannot check the exact window start time as it depands on the time that messages were - // inserted by the producer. So here we just use a low bound to make sure the internal - // conversion works. - assert( - row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, - s"Unexpected results: $row") - assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") - query.stop() - } - private def testFromLatestOffsets( topic: String, addPartitions: Boolean, @@ -800,9 +798,7 @@ class KafkaSourceSuite extends KafkaSourceTest { AddKafkaData(Set(topic), 7, 8), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) { - testUtils.addPartitions(topic, 10) - } + if (addPartitions) setTopicPartitions(topic, 10, query) true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -843,9 +839,7 @@ class KafkaSourceSuite extends KafkaSourceTest { StartStream(), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) { - testUtils.addPartitions(topic, 10) - } + if (addPartitions) setTopicPartitions(topic, 10, query) true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -977,20 +971,8 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared } } - test("stress test for failOnDataLoss=false") { - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", "failOnDataLoss.*") - .option("startingOffsets", "earliest") - .option("failOnDataLoss", "false") - .option("fetchOffset.retryIntervalMs", "3000") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { + protected def startStream(ds: Dataset[Int]) = { + ds.writeStream.foreach(new ForeachWriter[Int] { override def open(partitionId: Long, version: Long): Boolean = { true @@ -1004,6 +986,22 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared override def close(errorOrNull: Throwable): Unit = { } }).start() + } + + test("stress test for failOnDataLoss=false") { + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", "failOnDataLoss.*") + .option("startingOffsets", "earliest") + .option("failOnDataLoss", "false") + .option("fetchOffset.retryIntervalMs", "3000") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val query = startStream(kafka.map(kv => kv._2.toInt)) val testTime = 1.minutes val startTime = System.currentTimeMillis() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index e8d683a578f35..b714a46b5f786 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -191,6 +191,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { ds = ds.asInstanceOf[DataSourceV2], conf = sparkSession.sessionState.conf)).asJava) + // Streaming also uses the data source V2 API. So it may be that the data source implements + // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading + // the dataframe as a v1 source. val reader = (ds, userSpecifiedSchema) match { case (ds: ReadSupportWithSchema, Some(schema)) => ds.createReader(schema, options) @@ -208,23 +211,30 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } reader - case _ => - throw new AnalysisException(s"$cls does not support data reading.") + case _ => null // fall back to v1 } - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + if (reader == null) { + loadV1Source(paths: _*) + } else { + Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + } } else { - // Code path for data source v1. - sparkSession.baseRelationToDataFrame( - DataSource.apply( - sparkSession, - paths = paths, - userSpecifiedSchema = userSpecifiedSchema, - className = source, - options = extraOptions.toMap).resolveRelation()) + loadV1Source(paths: _*) } } + private def loadV1Source(paths: String*) = { + // Code path for data source v1. + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap).resolveRelation()) + } + /** * Construct a `DataFrame` representing the database table accessible via JDBC URL * url named table and connection properties. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3304f368e1050..97f12ff625c42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -255,17 +255,24 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } - case _ => throw new AnalysisException(s"$cls does not support data writing.") + // Streaming also uses the data source V2 API. So it may be that the data source implements + // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving + // as though it's a V1 source. + case _ => saveToV1Source() } } else { - // Code path for data source v1. - runCommand(df.sparkSession, "save") { - DataSource( - sparkSession = df.sparkSession, - className = source, - partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) - } + saveToV1Source() + } + } + + private def saveToV1Source(): Unit = { + // Code path for data source v1. + runCommand(df.sparkSession, "save") { + DataSource( + sparkSession = df.sparkSession, + className = source, + partitionColumns = partitioningColumns.getOrElse(Nil), + options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index f0bdf84bb7a84..a4a857f2d4d9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -81,9 +81,11 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) (index, message: WriterCommitMessage) => messages(index) = message ) - logInfo(s"Data source writer $writer is committing.") - writer.commit(messages) - logInfo(s"Data source writer $writer committed.") + if (!writer.isInstanceOf[ContinuousWriter]) { + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") + } } catch { case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] => // Interruption is how continuous queries are ended, so accept and ignore the exception. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 24a8b000df0c1..cf27e1a70650a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -142,7 +142,8 @@ abstract class StreamExecution( override val id: UUID = UUID.fromString(streamMetadata.id) - override val runId: UUID = UUID.randomUUID + override def runId: UUID = currentRunId + protected var currentRunId = UUID.randomUUID /** * Pretty identified string of printing in logs. Format is @@ -418,11 +419,17 @@ abstract class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. */ - private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = { + private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = { assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets - !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset + if (sources == null) { + // sources might not be initialized yet + false + } else { + val source = sources(sourceIndex) + !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset + } } while (notDone) { @@ -436,7 +443,7 @@ abstract class StreamExecution( awaitProgressLock.unlock() } } - logDebug(s"Unblocked at $newOffset for $source") + logDebug(s"Unblocked at $newOffset for ${sources(sourceIndex)}") } /** A flag to indicate that a batch has completed with no new data available. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index d79e4bd65f563..e700aa4f9aea7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -77,7 +77,6 @@ class ContinuousDataSourceRDD( dataReaderThread.start() context.addTaskCompletionListener(_ => { - reader.close() dataReaderThread.interrupt() epochPollExecutor.shutdown() }) @@ -201,6 +200,8 @@ class DataReaderThread( failedFlag.set(true) // Don't rethrow the exception in this thread. It's not needed, and the default Spark // exception handler will kill the executor. + } finally { + reader.close() } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 9657b5e26d770..667410ef9f1c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution.streaming.continuous +import java.util.UUID import java.util.concurrent.TimeUnit +import java.util.function.UnaryOperator import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} @@ -52,7 +54,7 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReader] = Seq.empty + @volatile protected var continuousSources: Seq[ContinuousReader] = _ override protected def sources: Seq[BaseStreamingSource] = continuousSources override lazy val logicalPlan: LogicalPlan = { @@ -78,15 +80,17 @@ class ContinuousExecution( } override protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { - do { - try { - runContinuous(sparkSessionForStream) - } catch { - case _: InterruptedException if state.get().equals(RECONFIGURING) => - // swallow exception and run again - state.set(ACTIVE) + val stateUpdate = new UnaryOperator[State] { + override def apply(s: State) = s match { + // If we ended the query to reconfigure, reset the state to active. + case RECONFIGURING => ACTIVE + case _ => s } - } while (state.get() == ACTIVE) + } + + do { + runContinuous(sparkSessionForStream) + } while (state.updateAndGet(stateUpdate) == ACTIVE) } /** @@ -120,12 +124,16 @@ class ContinuousExecution( } committedOffsets = nextOffsets.toStreamProgress(sources) - // Forcibly align commit and offset logs by slicing off any spurious offset logs from - // a previous run. We can't allow commits to an epoch that a previous run reached but - // this run has not. - offsetLog.purgeAfter(latestEpochId) + // Get to an epoch ID that has definitely never been sent to a sink before. Since sink + // commit happens between offset log write and commit log write, this means an epoch ID + // which is not in the offset log. + val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse { + throw new IllegalStateException( + s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" + + s"an element.") + } + currentBatchId = latestOffsetEpoch + 1 - currentBatchId = latestEpochId + 1 logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets") nextOffsets case None => @@ -141,6 +149,7 @@ class ContinuousExecution( * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with. */ private def runContinuous(sparkSessionForQuery: SparkSession): Unit = { + currentRunId = UUID.randomUUID // A list of attributes that will need to be updated. val replacements = new ArrayBuffer[(Attribute, Attribute)] // Translate from continuous relation to the underlying data source. @@ -225,13 +234,11 @@ class ContinuousExecution( triggerExecutor.execute(() => { startTrigger() - if (reader.needsReconfiguration()) { - state.set(RECONFIGURING) + if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { stopSources() if (queryExecutionThread.isAlive) { sparkSession.sparkContext.cancelJobGroup(runId.toString) queryExecutionThread.interrupt() - // No need to join - this thread is about to end anyway. } false } else if (isActive) { @@ -259,6 +266,7 @@ class ContinuousExecution( sparkSessionForQuery, lastExecution)(lastExecution.toRdd) } } finally { + epochEndpoint.askSync[Unit](StopContinuousExecutionWrites) SparkEnv.get.rpcEnv.stop(epochEndpoint) epochUpdateThread.interrupt() @@ -273,17 +281,22 @@ class ContinuousExecution( epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") - if (partitionOffsets.contains(null)) { - // If any offset is null, that means the corresponding partition hasn't seen any data yet, so - // there's nothing meaningful to add to the offset log. - } val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) - synchronized { - if (queryExecutionThread.isAlive) { - offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) - } else { - return - } + val oldOffset = synchronized { + offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) + offsetLog.get(epoch - 1) + } + + // If offset hasn't changed since last epoch, there's been no new data. + if (oldOffset.contains(OffsetSeq.fill(globalOffset))) { + noNewData = true + } + + awaitProgressLock.lock() + try { + awaitProgressLockCondition.signalAll() + } finally { + awaitProgressLock.unlock() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 98017c3ac6a33..40dcbecade814 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -39,6 +39,15 @@ private[continuous] sealed trait EpochCoordinatorMessage extends Serializable */ private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage +/** + * The RpcEndpoint stop() will wait to clear out the message queue before terminating the + * object. This can lead to a race condition where the query restarts at epoch n, a new + * EpochCoordinator starts at epoch n, and then the old epoch coordinator commits epoch n + 1. + * The framework doesn't provide a handle to wait on the message queue, so we use a synchronous + * message to stop any writes to the ContinuousExecution object. + */ +private[sql] case object StopContinuousExecutionWrites extends EpochCoordinatorMessage + // Init messages /** * Set the reader and writer partition counts. Tasks may not be started until the coordinator @@ -116,6 +125,8 @@ private[continuous] class EpochCoordinator( override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + private var queryWritesStopped: Boolean = false + private var numReaderPartitions: Int = _ private var numWriterPartitions: Int = _ @@ -147,12 +158,16 @@ private[continuous] class EpochCoordinator( partitionCommits.remove(k) } for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) { - partitionCommits.remove(k) + partitionOffsets.remove(k) } } } override def receive: PartialFunction[Any, Unit] = { + // If we just drop these messages, we won't do any writes to the query. The lame duck tasks + // won't shed errors or anything. + case _ if queryWritesStopped => () + case CommitPartitionEpoch(partitionId, epoch, message) => logDebug(s"Got commit from partition $partitionId at epoch $epoch: $message") if (!partitionCommits.isDefinedAt((epoch, partitionId))) { @@ -188,5 +203,9 @@ private[continuous] class EpochCoordinator( case SetWriterPartitions(numPartitions) => numWriterPartitions = numPartitions context.reply(()) + + case StopContinuousExecutionWrites => + queryWritesStopped = true + context.reply(()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index db588ae282f38..b5b4a05ab4973 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} +import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -279,18 +280,29 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val dataSource = - DataSource( - df.sparkSession, - className = source, - options = extraOptions.toMap, - partitionColumns = normalizedParCols.getOrElse(Nil)) + val sink = trigger match { + case _: ContinuousTrigger => + val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) + ds.newInstance() match { + case w: ContinuousWriteSupport => w + case _ => throw new AnalysisException( + s"Data source $source does not support continuous writing") + } + case _ => + val ds = DataSource( + df.sparkSession, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) + ds.createSink(outputMode) + } + df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), df, extraOptions.toMap, - dataSource.createSink(outputMode), + sink, outputMode, useTempCheckpointLocation = source == "console", recoverFromCheckpointLocation = true, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index d46461fa9bf6d..0762895fdc620 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -38,8 +38,9 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ @@ -80,6 +81,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be StateStore.stop() // stop the state store maintenance thread and unload store providers } + protected val defaultTrigger = Trigger.ProcessingTime(0) + protected val defaultUseV2Sink = false + /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds @@ -189,7 +193,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be /** Starts the stream, resuming if data has already been processed. It must not be running. */ case class StartStream( - trigger: Trigger = Trigger.ProcessingTime(0), + trigger: Trigger = defaultTrigger, triggerClock: Clock = new SystemClock, additionalConfs: Map[String, String] = Map.empty, checkpointLocation: String = null) @@ -276,7 +280,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def testStream( _stream: Dataset[_], outputMode: OutputMode = OutputMode.Append, - useV2Sink: Boolean = false)(actions: StreamAction*): Unit = synchronized { + useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = synchronized { import org.apache.spark.sql.streaming.util.StreamManualClock // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently @@ -403,18 +407,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = { verify(currentStream != null, "stream not running") - // Get the map of source index to the current source objects - val indexToSource = currentStream - .logicalPlan - .collect { case StreamingExecutionRelation(s, _) => s } - .zipWithIndex - .map(_.swap) - .toMap // Block until all data added has been processed for all the source awaiting.foreach { case (sourceIndex, offset) => failAfter(streamingTimeout) { - currentStream.awaitOffset(indexToSource(sourceIndex), offset) + currentStream.awaitOffset(sourceIndex, offset) } } @@ -473,6 +470,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be // after starting the query. try { currentStream.awaitInitialization(streamingTimeout.toMillis) + currentStream match { + case s: ContinuousExecution => eventually("IncrementalExecution was not created") { + s.lastExecution.executedPlan // will fail if lastExecution is null + } + case _ => + } } catch { case _: StreamingQueryException => // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well. @@ -600,7 +603,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def findSourceIndex(plan: LogicalPlan): Option[Int] = { plan - .collect { case StreamingExecutionRelation(s, _) => s } + .collect { + case StreamingExecutionRelation(s, _) => s + case DataSourceV2Relation(_, r) => r + } .zipWithIndex .find(_._1 == source) .map(_._2) @@ -613,9 +619,13 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be findSourceIndex(query.logicalPlan) }.orElse { findSourceIndex(stream.logicalPlan) + }.orElse { + queryToUse.flatMap { q => + findSourceIndex(q.lastExecution.logical) + } }.getOrElse { throw new IllegalArgumentException( - "Could find index of the source to which data was added") + "Could not find index of the source to which data was added") } // Store the expected offset of added data to wait for it later From 186bf8fb2e9ff8a80f3f6bcb5f2a0327fa79a1c9 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Thu, 11 Jan 2018 13:57:15 -0800 Subject: [PATCH 071/774] [SPARK-23046][ML][SPARKR] Have RFormula include VectorSizeHint in pipeline ## What changes were proposed in this pull request? Including VectorSizeHint in RFormula piplelines will allow them to be applied to streaming dataframes. ## How was this patch tested? Unit tests. Author: Bago Amirbekian Closes #20238 from MrBago/rFormulaVectorSize. --- R/pkg/R/mllib_utils.R | 1 + .../apache/spark/ml/feature/RFormula.scala | 18 +++++++-- .../spark/ml/feature/RFormulaSuite.scala | 37 ++++++++++++++++--- 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R index a53c92c2c4815..23dda42c325be 100644 --- a/R/pkg/R/mllib_utils.R +++ b/R/pkg/R/mllib_utils.R @@ -130,3 +130,4 @@ read.ml <- function(path) { stop("Unsupported model: ", jobj) } } + diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 7da3339f8b487..f384ffbf578bc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.linalg.VectorUDT +import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol} import org.apache.spark.ml.util._ @@ -210,8 +210,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) // First we index each string column referenced by the input terms. val indexed: Map[String, String] = resolvedFormula.terms.flatten.distinct.map { term => - dataset.schema(term) match { - case column if column.dataType == StringType => + dataset.schema(term).dataType match { + case _: StringType => val indexCol = tmpColumn("stridx") encoderStages += new StringIndexer() .setInputCol(term) @@ -220,6 +220,18 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) .setHandleInvalid($(handleInvalid)) prefixesToRewrite(indexCol + "_") = term + "_" (term, indexCol) + case _: VectorUDT => + val group = AttributeGroup.fromStructField(dataset.schema(term)) + val size = if (group.size < 0) { + dataset.select(term).first().getAs[Vector](0).size + } else { + group.size + } + encoderStages += new VectorSizeHint(uid) + .setHandleInvalid("optimistic") + .setInputCol(term) + .setSize(size) + (term, term) case _ => (term, term) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 5d09c90ec6dfa..f3f4b5a3d0233 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -17,15 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.sql.{DataFrame, Encoder, Row} import org.apache.spark.sql.types.DoubleType -class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class RFormulaSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -548,4 +548,31 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(result3.collect() === expected3.collect()) assert(result4.collect() === expected4.collect()) } + + test("Use Vectors as inputs to formula.") { + val original = Seq( + (1, 4, Vectors.dense(0.0, 0.0, 4.0)), + (2, 4, Vectors.dense(1.0, 0.0, 4.0)), + (3, 5, Vectors.dense(1.0, 0.0, 5.0)), + (4, 5, Vectors.dense(0.0, 1.0, 5.0)) + ).toDF("id", "a", "b") + val formula = new RFormula().setFormula("id ~ a + b") + val (first +: rest) = Seq("id", "a", "b", "features", "label") + testTransformer[(Int, Int, Vector)](original, formula.fit(original), first, rest: _*) { + case Row(id: Int, a: Int, b: Vector, features: Vector, label: Double) => + assert(label === id) + assert(features.toArray === a +: b.toArray) + } + + val group = new AttributeGroup("b", 3) + val vectorColWithMetadata = original("b").as("b", group.toMetadata()) + val dfWithMetadata = original.withColumn("b", vectorColWithMetadata) + val model = formula.fit(dfWithMetadata) + // model should work even when applied to dataframe without metadata. + testTransformer[(Int, Int, Vector)](original, model, first, rest: _*) { + case Row(id: Int, a: Int, b: Vector, features: Vector, label: Double) => + assert(label === id) + assert(features.toArray === a +: b.toArray) + } + } } From b5042d75c2faa5f15bc1e160d75f06dfdd6eea37 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 11 Jan 2018 16:20:30 -0800 Subject: [PATCH 072/774] [SPARK-23008][ML] OnehotEncoderEstimator python API ## What changes were proposed in this pull request? OnehotEncoderEstimator python API. ## How was this patch tested? doctest Author: WeichenXu Closes #20209 from WeichenXu123/ohe_py. --- python/pyspark/ml/feature.py | 113 ++++++++++++++++++ .../ml/param/_shared_params_code_gen.py | 1 + python/pyspark/ml/param/shared.py | 23 ++++ 3 files changed, 137 insertions(+) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 13bf95cce40be..b963e45dd7cff 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -45,6 +45,7 @@ 'NGram', 'Normalizer', 'OneHotEncoder', + 'OneHotEncoderEstimator', 'OneHotEncoderModel', 'PCA', 'PCAModel', 'PolynomialExpansion', 'QuantileDiscretizer', @@ -1641,6 +1642,118 @@ def getDropLast(self): return self.getOrDefault(self.dropLast) +@inherit_doc +class OneHotEncoderEstimator(JavaEstimator, HasInputCols, HasOutputCols, HasHandleInvalid, + JavaMLReadable, JavaMLWritable): + """ + A one-hot encoder that maps a column of category indices to a column of binary vectors, with + at most a single one-value per row that indicates the input category index. + For example with 5 categories, an input value of 2.0 would map to an output vector of + `[0.0, 0.0, 1.0, 0.0]`. + The last category is not included by default (configurable via `dropLast`), + because it makes the vector entries sum up to one, and hence linearly dependent. + So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. + + Note: This is different from scikit-learn's OneHotEncoder, which keeps all categories. + The output vectors are sparse. + + When `handleInvalid` is configured to 'keep', an extra "category" indicating invalid values is + added as last category. So when `dropLast` is true, invalid values are encoded as all-zeros + vector. + + Note: When encoding multi-column by using `inputCols` and `outputCols` params, input/output + cols come in pairs, specified by the order in the arrays, and each pair is treated + independently. + + See `StringIndexer` for converting categorical values into category indices + + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame([(0.0,), (1.0,), (2.0,)], ["input"]) + >>> ohe = OneHotEncoderEstimator(inputCols=["input"], outputCols=["output"]) + >>> model = ohe.fit(df) + >>> model.transform(df).head().output + SparseVector(2, {0: 1.0}) + >>> ohePath = temp_path + "/oheEstimator" + >>> ohe.save(ohePath) + >>> loadedOHE = OneHotEncoderEstimator.load(ohePath) + >>> loadedOHE.getInputCols() == ohe.getInputCols() + True + >>> modelPath = temp_path + "/ohe-model" + >>> model.save(modelPath) + >>> loadedModel = OneHotEncoderModel.load(modelPath) + >>> loadedModel.categorySizes == model.categorySizes + True + + .. versionadded:: 2.3.0 + """ + + handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data during " + + "transform(). Options are 'keep' (invalid data presented as an extra " + + "categorical feature) or error (throw an error). Note that this Param " + + "is only used during transform; during fitting, invalid data will " + + "result in an error.", + typeConverter=TypeConverters.toString) + + dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category", + typeConverter=TypeConverters.toBoolean) + + @keyword_only + def __init__(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True): + """ + __init__(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True) + """ + super(OneHotEncoderEstimator, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.feature.OneHotEncoderEstimator", self.uid) + self._setDefault(handleInvalid="error", dropLast=True) + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.3.0") + def setParams(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True): + """ + setParams(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True) + Sets params for this OneHotEncoderEstimator. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("2.3.0") + def setDropLast(self, value): + """ + Sets the value of :py:attr:`dropLast`. + """ + return self._set(dropLast=value) + + @since("2.3.0") + def getDropLast(self): + """ + Gets the value of dropLast or its default value. + """ + return self.getOrDefault(self.dropLast) + + def _create_model(self, java_model): + return OneHotEncoderModel(java_model) + + +class OneHotEncoderModel(JavaModel, JavaMLReadable, JavaMLWritable): + """ + Model fitted by :py:class:`OneHotEncoderEstimator`. + + .. versionadded:: 2.3.0 + """ + + @property + @since("2.3.0") + def categorySizes(self): + """ + Original number of categories for each feature being encoded. + The array contains one value for each input column, in order. + """ + return self._call_java("categorySizes") + + @inherit_doc class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 1d0f60acc6983..db951d81de1e7 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -119,6 +119,7 @@ def get$Name(self): ("inputCol", "input column name.", None, "TypeConverters.toString"), ("inputCols", "input column names.", None, "TypeConverters.toListString"), ("outputCol", "output column name.", "self.uid + '__output'", "TypeConverters.toString"), + ("outputCols", "output column names.", None, "TypeConverters.toListString"), ("numFeatures", "number of features.", None, "TypeConverters.toInt"), ("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " + "E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: " + diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 813f7a59f3fd1..474c38764e5a1 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -256,6 +256,29 @@ def getOutputCol(self): return self.getOrDefault(self.outputCol) +class HasOutputCols(Params): + """ + Mixin for param outputCols: output column names. + """ + + outputCols = Param(Params._dummy(), "outputCols", "output column names.", typeConverter=TypeConverters.toListString) + + def __init__(self): + super(HasOutputCols, self).__init__() + + def setOutputCols(self, value): + """ + Sets the value of :py:attr:`outputCols`. + """ + return self._set(outputCols=value) + + def getOutputCols(self): + """ + Gets the value of outputCols or its default value. + """ + return self.getOrDefault(self.outputCols) + + class HasNumFeatures(Params): """ Mixin for param numFeatures: number of features. From cbe7c6fbf9dc2fc422b93b3644c40d449a869eea Mon Sep 17 00:00:00 2001 From: ho3rexqj Date: Fri, 12 Jan 2018 15:27:00 +0800 Subject: [PATCH 073/774] [SPARK-22986][CORE] Use a cache to avoid instantiating multiple instances of broadcast variable values When resources happen to be constrained on an executor the first time a broadcast variable is instantiated it is persisted to disk by the BlockManager. Consequently, every subsequent call to TorrentBroadcast::readBroadcastBlock from other instances of that broadcast variable spawns another instance of the underlying value. That is, broadcast variables are spawned once per executor **unless** memory is constrained, in which case every instance of a broadcast variable is provided with a unique copy of the underlying value. This patch fixes the above by explicitly caching the underlying values using weak references in a ReferenceMap. Author: ho3rexqj Closes #20183 from ho3rexqj/fix/cache-broadcast-values. --- .../spark/broadcast/BroadcastManager.scala | 6 ++ .../spark/broadcast/TorrentBroadcast.scala | 72 +++++++++++-------- .../spark/broadcast/BroadcastSuite.scala | 34 +++++++++ 3 files changed, 83 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index e88988fe03b2e..8d7a4a353a792 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -21,6 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import scala.reflect.ClassTag +import org.apache.commons.collections.map.{AbstractReferenceMap, ReferenceMap} + import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging @@ -52,6 +54,10 @@ private[spark] class BroadcastManager( private val nextBroadcastId = new AtomicLong(0) + private[broadcast] val cachedValues = { + new ReferenceMap(AbstractReferenceMap.HARD, AbstractReferenceMap.WEAK) + } + def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = { broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 7aecd3c9668ea..e125095cf4777 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -206,36 +206,50 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) private def readBroadcastBlock(): T = Utils.tryOrIOException { TorrentBroadcast.synchronized { - setConf(SparkEnv.get.conf) - val blockManager = SparkEnv.get.blockManager - blockManager.getLocalValues(broadcastId) match { - case Some(blockResult) => - if (blockResult.data.hasNext) { - val x = blockResult.data.next().asInstanceOf[T] - releaseLock(broadcastId) - x - } else { - throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId") - } - case None => - logInfo("Started reading broadcast variable " + id) - val startTimeMs = System.currentTimeMillis() - val blocks = readBlocks() - logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) - - try { - val obj = TorrentBroadcast.unBlockifyObject[T]( - blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) - // Store the merged copy in BlockManager so other tasks on this executor don't - // need to re-fetch it. - val storageLevel = StorageLevel.MEMORY_AND_DISK - if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { - throw new SparkException(s"Failed to store $broadcastId in BlockManager") + val broadcastCache = SparkEnv.get.broadcastManager.cachedValues + + Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse { + setConf(SparkEnv.get.conf) + val blockManager = SparkEnv.get.blockManager + blockManager.getLocalValues(broadcastId) match { + case Some(blockResult) => + if (blockResult.data.hasNext) { + val x = blockResult.data.next().asInstanceOf[T] + releaseLock(broadcastId) + + if (x != null) { + broadcastCache.put(broadcastId, x) + } + + x + } else { + throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId") } - obj - } finally { - blocks.foreach(_.dispose()) - } + case None => + logInfo("Started reading broadcast variable " + id) + val startTimeMs = System.currentTimeMillis() + val blocks = readBlocks() + logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) + + try { + val obj = TorrentBroadcast.unBlockifyObject[T]( + blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) + // Store the merged copy in BlockManager so other tasks on this executor don't + // need to re-fetch it. + val storageLevel = StorageLevel.MEMORY_AND_DISK + if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { + throw new SparkException(s"Failed to store $broadcastId in BlockManager") + } + + if (obj != null) { + broadcastCache.put(broadcastId, obj) + } + + obj + } finally { + blocks.foreach(_.dispose()) + } + } } } } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 159629825c677..9ad2e9a5e74ac 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -153,6 +153,40 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio assert(broadcast.value.sum === 10) } + test("One broadcast value instance per executor") { + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName("test") + + sc = new SparkContext(conf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val instances = sc.parallelize(1 to 10) + .map(x => System.identityHashCode(broadcast.value)) + .collect() + .toSet + + assert(instances.size === 1) + } + + test("One broadcast value instance per executor when memory is constrained") { + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName("test") + .set("spark.memory.useLegacyMode", "true") + .set("spark.storage.memoryFraction", "0.0") + + sc = new SparkContext(conf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val instances = sc.parallelize(1 to 10) + .map(x => System.identityHashCode(broadcast.value)) + .collect() + .toSet + + assert(instances.size === 1) + } + /** * Verify the persistence of state associated with a TorrentBroadcast in a local-cluster. * From a7d98d53ceaf69cabaecc6c9113f17438c4e61f6 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 12 Jan 2018 11:27:02 +0200 Subject: [PATCH 074/774] [SPARK-23008][ML][FOLLOW-UP] mark OneHotEncoder python API deprecated ## What changes were proposed in this pull request? mark OneHotEncoder python API deprecated ## How was this patch tested? N/A Author: WeichenXu Closes #20241 from WeichenXu123/mark_ohe_deprecated. --- python/pyspark/ml/feature.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index b963e45dd7cff..eb79b193103e2 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1578,6 +1578,9 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, .. note:: This is different from scikit-learn's OneHotEncoder, which keeps all categories. The output vectors are sparse. + .. note:: Deprecated in 2.3.0. :py:class:`OneHotEncoderEstimator` will be renamed to + :py:class:`OneHotEncoder` and this :py:class:`OneHotEncoder` will be removed in 3.0.0. + .. seealso:: :py:class:`StringIndexer` for converting categorical values into From 505086806997b4331d4a8c2fc5e08345d869a23c Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 12 Jan 2018 18:04:44 +0800 Subject: [PATCH 075/774] [SPARK-23025][SQL] Support Null type in scala reflection ## What changes were proposed in this pull request? Add support for `Null` type in the `schemaFor` method for Scala reflection. ## How was this patch tested? Added UT Author: Marco Gaido Closes #20219 from mgaido91/SPARK-23025. --- .../org/apache/spark/sql/catalyst/ScalaReflection.scala | 4 ++++ .../apache/spark/sql/catalyst/ScalaReflectionSuite.scala | 9 +++++++++ .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 5 +++++ 3 files changed, 18 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 65040f1af4b04..9a4bf0075a178 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -63,6 +63,7 @@ object ScalaReflection extends ScalaReflection { private def dataTypeFor(tpe: `Type`): DataType = cleanUpReflectionObjects { tpe.dealias match { + case t if t <:< definitions.NullTpe => NullType case t if t <:< definitions.IntTpe => IntegerType case t if t <:< definitions.LongTpe => LongType case t if t <:< definitions.DoubleTpe => DoubleType @@ -712,6 +713,9 @@ object ScalaReflection extends ScalaReflection { /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = cleanUpReflectionObjects { tpe.dealias match { + // this must be the first case, since all objects in scala are instances of Null, therefore + // Null type would wrongly match the first of them, which is Option as of now + case t if t <:< definitions.NullTpe => Schema(NullType, nullable = true) case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() Schema(udt, nullable = true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 23e866cdf4917..8c3db48a01f12 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -356,4 +356,13 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(deserializerFor[Int].isInstanceOf[AssertNotNull]) assert(!deserializerFor[String].isInstanceOf[AssertNotNull]) } + + test("SPARK-23025: schemaFor should support Null type") { + val schema = schemaFor[(Int, Null)] + assert(schema === Schema( + StructType(Seq( + StructField("_1", IntegerType, nullable = false), + StructField("_2", NullType, nullable = true))), + nullable = true)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d535896723bd5..54893c184642b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1441,6 +1441,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(e.getCause.isInstanceOf[NullPointerException]) } } + + test("SPARK-23025: Add support for null type in scala reflection") { + val data = Seq(("a", null)) + checkDataset(data.toDS(), data: _*) + } } case class SingleData(id: Int) From f5300fbbe370af3741560f67bfb5ae6f0b0f7bb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20Beaup=C3=A8re?= Date: Fri, 12 Jan 2018 08:29:46 -0600 Subject: [PATCH 076/774] Update rdd-programming-guide.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Small typing correction - double word ## How was this patch tested? Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Matthias Beaupère Closes #20212 from matthiasbe/patch-1. --- docs/rdd-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index 29af159510e46..2e29aef7f21a2 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -91,7 +91,7 @@ so C libraries like NumPy can be used. It also works with PyPy 2.3+. Python 2.6 support was removed in Spark 2.2.0. -Spark applications in Python can either be run with the `bin/spark-submit` script which includes Spark at runtime, or by including including it in your setup.py as: +Spark applications in Python can either be run with the `bin/spark-submit` script which includes Spark at runtime, or by including it in your setup.py as: {% highlight python %} install_requires=[ From 651f76153f5e9b185aaf593161d40cabe7994fea Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 13 Jan 2018 00:37:59 +0800 Subject: [PATCH 077/774] [SPARK-23028] Bump master branch version to 2.4.0-SNAPSHOT ## What changes were proposed in this pull request? This patch bumps the master branch version to `2.4.0-SNAPSHOT`. ## How was this patch tested? N/A Author: gatorsmile Closes #20222 from gatorsmile/bump24. --- R/pkg/DESCRIPTION | 2 +- assembly/pom.xml | 2 +- common/kvstore/pom.xml | 2 +- common/network-common/pom.xml | 2 +- common/network-shuffle/pom.xml | 2 +- common/network-yarn/pom.xml | 2 +- common/sketch/pom.xml | 2 +- common/tags/pom.xml | 2 +- common/unsafe/pom.xml | 2 +- core/pom.xml | 2 +- dev/run-tests-jenkins.py | 4 ++-- docs/_config.yml | 4 ++-- examples/pom.xml | 2 +- external/docker-integration-tests/pom.xml | 2 +- external/flume-assembly/pom.xml | 2 +- external/flume-sink/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka-0-10-assembly/pom.xml | 2 +- external/kafka-0-10-sql/pom.xml | 2 +- external/kafka-0-10/pom.xml | 2 +- external/kafka-0-8-assembly/pom.xml | 2 +- external/kafka-0-8/pom.xml | 2 +- external/kinesis-asl-assembly/pom.xml | 2 +- external/kinesis-asl/pom.xml | 2 +- external/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- hadoop-cloud/pom.xml | 2 +- launcher/pom.xml | 2 +- mllib-local/pom.xml | 2 +- mllib/pom.xml | 2 +- pom.xml | 2 +- project/MimaExcludes.scala | 5 +++++ python/pyspark/version.py | 2 +- repl/pom.xml | 2 +- resource-managers/kubernetes/core/pom.xml | 2 +- resource-managers/mesos/pom.xml | 2 +- resource-managers/yarn/pom.xml | 2 +- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- sql/hive-thriftserver/pom.xml | 2 +- sql/hive/pom.xml | 2 +- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- 43 files changed, 49 insertions(+), 44 deletions(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 6d46c31906260..855eb5bf77f16 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,6 +1,6 @@ Package: SparkR Type: Package -Version: 2.3.0 +Version: 2.4.0 Title: R Frontend for Apache Spark Description: Provides an R Frontend for Apache Spark. Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), diff --git a/assembly/pom.xml b/assembly/pom.xml index b3b4239771bc3..a207dae5a74ff 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml index cf93d41cd77cf..8c148359c3029 100644 --- a/common/kvstore/pom.xml +++ b/common/kvstore/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 18cbdadd224ab..8ca7733507f1b 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 9968480ab7658..05335df61a664 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index ec2db6e5bb88c..564e6583c909e 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 2d59c71cc3757..2f04abe8c7e88 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/tags/pom.xml b/common/tags/pom.xml index f7e586ee777e1..ba127408e1c59 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index a3772a2620088..1527854730394 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 0a5bd958fc9c5..9258a856028a0 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 914eb93622d51..3960a0de62530 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -181,8 +181,8 @@ def main(): short_commit_hash = ghprb_actual_commit[0:7] # format: http://linux.die.net/man/1/timeout - # must be less than the timeout configured on Jenkins (currently 300m) - tests_timeout = "250m" + # must be less than the timeout configured on Jenkins (currently 350m) + tests_timeout = "300m" # Array to capture all test names to run on the pull request. These tests are represented # by their file equivalents in the dev/tests/ directory. diff --git a/docs/_config.yml b/docs/_config.yml index dcc211204d766..095fadb93fe5d 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 2.3.0-SNAPSHOT -SPARK_VERSION_SHORT: 2.3.0 +SPARK_VERSION: 2.4.0-SNAPSHOT +SPARK_VERSION_SHORT: 2.4.0 SCALA_BINARY_VERSION: "2.11" SCALA_VERSION: "2.11.8" MESOS_VERSION: 1.0.0 diff --git a/examples/pom.xml b/examples/pom.xml index 1791dbaad775e..868110b8e35ef 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 485b562dce990..431339d412194 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index 71016bc645ca7..7cd1ec4c9c09a 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 12630840e79dc..f810aa80e8780 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 87a09642405a7..498e88f665eb5 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index d6f97316b326a..a742b8d6dbddb 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 0c9f0aa765a39..16bbc6db641ca 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index 6eb7ba5f0092d..3b124b2a69d50 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml index 786349474389b..41bc8b3e3ee1f 100644 --- a/external/kafka-0-8-assembly/pom.xml +++ b/external/kafka-0-8-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8/pom.xml b/external/kafka-0-8/pom.xml index 849c8b465f99e..6d1c4789f382d 100644 --- a/external/kafka-0-8/pom.xml +++ b/external/kafka-0-8/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index 48783d65826aa..37c7d1e604ec5 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index 40a751a652fa9..4915893965595 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index 36d555066b181..027157e53d511 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index cb30e4a4af4bc..fbe77fcb958d5 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index aa36dd4774d86..8e424b1c50236 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/launcher/pom.xml b/launcher/pom.xml index e9b46c4cf0ffa..912eb6b6d2a08 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index 043d13609fd26..53286fe93478d 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index a906c9e02cd4c..f07d7f24fd312 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/pom.xml b/pom.xml index 1b37164376460..d14594aa4ccb0 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 3b452f35c5ec1..32eb31f495979 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -34,6 +34,10 @@ import com.typesafe.tools.mima.core.ProblemFilters._ */ object MimaExcludes { + // Exclude rules for 2.4.x + lazy val v24excludes = v23excludes ++ Seq( + ) + // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( // [SPARK-22897] Expose stageAttemptId in TaskContext @@ -1082,6 +1086,7 @@ object MimaExcludes { } def excludes(version: String) = version match { + case v if v.startsWith("2.4") => v24excludes case v if v.startsWith("2.3") => v23excludes case v if v.startsWith("2.2") => v22excludes case v if v.startsWith("2.1") => v21excludes diff --git a/python/pyspark/version.py b/python/pyspark/version.py index 12dd53b9d2902..b9c2c4ced71d5 100644 --- a/python/pyspark/version.py +++ b/python/pyspark/version.py @@ -16,4 +16,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.3.0.dev0" +__version__ = "2.4.0.dev0" diff --git a/repl/pom.xml b/repl/pom.xml index 1cb0098d0eca3..6f4a863c48bc7 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index 7d35aea8a4142..a62f271273465 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../../pom.xml diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index 70d0c1750b14e..3995d0afeb5f4 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index 43a7ce95bd3de..37e25ceecb883 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 9e2ced30407d4..839b929abd3cb 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 93010c606cf45..744daa6079779 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 3135a8a275dae..9f247f9224c75 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 66fad85ea0263..c55ba32fa458c 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index fea882ad11230..4497e53b65984 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 37427e8da62d8..242219e29f50f 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml From 7bd14cfd40500a0b6462cda647bdbb686a430328 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 12 Jan 2018 10:18:42 -0800 Subject: [PATCH 078/774] [MINOR][BUILD] Fix Java linter errors ## What changes were proposed in this pull request? This PR cleans up the java-lint errors (for v2.3.0-rc1 tag). Hopefully, this will be the final one. ``` $ dev/lint-java Using `mvn` from path: /usr/local/bin/mvn Checkstyle checks failed at following occurrences: [ERROR] src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java:[85] (sizes) LineLength: Line is longer than 100 characters (found 101). [ERROR] src/main/java/org/apache/spark/launcher/InProcessAppHandle.java:[20,8] (imports) UnusedImports: Unused import - java.io.IOException. [ERROR] src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java:[41,9] (modifier) ModifierOrder: 'private' modifier out of order with the JLS suggestions. [ERROR] src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java:[464] (sizes) LineLength: Line is longer than 100 characters (found 102). ``` ## How was this patch tested? Manual. ``` $ dev/lint-java Using `mvn` from path: /usr/local/bin/mvn Checkstyle checks passed. ``` Author: Dongjoon Hyun Closes #20242 from dongjoon-hyun/fix_lint_java_2.3_rc1. --- .../org/apache/spark/unsafe/memory/HeapMemoryAllocator.java | 3 ++- .../java/org/apache/spark/launcher/InProcessAppHandle.java | 1 - .../spark/sql/execution/datasources/orc/OrcColumnVector.java | 2 +- .../java/test/org/apache/spark/sql/JavaDataFrameSuite.java | 3 ++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index 3acfe3696cb1e..a9603c1aba051 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -82,7 +82,8 @@ public void free(MemoryBlock memory) { "page has already been freed"; assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : - "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator free()"; + "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator " + + "free()"; final long size = memory.size(); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index 0d6a73a3da3ed..acd64c962604f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -17,7 +17,6 @@ package org.apache.spark.launcher; -import java.io.IOException; import java.lang.reflect.Method; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index f94c55d860304..b6e792274da11 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -38,7 +38,7 @@ public class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVecto private BytesColumnVector bytesData; private DecimalColumnVector decimalData; private TimestampColumnVector timestampData; - final private boolean isTimestamp; + private final boolean isTimestamp; private int batchSize; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 4f8a31f185724..69a2904f5f3fe 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -461,7 +461,8 @@ public void testCircularReferenceBean() { public void testUDF() { UserDefinedFunction foo = udf((Integer i, String s) -> i.toString() + s, DataTypes.StringType); Dataset df = spark.table("testData").select(foo.apply(col("key"), col("value"))); - String[] result = df.collectAsList().stream().map(row -> row.getString(0)).toArray(String[]::new); + String[] result = df.collectAsList().stream().map(row -> row.getString(0)) + .toArray(String[]::new); String[] expected = spark.table("testData").collectAsList().stream() .map(row -> row.get(0).toString() + row.getString(1)).toArray(String[]::new); Assert.assertArrayEquals(expected, result); From 54277398afbde92a38ba2802f4a7a3e5910533de Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 12 Jan 2018 11:25:37 -0800 Subject: [PATCH 079/774] [SPARK-22975][SS] MetricsReporter should not throw exception when there was no progress reported ## What changes were proposed in this pull request? `MetricsReporter ` assumes that there has been some progress for the query, ie. `lastProgress` is not null. If this is not true, as it might happen in particular conditions, a `NullPointerException` can be thrown. The PR checks whether there is a `lastProgress` and if this is not true, it returns a default value for the metrics. ## How was this patch tested? added UT Author: Marco Gaido Closes #20189 from mgaido91/SPARK-22975. --- .../execution/streaming/MetricsReporter.scala | 21 ++++++++--------- .../sql/streaming/StreamingQuerySuite.scala | 23 +++++++++++++++++++ 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala index b84e6ce64c611..66b11ecddf233 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala @@ -17,15 +17,11 @@ package org.apache.spark.sql.execution.streaming -import java.{util => ju} - -import scala.collection.mutable - import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.{Source => CodahaleSource} -import org.apache.spark.util.Clock +import org.apache.spark.sql.streaming.StreamingQueryProgress /** * Serves metrics from a [[org.apache.spark.sql.streaming.StreamingQuery]] to @@ -39,14 +35,17 @@ class MetricsReporter( // Metric names should not have . in them, so that all the metrics of a query are identified // together in Ganglia as a single metric group - registerGauge("inputRate-total", () => stream.lastProgress.inputRowsPerSecond) - registerGauge("processingRate-total", () => stream.lastProgress.processedRowsPerSecond) - registerGauge("latency", () => stream.lastProgress.durationMs.get("triggerExecution").longValue()) - - private def registerGauge[T](name: String, f: () => T)(implicit num: Numeric[T]): Unit = { + registerGauge("inputRate-total", _.inputRowsPerSecond, 0.0) + registerGauge("processingRate-total", _.processedRowsPerSecond, 0.0) + registerGauge("latency", _.durationMs.get("triggerExecution").longValue(), 0L) + + private def registerGauge[T]( + name: String, + f: StreamingQueryProgress => T, + default: T): Unit = { synchronized { metricRegistry.register(name, new Gauge[T] { - override def getValue: T = f() + override def getValue: T = Option(stream.lastProgress).map(f).getOrElse(default) }) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 2fa4595dab376..76201c63a2701 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -424,6 +424,29 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + test("SPARK-22975: MetricsReporter defaults when there was no progress reported") { + withSQLConf("spark.sql.streaming.metricsEnabled" -> "true") { + BlockingSource.latch = new CountDownLatch(1) + withTempDir { tempDir => + val sq = spark.readStream + .format("org.apache.spark.sql.streaming.util.BlockingSource") + .load() + .writeStream + .format("org.apache.spark.sql.streaming.util.BlockingSource") + .option("checkpointLocation", tempDir.toString) + .start() + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + + val gauges = sq.streamMetrics.metricRegistry.getGauges + assert(gauges.get("latency").getValue.asInstanceOf[Long] == 0) + assert(gauges.get("processingRate-total").getValue.asInstanceOf[Double] == 0.0) + assert(gauges.get("inputRate-total").getValue.asInstanceOf[Double] == 0.0) + sq.stop() + } + } + } + test("input row calculation with mixed batch and streaming sources") { val streamingTriggerDF = spark.createDataset(1 to 10).toDF val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") From 55dbfbca37ce4c05f83180777ba3d4fe2d96a02e Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Fri, 12 Jan 2018 15:00:00 -0800 Subject: [PATCH 080/774] Revert "[SPARK-22908] Add kafka source and sink for continuous processing." This reverts commit 6f7aaed805070d29dcba32e04ca7a1f581fa54b9. --- .../sql/kafka010/KafkaContinuousReader.scala | 232 --------- .../sql/kafka010/KafkaContinuousWriter.scala | 119 ----- .../sql/kafka010/KafkaOffsetReader.scala | 21 +- .../spark/sql/kafka010/KafkaSource.scala | 17 +- .../sql/kafka010/KafkaSourceOffset.scala | 7 +- .../sql/kafka010/KafkaSourceProvider.scala | 105 +--- .../spark/sql/kafka010/KafkaWriteTask.scala | 71 +-- .../spark/sql/kafka010/KafkaWriter.scala | 5 +- .../kafka010/KafkaContinuousSinkSuite.scala | 474 ------------------ .../kafka010/KafkaContinuousSourceSuite.scala | 96 ---- .../sql/kafka010/KafkaContinuousTest.scala | 64 --- .../spark/sql/kafka010/KafkaSourceSuite.scala | 470 ++++++++--------- .../apache/spark/sql/DataFrameReader.scala | 32 +- .../apache/spark/sql/DataFrameWriter.scala | 25 +- .../datasources/v2/WriteToDataSourceV2.scala | 8 +- .../execution/streaming/StreamExecution.scala | 15 +- .../ContinuousDataSourceRDDIter.scala | 3 +- .../continuous/ContinuousExecution.scala | 67 +-- .../continuous/EpochCoordinator.scala | 21 +- .../sql/streaming/DataStreamWriter.scala | 26 +- .../spark/sql/streaming/StreamTest.scala | 36 +- 21 files changed, 383 insertions(+), 1531 deletions(-) delete mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala delete mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala delete mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala delete mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala delete mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala deleted file mode 100644 index 928379544758c..0000000000000 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ /dev/null @@ -1,232 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import java.{util => ju} - -import org.apache.kafka.clients.consumer.ConsumerRecord -import org.apache.kafka.common.TopicPartition - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.types.UTF8String - -/** - * A [[ContinuousReader]] for data from kafka. - * - * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be - * read by per-task consumers generated later. - * @param kafkaParams String params for per-task Kafka consumers. - * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceV2Options]] params which - * are not Kafka consumer params. - * @param metadataPath Path to a directory this reader can use for writing metadata. - * @param initialOffsets The Kafka offsets to start reading data at. - * @param failOnDataLoss Flag indicating whether reading should fail in data loss - * scenarios, where some offsets after the specified initial ones can't be - * properly read. - */ -class KafkaContinuousReader( - offsetReader: KafkaOffsetReader, - kafkaParams: ju.Map[String, Object], - sourceOptions: Map[String, String], - metadataPath: String, - initialOffsets: KafkaOffsetRangeLimit, - failOnDataLoss: Boolean) - extends ContinuousReader with SupportsScanUnsafeRow with Logging { - - private lazy val session = SparkSession.getActiveSession.get - private lazy val sc = session.sparkContext - - // Initialized when creating read tasks. If this diverges from the partitions at the latest - // offsets, we need to reconfigure. - // Exposed outside this object only for unit tests. - private[sql] var knownPartitions: Set[TopicPartition] = _ - - override def readSchema: StructType = KafkaOffsetReader.kafkaSchema - - private var offset: Offset = _ - override def setOffset(start: ju.Optional[Offset]): Unit = { - offset = start.orElse { - val offsets = initialOffsets match { - case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) - case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) - case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) - } - logInfo(s"Initial offsets: $offsets") - offsets - } - } - - override def getStartOffset(): Offset = offset - - override def deserializeOffset(json: String): Offset = { - KafkaSourceOffset(JsonUtils.partitionOffsets(json)) - } - - override def createUnsafeRowReadTasks(): ju.List[ReadTask[UnsafeRow]] = { - import scala.collection.JavaConverters._ - - val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) - - val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet - val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) - val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) - - val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) - if (deletedPartitions.nonEmpty) { - reportDataLoss(s"Some partitions were deleted: $deletedPartitions") - } - - val startOffsets = newPartitionOffsets ++ - oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) - knownPartitions = startOffsets.keySet - - startOffsets.toSeq.map { - case (topicPartition, start) => - KafkaContinuousReadTask( - topicPartition, start, kafkaParams, failOnDataLoss) - .asInstanceOf[ReadTask[UnsafeRow]] - }.asJava - } - - /** Stop this source and free any resources it has allocated. */ - def stop(): Unit = synchronized { - offsetReader.close() - } - - override def commit(end: Offset): Unit = {} - - override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { - val mergedMap = offsets.map { - case KafkaSourcePartitionOffset(p, o) => Map(p -> o) - }.reduce(_ ++ _) - KafkaSourceOffset(mergedMap) - } - - override def needsReconfiguration(): Boolean = { - knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions - } - - override def toString(): String = s"KafkaSource[$offsetReader]" - - /** - * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. - * Otherwise, just log a warning. - */ - private def reportDataLoss(message: String): Unit = { - if (failOnDataLoss) { - throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE") - } else { - logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") - } - } -} - -/** - * A read task for continuous Kafka processing. This will be serialized and transformed into a - * full reader on executors. - * - * @param topicPartition The (topic, partition) pair this task is responsible for. - * @param startOffset The offset to start reading from within the partition. - * @param kafkaParams Kafka consumer params to use. - * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets - * are skipped. - */ -case class KafkaContinuousReadTask( - topicPartition: TopicPartition, - startOffset: Long, - kafkaParams: ju.Map[String, Object], - failOnDataLoss: Boolean) extends ReadTask[UnsafeRow] { - override def createDataReader(): KafkaContinuousDataReader = { - new KafkaContinuousDataReader(topicPartition, startOffset, kafkaParams, failOnDataLoss) - } -} - -/** - * A per-task data reader for continuous Kafka processing. - * - * @param topicPartition The (topic, partition) pair this data reader is responsible for. - * @param startOffset The offset to start reading from within the partition. - * @param kafkaParams Kafka consumer params to use. - * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets - * are skipped. - */ -class KafkaContinuousDataReader( - topicPartition: TopicPartition, - startOffset: Long, - kafkaParams: ju.Map[String, Object], - failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { - private val topic = topicPartition.topic - private val kafkaPartition = topicPartition.partition - private val consumer = CachedKafkaConsumer.createUncached(topic, kafkaPartition, kafkaParams) - - private val sharedRow = new UnsafeRow(7) - private val bufferHolder = new BufferHolder(sharedRow) - private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) - - private var nextKafkaOffset = startOffset - private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _ - - override def next(): Boolean = { - var r: ConsumerRecord[Array[Byte], Array[Byte]] = null - while (r == null) { - r = consumer.get( - nextKafkaOffset, - untilOffset = Long.MaxValue, - pollTimeoutMs = Long.MaxValue, - failOnDataLoss) - } - nextKafkaOffset = r.offset + 1 - currentRecord = r - true - } - - override def get(): UnsafeRow = { - bufferHolder.reset() - - if (currentRecord.key == null) { - rowWriter.setNullAt(0) - } else { - rowWriter.write(0, currentRecord.key) - } - rowWriter.write(1, currentRecord.value) - rowWriter.write(2, UTF8String.fromString(currentRecord.topic)) - rowWriter.write(3, currentRecord.partition) - rowWriter.write(4, currentRecord.offset) - rowWriter.write(5, - DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(currentRecord.timestamp))) - rowWriter.write(6, currentRecord.timestampType.id) - sharedRow.setTotalSize(bufferHolder.totalSize) - sharedRow - } - - override def getOffset(): KafkaSourcePartitionOffset = { - KafkaSourcePartitionOffset(topicPartition, nextKafkaOffset) - } - - override def close(): Unit = { - consumer.close() - } -} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala deleted file mode 100644 index 9843f469c5b25..0000000000000 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import org.apache.kafka.clients.producer.{Callback, ProducerRecord, RecordMetadata} -import scala.collection.JavaConverters._ - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} -import org.apache.spark.sql.kafka010.KafkaSourceProvider.{kafkaParamsForProducer, TOPIC_OPTION_KEY} -import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter -import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{BinaryType, StringType, StructType} - -/** - * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we - * don't need to really send one. - */ -case object KafkaWriterCommitMessage extends WriterCommitMessage - -/** - * A [[ContinuousWriter]] for Kafka writing. Responsible for generating the writer factory. - * @param topic The topic this writer is responsible for. If None, topic will be inferred from - * a `topic` field in the incoming data. - * @param producerParams Parameters for Kafka producers in each task. - * @param schema The schema of the input data. - */ -class KafkaContinuousWriter( - topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends ContinuousWriter with SupportsWriteInternalRow { - - validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) - - override def createInternalRowWriterFactory(): KafkaContinuousWriterFactory = - KafkaContinuousWriterFactory(topic, producerParams, schema) - - override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def abort(messages: Array[WriterCommitMessage]): Unit = {} -} - -/** - * A [[DataWriterFactory]] for Kafka writing. Will be serialized and sent to executors to generate - * the per-task data writers. - * @param topic The topic that should be written to. If None, topic will be inferred from - * a `topic` field in the incoming data. - * @param producerParams Parameters for Kafka producers in each task. - * @param schema The schema of the input data. - */ -case class KafkaContinuousWriterFactory( - topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends DataWriterFactory[InternalRow] { - - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { - new KafkaContinuousDataWriter(topic, producerParams, schema.toAttributes) - } -} - -/** - * A [[DataWriter]] for Kafka writing. One data writer will be created in each partition to - * process incoming rows. - * - * @param targetTopic The topic that this data writer is targeting. If None, topic will be inferred - * from a `topic` field in the incoming data. - * @param producerParams Parameters to use for the Kafka producer. - * @param inputSchema The attributes in the input data. - */ -class KafkaContinuousDataWriter( - targetTopic: Option[String], producerParams: Map[String, String], inputSchema: Seq[Attribute]) - extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] { - import scala.collection.JavaConverters._ - - private lazy val producer = CachedKafkaProducer.getOrCreate( - new java.util.HashMap[String, Object](producerParams.asJava)) - - def write(row: InternalRow): Unit = { - checkForErrors() - sendRow(row, producer) - } - - def commit(): WriterCommitMessage = { - // Send is asynchronous, but we can't commit until all rows are actually in Kafka. - // This requires flushing and then checking that no callbacks produced errors. - // We also check for errors before to fail as soon as possible - the check is cheap. - checkForErrors() - producer.flush() - checkForErrors() - KafkaWriterCommitMessage - } - - def abort(): Unit = {} - - def close(): Unit = { - checkForErrors() - if (producer != null) { - producer.flush() - checkForErrors() - CachedKafkaProducer.close(new java.util.HashMap[String, Object](producerParams.asJava)) - } - } -} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 551641cfdbca8..3e65949a6fd1b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -117,14 +117,10 @@ private[kafka010] class KafkaOffsetReader( * Resolves the specific offsets based on Kafka seek positions. * This method resolves offset value -1 to the latest and -2 to the * earliest Kafka seek position. - * - * @param partitionOffsets the specific offsets to resolve - * @param reportDataLoss callback to either report or log data loss depending on setting */ def fetchSpecificOffsets( - partitionOffsets: Map[TopicPartition, Long], - reportDataLoss: String => Unit): KafkaSourceOffset = { - val fetched = runUninterruptibly { + partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = + runUninterruptibly { withRetriesWithoutInterrupt { // Poll to get the latest assigned partitions consumer.poll(0) @@ -149,19 +145,6 @@ private[kafka010] class KafkaOffsetReader( } } - partitionOffsets.foreach { - case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && - off != KafkaOffsetRangeLimit.EARLIEST => - if (fetched(tp) != off) { - reportDataLoss( - s"startingOffsets for $tp was $off but consumer reset to ${fetched(tp)}") - } - case _ => - // no real way to check that beginning or end is reasonable - } - KafkaSourceOffset(fetched) - } - /** * Fetch the earliest offsets for the topic partitions that are indicated * in the [[ConsumerStrategy]]. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 27da76068a66f..e9cff04ba5f2e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -130,7 +130,7 @@ private[kafka010] class KafkaSource( val offsets = startingOffsets match { case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchEarliestOffsets()) case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets()) - case SpecificOffsetRangeLimit(p) => kafkaReader.fetchSpecificOffsets(p, reportDataLoss) + case SpecificOffsetRangeLimit(p) => fetchAndVerify(p) } metadataLog.add(0, offsets) logInfo(s"Initial offsets: $offsets") @@ -138,6 +138,21 @@ private[kafka010] class KafkaSource( }.partitionToOffsets } + private def fetchAndVerify(specificOffsets: Map[TopicPartition, Long]) = { + val result = kafkaReader.fetchSpecificOffsets(specificOffsets) + specificOffsets.foreach { + case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && + off != KafkaOffsetRangeLimit.EARLIEST => + if (result(tp) != off) { + reportDataLoss( + s"startingOffsets for $tp was $off but consumer reset to ${result(tp)}") + } + case _ => + // no real way to check that beginning or end is reasonable + } + KafkaSourceOffset(result) + } + private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None override def schema: StructType = KafkaOffsetReader.kafkaSchema diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index c82154cfbad7f..b5da415b3097e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -20,22 +20,17 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} -import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2, PartitionOffset} /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and * their offsets. */ private[kafka010] -case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends OffsetV2 { +case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset { override val json = JsonUtils.partitionOffsets(partitionToOffsets) } -private[kafka010] -case class KafkaSourcePartitionOffset(topicPartition: TopicPartition, partitionOffset: Long) - extends PartitionOffset - /** Companion object of the [[KafkaSourceOffset]] */ private[kafka010] object KafkaSourceOffset { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 3914370a96595..3cb4d8cad12cc 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.{Locale, Optional, UUID} +import java.util.{Locale, UUID} import scala.collection.JavaConverters._ @@ -27,12 +27,9 @@ import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} -import org.apache.spark.sql.execution.streaming.{Offset, Sink, Source} +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} +import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -46,8 +43,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider - with ContinuousWriteSupport - with ContinuousReadSupport with Logging { import KafkaSourceProvider._ @@ -106,43 +101,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParams)) } - override def createContinuousReader( - schema: Optional[StructType], - metadataPath: String, - options: DataSourceV2Options): KafkaContinuousReader = { - val parameters = options.asMap().asScala.toMap - validateStreamOptions(parameters) - // Each running query should use its own group id. Otherwise, the query may be only assigned - // partial data since Kafka will assign partitions to multiple consumers having the same group - // id. Hence, we should generate a unique id for each query. - val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" - - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - val specifiedKafkaParams = - parameters - .keySet - .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } - .toMap - - val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, - STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) - - val kafkaOffsetReader = new KafkaOffsetReader( - strategy(caseInsensitiveParams), - kafkaParamsForDriver(specifiedKafkaParams), - parameters, - driverGroupIdPrefix = s"$uniqueGroupId-driver") - - new KafkaContinuousReader( - kafkaOffsetReader, - kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), - parameters, - metadataPath, - startingStreamOffsets, - failOnDataLoss(caseInsensitiveParams)) - } - /** * Returns a new base relation with the given parameters. * @@ -223,22 +181,26 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - override def createContinuousWriter( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceV2Options): Optional[ContinuousWriter] = { - import scala.collection.JavaConverters._ - - val spark = SparkSession.getActiveSession.get - val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim) - // We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable. - val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) - - KafkaWriter.validateQuery( - schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic) + private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = { + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " + + "are serialized with ByteArraySerializer.") + } - Optional.of(new KafkaContinuousWriter(topic, producerParams, schema)) + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) + { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " + + "value are serialized with ByteArraySerializer.") + } + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, + ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) } private def strategy(caseInsensitiveParams: Map[String, String]) = @@ -488,27 +450,4 @@ private[kafka010] object KafkaSourceProvider extends Logging { def build(): ju.Map[String, Object] = map } - - private[kafka010] def kafkaParamsForProducer( - parameters: Map[String, String]): Map[String, String] = { - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " - + "are serialized with ByteArraySerializer.") - } - - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) - { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " - + "value are serialized with ByteArraySerializer.") - } - parameters - .keySet - .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } - .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, - ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) - } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index baa60febf661d..6fd333e2f43ba 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -33,8 +33,10 @@ import org.apache.spark.sql.types.{BinaryType, StringType} private[kafka010] class KafkaWriteTask( producerConfiguration: ju.Map[String, Object], inputSchema: Seq[Attribute], - topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) { + topic: Option[String]) { // used to synchronize with Kafka callbacks + @volatile private var failedWrite: Exception = null + private val projection = createProjection private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _ /** @@ -44,7 +46,23 @@ private[kafka010] class KafkaWriteTask( producer = CachedKafkaProducer.getOrCreate(producerConfiguration) while (iterator.hasNext && failedWrite == null) { val currentRow = iterator.next() - sendRow(currentRow, producer) + val projectedRow = projection(currentRow) + val topic = projectedRow.getUTF8String(0) + val key = projectedRow.getBinary(1) + val value = projectedRow.getBinary(2) + if (topic == null) { + throw new NullPointerException(s"null topic present in the data. Use the " + + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") + } + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) + val callback = new Callback() { + override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { + if (failedWrite == null && e != null) { + failedWrite = e + } + } + } + producer.send(record, callback) } } @@ -56,49 +74,8 @@ private[kafka010] class KafkaWriteTask( producer = null } } -} - -private[kafka010] abstract class KafkaRowWriter( - inputSchema: Seq[Attribute], topic: Option[String]) { - - // used to synchronize with Kafka callbacks - @volatile protected var failedWrite: Exception = _ - protected val projection = createProjection - - private val callback = new Callback() { - override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { - if (failedWrite == null && e != null) { - failedWrite = e - } - } - } - /** - * Send the specified row to the producer, with a callback that will save any exception - * to failedWrite. Note that send is asynchronous; subclasses must flush() their producer before - * assuming the row is in Kafka. - */ - protected def sendRow( - row: InternalRow, producer: KafkaProducer[Array[Byte], Array[Byte]]): Unit = { - val projectedRow = projection(row) - val topic = projectedRow.getUTF8String(0) - val key = projectedRow.getBinary(1) - val value = projectedRow.getBinary(2) - if (topic == null) { - throw new NullPointerException(s"null topic present in the data. Use the " + - s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") - } - val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) - producer.send(record, callback) - } - - protected def checkForErrors(): Unit = { - if (failedWrite != null) { - throw failedWrite - } - } - - private def createProjection = { + private def createProjection: UnsafeProjection = { val topicExpression = topic.map(Literal(_)).orElse { inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME) }.getOrElse { @@ -135,5 +112,11 @@ private[kafka010] abstract class KafkaRowWriter( Seq(topicExpression, Cast(keyExpression, BinaryType), Cast(valueExpression, BinaryType)), inputSchema) } + + private def checkForErrors(): Unit = { + if (failedWrite != null) { + throw failedWrite + } + } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 15cd44812cb0c..5e9ae35b3f008 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -43,9 +43,10 @@ private[kafka010] object KafkaWriter extends Logging { override def toString: String = "KafkaWriter" def validateQuery( - schema: Seq[Attribute], + queryExecution: QueryExecution, kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { + val schema = queryExecution.analyzed.output schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( if (topic.isEmpty) { throw new AnalysisException(s"topic option required when no " + @@ -83,7 +84,7 @@ private[kafka010] object KafkaWriter extends Logging { kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { val schema = queryExecution.analyzed.output - validateQuery(schema, kafkaParameters, topic) + validateQuery(queryExecution, kafkaParameters, topic) queryExecution.toRdd.foreachPartition { iter => val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) Utils.tryWithSafeFinally(block = writeTask.execute(iter))( diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala deleted file mode 100644 index dfc97b1c38bb5..0000000000000 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ /dev/null @@ -1,474 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import java.util.Locale -import java.util.concurrent.atomic.AtomicInteger - -import org.apache.kafka.clients.producer.ProducerConfig -import org.apache.kafka.common.serialization.ByteArraySerializer -import org.scalatest.time.SpanSugar._ -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} -import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.streaming._ -import org.apache.spark.sql.types.{BinaryType, DataType} -import org.apache.spark.util.Utils - -/** - * This is a temporary port of KafkaSinkSuite, since we do not yet have a V2 memory stream. - * Once we have one, this will be changed to a specialization of KafkaSinkSuite and we won't have - * to duplicate all the code. - */ -class KafkaContinuousSinkSuite extends KafkaContinuousTest { - import testImplicits._ - - override val streamingTimeout = 30.seconds - - override def beforeAll(): Unit = { - super.beforeAll() - testUtils = new KafkaTestUtils( - withBrokerProps = Map("auto.create.topics.enable" -> "false")) - testUtils.setup() - } - - override def afterAll(): Unit = { - if (testUtils != null) { - testUtils.teardown() - testUtils = null - } - super.afterAll() - } - - test("streaming - write to kafka with topic field") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - - val topic = newTopic() - testUtils.createTopic(topic) - - val writer = createKafkaWriter( - input.toDF(), - withTopic = None, - withOutputMode = Some(OutputMode.Append))( - withSelectExpr = s"'$topic' as topic", "value") - - val reader = createKafkaReader(topic) - .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") - .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] - .map(_._2) - - try { - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) - testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - } finally { - writer.stop() - } - } - - test("streaming - write w/o topic field, with topic option") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - - val topic = newTopic() - testUtils.createTopic(topic) - - val writer = createKafkaWriter( - input.toDF(), - withTopic = Some(topic), - withOutputMode = Some(OutputMode.Append()))() - - val reader = createKafkaReader(topic) - .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") - .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] - .map(_._2) - - try { - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) - testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - } finally { - writer.stop() - } - } - - test("streaming - topic field and topic option") { - /* The purpose of this test is to ensure that the topic option - * overrides the topic field. We begin by writing some data that - * includes a topic field and value (e.g., 'foo') along with a topic - * option. Then when we read from the topic specified in the option - * we should see the data i.e., the data was written to the topic - * option, and not to the topic in the data e.g., foo - */ - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - - val topic = newTopic() - testUtils.createTopic(topic) - - val writer = createKafkaWriter( - input.toDF(), - withTopic = Some(topic), - withOutputMode = Some(OutputMode.Append()))( - withSelectExpr = "'foo' as topic", "CAST(value as STRING) value") - - val reader = createKafkaReader(topic) - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .selectExpr("CAST(key AS INT)", "CAST(value AS INT)") - .as[(Int, Int)] - .map(_._2) - - try { - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) - testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - } finally { - writer.stop() - } - } - - test("null topic attribute") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - val topic = newTopic() - testUtils.createTopic(topic) - - /* No topic field or topic option */ - var writer: StreamingQuery = null - var ex: Exception = null - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = "CAST(null as STRING) as topic", "value" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getCause.getCause.getMessage - .toLowerCase(Locale.ROOT) - .contains("null topic present in the data.")) - } - - test("streaming - write data with bad schema") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - val topic = newTopic() - testUtils.createTopic(topic) - - /* No topic field or topic option */ - var writer: StreamingQuery = null - var ex: Exception = null - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = "value as key", "value" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage - .toLowerCase(Locale.ROOT) - .contains("topic option required when no 'topic' attribute is present")) - - try { - /* No value field */ - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = s"'$topic' as topic", "value as key" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "required attribute 'value' not found")) - } - - test("streaming - write data with valid schema but wrong types") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - .selectExpr("CAST(value as STRING) value") - val topic = newTopic() - testUtils.createTopic(topic) - - var writer: StreamingQuery = null - var ex: Exception = null - try { - /* topic field wrong type */ - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = s"CAST('1' as INT) as topic", "value" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string")) - - try { - /* value field wrong type */ - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "value attribute type must be a string or binarytype")) - - try { - ex = intercept[StreamingQueryException] { - /* key field wrong type */ - writer = createKafkaWriter(input.toDF())( - withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "key attribute type must be a string or binarytype")) - } - - test("streaming - write to non-existing topic") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - val topic = newTopic() - - var writer: StreamingQuery = null - var ex: Exception = null - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))() - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - eventually(timeout(streamingTimeout)) { - assert(writer.exception.isDefined) - } - throw writer.exception.get - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) - } - - test("streaming - exception on config serializer") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - testUtils.sendMessages(inputTopic, Array("0")) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .load() - var writer: StreamingQuery = null - var ex: Exception = null - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter( - input.toDF(), - withOptions = Map("kafka.key.serializer" -> "foo"))() - writer.processAllAvailable() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "kafka option 'key.serializer' is not supported")) - } finally { - writer.stop() - } - - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter( - input.toDF(), - withOptions = Map("kafka.value.serializer" -> "foo"))() - writer.processAllAvailable() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "kafka option 'value.serializer' is not supported")) - } finally { - writer.stop() - } - } - - test("generic - write big data with small producer buffer") { - /* This test ensures that we understand the semantics of Kafka when - * is comes to blocking on a call to send when the send buffer is full. - * This test will configure the smallest possible producer buffer and - * indicate that we should block when it is full. Thus, no exception should - * be thrown in the case of a full buffer. - */ - val topic = newTopic() - testUtils.createTopic(topic, 1) - val options = new java.util.HashMap[String, String] - options.put("bootstrap.servers", testUtils.brokerAddress) - options.put("buffer.memory", "16384") // min buffer size - options.put("block.on.buffer.full", "true") - options.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) - options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) - val inputSchema = Seq(AttributeReference("value", BinaryType)()) - val data = new Array[Byte](15000) // large value - val writeTask = new KafkaContinuousDataWriter(Some(topic), options.asScala.toMap, inputSchema) - try { - val fieldTypes: Array[DataType] = Array(BinaryType) - val converter = UnsafeProjection.create(fieldTypes) - val row = new SpecificInternalRow(fieldTypes) - row.update(0, data) - val iter = Seq.fill(1000)(converter.apply(row)).iterator - iter.foreach(writeTask.write(_)) - writeTask.commit() - } finally { - writeTask.close() - } - } - - private def createKafkaReader(topic: String): DataFrame = { - spark.read - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("startingOffsets", "earliest") - .option("endingOffsets", "latest") - .option("subscribe", topic) - .load() - } - - private def createKafkaWriter( - input: DataFrame, - withTopic: Option[String] = None, - withOutputMode: Option[OutputMode] = None, - withOptions: Map[String, String] = Map[String, String]()) - (withSelectExpr: String*): StreamingQuery = { - var stream: DataStreamWriter[Row] = null - val checkpointDir = Utils.createTempDir() - var df = input.toDF() - if (withSelectExpr.length > 0) { - df = df.selectExpr(withSelectExpr: _*) - } - stream = df.writeStream - .format("kafka") - .option("checkpointLocation", checkpointDir.getCanonicalPath) - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - // We need to reduce blocking time to efficiently test non-existent partition behavior. - .option("kafka.max.block.ms", "1000") - .trigger(Trigger.Continuous(1000)) - .queryName("kafkaStream") - withTopic.foreach(stream.option("topic", _)) - withOutputMode.foreach(stream.outputMode(_)) - withOptions.foreach(opt => stream.option(opt._1, opt._2)) - stream.start() - } -} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala deleted file mode 100644 index b3dade414f625..0000000000000 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import java.util.Properties -import java.util.concurrent.atomic.AtomicInteger - -import org.scalatest.time.SpanSugar._ -import scala.collection.mutable -import scala.util.Random - -import org.apache.spark.SparkContext -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.streaming.{StreamTest, Trigger} -import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} - -// Run tests in KafkaSourceSuiteBase in continuous execution mode. -class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest - -class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { - import testImplicits._ - - override val brokerProps = Map("auto.create.topics.enable" -> "false") - - test("subscribing topic by pattern with topic deletions") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-seems" - val topic2 = topicPrefix + "-bad" - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"$topicPrefix-.*") - .option("failOnDataLoss", "false") - - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped = kafka.map(kv => kv._2.toInt + 1) - - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - Execute { query => - testUtils.deleteTopic(topic) - testUtils.createTopic(topic2, partitions = 5) - eventually(timeout(streamingTimeout)) { - assert( - query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r - }.exists { r => - // Ensure the new topic is present and the old topic is gone. - r.knownPartitions.exists(_.topic == topic2) - }, - s"query never reconfigured to new topic $topic2") - } - }, - AddKafkaData(Set(topic2), 4, 5, 6), - CheckAnswer(2, 3, 4, 5, 6, 7) - ) - } -} - -class KafkaContinuousSourceStressForDontFailOnDataLossSuite - extends KafkaSourceStressForDontFailOnDataLossSuite { - override protected def startStream(ds: Dataset[Int]) = { - ds.writeStream - .format("memory") - .queryName("memory") - .start() - } -} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala deleted file mode 100644 index e713e6695d2bd..0000000000000 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import org.apache.spark.SparkContext -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.streaming.Trigger -import org.apache.spark.sql.test.TestSparkSession - -// Trait to configure StreamTest for kafka continuous execution tests. -trait KafkaContinuousTest extends KafkaSourceTest { - override val defaultTrigger = Trigger.Continuous(1000) - override val defaultUseV2Sink = true - - // We need more than the default local[2] to be able to schedule all partitions simultaneously. - override protected def createSparkSession = new TestSparkSession( - new SparkContext( - "local[10]", - "continuous-stream-test-sql-context", - sparkConf.set("spark.sql.testkey", "true"))) - - // In addition to setting the partitions in Kafka, we have to wait until the query has - // reconfigured to the new count so the test framework can hook in properly. - override protected def setTopicPartitions( - topic: String, newCount: Int, query: StreamExecution) = { - testUtils.addPartitions(topic, newCount) - eventually(timeout(streamingTimeout)) { - assert( - query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r - }.exists(_.knownPartitions.size == newCount), - s"query never reconfigured to $newCount partitions") - } - } - - test("ensure continuous stream is being used") { - val query = spark.readStream - .format("rate") - .option("numPartitions", "1") - .option("rowsPerSecond", "1") - .load() - - testStream(query)( - Execute(q => assert(q.isInstanceOf[ContinuousExecution])) - ) - } -} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index d66908f86ccc7..2034b9be07f24 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -34,14 +34,11 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2Exec} +import org.apache.spark.sql.ForeachWriter import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryWriter import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ -import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest, Trigger} +import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} import org.apache.spark.util.Utils @@ -52,11 +49,9 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override val streamingTimeout = 30.seconds - protected val brokerProps = Map[String, Object]() - override def beforeAll(): Unit = { super.beforeAll() - testUtils = new KafkaTestUtils(brokerProps) + testUtils = new KafkaTestUtils testUtils.setup() } @@ -64,25 +59,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { if (testUtils != null) { testUtils.teardown() testUtils = null + super.afterAll() } - super.afterAll() } protected def makeSureGetOffsetCalled = AssertOnQuery { q => // Because KafkaSource's initialPartitionOffsets is set lazily, we need to make sure - // its "getOffset" is called before pushing any data. Otherwise, because of the race condition, + // its "getOffset" is called before pushing any data. Otherwise, because of the race contion, // we don't know which data should be fetched when `startingOffsets` is latest. - q match { - case c: ContinuousExecution => c.awaitEpoch(0) - case m: MicroBatchExecution => m.processAllAvailable() - } + q.processAllAvailable() true } - protected def setTopicPartitions(topic: String, newCount: Int, query: StreamExecution) : Unit = { - testUtils.addPartitions(topic, newCount) - } - /** * Add data to Kafka. * @@ -94,7 +82,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { message: String = "", topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData { - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + override def addData(query: Option[StreamExecution]): (Source, Offset) = { if (query.get.isActive) { // Make sure no Spark job is running when deleting a topic query.get.processAllAvailable() @@ -109,18 +97,16 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2)) } + // Read all topics again in case some topics are delete. + val allTopics = testUtils.getAllTopicsAndPartitionSize().toMap.keys require( query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: KafkaSource, _) => source - } ++ (query.get.lastExecution match { - case null => Seq() - case e => e.logical.collect { - case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader - } - }) + case StreamingExecutionRelation(source, _) if source.isInstanceOf[KafkaSource] => + source.asInstanceOf[KafkaSource] + } if (sources.isEmpty) { throw new Exception( "Could not find Kafka source in the StreamExecution logical plan to add data to") @@ -151,158 +137,14 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override def toString: String = s"AddKafkaData(topics = $topics, data = $data, message = $message)" } - - private val topicId = new AtomicInteger(0) - protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" } -class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { - - import testImplicits._ - - test("(de)serialization of initial offsets") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", topic) - - testStream(reader.load)( - makeSureGetOffsetCalled, - StopStream, - StartStream(), - StopStream) - } - - test("maxOffsetsPerTrigger") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 3) - testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) - testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) - testUtils.sendMessages(topic, Array("1"), Some(2)) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("maxOffsetsPerTrigger", 10) - .option("subscribe", topic) - .option("startingOffsets", "earliest") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) - - val clock = new StreamManualClock - - val waitUntilBatchProcessed = AssertOnQuery { q => - eventually(Timeout(streamingTimeout)) { - if (!q.exception.isDefined) { - assert(clock.isStreamWaitingAt(clock.getTimeMillis())) - } - } - if (q.exception.isDefined) { - throw q.exception.get - } - true - } - - testStream(mapped)( - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // 1 from smallest, 1 from middle, 8 from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 - ), - StopStream, - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 - ), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, - 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 - ) - ) - } - - test("input row metrics") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val kafka = spark - .readStream - .format("kafka") - .option("subscribe", topic) - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - - val mapped = kafka.map(kv => kv._2.toInt + 1) - testStream(mapped)( - StartStream(trigger = ProcessingTime(1)), - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - AssertOnQuery { query => - val recordsRead = query.recentProgress.map(_.numInputRows).sum - recordsRead == 3 - } - ) - } - - test("subscribing topic by pattern with topic deletions") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-seems" - val topic2 = topicPrefix + "-bad" - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"$topicPrefix-.*") - .option("failOnDataLoss", "false") +class KafkaSourceSuite extends KafkaSourceTest { - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped = kafka.map(kv => kv._2.toInt + 1) + import testImplicits._ - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - Assert { - testUtils.deleteTopic(topic) - testUtils.createTopic(topic2, partitions = 5) - true - }, - AddKafkaData(Set(topic2), 4, 5, 6), - CheckAnswer(2, 3, 4, 5, 6, 7) - ) - } + private val topicId = new AtomicInteger(0) testWithUninterruptibleThread( "deserialization of initial offset with Spark 2.1.0") { @@ -395,51 +237,86 @@ class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { } } - test("KafkaSource with watermark") { - val now = System.currentTimeMillis() + test("(de)serialization of initial offsets") { val topic = newTopic() - testUtils.createTopic(newTopic(), partitions = 1) - testUtils.sendMessages(topic, Array(1).map(_.toString)) + testUtils.createTopic(topic, partitions = 64) - val kafka = spark + val reader = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("startingOffsets", s"earliest") .option("subscribe", topic) - .load() - - val windowedAggregation = kafka - .withWatermark("timestamp", "10 seconds") - .groupBy(window($"timestamp", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"window".getField("start") as 'window, $"count") - val query = windowedAggregation - .writeStream - .format("memory") - .outputMode("complete") - .queryName("kafkaWatermark") - .start() - query.processAllAvailable() - val rows = spark.table("kafkaWatermark").collect() - assert(rows.length === 1, s"Unexpected results: ${rows.toList}") - val row = rows(0) - // We cannot check the exact window start time as it depands on the time that messages were - // inserted by the producer. So here we just use a low bound to make sure the internal - // conversion works. - assert( - row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, - s"Unexpected results: $row") - assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") - query.stop() + testStream(reader.load)( + makeSureGetOffsetCalled, + StopStream, + StartStream(), + StopStream) } -} -class KafkaSourceSuiteBase extends KafkaSourceTest { + test("maxOffsetsPerTrigger") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) + testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) + testUtils.sendMessages(topic, Array("1"), Some(2)) - import testImplicits._ + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 10) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + val clock = new StreamManualClock + + val waitUntilBatchProcessed = AssertOnQuery { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + } + + testStream(mapped)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // 1 from smallest, 1 from middle, 8 from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 + ), + StopStream, + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 + ), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, + 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 + ) + ) + } test("cannot stop Kafka stream") { val topic = newTopic() @@ -451,7 +328,7 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"$topic.*") + .option("subscribePattern", s"topic-.*") val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") @@ -545,6 +422,65 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { } } + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Assert { + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + true + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } + + test("starting offset is latest by default") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("0")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + + val kafka = reader.load() + .selectExpr("CAST(value AS STRING)") + .as[String] + val mapped = kafka.map(_.toInt) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(1, 2, 3) // should not have 0 + ) + } + test("bad source options") { def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = { val ex = intercept[IllegalArgumentException] { @@ -604,6 +540,34 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { testUnsupportedConfig("kafka.auto.offset.reset", "latest") } + test("input row metrics") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + val mapped = kafka.map(kv => kv._2.toInt + 1) + testStream(mapped)( + StartStream(trigger = ProcessingTime(1)), + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + AssertOnQuery { query => + val recordsRead = query.recentProgress.map(_.numInputRows).sum + recordsRead == 3 + } + ) + } + test("delete a topic when a Spark job is running") { KafkaSourceSuite.collectedData.clear() @@ -665,6 +629,8 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { } } + private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" + private def assignString(topic: String, partitions: Iterable[Int]): String = { JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p))) } @@ -710,10 +676,6 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { testStream(mapped)( makeSureGetOffsetCalled, - Execute { q => - // wait to reach the last offset in every partition - q.awaitOffset(0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L))) - }, CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22), StopStream, StartStream(), @@ -744,7 +706,6 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { .format("memory") .outputMode("append") .queryName("kafkaColumnTypes") - .trigger(defaultTrigger) .start() query.processAllAvailable() val rows = spark.table("kafkaColumnTypes").collect() @@ -762,6 +723,47 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { query.stop() } + test("KafkaSource with watermark") { + val now = System.currentTimeMillis() + val topic = newTopic() + testUtils.createTopic(newTopic(), partitions = 1) + testUtils.sendMessages(topic, Array(1).map(_.toString)) + + val kafka = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("startingOffsets", s"earliest") + .option("subscribe", topic) + .load() + + val windowedAggregation = kafka + .withWatermark("timestamp", "10 seconds") + .groupBy(window($"timestamp", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start") as 'window, $"count") + + val query = windowedAggregation + .writeStream + .format("memory") + .outputMode("complete") + .queryName("kafkaWatermark") + .start() + query.processAllAvailable() + val rows = spark.table("kafkaWatermark").collect() + assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + val row = rows(0) + // We cannot check the exact window start time as it depands on the time that messages were + // inserted by the producer. So here we just use a low bound to make sure the internal + // conversion works. + assert( + row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, + s"Unexpected results: $row") + assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") + query.stop() + } + private def testFromLatestOffsets( topic: String, addPartitions: Boolean, @@ -798,7 +800,9 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { AddKafkaData(Set(topic), 7, 8), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) setTopicPartitions(topic, 10, query) + if (addPartitions) { + testUtils.addPartitions(topic, 10) + } true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -839,7 +843,9 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { StartStream(), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) setTopicPartitions(topic, 10, query) + if (addPartitions) { + testUtils.addPartitions(topic, 10) + } true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -971,23 +977,6 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared } } - protected def startStream(ds: Dataset[Int]) = { - ds.writeStream.foreach(new ForeachWriter[Int] { - - override def open(partitionId: Long, version: Long): Boolean = { - true - } - - override def process(value: Int): Unit = { - // Slow down the processing speed so that messages may be aged out. - Thread.sleep(Random.nextInt(500)) - } - - override def close(errorOrNull: Throwable): Unit = { - } - }).start() - } - test("stress test for failOnDataLoss=false") { val reader = spark .readStream @@ -1001,7 +990,20 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] - val query = startStream(kafka.map(kv => kv._2.toInt)) + val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { + + override def open(partitionId: Long, version: Long): Boolean = { + true + } + + override def process(value: Int): Unit = { + // Slow down the processing speed so that messages may be aged out. + Thread.sleep(Random.nextInt(500)) + } + + override def close(errorOrNull: Throwable): Unit = { + } + }).start() val testTime = 1.minutes val startTime = System.currentTimeMillis() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b714a46b5f786..e8d683a578f35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -191,9 +191,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { ds = ds.asInstanceOf[DataSourceV2], conf = sparkSession.sessionState.conf)).asJava) - // Streaming also uses the data source V2 API. So it may be that the data source implements - // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading - // the dataframe as a v1 source. val reader = (ds, userSpecifiedSchema) match { case (ds: ReadSupportWithSchema, Some(schema)) => ds.createReader(schema, options) @@ -211,30 +208,23 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } reader - case _ => null // fall back to v1 + case _ => + throw new AnalysisException(s"$cls does not support data reading.") } - if (reader == null) { - loadV1Source(paths: _*) - } else { - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) - } + Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) } else { - loadV1Source(paths: _*) + // Code path for data source v1. + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap).resolveRelation()) } } - private def loadV1Source(paths: String*) = { - // Code path for data source v1. - sparkSession.baseRelationToDataFrame( - DataSource.apply( - sparkSession, - paths = paths, - userSpecifiedSchema = userSpecifiedSchema, - className = source, - options = extraOptions.toMap).resolveRelation()) - } - /** * Construct a `DataFrame` representing the database table accessible via JDBC URL * url named table and connection properties. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 97f12ff625c42..3304f368e1050 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -255,24 +255,17 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } - // Streaming also uses the data source V2 API. So it may be that the data source implements - // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving - // as though it's a V1 source. - case _ => saveToV1Source() + case _ => throw new AnalysisException(s"$cls does not support data writing.") } } else { - saveToV1Source() - } - } - - private def saveToV1Source(): Unit = { - // Code path for data source v1. - runCommand(df.sparkSession, "save") { - DataSource( - sparkSession = df.sparkSession, - className = source, - partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) + // Code path for data source v1. + runCommand(df.sparkSession, "save") { + DataSource( + sparkSession = df.sparkSession, + className = source, + partitionColumns = partitioningColumns.getOrElse(Nil), + options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index a4a857f2d4d9b..f0bdf84bb7a84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -81,11 +81,9 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) (index, message: WriterCommitMessage) => messages(index) = message ) - if (!writer.isInstanceOf[ContinuousWriter]) { - logInfo(s"Data source writer $writer is committing.") - writer.commit(messages) - logInfo(s"Data source writer $writer committed.") - } + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") } catch { case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] => // Interruption is how continuous queries are ended, so accept and ignore the exception. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index cf27e1a70650a..24a8b000df0c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -142,8 +142,7 @@ abstract class StreamExecution( override val id: UUID = UUID.fromString(streamMetadata.id) - override def runId: UUID = currentRunId - protected var currentRunId = UUID.randomUUID + override val runId: UUID = UUID.randomUUID /** * Pretty identified string of printing in logs. Format is @@ -419,17 +418,11 @@ abstract class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. */ - private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = { + private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = { assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets - if (sources == null) { - // sources might not be initialized yet - false - } else { - val source = sources(sourceIndex) - !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset - } + !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset } while (notDone) { @@ -443,7 +436,7 @@ abstract class StreamExecution( awaitProgressLock.unlock() } } - logDebug(s"Unblocked at $newOffset for ${sources(sourceIndex)}") + logDebug(s"Unblocked at $newOffset for $source") } /** A flag to indicate that a batch has completed with no new data available. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index e700aa4f9aea7..d79e4bd65f563 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -77,6 +77,7 @@ class ContinuousDataSourceRDD( dataReaderThread.start() context.addTaskCompletionListener(_ => { + reader.close() dataReaderThread.interrupt() epochPollExecutor.shutdown() }) @@ -200,8 +201,6 @@ class DataReaderThread( failedFlag.set(true) // Don't rethrow the exception in this thread. It's not needed, and the default Spark // exception handler will kill the executor. - } finally { - reader.close() } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 667410ef9f1c6..9657b5e26d770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.execution.streaming.continuous -import java.util.UUID import java.util.concurrent.TimeUnit -import java.util.function.UnaryOperator import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} @@ -54,7 +52,7 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReader] = _ + @volatile protected var continuousSources: Seq[ContinuousReader] = Seq.empty override protected def sources: Seq[BaseStreamingSource] = continuousSources override lazy val logicalPlan: LogicalPlan = { @@ -80,17 +78,15 @@ class ContinuousExecution( } override protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { - val stateUpdate = new UnaryOperator[State] { - override def apply(s: State) = s match { - // If we ended the query to reconfigure, reset the state to active. - case RECONFIGURING => ACTIVE - case _ => s - } - } - do { - runContinuous(sparkSessionForStream) - } while (state.updateAndGet(stateUpdate) == ACTIVE) + try { + runContinuous(sparkSessionForStream) + } catch { + case _: InterruptedException if state.get().equals(RECONFIGURING) => + // swallow exception and run again + state.set(ACTIVE) + } + } while (state.get() == ACTIVE) } /** @@ -124,16 +120,12 @@ class ContinuousExecution( } committedOffsets = nextOffsets.toStreamProgress(sources) - // Get to an epoch ID that has definitely never been sent to a sink before. Since sink - // commit happens between offset log write and commit log write, this means an epoch ID - // which is not in the offset log. - val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse { - throw new IllegalStateException( - s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" + - s"an element.") - } - currentBatchId = latestOffsetEpoch + 1 + // Forcibly align commit and offset logs by slicing off any spurious offset logs from + // a previous run. We can't allow commits to an epoch that a previous run reached but + // this run has not. + offsetLog.purgeAfter(latestEpochId) + currentBatchId = latestEpochId + 1 logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets") nextOffsets case None => @@ -149,7 +141,6 @@ class ContinuousExecution( * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with. */ private def runContinuous(sparkSessionForQuery: SparkSession): Unit = { - currentRunId = UUID.randomUUID // A list of attributes that will need to be updated. val replacements = new ArrayBuffer[(Attribute, Attribute)] // Translate from continuous relation to the underlying data source. @@ -234,11 +225,13 @@ class ContinuousExecution( triggerExecutor.execute(() => { startTrigger() - if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { + if (reader.needsReconfiguration()) { + state.set(RECONFIGURING) stopSources() if (queryExecutionThread.isAlive) { sparkSession.sparkContext.cancelJobGroup(runId.toString) queryExecutionThread.interrupt() + // No need to join - this thread is about to end anyway. } false } else if (isActive) { @@ -266,7 +259,6 @@ class ContinuousExecution( sparkSessionForQuery, lastExecution)(lastExecution.toRdd) } } finally { - epochEndpoint.askSync[Unit](StopContinuousExecutionWrites) SparkEnv.get.rpcEnv.stop(epochEndpoint) epochUpdateThread.interrupt() @@ -281,22 +273,17 @@ class ContinuousExecution( epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") - val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) - val oldOffset = synchronized { - offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) - offsetLog.get(epoch - 1) + if (partitionOffsets.contains(null)) { + // If any offset is null, that means the corresponding partition hasn't seen any data yet, so + // there's nothing meaningful to add to the offset log. } - - // If offset hasn't changed since last epoch, there's been no new data. - if (oldOffset.contains(OffsetSeq.fill(globalOffset))) { - noNewData = true - } - - awaitProgressLock.lock() - try { - awaitProgressLockCondition.signalAll() - } finally { - awaitProgressLock.unlock() + val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) + synchronized { + if (queryExecutionThread.isAlive) { + offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) + } else { + return + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 40dcbecade814..98017c3ac6a33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -39,15 +39,6 @@ private[continuous] sealed trait EpochCoordinatorMessage extends Serializable */ private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage -/** - * The RpcEndpoint stop() will wait to clear out the message queue before terminating the - * object. This can lead to a race condition where the query restarts at epoch n, a new - * EpochCoordinator starts at epoch n, and then the old epoch coordinator commits epoch n + 1. - * The framework doesn't provide a handle to wait on the message queue, so we use a synchronous - * message to stop any writes to the ContinuousExecution object. - */ -private[sql] case object StopContinuousExecutionWrites extends EpochCoordinatorMessage - // Init messages /** * Set the reader and writer partition counts. Tasks may not be started until the coordinator @@ -125,8 +116,6 @@ private[continuous] class EpochCoordinator( override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { - private var queryWritesStopped: Boolean = false - private var numReaderPartitions: Int = _ private var numWriterPartitions: Int = _ @@ -158,16 +147,12 @@ private[continuous] class EpochCoordinator( partitionCommits.remove(k) } for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) { - partitionOffsets.remove(k) + partitionCommits.remove(k) } } } override def receive: PartialFunction[Any, Unit] = { - // If we just drop these messages, we won't do any writes to the query. The lame duck tasks - // won't shed errors or anything. - case _ if queryWritesStopped => () - case CommitPartitionEpoch(partitionId, epoch, message) => logDebug(s"Got commit from partition $partitionId at epoch $epoch: $message") if (!partitionCommits.isDefinedAt((epoch, partitionId))) { @@ -203,9 +188,5 @@ private[continuous] class EpochCoordinator( case SetWriterPartitions(numPartitions) => numWriterPartitions = numPartitions context.reply(()) - - case StopContinuousExecutionWrites => - queryWritesStopped = true - context.reply(()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index b5b4a05ab4973..db588ae282f38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -280,29 +279,18 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val sink = trigger match { - case _: ContinuousTrigger => - val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) - ds.newInstance() match { - case w: ContinuousWriteSupport => w - case _ => throw new AnalysisException( - s"Data source $source does not support continuous writing") - } - case _ => - val ds = DataSource( - df.sparkSession, - className = source, - options = extraOptions.toMap, - partitionColumns = normalizedParCols.getOrElse(Nil)) - ds.createSink(outputMode) - } - + val dataSource = + DataSource( + df.sparkSession, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), df, extraOptions.toMap, - sink, + dataSource.createSink(outputMode), outputMode, useTempCheckpointLocation = source == "console", recoverFromCheckpointLocation = true, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 0762895fdc620..d46461fa9bf6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -38,9 +38,8 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch} +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ @@ -81,9 +80,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be StateStore.stop() // stop the state store maintenance thread and unload store providers } - protected val defaultTrigger = Trigger.ProcessingTime(0) - protected val defaultUseV2Sink = false - /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds @@ -193,7 +189,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be /** Starts the stream, resuming if data has already been processed. It must not be running. */ case class StartStream( - trigger: Trigger = defaultTrigger, + trigger: Trigger = Trigger.ProcessingTime(0), triggerClock: Clock = new SystemClock, additionalConfs: Map[String, String] = Map.empty, checkpointLocation: String = null) @@ -280,7 +276,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def testStream( _stream: Dataset[_], outputMode: OutputMode = OutputMode.Append, - useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = synchronized { + useV2Sink: Boolean = false)(actions: StreamAction*): Unit = synchronized { import org.apache.spark.sql.streaming.util.StreamManualClock // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently @@ -407,11 +403,18 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = { verify(currentStream != null, "stream not running") + // Get the map of source index to the current source objects + val indexToSource = currentStream + .logicalPlan + .collect { case StreamingExecutionRelation(s, _) => s } + .zipWithIndex + .map(_.swap) + .toMap // Block until all data added has been processed for all the source awaiting.foreach { case (sourceIndex, offset) => failAfter(streamingTimeout) { - currentStream.awaitOffset(sourceIndex, offset) + currentStream.awaitOffset(indexToSource(sourceIndex), offset) } } @@ -470,12 +473,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be // after starting the query. try { currentStream.awaitInitialization(streamingTimeout.toMillis) - currentStream match { - case s: ContinuousExecution => eventually("IncrementalExecution was not created") { - s.lastExecution.executedPlan // will fail if lastExecution is null - } - case _ => - } } catch { case _: StreamingQueryException => // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well. @@ -603,10 +600,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def findSourceIndex(plan: LogicalPlan): Option[Int] = { plan - .collect { - case StreamingExecutionRelation(s, _) => s - case DataSourceV2Relation(_, r) => r - } + .collect { case StreamingExecutionRelation(s, _) => s } .zipWithIndex .find(_._1 == source) .map(_._2) @@ -619,13 +613,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be findSourceIndex(query.logicalPlan) }.orElse { findSourceIndex(stream.logicalPlan) - }.orElse { - queryToUse.flatMap { q => - findSourceIndex(q.lastExecution.logical) - } }.getOrElse { throw new IllegalArgumentException( - "Could not find index of the source to which data was added") + "Could find index of the source to which data was added") } // Store the expected offset of added data to wait for it later From cd9f49a2aed3799964976ead06080a0f7044a0c3 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 13 Jan 2018 16:13:44 +0900 Subject: [PATCH 081/774] [SPARK-22980][PYTHON][SQL] Clarify the length of each series is of each batch within scalar Pandas UDF ## What changes were proposed in this pull request? This PR proposes to add a note that saying the length of a scalar Pandas UDF's `Series` is not of the whole input column but of the batch. We are fine for a group map UDF because the usage is different from our typical UDF but scalar UDFs might cause confusion with the normal UDF. For example, please consider this example: ```python from pyspark.sql.functions import pandas_udf, col, lit df = spark.range(1) f = pandas_udf(lambda x, y: len(x) + y, LongType()) df.select(f(lit('text'), col('id'))).show() ``` ``` +------------------+ |(text, id)| +------------------+ | 1| +------------------+ ``` ```python from pyspark.sql.functions import udf, col, lit df = spark.range(1) f = udf(lambda x, y: len(x) + y, "long") df.select(f(lit('text'), col('id'))).show() ``` ``` +------------------+ |(text, id)| +------------------+ | 4| +------------------+ ``` ## How was this patch tested? Manually built the doc and checked the output. Author: hyukjinkwon Closes #20237 from HyukjinKwon/SPARK-22980. --- python/pyspark/sql/functions.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 733e32bd825b0..e1ad6590554cf 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2184,6 +2184,11 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 8| JOHN DOE| 22| +----------+--------------+------------+ + .. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input + column, but is the length of an internal batch used for each call to the function. + Therefore, this can be used, for example, to ensure the length of each returned + `pandas.Series`, and can not be used as the column length. + 2. GROUP_MAP A group map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` From 628a1ca5a4d14397a90e9e96a7e03e8f63531b20 Mon Sep 17 00:00:00 2001 From: shimamoto Date: Sat, 13 Jan 2018 09:40:00 -0600 Subject: [PATCH 082/774] [SPARK-23043][BUILD] Upgrade json4s to 3.5.3 ## What changes were proposed in this pull request? Spark still use a few years old version 3.2.11. This change is to upgrade json4s to 3.5.3. Note that this change does not include the Jackson update because the Jackson version referenced in json4s 3.5.3 is 2.8.4, which has a security vulnerability ([see](https://issues.apache.org/jira/browse/SPARK-20433)). ## How was this patch tested? Existing unit tests and build. Author: shimamoto Closes #20233 from shimamoto/upgrade-json4s. --- .../deploy/history/HistoryServerSuite.scala | 2 +- .../org/apache/spark/ui/UISeleniumSuite.scala | 19 ++++++++++--------- dev/deps/spark-deps-hadoop-2.6 | 8 ++++---- dev/deps/spark-deps-hadoop-2.7 | 8 ++++---- pom.xml | 13 +++++++------ 5 files changed, 26 insertions(+), 24 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 3738f85da5831..87778dda0e2c8 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -486,7 +486,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers json match { case JNothing => Seq() case apps: JArray => - apps.filter(app => { + apps.children.filter(app => { (app \ "attempts") match { case attempts: JArray => val state = (attempts.children.head \ "completed").asInstanceOf[JBool] diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 326546787ab6c..ed51fc445fdfb 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -131,7 +131,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B val storageJson = getJson(ui, "storage/rdd") storageJson.children.length should be (1) - (storageJson \ "storageLevel").extract[String] should be (StorageLevels.DISK_ONLY.description) + (storageJson.children.head \ "storageLevel").extract[String] should be ( + StorageLevels.DISK_ONLY.description) val rddJson = getJson(ui, "storage/rdd/0") (rddJson \ "storageLevel").extract[String] should be (StorageLevels.DISK_ONLY.description) @@ -150,7 +151,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B val updatedStorageJson = getJson(ui, "storage/rdd") updatedStorageJson.children.length should be (1) - (updatedStorageJson \ "storageLevel").extract[String] should be ( + (updatedStorageJson.children.head \ "storageLevel").extract[String] should be ( StorageLevels.MEMORY_ONLY.description) val updatedRddJson = getJson(ui, "storage/rdd/0") (updatedRddJson \ "storageLevel").extract[String] should be ( @@ -204,7 +205,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } val stageJson = getJson(sc.ui.get, "stages") stageJson.children.length should be (1) - (stageJson \ "status").extract[String] should be (StageStatus.FAILED.name()) + (stageJson.children.head \ "status").extract[String] should be (StageStatus.FAILED.name()) // Regression test for SPARK-2105 class NotSerializable @@ -325,11 +326,11 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B find(cssSelector(".progress-cell .progress")).get.text should be ("2/2 (1 failed)") } val jobJson = getJson(sc.ui.get, "jobs") - (jobJson \ "numTasks").extract[Int]should be (2) - (jobJson \ "numCompletedTasks").extract[Int] should be (3) - (jobJson \ "numFailedTasks").extract[Int] should be (1) - (jobJson \ "numCompletedStages").extract[Int] should be (2) - (jobJson \ "numFailedStages").extract[Int] should be (1) + (jobJson \\ "numTasks").extract[Int]should be (2) + (jobJson \\ "numCompletedTasks").extract[Int] should be (3) + (jobJson \\ "numFailedTasks").extract[Int] should be (1) + (jobJson \\ "numCompletedStages").extract[Int] should be (2) + (jobJson \\ "numFailedStages").extract[Int] should be (1) val stageJson = getJson(sc.ui.get, "stages") for { @@ -656,7 +657,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.ui.get.webUrl + "/api/v1/applications")) val appListJsonAst = JsonMethods.parse(appListRawJson) appListJsonAst.children.length should be (1) - val attempts = (appListJsonAst \ "attempts").children + val attempts = (appListJsonAst.children.head \ "attempts").children attempts.size should be (1) (attempts(0) \ "completed").extract[Boolean] should be (false) parseDate(attempts(0) \ "startTime") should be (sc.startTime) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index a7fce2ede0ea5..2a298769be44c 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -122,9 +122,10 @@ jline-2.12.1.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar -json4s-ast_2.11-3.2.11.jar -json4s-core_2.11-3.2.11.jar -json4s-jackson_2.11-3.2.11.jar +json4s-ast_2.11-3.5.3.jar +json4s-core_2.11-3.5.3.jar +json4s-jackson_2.11-3.5.3.jar +json4s-scalap_2.11-3.5.3.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar @@ -167,7 +168,6 @@ scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.5.jar -scalap-2.11.8.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 94b2e98d85e74..abee326f283ab 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -122,9 +122,10 @@ jline-2.12.1.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar -json4s-ast_2.11-3.2.11.jar -json4s-core_2.11-3.2.11.jar -json4s-jackson_2.11-3.2.11.jar +json4s-ast_2.11-3.5.3.jar +json4s-core_2.11-3.5.3.jar +json4s-jackson_2.11-3.5.3.jar +json4s-scalap_2.11-3.5.3.jar jsp-api-2.1.jar jsr305-1.3.9.jar jta-1.1.jar @@ -168,7 +169,6 @@ scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.5.jar -scalap-2.11.8.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar diff --git a/pom.xml b/pom.xml index d14594aa4ccb0..666d5d7169a15 100644 --- a/pom.xml +++ b/pom.xml @@ -705,7 +705,13 @@ org.json4s json4s-jackson_${scala.binary.version} - 3.2.11 + 3.5.3 + + + com.fasterxml.jackson.core + * + + org.scala-lang @@ -732,11 +738,6 @@ scala-parser-combinators_${scala.binary.version} 1.0.4 - - org.scala-lang - scalap - ${scala.version} - jline From fc6fe8a1d0f161c4788f3db94de49a8669ba3bcc Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 13 Jan 2018 10:01:44 -0600 Subject: [PATCH 083/774] [SPARK-22870][CORE] Dynamic allocation should allow 0 idle time ## What changes were proposed in this pull request? This pr to make `0` as a valid value for `spark.dynamicAllocation.executorIdleTimeout`. For details, see the jira description: https://issues.apache.org/jira/browse/SPARK-22870. ## How was this patch tested? N/A Author: Yuming Wang Author: Yuming Wang Closes #20080 from wangyum/SPARK-22870. --- .../scala/org/apache/spark/ExecutorAllocationManager.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 2e00dc8b49dd5..6c59038f2a6c1 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -195,8 +195,11 @@ private[spark] class ExecutorAllocationManager( throw new SparkException( "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout must be > 0!") } - if (executorIdleTimeoutS <= 0) { - throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be > 0!") + if (executorIdleTimeoutS < 0) { + throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be >= 0!") + } + if (cachedExecutorIdleTimeoutS < 0) { + throw new SparkException("spark.dynamicAllocation.cachedExecutorIdleTimeout must be >= 0!") } // Require external shuffle service for dynamic allocation // Otherwise, we may lose shuffle files when killing executors From bd4a21b4820c4ebaf750131574a6b2eeea36907e Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Sun, 14 Jan 2018 02:28:57 +0800 Subject: [PATCH 084/774] [SPARK-23036][SQL][TEST] Add withGlobalTempView for testing ## What changes were proposed in this pull request? Add withGlobalTempView when create global temp view, like withTempView and withView. And correct some improper usage. Please see jira. There are other similar place like that. I will fix it if community need. Please confirm it. ## How was this patch tested? no new test. Author: xubo245 <601450868@qq.com> Closes #20228 from xubo245/DropTempView. --- .../sql/execution/GlobalTempViewSuite.scala | 55 ++++++++----------- .../spark/sql/execution/SQLViewSuite.scala | 34 +++++++----- .../apache/spark/sql/test/SQLTestUtils.scala | 21 +++++-- 3 files changed, 59 insertions(+), 51 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala index cc943e0356f2a..dcc6fa6403f31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -36,7 +36,7 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { test("basic semantic") { val expectedErrorMsg = "not found" - try { + withGlobalTempView("src") { sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 'a'") // If there is no database in table name, we should try local temp view first, if not found, @@ -79,19 +79,15 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { // We can also use Dataset API to replace global temp view Seq(2 -> "b").toDF("i", "j").createOrReplaceGlobalTempView("src") checkAnswer(spark.table(s"$globalTempDB.src"), Row(2, "b")) - } finally { - spark.catalog.dropGlobalTempView("src") } } test("global temp view is shared among all sessions") { - try { + withGlobalTempView("src") { sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 2") checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, 2)) val newSession = spark.newSession() checkAnswer(newSession.table(s"$globalTempDB.src"), Row(1, 2)) - } finally { - spark.catalog.dropGlobalTempView("src") } } @@ -105,27 +101,25 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { test("CREATE GLOBAL TEMP VIEW USING") { withTempPath { path => - try { + withGlobalTempView("src") { Seq(1 -> "a").toDF("i", "j").write.parquet(path.getAbsolutePath) sql(s"CREATE GLOBAL TEMP VIEW src USING parquet OPTIONS (PATH '${path.toURI}')") checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) sql(s"INSERT INTO $globalTempDB.src SELECT 2, 'b'") checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a") :: Row(2, "b") :: Nil) - } finally { - spark.catalog.dropGlobalTempView("src") } } } test("CREATE TABLE LIKE should work for global temp view") { - try { - sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1 AS a, '2' AS b") - sql(s"CREATE TABLE cloned LIKE $globalTempDB.src") - val tableMeta = spark.sessionState.catalog.getTableMetadata(TableIdentifier("cloned")) - assert(tableMeta.schema == new StructType().add("a", "int", false).add("b", "string", false)) - } finally { - spark.catalog.dropGlobalTempView("src") - sql("DROP TABLE default.cloned") + withTable("cloned") { + withGlobalTempView("src") { + sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1 AS a, '2' AS b") + sql(s"CREATE TABLE cloned LIKE $globalTempDB.src") + val tableMeta = spark.sessionState.catalog.getTableMetadata(TableIdentifier("cloned")) + assert(tableMeta.schema == new StructType() + .add("a", "int", false).add("b", "string", false)) + } } } @@ -146,26 +140,25 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { } test("should lookup global temp view if and only if global temp db is specified") { - try { - sql("CREATE GLOBAL TEMP VIEW same_name AS SELECT 3, 4") - sql("CREATE TEMP VIEW same_name AS SELECT 1, 2") + withTempView("same_name") { + withGlobalTempView("same_name") { + sql("CREATE GLOBAL TEMP VIEW same_name AS SELECT 3, 4") + sql("CREATE TEMP VIEW same_name AS SELECT 1, 2") - checkAnswer(sql("SELECT * FROM same_name"), Row(1, 2)) + checkAnswer(sql("SELECT * FROM same_name"), Row(1, 2)) - // we never lookup global temp views if database is not specified in table name - spark.catalog.dropTempView("same_name") - intercept[AnalysisException](sql("SELECT * FROM same_name")) + // we never lookup global temp views if database is not specified in table name + spark.catalog.dropTempView("same_name") + intercept[AnalysisException](sql("SELECT * FROM same_name")) - // Use qualified name to lookup a global temp view. - checkAnswer(sql(s"SELECT * FROM $globalTempDB.same_name"), Row(3, 4)) - } finally { - spark.catalog.dropTempView("same_name") - spark.catalog.dropGlobalTempView("same_name") + // Use qualified name to lookup a global temp view. + checkAnswer(sql(s"SELECT * FROM $globalTempDB.same_name"), Row(3, 4)) + } } } test("public Catalog should recognize global temp view") { - try { + withGlobalTempView("src") { sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 2") assert(spark.catalog.tableExists(globalTempDB, "src")) @@ -175,8 +168,6 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { description = null, tableType = "TEMPORARY", isTemporary = true).toString) - } finally { - spark.catalog.dropGlobalTempView("src") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 08a4a21b20f61..8c55758cfe38d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -69,21 +69,25 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } test("create a permanent view on a temp view") { - withView("jtv1", "temp_jtv1", "global_temp_jtv1") { - sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jt WHERE id > 3") - var e = intercept[AnalysisException] { - sql("CREATE VIEW jtv1 AS SELECT * FROM temp_jtv1 WHERE id < 6") - }.getMessage - assert(e.contains("Not allowed to create a permanent view `jtv1` by " + - "referencing a temporary view `temp_jtv1`")) - - val globalTempDB = spark.sharedState.globalTempViewManager.database - sql("CREATE GLOBAL TEMP VIEW global_temp_jtv1 AS SELECT * FROM jt WHERE id > 0") - e = intercept[AnalysisException] { - sql(s"CREATE VIEW jtv1 AS SELECT * FROM $globalTempDB.global_temp_jtv1 WHERE id < 6") - }.getMessage - assert(e.contains(s"Not allowed to create a permanent view `jtv1` by referencing " + - s"a temporary view `global_temp`.`global_temp_jtv1`")) + withView("jtv1") { + withTempView("temp_jtv1") { + withGlobalTempView("global_temp_jtv1") { + sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jt WHERE id > 3") + var e = intercept[AnalysisException] { + sql("CREATE VIEW jtv1 AS SELECT * FROM temp_jtv1 WHERE id < 6") + }.getMessage + assert(e.contains("Not allowed to create a permanent view `jtv1` by " + + "referencing a temporary view `temp_jtv1`")) + + val globalTempDB = spark.sharedState.globalTempViewManager.database + sql("CREATE GLOBAL TEMP VIEW global_temp_jtv1 AS SELECT * FROM jt WHERE id > 0") + e = intercept[AnalysisException] { + sql(s"CREATE VIEW jtv1 AS SELECT * FROM $globalTempDB.global_temp_jtv1 WHERE id < 6") + }.getMessage + assert(e.contains(s"Not allowed to create a permanent view `jtv1` by referencing " + + s"a temporary view `global_temp`.`global_temp_jtv1`")) + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 904f9f2ad0b22..bc4a120f7042f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -254,13 +254,26 @@ private[sql] trait SQLTestUtilsBase } /** - * Drops temporary table `tableName` after calling `f`. + * Drops temporary view `viewNames` after calling `f`. */ - protected def withTempView(tableNames: String*)(f: => Unit): Unit = { + protected def withTempView(viewNames: String*)(f: => Unit): Unit = { try f finally { // If the test failed part way, we don't want to mask the failure by failing to remove - // temp tables that never got created. - try tableNames.foreach(spark.catalog.dropTempView) catch { + // temp views that never got created. + try viewNames.foreach(spark.catalog.dropTempView) catch { + case _: NoSuchTableException => + } + } + } + + /** + * Drops global temporary view `viewNames` after calling `f`. + */ + protected def withGlobalTempView(viewNames: String*)(f: => Unit): Unit = { + try f finally { + // If the test failed part way, we don't want to mask the failure by failing to remove + // global temp views that never got created. + try viewNames.foreach(spark.catalog.dropGlobalTempView) catch { case _: NoSuchTableException => } } From ba891ec993c616dc4249fc786c56ea82ed04a827 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Sun, 14 Jan 2018 02:36:32 +0800 Subject: [PATCH 085/774] [SPARK-22790][SQL] add a configurable factor to describe HadoopFsRelation's size ## What changes were proposed in this pull request? as per discussion in https://github.com/apache/spark/pull/19864#discussion_r156847927 the current HadoopFsRelation is purely based on the underlying file size which is not accurate and makes the execution vulnerable to errors like OOM Users can enable CBO with the functionalities in https://github.com/apache/spark/pull/19864 to avoid this issue This JIRA proposes to add a configurable factor to sizeInBytes method in HadoopFsRelation class so that users can mitigate this problem without CBO ## How was this patch tested? Existing tests Author: CodingCat Author: Nan Zhu Closes #20072 from CodingCat/SPARK-22790. --- .../apache/spark/sql/internal/SQLConf.scala | 13 +++++- .../datasources/HadoopFsRelation.scala | 6 ++- .../datasources/HadoopFsRelationSuite.scala | 41 +++++++++++++++++++ 3 files changed, 58 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 36e802a9faa6f..6746fbcaf2483 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -249,7 +249,7 @@ object SQLConf { val CONSTRAINT_PROPAGATION_ENABLED = buildConf("spark.sql.constraintPropagation.enabled") .internal() .doc("When true, the query optimizer will infer and propagate data constraints in the query " + - "plan to optimize them. Constraint propagation can sometimes be computationally expensive" + + "plan to optimize them. Constraint propagation can sometimes be computationally expensive " + "for certain kinds of query plans (such as those with a large number of predicates and " + "aliases) which might negatively impact overall runtime.") .booleanConf @@ -263,6 +263,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val FILE_COMRESSION_FACTOR = buildConf("spark.sql.sources.fileCompressionFactor") + .internal() + .doc("When estimating the output data size of a table scan, multiply the file size with this " + + "factor as the estimated data size, in case the data is compressed in the file and lead to" + + " a heavily underestimated result.") + .doubleConf + .checkValue(_ > 0, "the value of fileDataSizeFactor must be larger than 0") + .createWithDefault(1.0) + val PARQUET_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.parquet.mergeSchema") .doc("When true, the Parquet data source merges schemas collected from all data files, " + "otherwise the schema is picked from the summary file or a random data file " + @@ -1255,6 +1264,8 @@ class SQLConf extends Serializable with Logging { def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS) + def fileCompressionFactor: Double = getConf(FILE_COMRESSION_FACTOR) + def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader) /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index 89d8a85a9cbd2..6b34638529770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -82,7 +82,11 @@ case class HadoopFsRelation( } } - override def sizeInBytes: Long = location.sizeInBytes + override def sizeInBytes: Long = { + val compressionFactor = sqlContext.conf.fileCompressionFactor + (location.sizeInBytes * compressionFactor).toLong + } + override def inputFiles: Array[String] = location.inputFiles } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala index caf03885e3873..c1f2c18d1417d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources import java.io.{File, FilenameFilter} import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.test.SharedSQLContext class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { @@ -39,4 +40,44 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { assert(df.queryExecution.logical.stats.sizeInBytes === BigInt(totalSize)) } } + + test("SPARK-22790: spark.sql.sources.compressionFactor takes effect") { + import testImplicits._ + Seq(1.0, 0.5).foreach { compressionFactor => + withSQLConf("spark.sql.sources.fileCompressionFactor" -> compressionFactor.toString, + "spark.sql.autoBroadcastJoinThreshold" -> "400") { + withTempPath { workDir => + // the file size is 740 bytes + val workDirPath = workDir.getAbsolutePath + val data1 = Seq(100, 200, 300, 400).toDF("count") + data1.write.parquet(workDirPath + "/data1") + val df1FromFile = spark.read.parquet(workDirPath + "/data1") + val data2 = Seq(100, 200, 300, 400).toDF("count") + data2.write.parquet(workDirPath + "/data2") + val df2FromFile = spark.read.parquet(workDirPath + "/data2") + val joinedDF = df1FromFile.join(df2FromFile, Seq("count")) + if (compressionFactor == 0.5) { + val bJoinExec = joinedDF.queryExecution.executedPlan.collect { + case bJoin: BroadcastHashJoinExec => bJoin + } + assert(bJoinExec.nonEmpty) + val smJoinExec = joinedDF.queryExecution.executedPlan.collect { + case smJoin: SortMergeJoinExec => smJoin + } + assert(smJoinExec.isEmpty) + } else { + // compressionFactor is 1.0 + val bJoinExec = joinedDF.queryExecution.executedPlan.collect { + case bJoin: BroadcastHashJoinExec => bJoin + } + assert(bJoinExec.isEmpty) + val smJoinExec = joinedDF.queryExecution.executedPlan.collect { + case smJoin: SortMergeJoinExec => smJoin + } + assert(smJoinExec.nonEmpty) + } + } + } + } + } } From 0066d6f6fa604817468471832968d4339f71c5cb Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 14 Jan 2018 05:39:38 +0800 Subject: [PATCH 086/774] [SPARK-21213][SQL][FOLLOWUP] Use compatible types for comparisons in compareAndGetNewStats ## What changes were proposed in this pull request? This pr fixed code to compare values in `compareAndGetNewStats`. The test below fails in the current master; ``` val oldStats2 = CatalogStatistics(sizeInBytes = BigInt(Long.MaxValue) * 2) val newStats5 = CommandUtils.compareAndGetNewStats( Some(oldStats2), newTotalSize = BigInt(Long.MaxValue) * 2, None) assert(newStats5.isEmpty) ``` ## How was this patch tested? Added some tests in `CommandUtilsSuite`. Author: Takeshi Yamamuro Closes #20245 from maropu/SPARK-21213-FOLLOWUP. --- .../sql/execution/command/CommandUtils.scala | 4 +- .../execution/command/CommandUtilsSuite.scala | 56 +++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/CommandUtilsSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index 1a0d67fc71fbc..c27048626c8eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -116,8 +116,8 @@ object CommandUtils extends Logging { oldStats: Option[CatalogStatistics], newTotalSize: BigInt, newRowCount: Option[BigInt]): Option[CatalogStatistics] = { - val oldTotalSize = oldStats.map(_.sizeInBytes.toLong).getOrElse(-1L) - val oldRowCount = oldStats.flatMap(_.rowCount.map(_.toLong)).getOrElse(-1L) + val oldTotalSize = oldStats.map(_.sizeInBytes).getOrElse(BigInt(-1)) + val oldRowCount = oldStats.flatMap(_.rowCount).getOrElse(BigInt(-1)) var newStats: Option[CatalogStatistics] = None if (newTotalSize >= 0 && newTotalSize != oldTotalSize) { newStats = Some(CatalogStatistics(sizeInBytes = newTotalSize)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CommandUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CommandUtilsSuite.scala new file mode 100644 index 0000000000000..f3e15189a6418 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CommandUtilsSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.catalog.CatalogStatistics + +class CommandUtilsSuite extends SparkFunSuite { + + test("Check if compareAndGetNewStats returns correct results") { + val oldStats1 = CatalogStatistics(sizeInBytes = 10, rowCount = Some(100)) + val newStats1 = CommandUtils.compareAndGetNewStats( + Some(oldStats1), newTotalSize = 10, newRowCount = Some(100)) + assert(newStats1.isEmpty) + val newStats2 = CommandUtils.compareAndGetNewStats( + Some(oldStats1), newTotalSize = -1, newRowCount = None) + assert(newStats2.isEmpty) + val newStats3 = CommandUtils.compareAndGetNewStats( + Some(oldStats1), newTotalSize = 20, newRowCount = Some(-1)) + assert(newStats3.isDefined) + newStats3.foreach { stat => + assert(stat.sizeInBytes === 20) + assert(stat.rowCount.isEmpty) + } + val newStats4 = CommandUtils.compareAndGetNewStats( + Some(oldStats1), newTotalSize = -1, newRowCount = Some(200)) + assert(newStats4.isDefined) + newStats4.foreach { stat => + assert(stat.sizeInBytes === 10) + assert(stat.rowCount.isDefined && stat.rowCount.get === 200) + } + } + + test("Check if compareAndGetNewStats can handle large values") { + // Tests for large values + val oldStats2 = CatalogStatistics(sizeInBytes = BigInt(Long.MaxValue) * 2) + val newStats5 = CommandUtils.compareAndGetNewStats( + Some(oldStats2), newTotalSize = BigInt(Long.MaxValue) * 2, None) + assert(newStats5.isEmpty) + } +} From afae8f2bc82597593595af68d1aa2d802210ea8b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 14 Jan 2018 11:26:49 +0900 Subject: [PATCH 087/774] [SPARK-22959][PYTHON] Configuration to select the modules for daemon and worker in PySpark ## What changes were proposed in this pull request? We are now forced to use `pyspark/daemon.py` and `pyspark/worker.py` in PySpark. This doesn't allow a custom modification for it (well, maybe we can still do this in a super hacky way though, for example, setting Python executable that has the custom modification). Because of this, for example, it's sometimes hard to debug what happens inside Python worker processes. This is actually related with [SPARK-7721](https://issues.apache.org/jira/browse/SPARK-7721) too as somehow Coverage is unable to detect the coverage from `os.fork`. If we have some custom fixes to force the coverage, it works fine. This is also related with [SPARK-20368](https://issues.apache.org/jira/browse/SPARK-20368). This JIRA describes Sentry support which (roughly) needs some changes within worker side. With this configuration advanced users will be able to do a lot of pluggable workarounds and we can meet such potential needs in the future. As an example, let's say if I configure the module `coverage_daemon` and had `coverage_daemon.py` in the python path: ```python import os from pyspark import daemon if "COVERAGE_PROCESS_START" in os.environ: from pyspark.worker import main def _cov_wrapped(*args, **kwargs): import coverage cov = coverage.coverage( config_file=os.environ["COVERAGE_PROCESS_START"]) cov.start() try: main(*args, **kwargs) finally: cov.stop() cov.save() daemon.worker_main = _cov_wrapped if __name__ == '__main__': daemon.manager() ``` I can track the coverages in worker side too. More importantly, we can leave the main code intact but allow some workarounds. ## How was this patch tested? Manually tested. Author: hyukjinkwon Closes #20151 from HyukjinKwon/configuration-daemon-worker. --- .../api/python/PythonWorkerFactory.scala | 41 +++++++++++++++---- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index f53c6178047f5..30976ac752a8a 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -34,10 +34,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String import PythonWorkerFactory._ - // Because forking processes from Java is expensive, we prefer to launch a single Python daemon - // (pyspark/daemon.py) and tell it to fork new workers for our tasks. This daemon currently - // only works on UNIX-based systems now because it uses signals for child management, so we can - // also fall back to launching workers (pyspark/worker.py) directly. + // Because forking processes from Java is expensive, we prefer to launch a single Python daemon, + // pyspark/daemon.py (by default) and tell it to fork new workers for our tasks. This daemon + // currently only works on UNIX-based systems now because it uses signals for child management, + // so we can also fall back to launching workers, pyspark/worker.py (by default) directly. val useDaemon = { val useDaemonEnabled = SparkEnv.get.conf.getBoolean("spark.python.use.daemon", true) @@ -45,6 +45,28 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String !System.getProperty("os.name").startsWith("Windows") && useDaemonEnabled } + // WARN: Both configurations, 'spark.python.daemon.module' and 'spark.python.worker.module' are + // for very advanced users and they are experimental. This should be considered + // as expert-only option, and shouldn't be used before knowing what it means exactly. + + // This configuration indicates the module to run the daemon to execute its Python workers. + val daemonModule = SparkEnv.get.conf.getOption("spark.python.daemon.module").map { value => + logInfo( + s"Python daemon module in PySpark is set to [$value] in 'spark.python.daemon.module', " + + "using this to start the daemon up. Note that this configuration only has an effect when " + + "'spark.python.use.daemon' is enabled and the platform is not Windows.") + value + }.getOrElse("pyspark.daemon") + + // This configuration indicates the module to run each Python worker. + val workerModule = SparkEnv.get.conf.getOption("spark.python.worker.module").map { value => + logInfo( + s"Python worker module in PySpark is set to [$value] in 'spark.python.worker.module', " + + "using this to start the worker up. Note that this configuration only has an effect when " + + "'spark.python.use.daemon' is disabled or the platform is Windows.") + value + }.getOrElse("pyspark.worker") + var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) var daemonPort: Int = 0 @@ -74,8 +96,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } /** - * Connect to a worker launched through pyspark/daemon.py, which forks python processes itself - * to avoid the high cost of forking from Java. This currently only works on UNIX-based systems. + * Connect to a worker launched through pyspark/daemon.py (by default), which forks python + * processes itself to avoid the high cost of forking from Java. This currently only works + * on UNIX-based systems. */ private def createThroughDaemon(): Socket = { @@ -108,7 +131,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } /** - * Launch a worker by executing worker.py directly and telling it to connect to us. + * Launch a worker by executing worker.py (by default) directly and telling it to connect to us. */ private def createSimpleWorker(): Socket = { var serverSocket: ServerSocket = null @@ -116,7 +139,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) // Create and start the worker - val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.worker")) + val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", workerModule)) val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) @@ -159,7 +182,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String try { // Create and start the daemon - val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.daemon")) + val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", daemonModule)) val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) From c3548d11c3c57e8f2c6ebd9d2d6a3924ddcd3cba Mon Sep 17 00:00:00 2001 From: foxish Date: Sat, 13 Jan 2018 21:34:28 -0800 Subject: [PATCH 088/774] [SPARK-23063][K8S] K8s changes for publishing scripts (and a couple of other misses) ## What changes were proposed in this pull request? Including the `-Pkubernetes` flag in a few places it was missed. ## How was this patch tested? checkstyle, mima through manual tests. Author: foxish Closes #20256 from foxish/SPARK-23063. --- dev/create-release/release-build.sh | 4 ++-- dev/create-release/releaseutils.py | 2 ++ dev/deps/spark-deps-hadoop-2.6 | 11 +++++++++++ dev/deps/spark-deps-hadoop-2.7 | 11 +++++++++++ dev/lint-java | 2 +- dev/mima | 2 +- dev/scalastyle | 1 + dev/test-dependencies.sh | 2 +- 8 files changed, 30 insertions(+), 5 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index c71137468054f..a3579f21fc539 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -92,9 +92,9 @@ MVN="build/mvn --force" # Hive-specific profiles for some builds HIVE_PROFILES="-Phive -Phive-thriftserver" # Profiles for publishing snapshots and release to Maven Central -PUBLISH_PROFILES="-Pmesos -Pyarn -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" +PUBLISH_PROFILES="-Pmesos -Pyarn -Pkubernetes -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" # Profiles for building binary releases -BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pflume -Psparkr" +BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pkubernetes -Pflume -Psparkr" # Scala 2.11 only profiles for some builds SCALA_2_11_PROFILES="-Pkafka-0-8" # Scala 2.12 only profiles for some builds diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index 730138195e5fe..32f6cbb29f0be 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -185,6 +185,8 @@ def get_commits(tag): "graphx": "GraphX", "input/output": CORE_COMPONENT, "java api": "Java API", + "k8s": "Kubernetes", + "kubernetes": "Kubernetes", "mesos": "Mesos", "ml": "MLlib", "mllib": "MLlib", diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 2a298769be44c..48e54568e6fc6 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -17,6 +17,7 @@ arpack_combined_all-0.1.jar arrow-format-0.8.0.jar arrow-memory-0.8.0.jar arrow-vector-0.8.0.jar +automaton-1.11-8.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -60,6 +61,7 @@ datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar flatbuffers-1.2.0-3f79e055.jar +generex-1.0.1.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -91,8 +93,10 @@ jackson-annotations-2.6.7.jar jackson-core-2.6.7.jar jackson-core-asl-1.9.13.jar jackson-databind-2.6.7.1.jar +jackson-dataformat-yaml-2.6.7.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar +jackson-module-jaxb-annotations-2.6.7.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar @@ -131,10 +135,13 @@ jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar kryo-shaded-3.0.3.jar +kubernetes-client-3.0.0.jar +kubernetes-model-2.0.0.jar leveldbjni-all-1.8.jar libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar +logging-interceptor-3.8.1.jar lz4-java-1.4.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar @@ -147,6 +154,8 @@ minlog-1.3.0.jar netty-3.9.9.Final.jar netty-all-4.1.17.Final.jar objenesis-2.1.jar +okhttp-3.8.1.jar +okio-1.13.0.jar opencsv-2.3.jar orc-core-1.4.1-nohive.jar orc-mapreduce-1.4.1-nohive.jar @@ -171,6 +180,7 @@ scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar +snakeyaml-1.15.jar snappy-0.2.jar snappy-java-1.1.2.6.jar spire-macros_2.11-0.13.0.jar @@ -186,5 +196,6 @@ xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.0.jar +zjsonpatch-0.3.0.jar zookeeper-3.4.6.jar zstd-jni-1.3.2-2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index abee326f283ab..1807a77900e52 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -17,6 +17,7 @@ arpack_combined_all-0.1.jar arrow-format-0.8.0.jar arrow-memory-0.8.0.jar arrow-vector-0.8.0.jar +automaton-1.11-8.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -60,6 +61,7 @@ datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar flatbuffers-1.2.0-3f79e055.jar +generex-1.0.1.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -91,8 +93,10 @@ jackson-annotations-2.6.7.jar jackson-core-2.6.7.jar jackson-core-asl-1.9.13.jar jackson-databind-2.6.7.1.jar +jackson-dataformat-yaml-2.6.7.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar +jackson-module-jaxb-annotations-2.6.7.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar @@ -132,10 +136,13 @@ jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar kryo-shaded-3.0.3.jar +kubernetes-client-3.0.0.jar +kubernetes-model-2.0.0.jar leveldbjni-all-1.8.jar libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar +logging-interceptor-3.8.1.jar lz4-java-1.4.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar @@ -148,6 +155,8 @@ minlog-1.3.0.jar netty-3.9.9.Final.jar netty-all-4.1.17.Final.jar objenesis-2.1.jar +okhttp-3.8.1.jar +okio-1.13.0.jar opencsv-2.3.jar orc-core-1.4.1-nohive.jar orc-mapreduce-1.4.1-nohive.jar @@ -172,6 +181,7 @@ scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar +snakeyaml-1.15.jar snappy-0.2.jar snappy-java-1.1.2.6.jar spire-macros_2.11-0.13.0.jar @@ -187,5 +197,6 @@ xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.0.jar +zjsonpatch-0.3.0.jar zookeeper-3.4.6.jar zstd-jni-1.3.2-2.jar diff --git a/dev/lint-java b/dev/lint-java index c2e80538ef2a5..1f0b0c8379ed0 100755 --- a/dev/lint-java +++ b/dev/lint-java @@ -20,7 +20,7 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" -ERRORS=$($SCRIPT_DIR/../build/mvn -Pkinesis-asl -Pmesos -Pyarn -Phive -Phive-thriftserver checkstyle:check | grep ERROR) +ERRORS=$($SCRIPT_DIR/../build/mvn -Pkinesis-asl -Pmesos -Pkubernetes -Pyarn -Phive -Phive-thriftserver checkstyle:check | grep ERROR) if test ! -z "$ERRORS"; then echo -e "Checkstyle checks failed at following occurrences:\n$ERRORS" diff --git a/dev/mima b/dev/mima index 1e3ca9700bc07..cd2694ff4d3de 100755 --- a/dev/mima +++ b/dev/mima @@ -24,7 +24,7 @@ set -e FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" -SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pyarn -Pflume -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" +SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pkubernetes -Pyarn -Pflume -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | tail -n1)" OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)" diff --git a/dev/scalastyle b/dev/scalastyle index 89ecc8abd6f8c..b8053df05fa2b 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -24,6 +24,7 @@ ERRORS=$(echo -e "q\n" \ -Pkinesis-asl \ -Pmesos \ -Pkafka-0-8 \ + -Pkubernetes \ -Pyarn \ -Pflume \ -Phive \ diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 58b295d4f6e00..3bf7618e1ea96 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -29,7 +29,7 @@ export LC_ALL=C # TODO: This would be much nicer to do in SBT, once SBT supports Maven-style resolution. # NOTE: These should match those in the release publishing script -HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pyarn -Pflume -Phive" +HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pkubernetes -Pyarn -Pflume -Phive" MVN="build/mvn" HADOOP_PROFILES=( hadoop-2.6 From 7a3d0aad2b89aef54f7dd580397302e9ff984d9d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 13 Jan 2018 23:26:12 -0800 Subject: [PATCH 089/774] [SPARK-23038][TEST] Update docker/spark-test (JDK/OS) ## What changes were proposed in this pull request? This PR aims to update the followings in `docker/spark-test`. - JDK7 -> JDK8 Spark 2.2+ supports JDK8 only. - Ubuntu 12.04.5 LTS(precise) -> Ubuntu 16.04.3 LTS(xeniel) The end of life of `precise` was April 28, 2017. ## How was this patch tested? Manual. * Master ``` $ cd external/docker $ ./build $ export SPARK_HOME=... $ docker run -v $SPARK_HOME:/opt/spark spark-test-master CONTAINER_IP=172.17.0.3 ... 18/01/11 06:50:25 INFO MasterWebUI: Bound MasterWebUI to 172.17.0.3, and started at http://172.17.0.3:8080 18/01/11 06:50:25 INFO Utils: Successfully started service on port 6066. 18/01/11 06:50:25 INFO StandaloneRestServer: Started REST server for submitting applications on port 6066 18/01/11 06:50:25 INFO Master: I have been elected leader! New state: ALIVE ``` * Slave ``` $ docker run -v $SPARK_HOME:/opt/spark spark-test-worker spark://172.17.0.3:7077 CONTAINER_IP=172.17.0.4 ... 18/01/11 06:51:54 INFO Worker: Successfully registered with master spark://172.17.0.3:7077 ``` After slave starts, master will show ``` 18/01/11 06:51:54 INFO Master: Registering worker 172.17.0.4:8888 with 4 cores, 1024.0 MB RAM ``` Author: Dongjoon Hyun Closes #20230 from dongjoon-hyun/SPARK-23038. --- external/docker/spark-test/base/Dockerfile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/external/docker/spark-test/base/Dockerfile b/external/docker/spark-test/base/Dockerfile index 5a95a9387c310..c70cd71367679 100644 --- a/external/docker/spark-test/base/Dockerfile +++ b/external/docker/spark-test/base/Dockerfile @@ -15,14 +15,14 @@ # limitations under the License. # -FROM ubuntu:precise +FROM ubuntu:xenial # Upgrade package index -# install a few other useful packages plus Open Jdk 7 +# install a few other useful packages plus Open Jdk 8 # Remove unneeded /var/lib/apt/lists/* after install to reduce the # docker image size (by ~30MB) RUN apt-get update && \ - apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server && \ + apt-get install -y less openjdk-8-jre-headless iproute2 vim-tiny sudo openssh-server && \ rm -rf /var/lib/apt/lists/* ENV SCALA_VERSION 2.11.8 From 66738d29c59871b29d26fc3756772b95ef536248 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 14 Jan 2018 19:43:10 +0900 Subject: [PATCH 090/774] [SPARK-23069][DOCS][SPARKR] fix R doc for describe missing text ## What changes were proposed in this pull request? fix doc truncated ## How was this patch tested? manually Author: Felix Cheung Closes #20263 from felixcheung/r23docfix. --- R/pkg/R/DataFrame.R | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 9956f7eda91e6..6caa125e1e14a 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3054,10 +3054,10 @@ setMethod("describe", #' \item stddev #' \item min #' \item max -#' \item arbitrary approximate percentiles specified as a percentage (eg, "75%") +#' \item arbitrary approximate percentiles specified as a percentage (eg, "75\%") #' } #' If no statistics are given, this function computes count, mean, stddev, min, -#' approximate quartiles (percentiles at 25%, 50%, and 75%), and max. +#' approximate quartiles (percentiles at 25\%, 50\%, and 75\%), and max. #' This function is meant for exploratory data analysis, as we make no guarantee about the #' backward compatibility of the schema of the resulting Dataset. If you want to #' programmatically compute summary statistics, use the \code{agg} function instead. @@ -4019,9 +4019,9 @@ setMethod("broadcast", #' #' Spark will use this watermark for several purposes: #' \itemize{ -#' \item{-} To know when a given time window aggregation can be finalized and thus can be emitted +#' \item To know when a given time window aggregation can be finalized and thus can be emitted #' when using output modes that do not allow updates. -#' \item{-} To minimize the amount of state that we need to keep for on-going aggregations. +#' \item To minimize the amount of state that we need to keep for on-going aggregations. #' } #' The current watermark is computed by looking at the \code{MAX(eventTime)} seen across #' all of the partitions in the query minus a user specified \code{delayThreshold}. Due to the cost From 990f05c80347c6eec2ee06823cff587c9ea90b49 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 14 Jan 2018 22:26:21 +0800 Subject: [PATCH 091/774] [SPARK-23021][SQL] AnalysisBarrier should override innerChildren to print correct explain output ## What changes were proposed in this pull request? `AnalysisBarrier` in the current master cuts off explain results for parsed logical plans; ``` scala> Seq((1, 1)).toDF("a", "b").groupBy("a").count().sample(0.1).explain(true) == Parsed Logical Plan == Sample 0.0, 0.1, false, -7661439431999668039 +- AnalysisBarrier Aggregate [a#5], [a#5, count(1) AS count#14L] ``` To fix this, `AnalysisBarrier` needs to override `innerChildren` and this pr changed the output to; ``` == Parsed Logical Plan == Sample 0.0, 0.1, false, -5086223488015741426 +- AnalysisBarrier +- Aggregate [a#5], [a#5, count(1) AS count#14L] +- Project [_1#2 AS a#5, _2#3 AS b#6] +- LocalRelation [_1#2, _2#3] ``` ## How was this patch tested? Added tests in `DataFrameSuite`. Author: Takeshi Yamamuro Closes #20247 from maropu/SPARK-23021-2. --- .../plans/logical/basicLogicalOperators.scala | 1 + .../sql/hive/execution/HiveExplainSuite.scala | 17 +++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 95e099c340af1..a4fca790dd086 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -903,6 +903,7 @@ case class Deduplicate( * This analysis barrier will be removed at the end of analysis stage. */ case class AnalysisBarrier(child: LogicalPlan) extends LeafNode { + override protected def innerChildren: Seq[LogicalPlan] = Seq(child) override def output: Seq[Attribute] = child.output override def isStreaming: Boolean = child.isStreaming override def doCanonicalize(): LogicalPlan = child.canonicalized diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index dfabf1ec2a22a..a4273de5fe260 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -171,4 +171,21 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto sql("EXPLAIN EXTENDED CODEGEN SELECT 1") } } + + test("SPARK-23021 AnalysisBarrier should not cut off explain output for parsed logical plans") { + val df = Seq((1, 1)).toDF("a", "b").groupBy("a").count().limit(1) + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + df.explain(true) + } + assert(outputStream.toString.replaceAll("""#\d+""", "#0").contains( + s"""== Parsed Logical Plan == + |GlobalLimit 1 + |+- LocalLimit 1 + | +- AnalysisBarrier + | +- Aggregate [a#0], [a#0, count(1) AS count#0L] + | +- Project [_1#0 AS a#0, _2#0 AS b#0] + | +- LocalRelation [_1#0, _2#0] + |""".stripMargin)) + } } From 60eeecd7760aee6ce2fd207c83ae40054eadaf83 Mon Sep 17 00:00:00 2001 From: Sandor Murakozi Date: Sun, 14 Jan 2018 08:32:35 -0600 Subject: [PATCH 092/774] [SPARK-23051][CORE] Fix for broken job description in Spark UI ## What changes were proposed in this pull request? In 2.2, Spark UI displayed the stage description if the job description was not set. This functionality was broken, the GUI has shown no description in this case. In addition, the code uses jobName and jobDescription instead of stageName and stageDescription when JobTableRowData is created. In this PR the logic producing values for the job rows was modified to find the latest stage attempt for the job and use that as a fallback if job description was missing. StageName and stageDescription are also set using values from stage and jobName/description is used only as a fallback. ## How was this patch tested? Manual testing of the UI, using the code in the bug report. Author: Sandor Murakozi Closes #20251 from smurakozi/SPARK-23051. --- .../apache/spark/ui/jobs/AllJobsPage.scala | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 37e3b3b304a63..ff916bb6a5759 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -65,12 +65,10 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We }.map { job => val jobId = job.jobId val status = job.status - val displayJobDescription = - if (job.description.isEmpty) { - job.name - } else { - UIUtils.makeDescription(job.description.get, "", plainText = true).text - } + val jobDescription = store.lastStageAttempt(job.stageIds.max).description + val displayJobDescription = jobDescription + .map(UIUtils.makeDescription(_, "", plainText = true).text) + .getOrElse("") val submissionTime = job.submissionTime.get.getTime() val completionTime = job.completionTime.map(_.getTime()).getOrElse(System.currentTimeMillis()) val classNameByStatus = status match { @@ -429,20 +427,23 @@ private[ui] class JobDataSource( val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") val submissionTime = jobData.submissionTime val formattedSubmissionTime = submissionTime.map(UIUtils.formatDate).getOrElse("Unknown") - val jobDescription = UIUtils.makeDescription(jobData.description.getOrElse(""), - basePath, plainText = false) + val lastStageAttempt = store.lastStageAttempt(jobData.stageIds.max) + val lastStageDescription = lastStageAttempt.description.getOrElse("") + + val formattedJobDescription = + UIUtils.makeDescription(lastStageDescription, basePath, plainText = false) val detailUrl = "%s/jobs/job?id=%s".format(basePath, jobData.jobId) new JobTableRowData( jobData, - jobData.name, - jobData.description.getOrElse(jobData.name), + lastStageAttempt.name, + lastStageDescription, duration.getOrElse(-1), formattedDuration, submissionTime.map(_.getTime()).getOrElse(-1L), formattedSubmissionTime, - jobDescription, + formattedJobDescription, detailUrl ) } From 42a1a15d739890bdfbb367ef94198b19e98ffcb7 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Mon, 15 Jan 2018 02:02:49 +0800 Subject: [PATCH 093/774] [SPARK-22999][SQL] show databases like command' can remove the like keyword ## What changes were proposed in this pull request? SHOW DATABASES (LIKE pattern = STRING)? Can be like the back increase? When using this command, LIKE keyword can be removed. You can refer to the SHOW TABLES command, SHOW TABLES 'test *' and SHOW TABELS like 'test *' can be used. Similarly SHOW DATABASES 'test *' and SHOW DATABASES like 'test *' can be used. ## How was this patch tested? unit tests manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #20194 from guoxiaolongzte/SPARK-22999. --- .../antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../org/apache/spark/sql/execution/command/DDLSuite.scala | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 6daf01d98426c..39d5e4ed56628 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -141,7 +141,7 @@ statement (LIKE? pattern=STRING)? #showTables | SHOW TABLE EXTENDED ((FROM | IN) db=identifier)? LIKE pattern=STRING partitionSpec? #showTable - | SHOW DATABASES (LIKE pattern=STRING)? #showDatabases + | SHOW DATABASES (LIKE? pattern=STRING)? #showDatabases | SHOW TBLPROPERTIES table=tableIdentifier ('(' key=tablePropertyKey ')')? #showTblProperties | SHOW COLUMNS (FROM | IN) tableIdentifier diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 591510c1d8283..2b4b7c137428a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -991,6 +991,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql("SHOW DATABASES LIKE '*db1A'"), Row("showdb1a") :: Nil) + checkAnswer( + sql("SHOW DATABASES '*db1A'"), + Row("showdb1a") :: Nil) + checkAnswer( sql("SHOW DATABASES LIKE 'showdb1A'"), Row("showdb1a") :: Nil) From b98ffa4d6dabaf787177d3f14b200fc4b118c7ce Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 15 Jan 2018 10:55:21 +0800 Subject: [PATCH 094/774] [SPARK-23054][SQL] Fix incorrect results of casting UserDefinedType to String ## What changes were proposed in this pull request? This pr fixed the issue when casting `UserDefinedType`s into strings; ``` >>> from pyspark.ml.classification import MultilayerPerceptronClassifier >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([(0.0, Vectors.dense([0.0, 0.0])), (1.0, Vectors.dense([0.0, 1.0]))], ["label", "features"]) >>> df.selectExpr("CAST(features AS STRING)").show(truncate = False) +-------------------------------------------+ |features | +-------------------------------------------+ |[6,1,0,0,2800000020,2,0,0,0] | |[6,1,0,0,2800000020,2,0,0,3ff0000000000000]| +-------------------------------------------+ ``` The root cause is that `Cast` handles input data as `UserDefinedType.sqlType`(this is underlying storage type), so we should pass data into `UserDefinedType.deserialize` then `toString`. This pr modified the result into; ``` +---------+ |features | +---------+ |[0.0,0.0]| |[0.0,1.0]| +---------+ ``` ## How was this patch tested? Added tests in `UserDefinedTypeSuite `. Author: Takeshi Yamamuro Closes #20246 from maropu/SPARK-23054. --- .../spark/sql/catalyst/expressions/Cast.scala | 7 +++++++ .../apache/spark/sql/UserDefinedTypeSuite.scala | 15 +++++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f21aa1e9e3135..a95ebe301b9d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -282,6 +282,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String builder.append("]") builder.build() }) + case udt: UserDefinedType[_] => + buildCast[Any](_, o => UTF8String.fromString(udt.deserialize(o).toString)) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -836,6 +838,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$evPrim = $buffer.build(); """.stripMargin } + case udt: UserDefinedType[_] => + val udtRef = ctx.addReferenceObj("udt", udt) + (c, evPrim, evNull) => { + s"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());" + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index a08433ba794d9..cc8b600efa46a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -21,7 +21,7 @@ import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.expressions.{Cast, ExpressionEvalHelper, GenericInternalRow, Literal} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ @@ -44,6 +44,8 @@ object UDT { case v: MyDenseVector => java.util.Arrays.equals(this.data, v.data) case _ => false } + + override def toString: String = data.mkString("(", ", ", ")") } private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { @@ -143,7 +145,8 @@ private[spark] class ExampleSubTypeUDT extends UserDefinedType[IExampleSubType] override def userClass: Class[IExampleSubType] = classOf[IExampleSubType] } -class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest { +class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest + with ExpressionEvalHelper { import testImplicits._ private lazy val pointsRDD = Seq( @@ -304,4 +307,12 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT pointsRDD.except(pointsRDD2), Seq(Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) } + + test("SPARK-23054 Cast UserDefinedType to string") { + val udt = new UDT.MyDenseVectorUDT() + val vector = new UDT.MyDenseVector(Array(1.0, 3.0, 5.0, 7.0, 9.0)) + val data = udt.serialize(vector) + val ret = Cast(Literal(data, udt), StringType, None) + checkEvaluation(ret, "(1.0, 3.0, 5.0, 7.0, 9.0)") + } } From 9a96bfc8bf021cb4b6c62fac6ce1bcf87affcd43 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 15 Jan 2018 12:06:56 +0800 Subject: [PATCH 095/774] [SPARK-23049][SQL] `spark.sql.files.ignoreCorruptFiles` should work for ORC files ## What changes were proposed in this pull request? When `spark.sql.files.ignoreCorruptFiles=true`, we should ignore corrupted ORC files. ## How was this patch tested? Pass the Jenkins with a newly added test case. Author: Dongjoon Hyun Closes #20240 from dongjoon-hyun/SPARK-23049. --- .../execution/datasources/orc/OrcUtils.scala | 29 ++++++++---- .../datasources/orc/OrcQuerySuite.scala | 47 +++++++++++++++++++ .../parquet/ParquetQuerySuite.scala | 23 +++++++-- .../spark/sql/hive/orc/OrcFileFormat.scala | 8 +++- .../spark/sql/hive/orc/OrcFileOperator.scala | 28 +++++++++-- 5 files changed, 117 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 13a23996f4ade..460194ba61c8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.orc.{OrcFile, Reader, TypeDescription} +import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession @@ -50,23 +51,35 @@ object OrcUtils extends Logging { paths } - def readSchema(file: Path, conf: Configuration): Option[TypeDescription] = { + def readSchema(file: Path, conf: Configuration, ignoreCorruptFiles: Boolean) + : Option[TypeDescription] = { val fs = file.getFileSystem(conf) val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) - val reader = OrcFile.createReader(file, readerOptions) - val schema = reader.getSchema - if (schema.getFieldNames.size == 0) { - None - } else { - Some(schema) + try { + val reader = OrcFile.createReader(file, readerOptions) + val schema = reader.getSchema + if (schema.getFieldNames.size == 0) { + None + } else { + Some(schema) + } + } catch { + case e: org.apache.orc.FileFormatException => + if (ignoreCorruptFiles) { + logWarning(s"Skipped the footer in the corrupted file: $file", e) + None + } else { + throw new SparkException(s"Could not read footer for file: $file", e) + } } } def readSchema(sparkSession: SparkSession, files: Seq[FileStatus]) : Option[StructType] = { + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles val conf = sparkSession.sessionState.newHadoopConf() // TODO: We need to support merge schema. Please see SPARK-11412. - files.map(_.getPath).flatMap(readSchema(_, conf)).headOption.map { schema => + files.map(_.getPath).flatMap(readSchema(_, conf, ignoreCorruptFiles)).headOption.map { schema => logDebug(s"Reading schema from file $files, got Hive schema string: $schema") CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index e00e057a18cc6..f58c331f33ca8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -31,6 +31,7 @@ import org.apache.orc.OrcConf.COMPRESS import org.apache.orc.mapred.OrcStruct import org.apache.orc.mapreduce.OrcInputFormat +import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, RecordReaderIterator} @@ -531,6 +532,52 @@ abstract class OrcQueryTest extends OrcTest { val df = spark.read.orc(path1.getCanonicalPath, path2.getCanonicalPath) assert(df.count() == 20) } + + test("Enabling/disabling ignoreCorruptFiles") { + def testIgnoreCorruptFiles(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.orc(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.orc(new Path(basePath, "second").toString) + spark.range(2, 3).toDF("a").write.json(new Path(basePath, "third").toString) + val df = spark.read.orc( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + checkAnswer(df, Seq(Row(0), Row(1))) + } + } + + def testIgnoreCorruptFilesWithoutSchemaInfer(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.orc(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.orc(new Path(basePath, "second").toString) + spark.range(2, 3).toDF("a").write.json(new Path(basePath, "third").toString) + val df = spark.read.schema("a long").orc( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + checkAnswer(df, Seq(Row(0), Row(1))) + } + } + + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "true") { + testIgnoreCorruptFiles() + testIgnoreCorruptFilesWithoutSchemaInfer() + } + + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { + val m1 = intercept[SparkException] { + testIgnoreCorruptFiles() + }.getMessage + assert(m1.contains("Could not read footer for file")) + val m2 = intercept[SparkException] { + testIgnoreCorruptFilesWithoutSchemaInfer() + }.getMessage + assert(m2.contains("Malformed ORC file")) + } + } } class OrcQuerySuite extends OrcQueryTest with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 4c8c9ef6e0432..6ad88ed997ce7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -320,14 +320,27 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext new Path(basePath, "first").toString, new Path(basePath, "second").toString, new Path(basePath, "third").toString) - checkAnswer( - df, - Seq(Row(0), Row(1))) + checkAnswer(df, Seq(Row(0), Row(1))) + } + } + + def testIgnoreCorruptFilesWithoutSchemaInfer(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.parquet(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.parquet(new Path(basePath, "second").toString) + spark.range(2, 3).toDF("a").write.json(new Path(basePath, "third").toString) + val df = spark.read.schema("a long").parquet( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + checkAnswer(df, Seq(Row(0), Row(1))) } } withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "true") { testIgnoreCorruptFiles() + testIgnoreCorruptFilesWithoutSchemaInfer() } withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { @@ -335,6 +348,10 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext testIgnoreCorruptFiles() } assert(exception.getMessage().contains("is not a Parquet file")) + val exception2 = intercept[SparkException] { + testIgnoreCorruptFilesWithoutSchemaInfer() + } + assert(exception2.getMessage().contains("is not a Parquet file")) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 95741c7b30289..237ed9bc05988 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -59,9 +59,11 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles OrcFileOperator.readSchema( files.map(_.getPath.toString), - Some(sparkSession.sessionState.newHadoopConf()) + Some(sparkSession.sessionState.newHadoopConf()), + ignoreCorruptFiles ) } @@ -129,6 +131,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value @@ -138,7 +141,8 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable // SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this // case, `OrcFileOperator.readSchema` returns `None`, and we can't read the underlying file // using the given physical schema. Instead, we simply return an empty iterator. - val isEmptyFile = OrcFileOperator.readSchema(Seq(filePath.toString), Some(conf)).isEmpty + val isEmptyFile = + OrcFileOperator.readSchema(Seq(filePath.toString), Some(conf), ignoreCorruptFiles).isEmpty if (isEmptyFile) { Iterator.empty } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 5a3fcd7a759c0..80e44ca504356 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.hive.orc +import java.io.IOException + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.io.orc.{OrcFile, Reader} import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector +import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -46,7 +49,10 @@ private[hive] object OrcFileOperator extends Logging { * create the result reader from that file. If no such file is found, it returns `None`. * @todo Needs to consider all files when schema evolution is taken into account. */ - def getFileReader(basePath: String, config: Option[Configuration] = None): Option[Reader] = { + def getFileReader(basePath: String, + config: Option[Configuration] = None, + ignoreCorruptFiles: Boolean = false) + : Option[Reader] = { def isWithNonEmptySchema(path: Path, reader: Reader): Boolean = { reader.getObjectInspector match { case oi: StructObjectInspector if oi.getAllStructFieldRefs.size() == 0 => @@ -65,16 +71,28 @@ private[hive] object OrcFileOperator extends Logging { } listOrcFiles(basePath, conf).iterator.map { path => - path -> OrcFile.createReader(fs, path) + val reader = try { + Some(OrcFile.createReader(fs, path)) + } catch { + case e: IOException => + if (ignoreCorruptFiles) { + logWarning(s"Skipped the footer in the corrupted file: $path", e) + None + } else { + throw new SparkException(s"Could not read footer for file: $path", e) + } + } + path -> reader }.collectFirst { - case (path, reader) if isWithNonEmptySchema(path, reader) => reader + case (path, Some(reader)) if isWithNonEmptySchema(path, reader) => reader } } - def readSchema(paths: Seq[String], conf: Option[Configuration]): Option[StructType] = { + def readSchema(paths: Seq[String], conf: Option[Configuration], ignoreCorruptFiles: Boolean) + : Option[StructType] = { // Take the first file where we can open a valid reader if we can find one. Otherwise just // return None to indicate we can't infer the schema. - paths.flatMap(getFileReader(_, conf)).headOption.map { reader => + paths.flatMap(getFileReader(_, conf, ignoreCorruptFiles)).headOption.map { reader => val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] val schema = readerInspector.getTypeName logDebug(s"Reading schema from file $paths, got Hive schema string: $schema") From b59808385cfe24ce768e5b3098b9034e64b99a5a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 15 Jan 2018 16:26:52 +0800 Subject: [PATCH 096/774] [SPARK-23023][SQL] Cast field data to strings in showString ## What changes were proposed in this pull request? The current `Datset.showString` prints rows thru `RowEncoder` deserializers like; ``` scala> Seq(Seq(Seq(1, 2), Seq(3), Seq(4, 5, 6))).toDF("a").show(false) +------------------------------------------------------------+ |a | +------------------------------------------------------------+ |[WrappedArray(1, 2), WrappedArray(3), WrappedArray(4, 5, 6)]| +------------------------------------------------------------+ ``` This result is incorrect because the correct one is; ``` scala> Seq(Seq(Seq(1, 2), Seq(3), Seq(4, 5, 6))).toDF("a").show(false) +------------------------+ |a | +------------------------+ |[[1, 2], [3], [4, 5, 6]]| +------------------------+ ``` So, this pr fixed code in `showString` to cast field data to strings before printing. ## How was this patch tested? Added tests in `DataFrameSuite`. Author: Takeshi Yamamuro Closes #20214 from maropu/SPARK-23023. --- python/pyspark/sql/functions.py | 32 +++++++++---------- .../scala/org/apache/spark/sql/Dataset.scala | 21 ++++++------ .../org/apache/spark/sql/DataFrameSuite.scala | 28 ++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 12 +++---- 4 files changed, 61 insertions(+), 32 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e1ad6590554cf..f7b3f29764040 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1849,14 +1849,14 @@ def explode_outer(col): +---+----------+----+-----+ >>> df.select("id", "a_map", explode_outer("an_array")).show() - +---+-------------+----+ - | id| a_map| col| - +---+-------------+----+ - | 1|Map(x -> 1.0)| foo| - | 1|Map(x -> 1.0)| bar| - | 2| Map()|null| - | 3| null|null| - +---+-------------+----+ + +---+----------+----+ + | id| a_map| col| + +---+----------+----+ + | 1|[x -> 1.0]| foo| + | 1|[x -> 1.0]| bar| + | 2| []|null| + | 3| null|null| + +---+----------+----+ """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.explode_outer(_to_java_column(col)) @@ -1881,14 +1881,14 @@ def posexplode_outer(col): | 3| null|null|null| null| +---+----------+----+----+-----+ >>> df.select("id", "a_map", posexplode_outer("an_array")).show() - +---+-------------+----+----+ - | id| a_map| pos| col| - +---+-------------+----+----+ - | 1|Map(x -> 1.0)| 0| foo| - | 1|Map(x -> 1.0)| 1| bar| - | 2| Map()|null|null| - | 3| null|null|null| - +---+-------------+----+----+ + +---+----------+----+----+ + | id| a_map| pos| col| + +---+----------+----+----+ + | 1|[x -> 1.0]| 0| foo| + | 1|[x -> 1.0]| 1| bar| + | 2| []|null|null| + | 3| null|null|null| + +---+----------+----+----+ """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.posexplode_outer(_to_java_column(col)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 77e571272920a..34f0ab5aa6699 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -237,13 +237,20 @@ class Dataset[T] private[sql]( private[sql] def showString( _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { val numRows = _numRows.max(0).min(Int.MaxValue - 1) - val takeResult = toDF().take(numRows + 1) + val newDf = toDF() + val castCols = newDf.logicalPlan.output.map { col => + // Since binary types in top-level schema fields have a specific format to print, + // so we do not cast them to strings here. + if (col.dataType == BinaryType) { + Column(col) + } else { + Column(col).cast(StringType) + } + } + val takeResult = newDf.select(castCols: _*).take(numRows + 1) val hasMoreData = takeResult.length > numRows val data = takeResult.take(numRows) - lazy val timeZone = - DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone) - // For array values, replace Seq and Array with square brackets // For cells that are beyond `truncate` characters, replace it with the // first `truncate-3` and "..." @@ -252,12 +259,6 @@ class Dataset[T] private[sql]( val str = cell match { case null => "null" case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]") - case array: Array[_] => array.mkString("[", ", ", "]") - case seq: Seq[_] => seq.mkString("[", ", ", "]") - case d: Date => - DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d)) - case ts: Timestamp => - DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(ts), timeZone) case _ => cell.toString } if (truncate > 0 && str.length > truncate) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5e4c1a6a484fb..33707080c1301 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1255,6 +1255,34 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.select($"*").showString(1, vertical = true) === expectedAnswer) } + test("SPARK-23023 Cast rows to strings in showString") { + val df1 = Seq(Seq(1, 2, 3, 4)).toDF("a") + assert(df1.showString(10) === + s"""+------------+ + || a| + |+------------+ + ||[1, 2, 3, 4]| + |+------------+ + |""".stripMargin) + val df2 = Seq(Map(1 -> "a", 2 -> "b")).toDF("a") + assert(df2.showString(10) === + s"""+----------------+ + || a| + |+----------------+ + ||[1 -> a, 2 -> b]| + |+----------------+ + |""".stripMargin) + val df3 = Seq(((1, "a"), 0), ((2, "b"), 0)).toDF("a", "b") + assert(df3.showString(10) === + s"""+------+---+ + || a| b| + |+------+---+ + ||[1, a]| 0| + ||[2, b]| 0| + |+------+---+ + |""".stripMargin) + } + test("SPARK-7327 show with empty dataFrame") { val expectedAnswer = """+---+-----+ ||key|value| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 54893c184642b..49c59cf695dc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -958,12 +958,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ).toDS() val expected = - """+-------+ - || f| - |+-------+ - ||[foo,1]| - ||[bar,2]| - |+-------+ + """+--------+ + || f| + |+--------+ + ||[foo, 1]| + ||[bar, 2]| + |+--------+ |""".stripMargin checkShowString(ds, expected) From a38c887ac093d7cf343d807515147d87ca931ce7 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 15 Jan 2018 07:49:34 -0600 Subject: [PATCH 097/774] [SPARK-19550][BUILD][FOLLOW-UP] Remove MaxPermSize for sql module ## What changes were proposed in this pull request? Remove `MaxPermSize` for `sql` module ## How was this patch tested? Manually tested. Author: Yuming Wang Closes #20268 from wangyum/SPARK-19550-MaxPermSize. --- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 839b929abd3cb..7d23637e28342 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -134,7 +134,7 @@ org.scalatest scalatest-maven-plugin - -ea -Xmx4g -Xss4m -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + -ea -Xmx4g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize} diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 744daa6079779..ef41837f89d68 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -195,7 +195,7 @@ org.scalatest scalatest-maven-plugin - -ea -Xmx4g -Xss4m -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + -ea -Xmx4g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize} From bd08a9e7af4137bddca638e627ad2ae531bce20f Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 15 Jan 2018 22:32:38 +0800 Subject: [PATCH 098/774] [SPARK-23070] Bump previousSparkVersion in MimaBuild.scala to be 2.2.0 ## What changes were proposed in this pull request? Bump previousSparkVersion in MimaBuild.scala to be 2.2.0 and add the missing exclusions to `v23excludes` in `MimaExcludes`. No item can be un-excluded in `v23excludes`. ## How was this patch tested? The existing tests. Author: gatorsmile Closes #20264 from gatorsmile/bump22. --- project/MimaBuild.scala | 2 +- project/MimaExcludes.scala | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 2ef0e7b40d940..adde213e361f0 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -88,7 +88,7 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "2.0.0" + val previousSparkVersion = "2.2.0" val project = projectRef.project val fullId = "spark-" + project + "_2.11" mimaDefaultSettings ++ diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 32eb31f495979..d35c50e1d00fe 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -102,7 +102,40 @@ object MimaExcludes { // [SPARK-21087] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelWriter"), - ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter") + ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter"), + + // [SPARK-21728][CORE] Allow SparkSubmit to use Logging + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.downloadFileList"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.downloadFile"), + + // [SPARK-21714][CORE][YARN] Avoiding re-uploading remote resources in yarn client mode + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.prepareSubmitEnvironment"), + + // [SPARK-22324][SQL][PYTHON] Upgrade Arrow to 0.8.0 + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.network.util.AbstractFileRegion.transfered"), + + // [SPARK-20643][CORE] Add listener implementation to collect app state + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$5"), + + // [SPARK-20648][CORE] Port JobsTab and StageTab to the new UI backend + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$12"), + + // [SPARK-21462][SS] Added batchId to StreamingQueryProgress.json + // [SPARK-21409][SS] Expose state store memory usage in SQL metrics and progress updates + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StateOperatorProgress.this"), + + // [SPARK-22278][SS] Expose current event time watermark and current processing time in GroupState + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.GroupState.getCurrentWatermarkMs"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.GroupState.getCurrentProcessingTimeMs"), + + // [SPARK-20542][ML][SQL] Add an API to Bucketizer that can bin multiple columns + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.org$apache$spark$ml$param$shared$HasOutputCols$_setter_$outputCols_="), + + // [SPARK-18619][ML] Make QuantileDiscretizer/Bucketizer/StringIndexer/RFormula inherit from HasHandleInvalid + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.Bucketizer.getHandleInvalid"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.StringIndexer.getHandleInvalid"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.QuantileDiscretizer.getHandleInvalid"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.getHandleInvalid") ) // Exclude rules for 2.2.x From 6c81fe227a6233f5d9665d2efadf8a1cf09f700d Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Mon, 15 Jan 2018 23:13:15 +0800 Subject: [PATCH 099/774] [SPARK-23035][SQL] Fix improper information of TempTableAlreadyExistsException ## What changes were proposed in this pull request? Problem: it throw TempTableAlreadyExistsException and output "Temporary table '$table' already exists" when we create temp view by using org.apache.spark.sql.catalyst.catalog.GlobalTempViewManager#create, it's improper. So fix improper information about TempTableAlreadyExistsException when create temp view: change "Temporary table" to "Temporary view" ## How was this patch tested? test("rename temporary view - destination table already exists, with: CREATE TEMPORARY view") test("rename temporary view - destination table with database name,with:CREATE TEMPORARY view") Author: xubo245 <601450868@qq.com> Closes #20227 from xubo245/fixDeprecated. --- .../analysis/AlreadyExistException.scala | 2 +- .../catalog/SessionCatalogSuite.scala | 6 +- .../spark/sql/execution/SQLViewSuite.scala | 2 +- .../sql/execution/command/DDLSuite.scala | 75 ++++++++++++++++++- 4 files changed, 78 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala index 57f7a80bedc6c..6d587abd8fd4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala @@ -31,7 +31,7 @@ class TableAlreadyExistsException(db: String, table: String) extends AnalysisException(s"Table or view '$table' already exists in database '$db'") class TempTableAlreadyExistsException(table: String) - extends AnalysisException(s"Temporary table '$table' already exists") + extends AnalysisException(s"Temporary view '$table' already exists") class PartitionAlreadyExistsException(db: String, table: String, spec: TablePartitionSpec) extends AnalysisException( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 95c87ffa20cb7..6abab0073cca3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -279,7 +279,7 @@ abstract class SessionCatalogSuite extends AnalysisTest { } } - test("create temp table") { + test("create temp view") { withBasicCatalog { catalog => val tempTable1 = Range(1, 10, 1, 10) val tempTable2 = Range(1, 20, 2, 10) @@ -288,11 +288,11 @@ abstract class SessionCatalogSuite extends AnalysisTest { assert(catalog.getTempView("tbl1") == Option(tempTable1)) assert(catalog.getTempView("tbl2") == Option(tempTable2)) assert(catalog.getTempView("tbl3").isEmpty) - // Temporary table already exists + // Temporary view already exists intercept[TempTableAlreadyExistsException] { catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) } - // Temporary table already exists but we override it + // Temporary view already exists but we override it catalog.createTempView("tbl1", tempTable2, overrideIfExists = true) assert(catalog.getTempView("tbl1") == Option(tempTable2)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 8c55758cfe38d..14082197ba0bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -293,7 +293,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") } - assert(e.message.contains("Temporary table") && e.message.contains("already exists")) + assert(e.message.contains("Temporary view") && e.message.contains("already exists")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 2b4b7c137428a..6ca21b5aa1595 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -835,6 +835,31 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("rename temporary view - destination table with database name,with:CREATE TEMPORARY view") { + withTempView("view1") { + sql( + """ + |CREATE TEMPORARY VIEW view1 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + + val e = intercept[AnalysisException] { + sql("ALTER TABLE view1 RENAME TO default.tab2") + } + assert(e.getMessage.contains( + "RENAME TEMPORARY VIEW from '`view1`' to '`default`.`tab2`': " + + "cannot specify database name 'default' in the destination table")) + + val catalog = spark.sessionState.catalog + assert(catalog.listTables("default") == Seq(TableIdentifier("view1"))) + } + } + test("rename temporary view") { withTempView("tab1", "tab2") { spark.range(10).createOrReplaceTempView("tab1") @@ -883,6 +908,42 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("rename temporary view - destination table already exists, with: CREATE TEMPORARY view") { + withTempView("view1", "view2") { + sql( + """ + |CREATE TEMPORARY VIEW view1 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + + sql( + """ + |CREATE TEMPORARY VIEW view2 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + + val e = intercept[AnalysisException] { + sql("ALTER TABLE view1 RENAME TO view2") + } + assert(e.getMessage.contains( + "RENAME TEMPORARY VIEW from '`view1`' to '`view2`': destination table already exists")) + + val catalog = spark.sessionState.catalog + assert(catalog.listTables("default") == + Seq(TableIdentifier("view1"), TableIdentifier("view2"))) + } + } + test("alter table: bucketing is not supported") { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) @@ -1728,12 +1789,22 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("block creating duplicate temp table") { - withView("t_temp") { + withTempView("t_temp") { sql("CREATE TEMPORARY VIEW t_temp AS SELECT 1, 2") val e = intercept[TempTableAlreadyExistsException] { sql("CREATE TEMPORARY TABLE t_temp (c3 int, c4 string) USING JSON") }.getMessage - assert(e.contains("Temporary table 't_temp' already exists")) + assert(e.contains("Temporary view 't_temp' already exists")) + } + } + + test("block creating duplicate temp view") { + withTempView("t_temp") { + sql("CREATE TEMPORARY VIEW t_temp AS SELECT 1, 2") + val e = intercept[TempTableAlreadyExistsException] { + sql("CREATE TEMPORARY VIEW t_temp (c3 int, c4 string) USING JSON") + }.getMessage + assert(e.contains("Temporary view 't_temp' already exists")) } } From 8ab2d7ea99b2cff8b54b2cb3a1dbf7580845986a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 16 Jan 2018 11:47:42 +0900 Subject: [PATCH 100/774] [SPARK-23080][SQL] Improve error message for built-in functions ## What changes were proposed in this pull request? When a user puts the wrong number of parameters in a function, an AnalysisException is thrown. If the function is a UDF, he user is told how many parameters the function expected and how many he/she put. If the function, instead, is a built-in one, no information about the number of parameters expected and the actual one is provided. This can help in some cases, to debug the errors (eg. bad quotes escaping may lead to a different number of parameters than expected, etc. etc.) The PR adds the information about the number of parameters passed and the expected one, analogously to what happens for UDF. ## How was this patch tested? modified existing UT + manual test Author: Marco Gaido Closes #20271 from mgaido91/SPARK-23080. --- .../spark/sql/catalyst/analysis/FunctionRegistry.scala | 10 +++++++++- .../resources/sql-tests/results/json-functions.sql.out | 4 ++-- .../src/test/scala/org/apache/spark/sql/UDFSuite.scala | 4 ++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 5ddb39822617d..747016beb06e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -526,7 +526,15 @@ object FunctionRegistry { // Otherwise, find a constructor method that matches the number of arguments, and use that. val params = Seq.fill(expressions.size)(classOf[Expression]) val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { - throw new AnalysisException(s"Invalid number of arguments for function $name") + val validParametersCount = constructors.map(_.getParameterCount).distinct.sorted + val expectedNumberOfParameters = if (validParametersCount.length == 1) { + validParametersCount.head.toString + } else { + validParametersCount.init.mkString("one of ", ", ", " and ") + + validParametersCount.last + } + throw new AnalysisException(s"Invalid number of arguments for function $name. " + + s"Expected: $expectedNumberOfParameters; Found: ${params.length}") } Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match { case Success(e) => e diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index d9dc728a18e8d..581dddc89d0bb 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -129,7 +129,7 @@ select to_json() struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function to_json; line 1 pos 7 +Invalid number of arguments for function to_json. Expected: one of 1, 2 and 3; Found: 0; line 1 pos 7 -- !query 13 @@ -225,7 +225,7 @@ select from_json() struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function from_json; line 1 pos 7 +Invalid number of arguments for function from_json. Expected: one of 2, 3 and 4; Found: 0; line 1 pos 7 -- !query 22 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index db37be68e42e6..af6a10b425b9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -80,7 +80,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { val e = intercept[AnalysisException] { df.selectExpr("substr('abcd', 2, 3, 4)") } - assert(e.getMessage.contains("Invalid number of arguments for function substr")) + assert(e.getMessage.contains("Invalid number of arguments for function substr. Expected:")) } test("error reporting for incorrect number of arguments - udf") { @@ -89,7 +89,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { spark.udf.register("foo", (_: String).length) df.selectExpr("foo(2, 3, 4)") } - assert(e.getMessage.contains("Invalid number of arguments for function foo")) + assert(e.getMessage.contains("Invalid number of arguments for function foo. Expected:")) } test("error reporting for undefined functions") { From c7572b79da0a29e502890d7618eaf805a1c9f474 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 16 Jan 2018 11:20:18 +0800 Subject: [PATCH 101/774] [SPARK-23000] Use fully qualified table names in HiveMetastoreCatalogSuite ## What changes were proposed in this pull request? In another attempt to fix DataSourceWithHiveMetastoreCatalogSuite, this patch uses qualified table names (`default.t`) in the individual tests. ## How was this patch tested? N/A (Test Only Change) Author: Sameer Agarwal Closes #20273 from sameeragarwal/flaky-test. --- .../sql/hive/HiveMetastoreCatalogSuite.scala | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index ba9b944e4a055..83b4c862e2546 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -166,13 +166,13 @@ class DataSourceWithHiveMetastoreCatalogSuite )) ).foreach { case (provider, (inputFormat, outputFormat, serde)) => test(s"Persist non-partitioned $provider relation into metastore as managed table") { - withTable("t") { + withTable("default.t") { withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { testDF .write .mode(SaveMode.Overwrite) .format(provider) - .saveAsTable("t") + .saveAsTable("default.t") } val hiveTable = sessionState.catalog.getTableMetadata(TableIdentifier("t", Some("default"))) @@ -187,14 +187,15 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.dataType) === Seq(DecimalType(10, 3), StringType)) - checkAnswer(table("t"), testDF) - assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) + checkAnswer(table("default.t"), testDF) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM default.t") === + Seq("1.1\t1", "2.1\t2")) } } test(s"Persist non-partitioned $provider relation into metastore as external table") { withTempPath { dir => - withTable("t") { + withTable("default.t") { val path = dir.getCanonicalFile withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { @@ -203,7 +204,7 @@ class DataSourceWithHiveMetastoreCatalogSuite .mode(SaveMode.Overwrite) .format(provider) .option("path", path.toString) - .saveAsTable("t") + .saveAsTable("default.t") } val hiveTable = @@ -219,8 +220,8 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.dataType) === Seq(DecimalType(10, 3), StringType)) - checkAnswer(table("t"), testDF) - assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === + checkAnswer(table("default.t"), testDF) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM default.t") === Seq("1.1\t1", "2.1\t2")) } } @@ -228,9 +229,9 @@ class DataSourceWithHiveMetastoreCatalogSuite test(s"Persist non-partitioned $provider relation into metastore as managed table using CTAS") { withTempPath { dir => - withTable("t") { + withTable("default.t") { sql( - s"""CREATE TABLE t USING $provider + s"""CREATE TABLE default.t USING $provider |OPTIONS (path '${dir.toURI}') |AS SELECT 1 AS d1, "val_1" AS d2 """.stripMargin) @@ -248,8 +249,9 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.dataType) === Seq(IntegerType, StringType)) - checkAnswer(table("t"), Row(1, "val_1")) - assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) + checkAnswer(table("default.t"), Row(1, "val_1")) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM default.t") === + Seq("1\tval_1")) } } } From 07ae39d0ec1f03b1c73259373a8bb599694c7860 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 15 Jan 2018 22:01:14 -0800 Subject: [PATCH 102/774] [SPARK-22956][SS] Bug fix for 2 streams union failover scenario ## What changes were proposed in this pull request? This problem reported by yanlin-Lynn ivoson and LiangchangZ. Thanks! When we union 2 streams from kafka or other sources, while one of them have no continues data coming and in the same time task restart, this will cause an `IllegalStateException`. This mainly cause because the code in [MicroBatchExecution](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala#L190) , while one stream has no continues data, its comittedOffset same with availableOffset during `populateStartOffsets`, and `currentPartitionOffsets` not properly handled in KafkaSource. Also, maybe we should also consider this scenario in other Source. ## How was this patch tested? Add a UT in KafkaSourceSuite.scala Author: Yuanjian Li Closes #20150 from xuanyuanking/SPARK-22956. --- .../spark/sql/kafka010/KafkaSource.scala | 13 ++-- .../spark/sql/kafka010/KafkaSourceSuite.scala | 65 +++++++++++++++++++ .../streaming/MicroBatchExecution.scala | 6 +- .../sql/execution/streaming/memory.scala | 6 ++ 4 files changed, 81 insertions(+), 9 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index e9cff04ba5f2e..864a92b8f813f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -223,6 +223,14 @@ private[kafka010] class KafkaSource( logInfo(s"GetBatch called with start = $start, end = $end") val untilPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(end) + // On recovery, getBatch will get called before getOffset + if (currentPartitionOffsets.isEmpty) { + currentPartitionOffsets = Some(untilPartitionOffsets) + } + if (start.isDefined && start.get == end) { + return sqlContext.internalCreateDataFrame( + sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) + } val fromPartitionOffsets = start match { case Some(prevBatchEndOffset) => KafkaSourceOffset.getPartitionOffsets(prevBatchEndOffset) @@ -305,11 +313,6 @@ private[kafka010] class KafkaSource( logInfo("GetBatch generating RDD of offset range: " + offsetRanges.sortBy(_.topicPartition.toString).mkString(", ")) - // On recovery, getBatch will get called before getOffset - if (currentPartitionOffsets.isEmpty) { - currentPartitionOffsets = Some(untilPartitionOffsets) - } - sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 2034b9be07f24..a0f5695fc485c 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -318,6 +318,71 @@ class KafkaSourceSuite extends KafkaSourceTest { ) } + test("SPARK-22956: currentPartitionOffsets should be set when no new data comes in") { + def getSpecificDF(range: Range.Inclusive): org.apache.spark.sql.Dataset[Int] = { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 1) + testUtils.sendMessages(topic, range.map(_.toString).toArray, Some(0)) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 5) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + + reader.load() + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(k => k.toInt) + } + + val df1 = getSpecificDF(0 to 9) + val df2 = getSpecificDF(100 to 199) + + val kafka = df1.union(df2) + + val clock = new StreamManualClock + + val waitUntilBatchProcessed = AssertOnQuery { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + } + + testStream(kafka)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // 5 from smaller topic, 5 from bigger one + CheckLastBatch((0 to 4) ++ (100 to 104): _*), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // 5 from smaller topic, 5 from bigger one + CheckLastBatch((5 to 9) ++ (105 to 109): _*), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smaller topic empty, 5 from bigger one + CheckLastBatch(110 to 114: _*), + StopStream, + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // smallest now empty, 5 from bigger one + CheckLastBatch(115 to 119: _*), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 5 from bigger one + CheckLastBatch(120 to 124: _*) + ) + } + test("cannot stop Kafka stream") { val topic = newTopic() testUtils.createTopic(topic, partitions = 5) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 42240eeb58d4b..70407f0580f97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -208,10 +208,8 @@ class MicroBatchExecution( * batch will be executed before getOffset is called again. */ availableOffsets.foreach { case (source: Source, end: Offset) => - if (committedOffsets.get(source).map(_ != end).getOrElse(true)) { - val start = committedOffsets.get(source) - source.getBatch(start, end) - } + val start = committedOffsets.get(source) + source.getBatch(start, end) case nonV1Tuple => // The V2 API does not have the same edge case requiring getBatch to be called // here, so we do nothing here. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 3041d4d703cb4..509a69dd922fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -119,9 +119,15 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) val newBlocks = synchronized { val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 + assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd") batches.slice(sliceStart, sliceEnd) } + if (newBlocks.isEmpty) { + return sqlContext.internalCreateDataFrame( + sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) + } + logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal)) newBlocks From 66217dac4f8952a9923625908ad3dcb030763c81 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 15 Jan 2018 22:40:44 -0800 Subject: [PATCH 103/774] [SPARK-23020][CORE] Fix races in launcher code, test. The race in the code is because the handle might update its state to the wrong state if the connection handling thread is still processing incoming data; so the handle needs to wait for the connection to finish up before checking the final state. The race in the test is because when waiting for a handle to reach a final state, the waitFor() method needs to wait until all handle state is updated (which also includes waiting for the connection thread above to finish). Otherwise, waitFor() may return too early, which would cause a bunch of different races (like the listener not being yet notified of the state change, or being in the middle of being notified, or the handle not being properly disposed and causing postChecks() to assert). On top of that I found, by code inspection, a couple of potential races that could make a handle end up in the wrong state when being killed. Tested by running the existing unit tests a lot (and not seeing the errors I was seeing before). Author: Marcelo Vanzin Closes #20223 from vanzin/SPARK-23020. --- .../spark/launcher/SparkLauncherSuite.java | 49 ++++++++++++------- .../spark/launcher/AbstractAppHandle.java | 22 +++++++-- .../spark/launcher/ChildProcAppHandle.java | 18 ++++--- .../spark/launcher/InProcessAppHandle.java | 17 ++++--- .../spark/launcher/LauncherConnection.java | 14 +++--- .../apache/spark/launcher/LauncherServer.java | 46 ++++++++++++++--- .../org/apache/spark/launcher/BaseSuite.java | 42 +++++++++++++--- .../spark/launcher/LauncherServerSuite.java | 20 +++----- 8 files changed, 156 insertions(+), 72 deletions(-) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 9d2f563b2e367..a042375c6ae91 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.launcher; +import java.time.Duration; import java.util.Arrays; import java.util.ArrayList; import java.util.HashMap; @@ -31,6 +32,7 @@ import static org.mockito.Mockito.*; import org.apache.spark.SparkContext; +import org.apache.spark.SparkContext$; import org.apache.spark.internal.config.package$; import org.apache.spark.util.Utils; @@ -137,7 +139,9 @@ public void testInProcessLauncher() throws Exception { // Here DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet. // Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM. // See SPARK-23019 and SparkContext.stop() for details. - TimeUnit.MILLISECONDS.sleep(500); + eventually(Duration.ofSeconds(5), Duration.ofMillis(10), () -> { + assertTrue("SparkContext is still alive.", SparkContext$.MODULE$.getActive().isEmpty()); + }); } } @@ -146,26 +150,35 @@ private void inProcessLauncherTestImpl() throws Exception { SparkAppHandle.Listener listener = mock(SparkAppHandle.Listener.class); doAnswer(invocation -> { SparkAppHandle h = (SparkAppHandle) invocation.getArguments()[0]; - transitions.add(h.getState()); + synchronized (transitions) { + transitions.add(h.getState()); + } return null; }).when(listener).stateChanged(any(SparkAppHandle.class)); - SparkAppHandle handle = new InProcessLauncher() - .setMaster("local") - .setAppResource(SparkLauncher.NO_RESOURCE) - .setMainClass(InProcessTestApp.class.getName()) - .addAppArgs("hello") - .startApplication(listener); - - waitFor(handle); - assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); - - // Matches the behavior of LocalSchedulerBackend. - List expected = Arrays.asList( - SparkAppHandle.State.CONNECTED, - SparkAppHandle.State.RUNNING, - SparkAppHandle.State.FINISHED); - assertEquals(expected, transitions); + SparkAppHandle handle = null; + try { + handle = new InProcessLauncher() + .setMaster("local") + .setAppResource(SparkLauncher.NO_RESOURCE) + .setMainClass(InProcessTestApp.class.getName()) + .addAppArgs("hello") + .startApplication(listener); + + waitFor(handle); + assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); + + // Matches the behavior of LocalSchedulerBackend. + List expected = Arrays.asList( + SparkAppHandle.State.CONNECTED, + SparkAppHandle.State.RUNNING, + SparkAppHandle.State.FINISHED); + assertEquals(expected, transitions); + } finally { + if (handle != null) { + handle.kill(); + } + } } public static class SparkLauncherTestApp { diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java index df1e7316861d4..daf0972f824dd 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java @@ -33,7 +33,7 @@ abstract class AbstractAppHandle implements SparkAppHandle { private List listeners; private State state; private String appId; - private boolean disposed; + private volatile boolean disposed; protected AbstractAppHandle(LauncherServer server) { this.server = server; @@ -70,8 +70,7 @@ public void stop() { @Override public synchronized void disconnect() { - if (!disposed) { - disposed = true; + if (!isDisposed()) { if (connection != null) { try { connection.close(); @@ -79,7 +78,7 @@ public synchronized void disconnect() { // no-op. } } - server.unregister(this); + dispose(); } } @@ -95,6 +94,21 @@ boolean isDisposed() { return disposed; } + /** + * Mark the handle as disposed, and set it as LOST in case the current state is not final. + */ + synchronized void dispose() { + if (!isDisposed()) { + // Unregister first to make sure that the connection with the app has been really + // terminated. + server.unregister(this); + if (!getState().isFinal()) { + setState(State.LOST); + } + this.disposed = true; + } + } + void setState(State s) { setState(s, false); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index 8b3f427b7750e..2b99461652e1f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -48,14 +48,16 @@ public synchronized void disconnect() { @Override public synchronized void kill() { - disconnect(); - if (childProc != null) { - if (childProc.isAlive()) { - childProc.destroyForcibly(); + if (!isDisposed()) { + setState(State.KILLED); + disconnect(); + if (childProc != null) { + if (childProc.isAlive()) { + childProc.destroyForcibly(); + } + childProc = null; } - childProc = null; } - setState(State.KILLED); } void setChildProc(Process childProc, String loggerName, InputStream logStream) { @@ -94,8 +96,6 @@ void monitorChild() { return; } - disconnect(); - int ec; try { ec = proc.exitValue(); @@ -118,6 +118,8 @@ void monitorChild() { if (newState != null) { setState(newState, true); } + + disconnect(); } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index acd64c962604f..f04263cb74a58 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -39,15 +39,16 @@ class InProcessAppHandle extends AbstractAppHandle { @Override public synchronized void kill() { - LOG.warning("kill() may leave the underlying app running in in-process mode."); - disconnect(); - - // Interrupt the thread. This is not guaranteed to kill the app, though. - if (app != null) { - app.interrupt(); + if (!isDisposed()) { + LOG.warning("kill() may leave the underlying app running in in-process mode."); + setState(State.KILLED); + disconnect(); + + // Interrupt the thread. This is not guaranteed to kill the app, though. + if (app != null) { + app.interrupt(); + } } - - setState(State.KILLED); } synchronized void start(String appName, Method main, String[] args) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java index b4a8719e26053..fd6f229b2349c 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java @@ -95,15 +95,15 @@ protected synchronized void send(Message msg) throws IOException { } @Override - public void close() throws IOException { + public synchronized void close() throws IOException { if (!closed) { - synchronized (this) { - if (!closed) { - closed = true; - socket.close(); - } - } + closed = true; + socket.close(); } } + boolean isOpen() { + return !closed; + } + } diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index b8999a1d7a4f4..660c4443b20b9 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -217,6 +217,33 @@ void unregister(AbstractAppHandle handle) { break; } } + + // If there is a live connection for this handle, we need to wait for it to finish before + // returning, otherwise there might be a race between the connection thread processing + // buffered data and the handle cleaning up after itself, leading to potentially the wrong + // state being reported for the handle. + ServerConnection conn = null; + synchronized (clients) { + for (ServerConnection c : clients) { + if (c.handle == handle) { + conn = c; + break; + } + } + } + + if (conn != null) { + synchronized (conn) { + if (conn.isOpen()) { + try { + conn.wait(); + } catch (InterruptedException ie) { + // Ignore. + } + } + } + } + unref(); } @@ -288,7 +315,7 @@ private String createSecret() { private class ServerConnection extends LauncherConnection { private TimerTask timeout; - private AbstractAppHandle handle; + volatile AbstractAppHandle handle; ServerConnection(Socket socket, TimerTask timeout) throws IOException { super(socket); @@ -338,16 +365,21 @@ protected void handle(Message msg) throws IOException { @Override public void close() throws IOException { + if (!isOpen()) { + return; + } + synchronized (clients) { clients.remove(this); } - super.close(); + + synchronized (this) { + super.close(); + notifyAll(); + } + if (handle != null) { - if (!handle.getState().isFinal()) { - LOG.log(Level.WARNING, "Lost connection to spark application."); - handle.setState(SparkAppHandle.State.LOST); - } - handle.disconnect(); + handle.dispose(); } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java index 3e1a90eae98d4..3722a59d9438e 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.launcher; +import java.time.Duration; import java.util.concurrent.TimeUnit; import org.junit.After; @@ -47,19 +48,46 @@ public void postChecks() { assertNull(server); } - protected void waitFor(SparkAppHandle handle) throws Exception { - long deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + protected void waitFor(final SparkAppHandle handle) throws Exception { try { - while (!handle.getState().isFinal()) { - assertTrue("Timed out waiting for handle to transition to final state.", - System.nanoTime() < deadline); - TimeUnit.MILLISECONDS.sleep(10); - } + eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { + assertTrue("Handle is not in final state.", handle.getState().isFinal()); + }); } finally { if (!handle.getState().isFinal()) { handle.kill(); } } + + // Wait until the handle has been marked as disposed, to make sure all cleanup tasks + // have been performed. + AbstractAppHandle ahandle = (AbstractAppHandle) handle; + eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { + assertTrue("Handle is still not marked as disposed.", ahandle.isDisposed()); + }); + } + + /** + * Call a closure that performs a check every "period" until it succeeds, or the timeout + * elapses. + */ + protected void eventually(Duration timeout, Duration period, Runnable check) throws Exception { + assertTrue("Timeout needs to be larger than period.", timeout.compareTo(period) > 0); + long deadline = System.nanoTime() + timeout.toNanos(); + int count = 0; + while (true) { + try { + count++; + check.run(); + return; + } catch (Throwable t) { + if (System.nanoTime() >= deadline) { + String msg = String.format("Failed check after %d tries: %s.", count, t.getMessage()); + throw new IllegalStateException(msg, t); + } + Thread.sleep(period.toMillis()); + } + } } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 7e2b09ce25c9b..75c1af0c71e2a 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -23,12 +23,14 @@ import java.net.InetAddress; import java.net.Socket; import java.net.SocketException; +import java.time.Duration; import java.util.Arrays; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import static org.junit.Assert.*; @@ -197,28 +199,20 @@ private void close(Closeable c) { * server-side close immediately. */ private void waitForError(TestClient client, String secret) throws Exception { - boolean helloSent = false; - int maxTries = 10; - for (int i = 0; i < maxTries; i++) { + final AtomicBoolean helloSent = new AtomicBoolean(); + eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> { try { - if (!helloSent) { + if (!helloSent.get()) { client.send(new Hello(secret, "1.4.0")); - helloSent = true; + helloSent.set(true); } else { client.send(new SetAppId("appId")); } fail("Expected error but message went through."); } catch (IllegalStateException | IOException e) { // Expected. - break; - } catch (AssertionError e) { - if (i < maxTries - 1) { - Thread.sleep(100); - } else { - throw new AssertionError("Test failed after " + maxTries + " attempts.", e); - } } - } + }); } private static class TestClient extends LauncherConnection { From b85eb946ac298e711dad25db0d04eee41d7fd236 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 16 Jan 2018 20:20:33 +0900 Subject: [PATCH 104/774] [SPARK-22978][PYSPARK] Register Vectorized UDFs for SQL Statement ## What changes were proposed in this pull request? Register Vectorized UDFs for SQL Statement. For example, ```Python >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> pandas_udf("integer", PandasUDFType.SCALAR) ... def add_one(x): ... return x + 1 ... >>> _ = spark.udf.register("add_one", add_one) >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] ``` ## How was this patch tested? Added test cases Author: gatorsmile Closes #20171 from gatorsmile/supportVectorizedUDF. --- python/pyspark/sql/catalog.py | 75 ++++++++++++++++++++++++---------- python/pyspark/sql/context.py | 51 ++++++++++++++++------- python/pyspark/sql/tests.py | 76 ++++++++++++++++++++++++++++++----- 3 files changed, 155 insertions(+), 47 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 156603128d063..35fbe9e669adb 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -226,18 +226,23 @@ def dropGlobalTempView(self, viewName): @ignore_unicode_prefix @since(2.0) - def registerFunction(self, name, f, returnType=StringType()): + def registerFunction(self, name, f, returnType=None): """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` - as a UDF. The registered UDF can be used in SQL statement. + as a UDF. The registered UDF can be used in SQL statements. - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not given it default to a string and conversion will automatically - be done. For any other return type, the produced object must match the specified type. + :func:`spark.udf.register` is an alias for :func:`spark.catalog.registerFunction`. - :param name: name of the UDF - :param f: a Python function, or a wrapped/native UserDefinedFunction - :param returnType: a :class:`pyspark.sql.types.DataType` object - :return: a wrapped :class:`UserDefinedFunction` + In addition to a name and the function itself, `returnType` can be optionally specified. + 1) When f is a Python function, `returnType` defaults to a string. The produced object must + match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return + type of the given UDF as the return type of the registered UDF. The input parameter + `returnType` is None by default. If given by users, the value must be None. + + :param name: name of the UDF in SQL statements. + :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either + row-at-a-time or vectorized. + :param returnType: the return type of the registered UDF. + :return: a wrapped/native :class:`UserDefinedFunction` >>> strlen = spark.catalog.registerFunction("stringLengthString", len) >>> spark.sql("SELECT stringLengthString('test')").collect() @@ -256,27 +261,55 @@ def registerFunction(self, name, f, returnType=StringType()): >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] + >>> from pyspark.sql.types import IntegerType + >>> from pyspark.sql.functions import udf + >>> slen = udf(lambda s: len(s), IntegerType()) + >>> _ = spark.udf.register("slen", slen) + >>> spark.sql("SELECT slen('test')").collect() + [Row(slen(test)=4)] + >>> import random >>> from pyspark.sql.functions import udf - >>> from pyspark.sql.types import IntegerType, StringType + >>> from pyspark.sql.types import IntegerType >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() - >>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType()) + >>> new_random_udf = spark.catalog.registerFunction("random_udf", random_udf) >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP - [Row(random_udf()=u'82')] - >>> spark.range(1).select(newRandom_udf()).collect() # doctest: +SKIP - [Row(random_udf()=u'62')] + [Row(random_udf()=82)] + >>> spark.range(1).select(new_random_udf()).collect() # doctest: +SKIP + [Row(()=26)] + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP + ... def add_one(x): + ... return x + 1 + ... + >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP + >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP + [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] """ # This is to check whether the input function is a wrapped/native UserDefinedFunction if hasattr(f, 'asNondeterministic'): - udf = UserDefinedFunction(f.func, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF, - deterministic=f.deterministic) + if returnType is not None: + raise TypeError( + "Invalid returnType: None is expected when f is a UserDefinedFunction, " + "but got %s." % returnType) + if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_PANDAS_SCALAR_UDF]: + raise ValueError( + "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF") + register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, + evalType=f.evalType, + deterministic=f.deterministic) + return_udf = f else: - udf = UserDefinedFunction(f, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF) - self._jsparkSession.udf().registerPython(name, udf._judf) - return udf._wrapped() + if returnType is None: + returnType = StringType() + register_udf = UserDefinedFunction(f, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF) + return_udf = register_udf._wrapped() + self._jsparkSession.udf().registerPython(name, register_udf._judf) + return return_udf @since(2.0) def isCached(self, tableName): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index b8d86cc098e94..85479095af594 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -174,18 +174,23 @@ def range(self, start, end=None, step=1, numPartitions=None): @ignore_unicode_prefix @since(1.2) - def registerFunction(self, name, f, returnType=StringType()): + def registerFunction(self, name, f, returnType=None): """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` - as a UDF. The registered UDF can be used in SQL statement. + as a UDF. The registered UDF can be used in SQL statements. - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not given it default to a string and conversion will automatically - be done. For any other return type, the produced object must match the specified type. + :func:`spark.udf.register` is an alias for :func:`sqlContext.registerFunction`. - :param name: name of the UDF - :param f: a Python function, or a wrapped/native UserDefinedFunction - :param returnType: a :class:`pyspark.sql.types.DataType` object - :return: a wrapped :class:`UserDefinedFunction` + In addition to a name and the function itself, `returnType` can be optionally specified. + 1) When f is a Python function, `returnType` defaults to a string. The produced object must + match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return + type of the given UDF as the return type of the registered UDF. The input parameter + `returnType` is None by default. If given by users, the value must be None. + + :param name: name of the UDF in SQL statements. + :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either + row-at-a-time or vectorized. + :param returnType: the return type of the registered UDF. + :return: a wrapped/native :class:`UserDefinedFunction` >>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlContext.sql("SELECT stringLengthString('test')").collect() @@ -204,15 +209,31 @@ def registerFunction(self, name, f, returnType=StringType()): >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] + >>> from pyspark.sql.types import IntegerType + >>> from pyspark.sql.functions import udf + >>> slen = udf(lambda s: len(s), IntegerType()) + >>> _ = sqlContext.udf.register("slen", slen) + >>> sqlContext.sql("SELECT slen('test')").collect() + [Row(slen(test)=4)] + >>> import random >>> from pyspark.sql.functions import udf - >>> from pyspark.sql.types import IntegerType, StringType + >>> from pyspark.sql.types import IntegerType >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() - >>> newRandom_udf = sqlContext.registerFunction("random_udf", random_udf, StringType()) + >>> new_random_udf = sqlContext.registerFunction("random_udf", random_udf) >>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP - [Row(random_udf()=u'82')] - >>> sqlContext.range(1).select(newRandom_udf()).collect() # doctest: +SKIP - [Row(random_udf()=u'62')] + [Row(random_udf()=82)] + >>> sqlContext.range(1).select(new_random_udf()).collect() # doctest: +SKIP + [Row(()=26)] + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP + ... def add_one(x): + ... return x + 1 + ... + >>> _ = sqlContext.udf.register("add_one", add_one) # doctest: +SKIP + >>> sqlContext.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP + [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] """ return self.sparkSession.catalog.registerFunction(name, f, returnType) @@ -575,7 +596,7 @@ class UDFRegistration(object): def __init__(self, sqlContext): self.sqlContext = sqlContext - def register(self, name, f, returnType=StringType()): + def register(self, name, f, returnType=None): return self.sqlContext.registerFunction(name, f, returnType) def registerJavaFunction(self, name, javaClassName, returnType=None): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 80a94a91a87b3..8906618666b14 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -380,12 +380,25 @@ def test_udf2(self): self.assertEqual(4, res[0]) def test_udf3(self): - twoargs = self.spark.catalog.registerFunction( - "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType()) - self.assertEqual(twoargs.deterministic, True) + two_args = self.spark.catalog.registerFunction( + "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y)) + self.assertEqual(two_args.deterministic, True) + [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], u'5') + + def test_udf_registration_return_type_none(self): + two_args = self.spark.catalog.registerFunction( + "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, "integer"), None) + self.assertEqual(two_args.deterministic, True) [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() self.assertEqual(row[0], 5) + def test_udf_registration_return_type_not_none(self): + with QuietTest(self.sc): + with self.assertRaisesRegexp(TypeError, "Invalid returnType"): + self.spark.catalog.registerFunction( + "f", UserDefinedFunction(lambda x, y: len(x) + y, StringType()), StringType()) + def test_nondeterministic_udf(self): # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations from pyspark.sql.functions import udf @@ -402,12 +415,12 @@ def test_nondeterministic_udf2(self): from pyspark.sql.functions import udf random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic() self.assertEqual(random_udf.deterministic, False) - random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType()) + random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf) self.assertEqual(random_udf1.deterministic, False) [row] = self.spark.sql("SELECT randInt()").collect() - self.assertEqual(row[0], "6") + self.assertEqual(row[0], 6) [row] = self.spark.range(1).select(random_udf1()).collect() - self.assertEqual(row[0], "6") + self.assertEqual(row[0], 6) [row] = self.spark.range(1).select(random_udf()).collect() self.assertEqual(row[0], 6) # render_doc() reproduces the help() exception without printing output @@ -3691,7 +3704,7 @@ def tearDownClass(cls): ReusedSQLTestCase.tearDownClass() @property - def random_udf(self): + def nondeterministic_vectorized_udf(self): from pyspark.sql.functions import pandas_udf @pandas_udf('double') @@ -3726,6 +3739,21 @@ def test_vectorized_udf_basic(self): bool_f(col('bool'))) self.assertEquals(df.collect(), res.collect()) + def test_register_nondeterministic_vectorized_udf_basic(self): + from pyspark.sql.functions import pandas_udf + from pyspark.rdd import PythonEvalType + import random + random_pandas_udf = pandas_udf( + lambda x: random.randint(6, 6) + x, IntegerType()).asNondeterministic() + self.assertEqual(random_pandas_udf.deterministic, False) + self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + nondeterministic_pandas_udf = self.spark.catalog.registerFunction( + "randomPandasUDF", random_pandas_udf) + self.assertEqual(nondeterministic_pandas_udf.deterministic, False) + self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + [row] = self.spark.sql("SELECT randomPandasUDF(1)").collect() + self.assertEqual(row[0], 7) + def test_vectorized_udf_null_boolean(self): from pyspark.sql.functions import pandas_udf, col data = [(True,), (True,), (None,), (False,)] @@ -4085,14 +4113,14 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self): finally: self.spark.conf.set("spark.sql.session.timeZone", orig_tz) - def test_nondeterministic_udf(self): + def test_nondeterministic_vectorized_udf(self): # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations from pyspark.sql.functions import udf, pandas_udf, col @pandas_udf('double') def plus_ten(v): return v + 10 - random_udf = self.random_udf + random_udf = self.nondeterministic_vectorized_udf df = self.spark.range(10).withColumn('rand', random_udf(col('id'))) result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas() @@ -4100,11 +4128,11 @@ def plus_ten(v): self.assertEqual(random_udf.deterministic, False) self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10)) - def test_nondeterministic_udf_in_aggregate(self): + def test_nondeterministic_vectorized_udf_in_aggregate(self): from pyspark.sql.functions import pandas_udf, sum df = self.spark.range(10) - random_udf = self.random_udf + random_udf = self.nondeterministic_vectorized_udf with QuietTest(self.sc): with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): @@ -4112,6 +4140,23 @@ def test_nondeterministic_udf_in_aggregate(self): with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): df.agg(sum(random_udf(df.id))).collect() + def test_register_vectorized_udf_basic(self): + from pyspark.rdd import PythonEvalType + from pyspark.sql.functions import pandas_udf, col, expr + df = self.spark.range(10).select( + col('id').cast('int').alias('a'), + col('id').cast('int').alias('b')) + original_add = pandas_udf(lambda x, y: x + y, IntegerType()) + self.assertEqual(original_add.deterministic, True) + self.assertEqual(original_add.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + new_add = self.spark.catalog.registerFunction("add1", original_add) + res1 = df.select(new_add(col('a'), col('b'))) + res2 = self.spark.sql( + "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t") + expected = df.select(expr('a + b')) + self.assertEquals(expected.collect(), res1.collect()) + self.assertEquals(expected.collect(), res2.collect()) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedSQLTestCase): @@ -4147,6 +4192,15 @@ def test_simple(self): expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) self.assertFramesEqual(expected, result) + def test_register_group_map_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUP_MAP) + with QuietTest(self.sc): + with self.assertRaisesRegexp(ValueError, 'f must be either SQL_BATCHED_UDF or ' + 'SQL_PANDAS_SCALAR_UDF'): + self.spark.catalog.registerFunction("foo_udf", foo_udf) + def test_decorator(self): from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data From 75db14864d2bd9b8e13154226e94d466e3a7e0a0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 16 Jan 2018 22:41:30 +0800 Subject: [PATCH 105/774] [SPARK-22392][SQL] data source v2 columnar batch reader ## What changes were proposed in this pull request? a new Data Source V2 interface to allow the data source to return `ColumnarBatch` during the scan. ## How was this patch tested? new tests Author: Wenchen Fan Closes #20153 from cloud-fan/columnar-reader. --- .../sources/v2/reader/DataSourceV2Reader.java | 5 +- .../v2/reader/SupportsScanColumnarBatch.java | 52 ++++++++ .../v2/reader/SupportsScanUnsafeRow.java | 2 +- .../sql/execution/ColumnarBatchScan.scala | 37 +++++- .../sql/execution/DataSourceScanExec.scala | 39 ++---- .../columnar/InMemoryTableScanExec.scala | 101 +++++++++------- .../datasources/v2/DataSourceRDD.scala | 20 ++-- .../datasources/v2/DataSourceV2ScanExec.scala | 72 ++++++----- .../ContinuousDataSourceRDDIter.scala | 4 +- .../sql/sources/v2/JavaBatchDataSourceV2.java | 112 ++++++++++++++++++ .../execution/WholeStageCodegenSuite.scala | 28 ++--- .../sql/sources/v2/DataSourceV2Suite.scala | 72 ++++++++++- 12 files changed, 400 insertions(+), 144 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java index 95ee4a8278322..f23c3842bf1b1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java @@ -38,7 +38,10 @@ * 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc. * Names of these interfaces start with `SupportsReporting`. * 3. Special scans. E.g, columnar scan, unsafe row scan, etc. - * Names of these interfaces start with `SupportsScan`. + * Names of these interfaces start with `SupportsScan`. Note that a reader should only + * implement at most one of the special scans, if more than one special scans are implemented, + * only one of them would be respected, according to the priority list from high to low: + * {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}. * * If an exception was throw when applying any of these query optimizations, the action would fail * and no Spark job was submitted. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java new file mode 100644 index 0000000000000..27cf3a77724f0 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.util.List; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +/** + * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * interface to output {@link ColumnarBatch} and make the scan faster. + */ +@InterfaceStability.Evolving +public interface SupportsScanColumnarBatch extends DataSourceV2Reader { + @Override + default List> createReadTasks() { + throw new IllegalStateException( + "createReadTasks not supported by default within SupportsScanColumnarBatch."); + } + + /** + * Similar to {@link DataSourceV2Reader#createReadTasks()}, but returns columnar data in batches. + */ + List> createBatchReadTasks(); + + /** + * Returns true if the concrete data source reader can read data in batch according to the scan + * properties like required columns, pushes filters, etc. It's possible that the implementation + * can only support some certain columns with certain types. Users can overwrite this method and + * {@link #createReadTasks()} to fallback to normal read path under some conditions. + */ + default boolean enableBatchRead() { + return true; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java index b90ec880dc85e..2d3ad0eee65ff 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java @@ -35,7 +35,7 @@ public interface SupportsScanUnsafeRow extends DataSourceV2Reader { @Override default List> createReadTasks() { throw new IllegalStateException( - "createReadTasks should not be called with SupportsScanUnsafeRow."); + "createReadTasks not supported by default within SupportsScanUnsafeRow"); } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 5617046e1396e..dd68df9686691 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType @@ -25,13 +25,16 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} /** - * Helper trait for abstracting scan functionality using - * [[ColumnarBatch]]es. + * Helper trait for abstracting scan functionality using [[ColumnarBatch]]es. */ private[sql] trait ColumnarBatchScan extends CodegenSupport { def vectorTypes: Option[Seq[String]] = None + protected def supportsBatch: Boolean = true + + protected def needsUnsafeRowConversion: Boolean = true + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) @@ -71,7 +74,14 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { // PhysicalRDD always just has one input val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") + if (supportsBatch) { + produceBatches(ctx, input) + } else { + produceRows(ctx, input) + } + } + private def produceBatches(ctx: CodegenContext, input: String): String = { // metrics val numOutputRows = metricTerm(ctx, "numOutputRows") val scanTimeMetric = metricTerm(ctx, "scanTime") @@ -137,4 +147,25 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { """.stripMargin } + private def produceRows(ctx: CodegenContext, input: String): String = { + val numOutputRows = metricTerm(ctx, "numOutputRows") + val row = ctx.freshName("row") + + ctx.INPUT_ROW = row + ctx.currentVars = null + // Always provide `outputVars`, so that the framework can help us build unsafe row if the input + // row is not unsafe row, i.e. `needsUnsafeRowConversion` is true. + val outputVars = output.zipWithIndex.map { case (a, i) => + BoundReference(i, a.dataType, a.nullable).genCode(ctx) + } + val inputRow = if (needsUnsafeRowConversion) null else row + s""" + |while ($input.hasNext()) { + | InternalRow $row = (InternalRow) $input.next(); + | $numOutputRows.add(1); + | ${consume(ctx, outputVars, inputRow).trim} + | if (shouldStop()) return; + |} + """.stripMargin + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index d1ff82c7c06bc..7c7d79c2bbd7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -164,13 +164,15 @@ case class FileSourceScanExec( override val tableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with ColumnarBatchScan { - val supportsBatch: Boolean = relation.fileFormat.supportBatch( + override val supportsBatch: Boolean = relation.fileFormat.supportBatch( relation.sparkSession, StructType.fromAttributes(output)) - val needsUnsafeRowConversion: Boolean = if (relation.fileFormat.isInstanceOf[ParquetSource]) { - SparkSession.getActiveSession.get.sessionState.conf.parquetVectorizedReaderEnabled - } else { - false + override val needsUnsafeRowConversion: Boolean = { + if (relation.fileFormat.isInstanceOf[ParquetSource]) { + SparkSession.getActiveSession.get.sessionState.conf.parquetVectorizedReaderEnabled + } else { + false + } } override def vectorTypes: Option[Seq[String]] = @@ -346,33 +348,6 @@ case class FileSourceScanExec( override val nodeNamePrefix: String = "File" - override protected def doProduce(ctx: CodegenContext): String = { - if (supportsBatch) { - return super.doProduce(ctx) - } - val numOutputRows = metricTerm(ctx, "numOutputRows") - // PhysicalRDD always just has one input - val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") - val row = ctx.freshName("row") - - ctx.INPUT_ROW = row - ctx.currentVars = null - // Always provide `outputVars`, so that the framework can help us build unsafe row if the input - // row is not unsafe row, i.e. `needsUnsafeRowConversion` is true. - val outputVars = output.zipWithIndex.map{ case (a, i) => - BoundReference(i, a.dataType, a.nullable).genCode(ctx) - } - val inputRow = if (needsUnsafeRowConversion) null else row - s""" - |while ($input.hasNext()) { - | InternalRow $row = (InternalRow) $input.next(); - | $numOutputRows.add(1); - | ${consume(ctx, outputVars, inputRow).trim} - | if (shouldStop()) return; - |} - """.stripMargin - } - /** * Create an RDD for bucketed reads. * The non-bucketed variant of this function is [[createNonBucketedReadRDD]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 933b9753faa61..3565ee3af1b9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -49,9 +49,9 @@ case class InMemoryTableScanExec( /** * If true, get data from ColumnVector in ColumnarBatch, which are generally faster. - * If false, get data from UnsafeRow build from ColumnVector + * If false, get data from UnsafeRow build from CachedBatch */ - override val supportCodegen: Boolean = { + override val supportsBatch: Boolean = { // In the initial implementation, for ease of review // support only primitive data types and # of fields is less than wholeStageMaxNumFields relation.schema.fields.forall(f => f.dataType match { @@ -61,6 +61,8 @@ case class InMemoryTableScanExec( }) && !WholeStageCodegenExec.isTooManyFields(conf, relation.schema) } + override protected def needsUnsafeRowConversion: Boolean = false + private val columnIndices = attributes.map(a => relation.output.map(o => o.exprId).indexOf(a.exprId)).toArray @@ -90,14 +92,56 @@ case class InMemoryTableScanExec( columnarBatch } - override def inputRDDs(): Seq[RDD[InternalRow]] = { - assert(supportCodegen) + private lazy val inputRDD: RDD[InternalRow] = { val buffers = filteredCachedBatches() - // HACK ALERT: This is actually an RDD[ColumnarBatch]. - // We're taking advantage of Scala's type erasure here to pass these batches along. - Seq(buffers.map(createAndDecompressColumn(_)).asInstanceOf[RDD[InternalRow]]) + if (supportsBatch) { + // HACK ALERT: This is actually an RDD[ColumnarBatch]. + // We're taking advantage of Scala's type erasure here to pass these batches along. + buffers.map(createAndDecompressColumn).asInstanceOf[RDD[InternalRow]] + } else { + val numOutputRows = longMetric("numOutputRows") + + if (enableAccumulatorsForTest) { + readPartitions.setValue(0) + readBatches.setValue(0) + } + + // Using these variables here to avoid serialization of entire objects (if referenced + // directly) within the map Partitions closure. + val relOutput: AttributeSeq = relation.output + + filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator => + // Find the ordinals and data types of the requested columns. + val (requestedColumnIndices, requestedColumnDataTypes) = + attributes.map { a => + relOutput.indexOf(a.exprId) -> a.dataType + }.unzip + + // update SQL metrics + val withMetrics = cachedBatchIterator.map { batch => + if (enableAccumulatorsForTest) { + readBatches.add(1) + } + numOutputRows += batch.numRows + batch + } + + val columnTypes = requestedColumnDataTypes.map { + case udt: UserDefinedType[_] => udt.sqlType + case other => other + }.toArray + val columnarIterator = GenerateColumnAccessor.generate(columnTypes) + columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray) + if (enableAccumulatorsForTest && columnarIterator.hasNext) { + readPartitions.add(1) + } + columnarIterator + } + } } + override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) + override def output: Seq[Attribute] = attributes private def updateAttribute(expr: Expression): Expression = { @@ -185,7 +229,7 @@ case class InMemoryTableScanExec( } } - lazy val enableAccumulators: Boolean = + lazy val enableAccumulatorsForTest: Boolean = sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean // Accumulators used for testing purposes @@ -230,43 +274,10 @@ case class InMemoryTableScanExec( } protected override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - if (enableAccumulators) { - readPartitions.setValue(0) - readBatches.setValue(0) - } - - // Using these variables here to avoid serialization of entire objects (if referenced directly) - // within the map Partitions closure. - val relOutput: AttributeSeq = relation.output - - filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator => - // Find the ordinals and data types of the requested columns. - val (requestedColumnIndices, requestedColumnDataTypes) = - attributes.map { a => - relOutput.indexOf(a.exprId) -> a.dataType - }.unzip - - // update SQL metrics - val withMetrics = cachedBatchIterator.map { batch => - if (enableAccumulators) { - readBatches.add(1) - } - numOutputRows += batch.numRows - batch - } - - val columnTypes = requestedColumnDataTypes.map { - case udt: UserDefinedType[_] => udt.sqlType - case other => other - }.toArray - val columnarIterator = GenerateColumnAccessor.generate(columnTypes) - columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray) - if (enableAccumulators && columnarIterator.hasNext) { - readPartitions.add(1) - } - columnarIterator + if (supportsBatch) { + WholeStageCodegenExec(this).execute() + } else { + inputRDD } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 5f30be5ed4af1..ac104d7cd0cb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -18,19 +18,19 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.JavaConverters._ +import scala.reflect.ClassTag import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.sources.v2.reader.ReadTask -class DataSourceRDDPartition(val index: Int, val readTask: ReadTask[UnsafeRow]) +class DataSourceRDDPartition[T : ClassTag](val index: Int, val readTask: ReadTask[T]) extends Partition with Serializable -class DataSourceRDD( +class DataSourceRDD[T: ClassTag]( sc: SparkContext, - @transient private val readTasks: java.util.List[ReadTask[UnsafeRow]]) - extends RDD[UnsafeRow](sc, Nil) { + @transient private val readTasks: java.util.List[ReadTask[T]]) + extends RDD[T](sc, Nil) { override protected def getPartitions: Array[Partition] = { readTasks.asScala.zipWithIndex.map { @@ -38,10 +38,10 @@ class DataSourceRDD( }.toArray } - override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createDataReader() + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readTask.createDataReader() context.addTaskCompletionListener(_ => reader.close()) - val iter = new Iterator[UnsafeRow] { + val iter = new Iterator[T] { private[this] var valuePrepared = false override def hasNext: Boolean = { @@ -51,7 +51,7 @@ class DataSourceRDD( valuePrepared } - override def next(): UnsafeRow = { + override def next(): T = { if (!hasNext) { throw new java.util.NoSuchElementException("End of stream") } @@ -63,6 +63,6 @@ class DataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition].readTask.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition[T]].readTask.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 49c506bc560cf..8c64df080242f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -24,10 +24,8 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.LeafExecNode -import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousDataSourceRDD, ContinuousExecution, EpochCoordinatorRef, SetReaderPartitions} +import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} +import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader import org.apache.spark.sql.types.StructType @@ -37,40 +35,56 @@ import org.apache.spark.sql.types.StructType */ case class DataSourceV2ScanExec( fullOutput: Seq[AttributeReference], - @transient reader: DataSourceV2Reader) extends LeafExecNode with DataSourceReaderHolder { + @transient reader: DataSourceV2Reader) + extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] - override def references: AttributeSet = AttributeSet.empty + override def producedAttributes: AttributeSet = AttributeSet(fullOutput) - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + private lazy val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match { + case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks() + case _ => + reader.createReadTasks().asScala.map { + new RowToUnsafeRowReadTask(_, reader.readSchema()): ReadTask[UnsafeRow] + }.asJava + } - override protected def doExecute(): RDD[InternalRow] = { - val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks() - case _ => - reader.createReadTasks().asScala.map { - new RowToUnsafeRowReadTask(_, reader.readSchema()): ReadTask[UnsafeRow] - }.asJava - } + private lazy val inputRDD: RDD[InternalRow] = reader match { + case r: SupportsScanColumnarBatch if r.enableBatchRead() => + assert(!reader.isInstanceOf[ContinuousReader], + "continuous stream reader does not support columnar read yet.") + new DataSourceRDD(sparkContext, r.createBatchReadTasks()).asInstanceOf[RDD[InternalRow]] + + case _: ContinuousReader => + EpochCoordinatorRef.get( + sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) + .askSync[Unit](SetReaderPartitions(readTasks.size())) + new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks) + .asInstanceOf[RDD[InternalRow]] + + case _ => + new DataSourceRDD(sparkContext, readTasks).asInstanceOf[RDD[InternalRow]] + } - val inputRDD = reader match { - case _: ContinuousReader => - EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) - .askSync[Unit](SetReaderPartitions(readTasks.size())) + override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) - new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks) + override val supportsBatch: Boolean = reader match { + case r: SupportsScanColumnarBatch if r.enableBatchRead() => true + case _ => false + } - case _ => - new DataSourceRDD(sparkContext, readTasks) - } + override protected def needsUnsafeRowConversion: Boolean = false - val numOutputRows = longMetric("numOutputRows") - inputRDD.asInstanceOf[RDD[InternalRow]].map { r => - numOutputRows += 1 - r + override protected def doExecute(): RDD[InternalRow] = { + if (supportsBatch) { + WholeStageCodegenExec(this).execute() + } else { + val numOutputRows = longMetric("numOutputRows") + inputRDD.map { r => + numOutputRows += 1 + r + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index d79e4bd65f563..b3f1a1a1aaab3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -52,7 +52,7 @@ class ContinuousDataSourceRDD( } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createDataReader() + val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader() val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY) @@ -132,7 +132,7 @@ class ContinuousDataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition].readTask.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.preferredLocations() } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java new file mode 100644 index 0000000000000..44e5146d7c553 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.List; + +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + + +public class JavaBatchDataSourceV2 implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceV2Reader, SupportsScanColumnarBatch { + private final StructType schema = new StructType().add("i", "int").add("j", "int"); + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List> createBatchReadTasks() { + return java.util.Arrays.asList(new JavaBatchReadTask(0, 50), new JavaBatchReadTask(50, 90)); + } + } + + static class JavaBatchReadTask implements ReadTask, DataReader { + private int start; + private int end; + + private static final int BATCH_SIZE = 20; + + private OnHeapColumnVector i; + private OnHeapColumnVector j; + private ColumnarBatch batch; + + JavaBatchReadTask(int start, int end) { + this.start = start; + this.end = end; + } + + @Override + public DataReader createDataReader() { + this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + ColumnVector[] vectors = new ColumnVector[2]; + vectors[0] = i; + vectors[1] = j; + this.batch = new ColumnarBatch(new StructType().add("i", "int").add("j", "int"), vectors, BATCH_SIZE); + return this; + } + + @Override + public boolean next() { + i.reset(); + j.reset(); + int count = 0; + while (start < end && count < BATCH_SIZE) { + i.putInt(count, start); + j.putInt(count, -start); + start += 1; + count += 1; + } + + if (count == 0) { + return false; + } else { + batch.setNumRows(count); + return true; + } + } + + @Override + public ColumnarBatch get() { + return batch; + } + + @Override + public void close() throws IOException { + batch.close(); + } + } + + + @Override + public DataSourceV2Reader createReader(DataSourceV2Options options) { + return new Reader(); + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index bc05dca578c47..22ca128c27768 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -121,31 +121,23 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { test("cache for primitive type should be in WholeStageCodegen with InMemoryTableScanExec") { import testImplicits._ - val dsInt = spark.range(3).cache - dsInt.count + val dsInt = spark.range(3).cache() + dsInt.count() val dsIntFilter = dsInt.filter(_ > 0) val planInt = dsIntFilter.queryExecution.executedPlan - assert(planInt.find(p => - p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && - p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child - .isInstanceOf[InMemoryTableScanExec] && - p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child - .asInstanceOf[InMemoryTableScanExec].supportCodegen).isDefined - ) + assert(planInt.collect { + case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec)) if i.supportsBatch => () + }.length == 1) assert(dsIntFilter.collect() === Array(1, 2)) // cache for string type is not supported for InMemoryTableScanExec - val dsString = spark.range(3).map(_.toString).cache - dsString.count + val dsString = spark.range(3).map(_.toString).cache() + dsString.count() val dsStringFilter = dsString.filter(_ == "1") val planString = dsStringFilter.queryExecution.executedPlan - assert(planString.find(p => - p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && - !p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child - .isInstanceOf[InMemoryTableScanExec]).isDefined - ) + assert(planString.collect { + case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec)) if !i.supportsBatch => () + }.length == 1) assert(dsStringFilter.collect() === Array("1")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index ab37e4984bd1f..a89f7c55bf4f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -24,10 +24,12 @@ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch class DataSourceV2Suite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -56,7 +58,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } - test("unsafe row implementation") { + test("unsafe row scan implementation") { Seq(classOf[UnsafeRowDataSourceV2], classOf[JavaUnsafeRowDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() @@ -67,6 +69,17 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } + test("columnar batch scan implementation") { + Seq(classOf[BatchDataSourceV2], classOf[JavaBatchDataSourceV2]).foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, (0 until 90).map(i => Row(i, -i))) + checkAnswer(df.select('j), (0 until 90).map(i => Row(-i))) + checkAnswer(df.filter('i > 50), (51 until 90).map(i => Row(i, -i))) + } + } + } + test("schema required data source") { Seq(classOf[SchemaRequiredDataSource], classOf[JavaSchemaRequiredDataSource]).foreach { cls => withClue(cls.getName) { @@ -275,7 +288,7 @@ class UnsafeRowReadTask(start: Int, end: Int) private var current = start - 1 - override def createDataReader(): DataReader[UnsafeRow] = new UnsafeRowReadTask(start, end) + override def createDataReader(): DataReader[UnsafeRow] = this override def next(): Boolean = { current += 1 @@ -300,3 +313,56 @@ class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { override def createReader(schema: StructType, options: DataSourceV2Options): DataSourceV2Reader = new Reader(schema) } + +class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceV2Reader with SupportsScanColumnarBatch { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + + override def createBatchReadTasks(): JList[ReadTask[ColumnarBatch]] = { + java.util.Arrays.asList(new BatchReadTask(0, 50), new BatchReadTask(50, 90)) + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader +} + +class BatchReadTask(start: Int, end: Int) + extends ReadTask[ColumnarBatch] with DataReader[ColumnarBatch] { + + private final val BATCH_SIZE = 20 + private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val batch = new ColumnarBatch( + new StructType().add("i", "int").add("j", "int"), Array(i, j), BATCH_SIZE) + + private var current = start + + override def createDataReader(): DataReader[ColumnarBatch] = this + + override def next(): Boolean = { + i.reset() + j.reset() + + var count = 0 + while (current < end && count < BATCH_SIZE) { + i.putInt(count, current) + j.putInt(count, -current) + current += 1 + count += 1 + } + + if (count == 0) { + false + } else { + batch.setNumRows(count) + true + } + } + + override def get(): ColumnarBatch = { + batch + } + + override def close(): Unit = batch.close() +} From 12db365b4faf7a185708648d246fc4a2aae0c2c0 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Tue, 16 Jan 2018 11:41:08 -0800 Subject: [PATCH 106/774] [SPARK-16139][TEST] Add logging functionality for leaked threads in tests ## What changes were proposed in this pull request? Lots of our tests don't properly shutdown everything they create, and end up leaking lots of threads. For example, `TaskSetManagerSuite` doesn't stop the extra `TaskScheduler` and `DAGScheduler` it creates. There are a couple more instances, eg. in `DAGSchedulerSuite`. This PR adds the possibility to print out the not properly stopped thread list after a test suite executed. The format is the following: ``` ===== FINISHED o.a.s.scheduler.DAGSchedulerSuite: 'task end event should have updated accumulators (SPARK-20342)' ===== ... ===== Global thread whitelist loaded with name /thread_whitelist from classpath: rpc-client.*, rpc-server.*, shuffle-client.*, shuffle-server.*' ===== ScalaTest-run: ===== THREADS NOT STOPPED PROPERLY ===== ScalaTest-run: dag-scheduler-event-loop ScalaTest-run: globalEventExecutor-2-5 ScalaTest-run: ===== END OF THREAD DUMP ===== ScalaTest-run: ===== EITHER PUT THREAD NAME INTO THE WHITELIST FILE OR SHUT IT DOWN PROPERLY ===== ``` With the help of this leaking threads has been identified in TaskSetManagerSuite. My intention is to hunt down and fix such bugs in later PRs. ## How was this patch tested? Manual: TaskSetManagerSuite test executed and found out where are the leaking threads. Automated: Pass the Jenkins. Author: Gabor Somogyi Closes #19893 from gaborgsomogyi/SPARK-16139. --- .../org/apache/spark/SparkFunSuite.scala | 34 +++++++ .../scala/org/apache/spark/ThreadAudit.scala | 99 +++++++++++++++++++ .../spark/scheduler/TaskSetManagerSuite.scala | 7 +- .../apache/spark/sql/SessionStateSuite.scala | 1 + .../sql/sources/DataSourceAnalysisSuite.scala | 1 + .../spark/sql/test/SharedSQLContext.scala | 23 ++++- .../hive/HiveContextCompatibilitySuite.scala | 1 + .../sql/hive/HiveSessionStateSuite.scala | 1 + .../spark/sql/hive/HiveSparkSubmitSuite.scala | 2 + .../sql/hive/client/HiveClientSuite.scala | 1 + .../sql/hive/client/HiveVersionSuite.scala | 1 + .../spark/sql/hive/client/VersionsSuite.scala | 2 + .../hive/execution/HiveComparisonTest.scala | 2 + .../hive/orc/OrcHadoopFsRelationSuite.scala | 1 + .../sql/hive/test/TestHiveSingleton.scala | 1 + 15 files changed, 174 insertions(+), 3 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/ThreadAudit.scala diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 18077c08c9dcc..3af9d82393bc4 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -27,19 +27,53 @@ import org.apache.spark.util.AccumulatorContext /** * Base abstract class for all unit tests in Spark for handling common functionality. + * + * Thread audit happens normally here automatically when a new test suite created. + * The only prerequisite for that is that the test class must extend [[SparkFunSuite]]. + * + * It is possible to override the default thread audit behavior by setting enableAutoThreadAudit + * to false and manually calling the audit methods, if desired. For example: + * + * class MyTestSuite extends SparkFunSuite { + * + * override val enableAutoThreadAudit = false + * + * protected override def beforeAll(): Unit = { + * doThreadPreAudit() + * super.beforeAll() + * } + * + * protected override def afterAll(): Unit = { + * super.afterAll() + * doThreadPostAudit() + * } + * } */ abstract class SparkFunSuite extends FunSuite with BeforeAndAfterAll + with ThreadAudit with Logging { // scalastyle:on + protected val enableAutoThreadAudit = true + + protected override def beforeAll(): Unit = { + if (enableAutoThreadAudit) { + doThreadPreAudit() + } + super.beforeAll() + } + protected override def afterAll(): Unit = { try { // Avoid leaking map entries in tests that use accumulators without SparkContext AccumulatorContext.clear() } finally { super.afterAll() + if (enableAutoThreadAudit) { + doThreadPostAudit() + } } } diff --git a/core/src/test/scala/org/apache/spark/ThreadAudit.scala b/core/src/test/scala/org/apache/spark/ThreadAudit.scala new file mode 100644 index 0000000000000..b3cea9de8f304 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ThreadAudit.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.JavaConverters._ + +import org.apache.spark.internal.Logging + +/** + * Thread audit for test suites. + */ +trait ThreadAudit extends Logging { + + val threadWhiteList = Set( + /** + * Netty related internal threads. + * These are excluded because their lifecycle is handled by the netty itself + * and spark has no explicit effect on them. + */ + "netty.*", + + /** + * Netty related internal threads. + * A Single-thread singleton EventExecutor inside netty which creates such threads. + * These are excluded because their lifecycle is handled by the netty itself + * and spark has no explicit effect on them. + */ + "globalEventExecutor.*", + + /** + * Netty related internal threads. + * Checks if a thread is alive periodically and runs a task when a thread dies. + * These are excluded because their lifecycle is handled by the netty itself + * and spark has no explicit effect on them. + */ + "threadDeathWatcher.*", + + /** + * During [[SparkContext]] creation [[org.apache.spark.rpc.netty.NettyRpcEnv]] + * creates event loops. One is wrapped inside + * [[org.apache.spark.network.server.TransportServer]] + * the other one is inside [[org.apache.spark.network.client.TransportClient]]. + * The thread pools behind shut down asynchronously triggered by [[SparkContext#stop]]. + * Manually checked and all of them stopped properly. + */ + "rpc-client.*", + "rpc-server.*", + + /** + * During [[SparkContext]] creation BlockManager creates event loops. One is wrapped inside + * [[org.apache.spark.network.server.TransportServer]] + * the other one is inside [[org.apache.spark.network.client.TransportClient]]. + * The thread pools behind shut down asynchronously triggered by [[SparkContext#stop]]. + * Manually checked and all of them stopped properly. + */ + "shuffle-client.*", + "shuffle-server.*" + ) + private var threadNamesSnapshot: Set[String] = Set.empty + + protected def doThreadPreAudit(): Unit = { + threadNamesSnapshot = runningThreadNames() + } + + protected def doThreadPostAudit(): Unit = { + val shortSuiteName = this.getClass.getName.replaceAll("org.apache.spark", "o.a.s") + + if (threadNamesSnapshot.nonEmpty) { + val remainingThreadNames = runningThreadNames().diff(threadNamesSnapshot) + .filterNot { s => threadWhiteList.exists(s.matches(_)) } + if (remainingThreadNames.nonEmpty) { + logWarning(s"\n\n===== POSSIBLE THREAD LEAK IN SUITE $shortSuiteName, " + + s"thread names: ${remainingThreadNames.mkString(", ")} =====\n") + } + } else { + logWarning("\n\n===== THREAD AUDIT POST ACTION CALLED " + + s"WITHOUT PRE ACTION IN SUITE $shortSuiteName =====\n") + } + } + + private def runningThreadNames(): Set[String] = { + Thread.getAllStackTraces.keySet().asScala.map(_.getName).toSet + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 2ce81ae27daf6..ca6a7e5db3b17 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -683,7 +683,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val conf = new SparkConf().set("spark.speculation", "true") sc = new SparkContext("local", "test", conf) - val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) + sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) sched.initialize(new FakeSchedulerBackend() { override def killTask( taskId: Long, @@ -709,6 +709,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg } } } + sched.dagScheduler.stop() sched.setDAGScheduler(dagScheduler) val singleTask = new ShuffleMapTask(0, 0, null, new Partition { @@ -754,7 +755,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sc.conf.set("spark.speculation", "true") var killTaskCalled = false - val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3")) sched.initialize(new FakeSchedulerBackend() { override def killTask( @@ -789,6 +790,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg } } } + sched.dagScheduler.stop() sched.setDAGScheduler(dagScheduler) val taskSet = FakeTask.createShuffleMapTaskSet(4, 0, 0, @@ -1183,6 +1185,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sc = new SparkContext("local", "test") sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val mockDAGScheduler = mock(classOf[DAGScheduler]) + sched.dagScheduler.stop() sched.dagScheduler = mockDAGScheduler val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = new ManualClock(1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index c01666770720c..5d75f5835bf9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -39,6 +39,7 @@ class SessionStateSuite extends SparkFunSuite protected var activeSession: SparkSession = _ override def beforeAll(): Unit = { + super.beforeAll() activeSession = SparkSession.builder().master("local").getOrCreate() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala index 735e07c21373a..e1022e377132c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala @@ -33,6 +33,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { private var targetPartitionSchema: StructType = _ override def beforeAll(): Unit = { + super.beforeAll() targetAttributes = Seq('a.int, 'd.int, 'b.int, 'c.int) targetPartitionSchema = new StructType() .add("b", IntegerType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 4d578e21f5494..e6c7648c986ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,4 +17,25 @@ package org.apache.spark.sql.test -trait SharedSQLContext extends SQLTestUtils with SharedSparkSession +trait SharedSQLContext extends SQLTestUtils with SharedSparkSession { + + /** + * Suites extending [[SharedSQLContext]] are sharing resources (eg. SparkSession) in their tests. + * That trait initializes the spark session in its [[beforeAll()]] implementation before the + * automatic thread snapshot is performed, so the audit code could fail to report threads leaked + * by that shared session. + * + * The behavior is overridden here to take the snapshot before the spark session is initialized. + */ + override protected val enableAutoThreadAudit = false + + protected override def beforeAll(): Unit = { + doThreadPreAudit() + super.beforeAll() + } + + protected override def afterAll(): Unit = { + super.afterAll() + doThreadPostAudit() + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala index 8a7423663f28d..a80db765846e9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} class HiveContextCompatibilitySuite extends SparkFunSuite with BeforeAndAfterEach { + override protected val enableAutoThreadAudit = false private var sc: SparkContext = null private var hc: HiveContext = null diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala index 958ad3e1c3ce8..f7da3c4cbb0aa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala @@ -30,6 +30,7 @@ class HiveSessionStateSuite extends SessionStateSuite override def beforeAll(): Unit = { // Reuse the singleton session + super.beforeAll() activeSession = spark } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 21b3e281490cf..10204f4694663 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -44,6 +44,8 @@ class HiveSparkSubmitSuite with BeforeAndAfterEach with ResetSystemProperties { + override protected val enableAutoThreadAudit = false + // TODO: rewrite these or mark them as slow tests to be run sparingly override def beforeEach() { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index ce53acef51503..a5dfd89b3a574 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -67,6 +67,7 @@ class HiveClientSuite(version: String) } override def beforeAll() { + super.beforeAll() client = init(true) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala index 951ebfad4590e..bb8a4697b0a13 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.hive.HiveUtils private[client] abstract class HiveVersionSuite(version: String) extends SparkFunSuite { + override protected val enableAutoThreadAudit = false protected var client: HiveClient = null protected def buildClient(hadoopConf: Configuration): HiveClient = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index e64389e56b5a1..72536b833481a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -50,6 +50,8 @@ import org.apache.spark.util.{MutableURLClassLoader, Utils} @ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { + override protected val enableAutoThreadAudit = false + import HiveClientBuilder.buildClient /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index cee82cda4628a..272e6f51f5002 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -48,6 +48,8 @@ import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} abstract class HiveComparisonTest extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen { + override protected val enableAutoThreadAudit = false + /** * Path to the test datasets. We find this by looking up "hive-test-path-helper.txt" file. * diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index f87162f94c01a..a1f054b8e3f44 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.types._ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { import testImplicits._ + override protected val enableAutoThreadAudit = false override val dataSourceName: String = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala index df7988f542b71..d3fff37c3424d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.hive.client.HiveClient trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll { + override protected val enableAutoThreadAudit = false protected val spark: SparkSession = TestHive.sparkSession protected val hiveContext: TestHiveContext = TestHive protected val hiveClient: HiveClient = From 4371466b3f06ca171b10568e776c9446f7bae6dd Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Tue, 16 Jan 2018 12:56:57 -0800 Subject: [PATCH 107/774] [SPARK-23045][ML][SPARKR] Update RFormula to use OneHotEncoderEstimator. ## What changes were proposed in this pull request? RFormula should use VectorSizeHint & OneHotEncoderEstimator in its pipeline to avoid using the deprecated OneHotEncoder & to ensure the model produced can be used in streaming. ## How was this patch tested? Unit tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bago Amirbekian Closes #20229 from MrBago/rFormula. --- R/pkg/R/mllib_utils.R | 1 - .../apache/spark/ml/feature/RFormula.scala | 20 +++++-- .../spark/ml/feature/RFormulaSuite.scala | 53 +++++++++++-------- 3 files changed, 46 insertions(+), 28 deletions(-) diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R index 23dda42c325be..a53c92c2c4815 100644 --- a/R/pkg/R/mllib_utils.R +++ b/R/pkg/R/mllib_utils.R @@ -130,4 +130,3 @@ read.ml <- function(path) { stop("Unsupported model: ", jobj) } } - diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index f384ffbf578bc..1155ea5fdd85b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -199,6 +199,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) val parsedFormula = RFormulaParser.parse($(formula)) val resolvedFormula = parsedFormula.resolve(dataset.schema) val encoderStages = ArrayBuffer[PipelineStage]() + val oneHotEncodeColumns = ArrayBuffer[(String, String)]() val prefixesToRewrite = mutable.Map[String, String]() val tempColumns = ArrayBuffer[String]() @@ -242,16 +243,17 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) val encodedTerms = resolvedFormula.terms.map { case Seq(term) if dataset.schema(term).dataType == StringType => val encodedCol = tmpColumn("onehot") - var encoder = new OneHotEncoder() - .setInputCol(indexed(term)) - .setOutputCol(encodedCol) // Formula w/o intercept, one of the categories in the first category feature is // being used as reference category, we will not drop any category for that feature. if (!hasIntercept && !keepReferenceCategory) { - encoder = encoder.setDropLast(false) + encoderStages += new OneHotEncoderEstimator(uid) + .setInputCols(Array(indexed(term))) + .setOutputCols(Array(encodedCol)) + .setDropLast(false) keepReferenceCategory = true + } else { + oneHotEncodeColumns += indexed(term) -> encodedCol } - encoderStages += encoder prefixesToRewrite(encodedCol + "_") = term + "_" encodedCol case Seq(term) => @@ -265,6 +267,14 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) interactionCol } + if (oneHotEncodeColumns.nonEmpty) { + val (inputCols, outputCols) = oneHotEncodeColumns.toArray.unzip + encoderStages += new OneHotEncoderEstimator(uid) + .setInputCols(inputCols) + .setOutputCols(outputCols) + .setDropLast(true) + } + encoderStages += new VectorAssembler(uid) .setInputCols(encodedTerms.toArray) .setOutputCol($(featuresCol)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index f3f4b5a3d0233..bfe38d32dd77d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -29,6 +29,17 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ + def testRFormulaTransform[A: Encoder]( + dataframe: DataFrame, + formulaModel: RFormulaModel, + expected: DataFrame): Unit = { + val (first +: rest) = expected.schema.fieldNames.toSeq + val expectedRows = expected.collect() + testTransformerByGlobalCheckFunc[A](dataframe, formulaModel, first, rest: _*) { rows => + assert(rows === expectedRows) + } + } + test("params") { ParamsSuite.checkParams(new RFormula()) } @@ -47,7 +58,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString assert(result.schema.toString == resultSchema.toString) assert(resultSchema == expected.schema) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, Double, Double)](original, model, expected) } test("features column already exists") { @@ -109,7 +120,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (7, 8.0, 9.0, Vectors.dense(8.0, 9.0)) ).toDF("id", "a", "b", "features") assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, Double, Double)](original, model, expected) } test("encodes string terms") { @@ -126,7 +137,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0) ).toDF("id", "a", "b", "features", "label") assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, String, Int)](original, model, expected) } test("encodes string terms with string indexer order type") { @@ -167,7 +178,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected(idx).collect()) + testRFormulaTransform[(Int, String, Int)](original, model, expected(idx)) idx += 1 } } @@ -210,7 +221,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, String, Int)](original, model, expected) } test("formula w/o intercept, we should output reference category when encoding string terms") { @@ -253,7 +264,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0) ).toDF("id", "a", "b", "c", "features", "label") assert(result1.schema.toString == resultSchema1.toString) - assert(result1.collect() === expected1.collect()) + testRFormulaTransform[(Int, String, String, Int)](original, model1, expected1) val attrs1 = AttributeGroup.fromStructField(result1.schema("features")) val expectedAttrs1 = new AttributeGroup( @@ -280,7 +291,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (4, "baz", "zz", 5, Vectors.sparse(7, Array(2, 6), Array(1.0, 5.0)), 4.0) ).toDF("id", "a", "b", "c", "features", "label") assert(result2.schema.toString == resultSchema2.toString) - assert(result2.collect() === expected2.collect()) + testRFormulaTransform[(Int, String, String, Int)](original, model2, expected2) val attrs2 = AttributeGroup.fromStructField(result2.schema("features")) val expectedAttrs2 = new AttributeGroup( @@ -302,7 +313,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) val expected = Seq( ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), @@ -310,7 +320,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0) ).toDF("id", "a", "b", "features", "label") // assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(String, String, Int)](original, model, expected) } test("force to index label even it is numeric type") { @@ -319,7 +329,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5)) ).toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) val expected = spark.createDataFrame( Seq( (1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0), @@ -327,7 +336,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0), (1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0)) ).toDF("id", "a", "b", "features", "label") - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Double, String, Int)](original, model, expected) } test("attribute generation") { @@ -391,7 +400,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (1, 2, 4, 2, Vectors.dense(16.0), 1.0), (2, 3, 4, 1, Vectors.dense(12.0), 2.0) ).toDF("a", "b", "c", "d", "features", "label") - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, Int, Int, Int)](original, model, expected) val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", @@ -414,7 +423,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0) ).toDF("id", "a", "b", "features", "label") - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, String, Int)](original, model, expected) val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", @@ -436,7 +445,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0), (3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0) ).toDF("id", "a", "b", "features", "label") - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, String, String)](original, model, expected) val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", @@ -511,8 +520,8 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { intercept[SparkException] { formula1.fit(df1).transform(df2).collect() } - val result1 = formula1.setHandleInvalid("skip").fit(df1).transform(df2) - val result2 = formula1.setHandleInvalid("keep").fit(df1).transform(df2) + val model1 = formula1.setHandleInvalid("skip").fit(df1) + val model2 = formula1.setHandleInvalid("keep").fit(df1) val expected1 = Seq( (1, "foo", "zq", Vectors.dense(0.0, 1.0), 1.0), @@ -524,16 +533,16 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (3, "bar", "zy", Vectors.dense(1.0, 0.0, 0.0, 0.0), 3.0) ).toDF("id", "a", "b", "features", "label") - assert(result1.collect() === expected1.collect()) - assert(result2.collect() === expected2.collect()) + testRFormulaTransform[(Int, String, String)](df2, model1, expected1) + testRFormulaTransform[(Int, String, String)](df2, model2, expected2) // Handle unseen labels. val formula2 = new RFormula().setFormula("b ~ a + id") intercept[SparkException] { formula2.fit(df1).transform(df2).collect() } - val result3 = formula2.setHandleInvalid("skip").fit(df1).transform(df2) - val result4 = formula2.setHandleInvalid("keep").fit(df1).transform(df2) + val model3 = formula2.setHandleInvalid("skip").fit(df1) + val model4 = formula2.setHandleInvalid("keep").fit(df1) val expected3 = Seq( (1, "foo", "zq", Vectors.dense(0.0, 1.0), 0.0), @@ -545,8 +554,8 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (3, "bar", "zy", Vectors.dense(1.0, 0.0, 3.0), 2.0) ).toDF("id", "a", "b", "features", "label") - assert(result3.collect() === expected3.collect()) - assert(result4.collect() === expected4.collect()) + testRFormulaTransform[(Int, String, String)](df2, model3, expected3) + testRFormulaTransform[(Int, String, String)](df2, model4, expected4) } test("Use Vectors as inputs to formula.") { From 5ae333391bd73331b5b90af71a3de52cdbb24109 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 16 Jan 2018 16:25:10 -0800 Subject: [PATCH 108/774] [SPARK-23044] Error handling for jira assignment ## What changes were proposed in this pull request? * If there is any error while trying to assign the jira, prompt again * Filter out the "Apache Spark" choice * allow arbitrary user ids to be entered ## How was this patch tested? Couldn't really test the error case, just some testing of similar-ish code in python shell. Haven't run a merge yet. Author: Imran Rashid Closes #20236 from squito/SPARK-23044. --- dev/merge_spark_pr.py | 50 +++++++++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 57ca8400b6f3d..6b244d8184b2c 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -30,6 +30,7 @@ import re import subprocess import sys +import traceback import urllib2 try: @@ -298,24 +299,37 @@ def choose_jira_assignee(issue, asf_jira): Prompt the user to choose who to assign the issue to in jira, given a list of candidates, including the original reporter and all commentors """ - reporter = issue.fields.reporter - commentors = map(lambda x: x.author, issue.fields.comment.comments) - candidates = set(commentors) - candidates.add(reporter) - candidates = list(candidates) - print("JIRA is unassigned, choose assignee") - for idx, author in enumerate(candidates): - annotations = ["Reporter"] if author == reporter else [] - if author in commentors: - annotations.append("Commentor") - print("[%d] %s (%s)" % (idx, author.displayName, ",".join(annotations))) - assignee = raw_input("Enter number of user to assign to (blank to leave unassigned):") - if assignee == "": - return None - else: - assignee = candidates[int(assignee)] - asf_jira.assign_issue(issue.key, assignee.key) - return assignee + while True: + try: + reporter = issue.fields.reporter + commentors = map(lambda x: x.author, issue.fields.comment.comments) + candidates = set(commentors) + candidates.add(reporter) + candidates = list(candidates) + print("JIRA is unassigned, choose assignee") + for idx, author in enumerate(candidates): + if author.key == "apachespark": + continue + annotations = ["Reporter"] if author == reporter else [] + if author in commentors: + annotations.append("Commentor") + print("[%d] %s (%s)" % (idx, author.displayName, ",".join(annotations))) + raw_assignee = raw_input( + "Enter number of user, or userid, to assign to (blank to leave unassigned):") + if raw_assignee == "": + return None + else: + try: + id = int(raw_assignee) + assignee = candidates[id] + except: + # assume it's a user id, and try to assign (might fail, we just prompt again) + assignee = asf_jira.user(raw_assignee) + asf_jira.assign_issue(issue.key, assignee.key) + return assignee + except: + traceback.print_exc() + print("Error assigning JIRA, try again (or leave blank and fix manually)") def resolve_jira_issues(title, merge_branches, comment): From 0c2ba427bc7323729e6ffb34f1f06a97f0bf0c1d Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 17 Jan 2018 09:57:30 +0800 Subject: [PATCH 109/774] [SPARK-23095][SQL] Decorrelation of scalar subquery fails with java.util.NoSuchElementException ## What changes were proposed in this pull request? The following SQL involving scalar correlated query returns a map exception. ``` SQL SELECT t1a FROM t1 WHERE t1a = (SELECT count(*) FROM t2 WHERE t2c = t1c HAVING count(*) >= 1) ``` ``` SQL key not found: ExprId(278,786682bb-41f9-4bd5-a397-928272cc8e4e) java.util.NoSuchElementException: key not found: ExprId(278,786682bb-41f9-4bd5-a397-928272cc8e4e) at scala.collection.MapLike$class.default(MapLike.scala:228) at scala.collection.AbstractMap.default(Map.scala:59) at scala.collection.MapLike$class.apply(MapLike.scala:141) at scala.collection.AbstractMap.apply(Map.scala:59) at org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery$.org$apache$spark$sql$catalyst$optimizer$RewriteCorrelatedScalarSubquery$$evalSubqueryOnZeroTups(subquery.scala:378) at org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery$$anonfun$org$apache$spark$sql$catalyst$optimizer$RewriteCorrelatedScalarSubquery$$constructLeftJoins$1.apply(subquery.scala:430) at org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery$$anonfun$org$apache$spark$sql$catalyst$optimizer$RewriteCorrelatedScalarSubquery$$constructLeftJoins$1.apply(subquery.scala:426) ``` In this case, after evaluating the HAVING clause "count(*) > 1" statically against the binding of aggregtation result on empty input, we determine that this query will not have a the count bug. We should simply return the evalSubqueryOnZeroTups with empty value. (Please fill in changes proposed in this fix) ## How was this patch tested? A new test was added in the Subquery bucket. Author: Dilip Biswal Closes #20283 from dilipbiswal/scalar-count-defect. --- .../sql/catalyst/optimizer/subquery.scala | 5 +- .../scalar-subquery-predicate.sql | 10 ++++ .../scalar-subquery-predicate.sql.out | 57 ++++++++++++------- 3 files changed, 49 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 2673bea648d09..709db6d8bec7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -369,13 +369,14 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { case ne => (ne.exprId, evalAggOnZeroTups(ne)) }.toMap - case _ => sys.error(s"Unexpected operator in scalar subquery: $lp") + case _ => + sys.error(s"Unexpected operator in scalar subquery: $lp") } val resultMap = evalPlan(plan) // By convention, the scalar subquery result is the leftmost field. - resultMap(plan.output.head.exprId) + resultMap.getOrElse(plan.output.head.exprId, None) } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql index fb0d07fbdace7..1661209093fc4 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql @@ -173,6 +173,16 @@ WHERE t1a = (SELECT max(t2a) HAVING count(*) >= 0) OR t1i > '2014-12-31'; +-- TC 02.03.01 +SELECT t1a +FROM t1 +WHERE t1a = (SELECT max(t2a) + FROM t2 + WHERE t2c = t1c + GROUP BY t2c + HAVING count(*) >= 1) +OR t1i > '2014-12-31'; + -- TC 02.04 -- t1 on the right of an outer join -- can be reduced to inner join diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out index 8b29300e71f90..a2b86db3e4f4c 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 29 -- !query 0 @@ -293,6 +293,21 @@ val1d -- !query 19 +SELECT t1a +FROM t1 +WHERE t1a = (SELECT max(t2a) + FROM t2 + WHERE t2c = t1c + GROUP BY t2c + HAVING count(*) >= 1) +OR t1i > '2014-12-31' +-- !query 19 schema +struct +-- !query 19 output +val1c +val1d + +-- !query 22 SELECT count(t1a) FROM t1 RIGHT JOIN t2 ON t1d = t2d @@ -300,13 +315,13 @@ WHERE t1a < (SELECT max(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 19 schema +-- !query 22 schema struct --- !query 19 output +-- !query 22 output 7 --- !query 20 +-- !query 23 SELECT t1a FROM t1 WHERE t1b <= (SELECT max(t2b) @@ -317,14 +332,14 @@ AND t1b >= (SELECT min(t2b) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 20 schema +-- !query 23 schema struct --- !query 20 output +-- !query 23 output val1b val1c --- !query 21 +-- !query 24 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -338,14 +353,14 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 21 schema +-- !query 24 schema struct --- !query 21 output +-- !query 24 output val1b val1c --- !query 22 +-- !query 25 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -359,9 +374,9 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 22 schema +-- !query 25 schema struct --- !query 22 output +-- !query 25 output val1a val1a val1b @@ -372,7 +387,7 @@ val1d val1d --- !query 23 +-- !query 26 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -386,16 +401,16 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 23 schema +-- !query 26 schema struct --- !query 23 output +-- !query 26 output val1a val1b val1c val1d --- !query 24 +-- !query 27 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -409,13 +424,13 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 24 schema +-- !query 27 schema struct --- !query 24 output +-- !query 27 output val1a --- !query 25 +-- !query 28 SELECT t1a FROM t1 GROUP BY t1a, t1c @@ -423,8 +438,8 @@ HAVING max(t1b) <= (SELECT max(t2b) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 25 schema +-- !query 28 schema struct --- !query 25 output +-- !query 28 output val1b val1c From a9b845ebb5b51eb619cfa7d73b6153024a6a420d Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Wed, 17 Jan 2018 10:03:25 +0800 Subject: [PATCH 110/774] [SPARK-22361][SQL][TEST] Add unit test for Window Frames ## What changes were proposed in this pull request? There are already quite a few integration tests using window frames, but the unit tests coverage is not ideal. In this PR the already existing tests are reorganized, extended and where gaps found additional cases added. ## How was this patch tested? Automated: Pass the Jenkins. Author: Gabor Somogyi Closes #20019 from gaborgsomogyi/SPARK-22361. --- .../parser/ExpressionParserSuite.scala | 57 ++- .../sql/DataFrameWindowFramesSuite.scala | 405 ++++++++++++++++++ .../sql/DataFrameWindowFunctionsSuite.scala | 243 ----------- 3 files changed, 454 insertions(+), 251 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 2b9783a3295c6..cb8a1fecb80a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -249,8 +249,8 @@ class ExpressionParserSuite extends PlanTest { assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b))) assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b))) assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b))) - assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) - assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) + assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc))) + assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc))) assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc))) assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc))) @@ -263,21 +263,62 @@ class ExpressionParserSuite extends PlanTest { "sum(product + 1) over (partition by ((product / 2) + 1) order by 2)", WindowExpression('sum.function('product + 1), WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + } + + test("range/rows window function expressions") { + val func = 'foo.function(star()) + def windowed( + partitioning: Seq[Expression] = Seq.empty, + ordering: Seq[SortOrder] = Seq.empty, + frame: WindowFrame = UnspecifiedFrame): Expression = { + WindowExpression(func, WindowSpecDefinition(partitioning, ordering, frame)) + } - // Range/Row val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame)) val boundaries = Seq( - ("10 preceding", -Literal(10), CurrentRow), + // No between combinations + ("unbounded preceding", UnboundedPreceding, CurrentRow), ("2147483648 preceding", -Literal(2147483648L), CurrentRow), + ("10 preceding", -Literal(10), CurrentRow), + ("3 + 1 preceding", -Add(Literal(3), Literal(1)), CurrentRow), + ("0 preceding", -Literal(0), CurrentRow), + ("current row", CurrentRow, CurrentRow), + ("0 following", Literal(0), CurrentRow), ("3 + 1 following", Add(Literal(3), Literal(1)), CurrentRow), - ("unbounded preceding", UnboundedPreceding, CurrentRow), + ("10 following", Literal(10), CurrentRow), + ("2147483649 following", Literal(2147483649L), CurrentRow), ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis + + // Between combinations + ("between unbounded preceding and 5 following", + UnboundedPreceding, Literal(5)), + ("between unbounded preceding and 3 + 1 following", + UnboundedPreceding, Add(Literal(3), Literal(1))), + ("between unbounded preceding and 2147483649 following", + UnboundedPreceding, Literal(2147483649L)), ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow), - ("between unbounded preceding and unbounded following", - UnboundedPreceding, UnboundedFollowing), + ("between 2147483648 preceding and current row", -Literal(2147483648L), CurrentRow), ("between 10 preceding and current row", -Literal(10), CurrentRow), + ("between 3 + 1 preceding and current row", -Add(Literal(3), Literal(1)), CurrentRow), + ("between 0 preceding and current row", -Literal(0), CurrentRow), + ("between current row and current row", CurrentRow, CurrentRow), + ("between current row and 0 following", CurrentRow, Literal(0)), ("between current row and 5 following", CurrentRow, Literal(5)), - ("between 10 preceding and 5 following", -Literal(10), Literal(5)) + ("between current row and 3 + 1 following", CurrentRow, Add(Literal(3), Literal(1))), + ("between current row and 2147483649 following", CurrentRow, Literal(2147483649L)), + ("between current row and unbounded following", CurrentRow, UnboundedFollowing), + ("between 2147483648 preceding and unbounded following", + -Literal(2147483648L), UnboundedFollowing), + ("between 10 preceding and unbounded following", + -Literal(10), UnboundedFollowing), + ("between 3 + 1 preceding and unbounded following", + -Add(Literal(3), Literal(1)), UnboundedFollowing), + ("between 0 preceding and unbounded following", -Literal(0), UnboundedFollowing), + + // Between partial and full range + ("between 10 preceding and 5 following", -Literal(10), Literal(5)), + ("between unbounded preceding and unbounded following", + UnboundedPreceding, UnboundedFollowing) ) frameTypes.foreach { case (frameTypeSql, frameType) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala new file mode 100644 index 0000000000000..0ee9b0edc02b2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * Window frame testing for DataFrame API. + */ +class DataFrameWindowFramesSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("lead/lag with empty data frame") { + val df = Seq.empty[(Int, String)].toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + lead("value", 1).over(window), + lag("value", 1).over(window)), + Nil) + } + + test("lead/lag with positive offset") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + $"key", + lead("value", 1).over(window), + lag("value", 1).over(window)), + Row(1, "3", null) :: Row(1, null, "1") :: Row(2, "4", null) :: Row(2, null, "2") :: Nil) + } + + test("reverse lead/lag with positive offset") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value".desc) + + checkAnswer( + df.select( + $"key", + lead("value", 1).over(window), + lag("value", 1).over(window)), + Row(1, "1", null) :: Row(1, null, "3") :: Row(2, "2", null) :: Row(2, null, "4") :: Nil) + } + + test("lead/lag with negative offset") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + $"key", + lead("value", -1).over(window), + lag("value", -1).over(window)), + Row(1, null, "3") :: Row(1, "1", null) :: Row(2, null, "4") :: Row(2, "2", null) :: Nil) + } + + test("reverse lead/lag with negative offset") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value".desc) + + checkAnswer( + df.select( + $"key", + lead("value", -1).over(window), + lag("value", -1).over(window)), + Row(1, null, "1") :: Row(1, "3", null) :: Row(2, null, "2") :: Row(2, "4", null) :: Nil) + } + + test("lead/lag with default value") { + val default = "n/a" + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4"), (2, "5")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + $"key", + lead("value", 2, default).over(window), + lag("value", 2, default).over(window), + lead("value", -2, default).over(window), + lag("value", -2, default).over(window)), + Row(1, default, default, default, default) :: Row(1, default, default, default, default) :: + Row(2, "5", default, default, "5") :: Row(2, default, "2", "2", default) :: + Row(2, default, default, default, default) :: Nil) + } + + test("rows/range between with empty data frame") { + val df = Seq.empty[(String, Int)].toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + 'key, + first("value").over( + window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + first("value").over( + window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Nil) + } + + test("rows between should accept int/long values as boundary") { + val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), (3L, "2"), (2L, "1"), (2147483650L, "2")) + .toDF("key", "value") + + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483647))), + Seq(Row(1, 3), Row(1, 4), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) + ) + + val e = intercept[AnalysisException]( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483648L)))) + assert(e.message.contains("Boundary end is not a valid integer: 2147483648")) + } + + test("range between should accept at most one ORDER BY expression when unbounded") { + val df = Seq((1, 1)).toDF("key", "value") + val window = Window.orderBy($"key", $"value") + + checkAnswer( + df.select( + $"key", + min("key").over( + window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Seq(Row(1, 1)) + ) + + val e1 = intercept[AnalysisException]( + df.select( + min("key").over(window.rangeBetween(Window.unboundedPreceding, 1)))) + assert(e1.message.contains("A range window frame with value boundaries cannot be used in a " + + "window specification with multiple order by expressions")) + + val e2 = intercept[AnalysisException]( + df.select( + min("key").over(window.rangeBetween(-1, Window.unboundedFollowing)))) + assert(e2.message.contains("A range window frame with value boundaries cannot be used in a " + + "window specification with multiple order by expressions")) + + val e3 = intercept[AnalysisException]( + df.select( + min("key").over(window.rangeBetween(-1, 1)))) + assert(e3.message.contains("A range window frame with value boundaries cannot be used in a " + + "window specification with multiple order by expressions")) + } + + test("range between should accept numeric values only when bounded") { + val df = Seq("non_numeric").toDF("value") + val window = Window.orderBy($"value") + + checkAnswer( + df.select( + $"value", + min("value").over( + window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Row("non_numeric", "non_numeric") :: Nil) + + val e1 = intercept[AnalysisException]( + df.select( + min("value").over(window.rangeBetween(Window.unboundedPreceding, 1)))) + assert(e1.message.contains("The data type of the upper bound 'string' " + + "does not match the expected data type")) + + val e2 = intercept[AnalysisException]( + df.select( + min("value").over(window.rangeBetween(-1, Window.unboundedFollowing)))) + assert(e2.message.contains("The data type of the lower bound 'string' " + + "does not match the expected data type")) + + val e3 = intercept[AnalysisException]( + df.select( + min("value").over(window.rangeBetween(-1, 1)))) + assert(e3.message.contains("The data type of the lower bound 'string' " + + "does not match the expected data type")) + } + + test("range between should accept int/long values as boundary") { + val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), (3L, "2"), (2L, "1"), (2147483650L, "2")) + .toDF("key", "value") + + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))), + Seq(Row(1, 3), Row(1, 3), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) + ) + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))), + Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), Row(2147483650L, 4), Row(3, 1)) + ) + + def dt(date: String): Date = Date.valueOf(date) + + val df2 = Seq((dt("2017-08-01"), "1"), (dt("2017-08-01"), "1"), (dt("2020-12-31"), "1"), + (dt("2017-08-03"), "2"), (dt("2017-08-02"), "1"), (dt("2020-12-31"), "2")) + .toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key").rangeBetween(lit(0), lit(2)) + + checkAnswer( + df2.select( + $"key", + count("key").over(window)), + Seq(Row(dt("2017-08-01"), 3), Row(dt("2017-08-01"), 3), Row(dt("2020-12-31"), 1), + Row(dt("2017-08-03"), 1), Row(dt("2017-08-02"), 1), Row(dt("2020-12-31"), 1)) + ) + } + + test("range between should accept double values as boundary") { + val df = Seq((1.0D, "1"), (1.0D, "1"), (100.001D, "1"), (3.3D, "2"), (2.02D, "1"), + (100.001D, "2")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key").rangeBetween(currentRow, lit(2.5D)) + + checkAnswer( + df.select( + $"key", + count("key").over(window)), + Seq(Row(1.0, 3), Row(1.0, 3), Row(100.001, 1), Row(3.3, 1), Row(2.02, 1), Row(100.001, 1)) + ) + } + + test("range between should accept interval values as boundary") { + def ts(timestamp: Long): Timestamp = new Timestamp(timestamp * 1000) + + val df = Seq((ts(1501545600), "1"), (ts(1501545600), "1"), (ts(1609372800), "1"), + (ts(1503000000), "2"), (ts(1502000000), "1"), (ts(1609372800), "2")) + .toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key") + .rangeBetween(currentRow, lit(CalendarInterval.fromString("interval 23 days 4 hours"))) + + checkAnswer( + df.select( + $"key", + count("key").over(window)), + Seq(Row(ts(1501545600), 3), Row(ts(1501545600), 3), Row(ts(1609372800), 1), + Row(ts(1503000000), 1), Row(ts(1502000000), 1), Row(ts(1609372800), 1)) + ) + } + + test("unbounded rows/range between with aggregation") { + val df = Seq(("one", 1), ("two", 2), ("one", 3), ("two", 4)).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + 'key, + sum("value").over(window. + rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + sum("value").over(window. + rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Row("one", 4, 4) :: Row("one", 4, 4) :: Row("two", 6, 6) :: Row("two", 6, 6) :: Nil) + } + + test("unbounded preceding/following rows between with aggregation") { + val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key") + + checkAnswer( + df.select( + $"key", + last("key").over( + window.rowsBetween(Window.currentRow, Window.unboundedFollowing)), + last("key").over( + window.rowsBetween(Window.unboundedPreceding, Window.currentRow))), + Row(1, 1, 1) :: Row(2, 3, 2) :: Row(3, 3, 3) :: Row(1, 4, 1) :: Row(2, 4, 2) :: + Row(4, 4, 4) :: Nil) + } + + test("reverse unbounded preceding/following rows between with aggregation") { + val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key".desc) + + checkAnswer( + df.select( + $"key", + last("key").over( + window.rowsBetween(Window.currentRow, Window.unboundedFollowing)), + last("key").over( + window.rowsBetween(Window.unboundedPreceding, Window.currentRow))), + Row(1, 1, 1) :: Row(3, 2, 3) :: Row(2, 2, 2) :: Row(4, 1, 4) :: Row(2, 1, 2) :: + Row(1, 1, 1) :: Nil) + } + + test("unbounded preceding/following range between with aggregation") { + val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") + val window = Window.partitionBy("value").orderBy("key") + + checkAnswer( + df.select( + $"key", + avg("key").over(window.rangeBetween(Window.unboundedPreceding, 1)) + .as("avg_key1"), + avg("key").over(window.rangeBetween(Window.currentRow, Window.unboundedFollowing)) + .as("avg_key2")), + Row(3, 3.0d, 4.0d) :: Row(5, 4.0d, 5.0d) :: Row(2, 2.0d, 17.0d / 4.0d) :: + Row(4, 11.0d / 3.0d, 5.0d) :: Row(5, 17.0d / 4.0d, 11.0d / 2.0d) :: + Row(6, 17.0d / 4.0d, 6.0d) :: Nil) + } + + // This is here to illustrate the fact that reverse order also reverses offsets. + test("reverse preceding/following range between with aggregation") { + val df = Seq(1, 2, 4, 3, 2, 1).toDF("value") + val window = Window.orderBy($"value".desc) + + checkAnswer( + df.select( + $"value", + sum($"value").over(window.rangeBetween(Window.unboundedPreceding, 1)), + sum($"value").over(window.rangeBetween(1, Window.unboundedFollowing))), + Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: Row(3, 11, 6) :: + Row(2, 13, 2) :: Row(1, 13, null) :: Nil) + } + + test("sliding rows between with aggregation") { + val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2) + + checkAnswer( + df.select( + $"key", + avg("key").over(window)), + Row(1, 4.0d / 3.0d) :: Row(1, 4.0d / 3.0d) :: Row(2, 3.0d / 2.0d) :: Row(2, 2.0d) :: + Row(2, 2.0d) :: Nil) + } + + test("reverse sliding rows between with aggregation") { + val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key".desc).rowsBetween(-1, 2) + + checkAnswer( + df.select( + $"key", + avg("key").over(window)), + Row(1, 1.0d) :: Row(1, 4.0d / 3.0d) :: Row(2, 4.0d / 3.0d) :: Row(2, 2.0d) :: + Row(2, 2.0d) :: Nil) + } + + test("sliding range between with aggregation") { + val df = Seq((1, "1"), (1, "1"), (3, "1"), (2, "2"), (2, "1"), (2, "2")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1) + + checkAnswer( + df.select( + $"key", + avg("key").over(window)), + Row(1, 4.0d / 3.0d) :: Row(1, 4.0d / 3.0d) :: Row(2, 7.0d / 4.0d) :: Row(3, 5.0d / 2.0d) :: + Row(2, 2.0d) :: Row(2, 2.0d) :: Nil) + } + + test("reverse sliding range between with aggregation") { + val df = Seq( + (1, "Thin", "Cell Phone", 6000), + (2, "Normal", "Tablet", 1500), + (3, "Mini", "Tablet", 5500), + (4, "Ultra thin", "Cell Phone", 5500), + (5, "Very thin", "Cell Phone", 6000), + (6, "Big", "Tablet", 2500), + (7, "Bendable", "Cell Phone", 3000), + (8, "Foldable", "Cell Phone", 3000), + (9, "Pro", "Tablet", 4500), + (10, "Pro2", "Tablet", 6500)). + toDF("id", "product", "category", "revenue") + val window = Window.partitionBy($"category").orderBy($"revenue".desc). + rangeBetween(-2000L, 1000L) + + checkAnswer( + df.select( + $"id", + avg($"revenue").over(window).cast("int")), + Row(1, 5833) :: Row(2, 2000) :: Row(3, 5500) :: + Row(4, 5833) :: Row(5, 5833) :: Row(6, 2833) :: + Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: + Row(10, 6000) :: Nil) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 01c988ecc3726..281147835abde 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -55,56 +55,6 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) } - test("Window.rowsBetween") { - val df = Seq(("one", 1), ("two", 2)).toDF("key", "value") - // Running (cumulative) sum - checkAnswer( - df.select('key, sum("value").over( - Window.rowsBetween(Window.unboundedPreceding, Window.currentRow))), - Row("one", 1) :: Row("two", 3) :: Nil - ) - } - - test("lead") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - - checkAnswer( - df.select( - lead("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - Row("1") :: Row(null) :: Row("2") :: Row(null) :: Nil) - } - - test("lag") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - - checkAnswer( - df.select( - lag("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - Row(null) :: Row("1") :: Row(null) :: Row("2") :: Nil) - } - - test("lead with default value") { - val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), - (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))), - Seq(Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"), Row("n/a"), Row("n/a"))) - } - - test("lag with default value") { - val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), - (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))), - Seq(Row("n/a"), Row("n/a"), Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"))) - } - test("rank functions in unspecific window") { val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value") df.createOrReplaceTempView("window_table") @@ -136,199 +86,6 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { assert(e.message.contains("requires window to be ordered")) } - test("aggregation and rows between") { - val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))), - Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(3.0d / 2.0d), Row(2.0d), Row(2.0d))) - } - - test("aggregation and range between") { - val df = Seq((1, "1"), (1, "1"), (3, "1"), (2, "2"), (2, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))), - Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(7.0d / 4.0d), Row(5.0d / 2.0d), - Row(2.0d), Row(2.0d))) - } - - test("row between should accept integer values as boundary") { - val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), - (3L, "2"), (2L, "1"), (2147483650L, "2")) - .toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483647))), - Seq(Row(1, 3), Row(1, 4), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) - ) - - val e = intercept[AnalysisException]( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483648L)))) - assert(e.message.contains("Boundary end is not a valid integer: 2147483648")) - } - - test("range between should accept int/long values as boundary") { - val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), - (3L, "2"), (2L, "1"), (2147483650L, "2")) - .toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))), - Seq(Row(1, 3), Row(1, 3), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) - ) - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))), - Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), Row(2147483650L, 4), Row(3, 1)) - ) - - def dt(date: String): Date = Date.valueOf(date) - - val df2 = Seq((dt("2017-08-01"), "1"), (dt("2017-08-01"), "1"), (dt("2020-12-31"), "1"), - (dt("2017-08-03"), "2"), (dt("2017-08-02"), "1"), (dt("2020-12-31"), "2")) - .toDF("key", "value") - checkAnswer( - df2.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(lit(0), lit(2)))), - Seq(Row(dt("2017-08-01"), 3), Row(dt("2017-08-01"), 3), Row(dt("2020-12-31"), 1), - Row(dt("2017-08-03"), 1), Row(dt("2017-08-02"), 1), Row(dt("2020-12-31"), 1)) - ) - } - - test("range between should accept double values as boundary") { - val df = Seq((1.0D, "1"), (1.0D, "1"), (100.001D, "1"), - (3.3D, "2"), (2.02D, "1"), (100.001D, "2")) - .toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key") - .rangeBetween(currentRow, lit(2.5D)))), - Seq(Row(1.0, 3), Row(1.0, 3), Row(100.001, 1), Row(3.3, 1), Row(2.02, 1), Row(100.001, 1)) - ) - } - - test("range between should accept interval values as boundary") { - def ts(timestamp: Long): Timestamp = new Timestamp(timestamp * 1000) - - val df = Seq((ts(1501545600), "1"), (ts(1501545600), "1"), (ts(1609372800), "1"), - (ts(1503000000), "2"), (ts(1502000000), "1"), (ts(1609372800), "2")) - .toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key") - .rangeBetween(currentRow, - lit(CalendarInterval.fromString("interval 23 days 4 hours"))))), - Seq(Row(ts(1501545600), 3), Row(ts(1501545600), 3), Row(ts(1609372800), 1), - Row(ts(1503000000), 1), Row(ts(1502000000), 1), Row(ts(1609372800), 1)) - ) - } - - test("aggregation and rows between with unbounded") { - val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - last("key").over( - Window.partitionBy($"value").orderBy($"key") - .rowsBetween(Window.currentRow, Window.unboundedFollowing)), - last("key").over( - Window.partitionBy($"value").orderBy($"key") - .rowsBetween(Window.unboundedPreceding, Window.currentRow)), - last("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))), - Seq(Row(1, 1, 1, 1), Row(2, 3, 2, 3), Row(3, 3, 3, 3), Row(1, 4, 1, 2), Row(2, 4, 2, 4), - Row(4, 4, 4, 4))) - } - - test("aggregation and range between with unbounded") { - val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - last("value").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(-2, -1)) - .equalTo("2") - .as("last_v"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1)) - .as("avg_key1"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, Long.MaxValue)) - .as("avg_key2"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0)) - .as("avg_key3") - ), - Seq(Row(3, null, 3.0d, 4.0d, 3.0d), - Row(5, false, 4.0d, 5.0d, 5.0d), - Row(2, null, 2.0d, 17.0d / 4.0d, 2.0d), - Row(4, true, 11.0d / 3.0d, 5.0d, 4.0d), - Row(5, true, 17.0d / 4.0d, 11.0d / 2.0d, 4.5d), - Row(6, true, 17.0d / 4.0d, 6.0d, 11.0d / 2.0d))) - } - - test("reverse sliding range frame") { - val df = Seq( - (1, "Thin", "Cell Phone", 6000), - (2, "Normal", "Tablet", 1500), - (3, "Mini", "Tablet", 5500), - (4, "Ultra thin", "Cell Phone", 5500), - (5, "Very thin", "Cell Phone", 6000), - (6, "Big", "Tablet", 2500), - (7, "Bendable", "Cell Phone", 3000), - (8, "Foldable", "Cell Phone", 3000), - (9, "Pro", "Tablet", 4500), - (10, "Pro2", "Tablet", 6500)). - toDF("id", "product", "category", "revenue") - val window = Window. - partitionBy($"category"). - orderBy($"revenue".desc). - rangeBetween(-2000L, 1000L) - checkAnswer( - df.select( - $"id", - avg($"revenue").over(window).cast("int")), - Row(1, 5833) :: Row(2, 2000) :: Row(3, 5500) :: - Row(4, 5833) :: Row(5, 5833) :: Row(6, 2833) :: - Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: - Row(10, 6000) :: Nil) - } - - // This is here to illustrate the fact that reverse order also reverses offsets. - test("reverse unbounded range frame") { - val df = Seq(1, 2, 4, 3, 2, 1). - map(Tuple1.apply). - toDF("value") - val window = Window.orderBy($"value".desc) - checkAnswer( - df.select( - $"value", - sum($"value").over(window.rangeBetween(Long.MinValue, 1)), - sum($"value").over(window.rangeBetween(1, Long.MaxValue))), - Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: - Row(3, 11, 6) :: Row(2, 13, 2) :: Row(1, 13, null) :: Nil) - } - test("statistical functions") { val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). toDF("key", "value") From 16670578519a7b787b0c63888b7d2873af12d5b9 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 16 Jan 2018 18:11:27 -0800 Subject: [PATCH 111/774] [SPARK-22908][SS] Roll forward continuous processing Kafka support with fix to continuous Kafka data reader ## What changes were proposed in this pull request? The Kafka reader is now interruptible and can close itself. ## How was this patch tested? I locally ran one of the ContinuousKafkaSourceSuite tests in a tight loop. Before the fix, my machine ran out of open file descriptors a few iterations in; now it works fine. Author: Jose Torres Closes #20253 from jose-torres/fix-data-reader. --- .../sql/kafka010/KafkaContinuousReader.scala | 260 +++++++++ .../sql/kafka010/KafkaContinuousWriter.scala | 119 ++++ .../sql/kafka010/KafkaOffsetReader.scala | 21 +- .../spark/sql/kafka010/KafkaSource.scala | 17 +- .../sql/kafka010/KafkaSourceOffset.scala | 7 +- .../sql/kafka010/KafkaSourceProvider.scala | 105 +++- .../spark/sql/kafka010/KafkaWriteTask.scala | 71 ++- .../spark/sql/kafka010/KafkaWriter.scala | 5 +- .../kafka010/KafkaContinuousSinkSuite.scala | 476 ++++++++++++++++ .../kafka010/KafkaContinuousSourceSuite.scala | 96 ++++ .../sql/kafka010/KafkaContinuousTest.scala | 94 +++ .../spark/sql/kafka010/KafkaSourceSuite.scala | 539 +++++++++--------- .../apache/spark/sql/DataFrameReader.scala | 32 +- .../apache/spark/sql/DataFrameWriter.scala | 25 +- .../datasources/v2/WriteToDataSourceV2.scala | 8 +- .../execution/streaming/StreamExecution.scala | 15 +- .../ContinuousDataSourceRDDIter.scala | 4 +- .../continuous/ContinuousExecution.scala | 67 ++- .../continuous/EpochCoordinator.scala | 21 +- .../sql/streaming/DataStreamWriter.scala | 26 +- .../spark/sql/streaming/StreamTest.scala | 36 +- 21 files changed, 1628 insertions(+), 416 deletions(-) create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala new file mode 100644 index 0000000000000..fc977977504f7 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} +import java.util.concurrent.TimeoutException + +import org.apache.kafka.clients.consumer.{ConsumerRecord, OffsetOutOfRangeException} +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +/** + * A [[ContinuousReader]] for data from kafka. + * + * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be + * read by per-task consumers generated later. + * @param kafkaParams String params for per-task Kafka consumers. + * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceV2Options]] params which + * are not Kafka consumer params. + * @param metadataPath Path to a directory this reader can use for writing metadata. + * @param initialOffsets The Kafka offsets to start reading data at. + * @param failOnDataLoss Flag indicating whether reading should fail in data loss + * scenarios, where some offsets after the specified initial ones can't be + * properly read. + */ +class KafkaContinuousReader( + offsetReader: KafkaOffsetReader, + kafkaParams: ju.Map[String, Object], + sourceOptions: Map[String, String], + metadataPath: String, + initialOffsets: KafkaOffsetRangeLimit, + failOnDataLoss: Boolean) + extends ContinuousReader with SupportsScanUnsafeRow with Logging { + + private lazy val session = SparkSession.getActiveSession.get + private lazy val sc = session.sparkContext + + private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong + + // Initialized when creating read tasks. If this diverges from the partitions at the latest + // offsets, we need to reconfigure. + // Exposed outside this object only for unit tests. + private[sql] var knownPartitions: Set[TopicPartition] = _ + + override def readSchema: StructType = KafkaOffsetReader.kafkaSchema + + private var offset: Offset = _ + override def setOffset(start: ju.Optional[Offset]): Unit = { + offset = start.orElse { + val offsets = initialOffsets match { + case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) + case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) + case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) + } + logInfo(s"Initial offsets: $offsets") + offsets + } + } + + override def getStartOffset(): Offset = offset + + override def deserializeOffset(json: String): Offset = { + KafkaSourceOffset(JsonUtils.partitionOffsets(json)) + } + + override def createUnsafeRowReadTasks(): ju.List[ReadTask[UnsafeRow]] = { + import scala.collection.JavaConverters._ + + val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) + + val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet + val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) + val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) + + val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"Some partitions were deleted: $deletedPartitions") + } + + val startOffsets = newPartitionOffsets ++ + oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) + knownPartitions = startOffsets.keySet + + startOffsets.toSeq.map { + case (topicPartition, start) => + KafkaContinuousReadTask( + topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) + .asInstanceOf[ReadTask[UnsafeRow]] + }.asJava + } + + /** Stop this source and free any resources it has allocated. */ + def stop(): Unit = synchronized { + offsetReader.close() + } + + override def commit(end: Offset): Unit = {} + + override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { + val mergedMap = offsets.map { + case KafkaSourcePartitionOffset(p, o) => Map(p -> o) + }.reduce(_ ++ _) + KafkaSourceOffset(mergedMap) + } + + override def needsReconfiguration(): Boolean = { + knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions + } + + override def toString(): String = s"KafkaSource[$offsetReader]" + + /** + * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. + * Otherwise, just log a warning. + */ + private def reportDataLoss(message: String): Unit = { + if (failOnDataLoss) { + throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE") + } else { + logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") + } + } +} + +/** + * A read task for continuous Kafka processing. This will be serialized and transformed into a + * full reader on executors. + * + * @param topicPartition The (topic, partition) pair this task is responsible for. + * @param startOffset The offset to start reading from within the partition. + * @param kafkaParams Kafka consumer params to use. + * @param pollTimeoutMs The timeout for Kafka consumer polling. + * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets + * are skipped. + */ +case class KafkaContinuousReadTask( + topicPartition: TopicPartition, + startOffset: Long, + kafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean) extends ReadTask[UnsafeRow] { + override def createDataReader(): KafkaContinuousDataReader = { + new KafkaContinuousDataReader( + topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) + } +} + +/** + * A per-task data reader for continuous Kafka processing. + * + * @param topicPartition The (topic, partition) pair this data reader is responsible for. + * @param startOffset The offset to start reading from within the partition. + * @param kafkaParams Kafka consumer params to use. + * @param pollTimeoutMs The timeout for Kafka consumer polling. + * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets + * are skipped. + */ +class KafkaContinuousDataReader( + topicPartition: TopicPartition, + startOffset: Long, + kafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { + private val topic = topicPartition.topic + private val kafkaPartition = topicPartition.partition + private val consumer = CachedKafkaConsumer.createUncached(topic, kafkaPartition, kafkaParams) + + private val sharedRow = new UnsafeRow(7) + private val bufferHolder = new BufferHolder(sharedRow) + private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) + + private var nextKafkaOffset = startOffset + private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _ + + override def next(): Boolean = { + var r: ConsumerRecord[Array[Byte], Array[Byte]] = null + while (r == null) { + if (TaskContext.get().isInterrupted() || TaskContext.get().isCompleted()) return false + // Our consumer.get is not interruptible, so we have to set a low poll timeout, leaving + // interrupt points to end the query rather than waiting for new data that might never come. + try { + r = consumer.get( + nextKafkaOffset, + untilOffset = Long.MaxValue, + pollTimeoutMs, + failOnDataLoss) + } catch { + // We didn't read within the timeout. We're supposed to block indefinitely for new data, so + // swallow and ignore this. + case _: TimeoutException => + + // This is a failOnDataLoss exception. Retry if nextKafkaOffset is within the data range, + // or if it's the endpoint of the data range (i.e. the "true" next offset). + case e: IllegalStateException if e.getCause.isInstanceOf[OffsetOutOfRangeException] => + val range = consumer.getAvailableOffsetRange() + if (range.latest >= nextKafkaOffset && range.earliest <= nextKafkaOffset) { + // retry + } else { + throw e + } + } + } + nextKafkaOffset = r.offset + 1 + currentRecord = r + true + } + + override def get(): UnsafeRow = { + bufferHolder.reset() + + if (currentRecord.key == null) { + rowWriter.setNullAt(0) + } else { + rowWriter.write(0, currentRecord.key) + } + rowWriter.write(1, currentRecord.value) + rowWriter.write(2, UTF8String.fromString(currentRecord.topic)) + rowWriter.write(3, currentRecord.partition) + rowWriter.write(4, currentRecord.offset) + rowWriter.write(5, + DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(currentRecord.timestamp))) + rowWriter.write(6, currentRecord.timestampType.id) + sharedRow.setTotalSize(bufferHolder.totalSize) + sharedRow + } + + override def getOffset(): KafkaSourcePartitionOffset = { + KafkaSourcePartitionOffset(topicPartition, nextKafkaOffset) + } + + override def close(): Unit = { + consumer.close() + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala new file mode 100644 index 0000000000000..9843f469c5b25 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.clients.producer.{Callback, ProducerRecord, RecordMetadata} +import scala.collection.JavaConverters._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{kafkaParamsForProducer, TOPIC_OPTION_KEY} +import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.{BinaryType, StringType, StructType} + +/** + * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we + * don't need to really send one. + */ +case object KafkaWriterCommitMessage extends WriterCommitMessage + +/** + * A [[ContinuousWriter]] for Kafka writing. Responsible for generating the writer factory. + * @param topic The topic this writer is responsible for. If None, topic will be inferred from + * a `topic` field in the incoming data. + * @param producerParams Parameters for Kafka producers in each task. + * @param schema The schema of the input data. + */ +class KafkaContinuousWriter( + topic: Option[String], producerParams: Map[String, String], schema: StructType) + extends ContinuousWriter with SupportsWriteInternalRow { + + validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) + + override def createInternalRowWriterFactory(): KafkaContinuousWriterFactory = + KafkaContinuousWriterFactory(topic, producerParams, schema) + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(messages: Array[WriterCommitMessage]): Unit = {} +} + +/** + * A [[DataWriterFactory]] for Kafka writing. Will be serialized and sent to executors to generate + * the per-task data writers. + * @param topic The topic that should be written to. If None, topic will be inferred from + * a `topic` field in the incoming data. + * @param producerParams Parameters for Kafka producers in each task. + * @param schema The schema of the input data. + */ +case class KafkaContinuousWriterFactory( + topic: Option[String], producerParams: Map[String, String], schema: StructType) + extends DataWriterFactory[InternalRow] { + + override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + new KafkaContinuousDataWriter(topic, producerParams, schema.toAttributes) + } +} + +/** + * A [[DataWriter]] for Kafka writing. One data writer will be created in each partition to + * process incoming rows. + * + * @param targetTopic The topic that this data writer is targeting. If None, topic will be inferred + * from a `topic` field in the incoming data. + * @param producerParams Parameters to use for the Kafka producer. + * @param inputSchema The attributes in the input data. + */ +class KafkaContinuousDataWriter( + targetTopic: Option[String], producerParams: Map[String, String], inputSchema: Seq[Attribute]) + extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] { + import scala.collection.JavaConverters._ + + private lazy val producer = CachedKafkaProducer.getOrCreate( + new java.util.HashMap[String, Object](producerParams.asJava)) + + def write(row: InternalRow): Unit = { + checkForErrors() + sendRow(row, producer) + } + + def commit(): WriterCommitMessage = { + // Send is asynchronous, but we can't commit until all rows are actually in Kafka. + // This requires flushing and then checking that no callbacks produced errors. + // We also check for errors before to fail as soon as possible - the check is cheap. + checkForErrors() + producer.flush() + checkForErrors() + KafkaWriterCommitMessage + } + + def abort(): Unit = {} + + def close(): Unit = { + checkForErrors() + if (producer != null) { + producer.flush() + checkForErrors() + CachedKafkaProducer.close(new java.util.HashMap[String, Object](producerParams.asJava)) + } + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 3e65949a6fd1b..551641cfdbca8 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -117,10 +117,14 @@ private[kafka010] class KafkaOffsetReader( * Resolves the specific offsets based on Kafka seek positions. * This method resolves offset value -1 to the latest and -2 to the * earliest Kafka seek position. + * + * @param partitionOffsets the specific offsets to resolve + * @param reportDataLoss callback to either report or log data loss depending on setting */ def fetchSpecificOffsets( - partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = - runUninterruptibly { + partitionOffsets: Map[TopicPartition, Long], + reportDataLoss: String => Unit): KafkaSourceOffset = { + val fetched = runUninterruptibly { withRetriesWithoutInterrupt { // Poll to get the latest assigned partitions consumer.poll(0) @@ -145,6 +149,19 @@ private[kafka010] class KafkaOffsetReader( } } + partitionOffsets.foreach { + case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && + off != KafkaOffsetRangeLimit.EARLIEST => + if (fetched(tp) != off) { + reportDataLoss( + s"startingOffsets for $tp was $off but consumer reset to ${fetched(tp)}") + } + case _ => + // no real way to check that beginning or end is reasonable + } + KafkaSourceOffset(fetched) + } + /** * Fetch the earliest offsets for the topic partitions that are indicated * in the [[ConsumerStrategy]]. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 864a92b8f813f..169a5d006fb04 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -130,7 +130,7 @@ private[kafka010] class KafkaSource( val offsets = startingOffsets match { case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchEarliestOffsets()) case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets()) - case SpecificOffsetRangeLimit(p) => fetchAndVerify(p) + case SpecificOffsetRangeLimit(p) => kafkaReader.fetchSpecificOffsets(p, reportDataLoss) } metadataLog.add(0, offsets) logInfo(s"Initial offsets: $offsets") @@ -138,21 +138,6 @@ private[kafka010] class KafkaSource( }.partitionToOffsets } - private def fetchAndVerify(specificOffsets: Map[TopicPartition, Long]) = { - val result = kafkaReader.fetchSpecificOffsets(specificOffsets) - specificOffsets.foreach { - case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && - off != KafkaOffsetRangeLimit.EARLIEST => - if (result(tp) != off) { - reportDataLoss( - s"startingOffsets for $tp was $off but consumer reset to ${result(tp)}") - } - case _ => - // no real way to check that beginning or end is reasonable - } - KafkaSourceOffset(result) - } - private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None override def schema: StructType = KafkaOffsetReader.kafkaSchema diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index b5da415b3097e..c82154cfbad7f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -20,17 +20,22 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} +import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2, PartitionOffset} /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and * their offsets. */ private[kafka010] -case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset { +case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends OffsetV2 { override val json = JsonUtils.partitionOffsets(partitionToOffsets) } +private[kafka010] +case class KafkaSourcePartitionOffset(topicPartition: TopicPartition, partitionOffset: Long) + extends PartitionOffset + /** Companion object of the [[KafkaSourceOffset]] */ private[kafka010] object KafkaSourceOffset { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 3cb4d8cad12cc..3914370a96595 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.{Locale, UUID} +import java.util.{Locale, Optional, UUID} import scala.collection.JavaConverters._ @@ -27,9 +27,12 @@ import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} -import org.apache.spark.sql.execution.streaming.{Sink, Source} +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.execution.streaming.{Offset, Sink, Source} import org.apache.spark.sql.sources._ +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -43,6 +46,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider + with ContinuousWriteSupport + with ContinuousReadSupport with Logging { import KafkaSourceProvider._ @@ -101,6 +106,43 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParams)) } + override def createContinuousReader( + schema: Optional[StructType], + metadataPath: String, + options: DataSourceV2Options): KafkaContinuousReader = { + val parameters = options.asMap().asScala.toMap + validateStreamOptions(parameters) + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + val specifiedKafkaParams = + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + + val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, + STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + + val kafkaOffsetReader = new KafkaOffsetReader( + strategy(caseInsensitiveParams), + kafkaParamsForDriver(specifiedKafkaParams), + parameters, + driverGroupIdPrefix = s"$uniqueGroupId-driver") + + new KafkaContinuousReader( + kafkaOffsetReader, + kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), + parameters, + metadataPath, + startingStreamOffsets, + failOnDataLoss(caseInsensitiveParams)) + } + /** * Returns a new base relation with the given parameters. * @@ -181,26 +223,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = { - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " - + "are serialized with ByteArraySerializer.") - } + override def createContinuousWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[ContinuousWriter] = { + import scala.collection.JavaConverters._ - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) - { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " - + "value are serialized with ByteArraySerializer.") - } - parameters - .keySet - .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } - .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, - ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + val spark = SparkSession.getActiveSession.get + val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim) + // We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable. + val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) + + KafkaWriter.validateQuery( + schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic) + + Optional.of(new KafkaContinuousWriter(topic, producerParams, schema)) } private def strategy(caseInsensitiveParams: Map[String, String]) = @@ -450,4 +488,27 @@ private[kafka010] object KafkaSourceProvider extends Logging { def build(): ju.Map[String, Object] = map } + + private[kafka010] def kafkaParamsForProducer( + parameters: Map[String, String]): Map[String, String] = { + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " + + "are serialized with ByteArraySerializer.") + } + + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) + { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " + + "value are serialized with ByteArraySerializer.") + } + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, + ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index 6fd333e2f43ba..baa60febf661d 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -33,10 +33,8 @@ import org.apache.spark.sql.types.{BinaryType, StringType} private[kafka010] class KafkaWriteTask( producerConfiguration: ju.Map[String, Object], inputSchema: Seq[Attribute], - topic: Option[String]) { + topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) { // used to synchronize with Kafka callbacks - @volatile private var failedWrite: Exception = null - private val projection = createProjection private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _ /** @@ -46,23 +44,7 @@ private[kafka010] class KafkaWriteTask( producer = CachedKafkaProducer.getOrCreate(producerConfiguration) while (iterator.hasNext && failedWrite == null) { val currentRow = iterator.next() - val projectedRow = projection(currentRow) - val topic = projectedRow.getUTF8String(0) - val key = projectedRow.getBinary(1) - val value = projectedRow.getBinary(2) - if (topic == null) { - throw new NullPointerException(s"null topic present in the data. Use the " + - s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") - } - val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) - val callback = new Callback() { - override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { - if (failedWrite == null && e != null) { - failedWrite = e - } - } - } - producer.send(record, callback) + sendRow(currentRow, producer) } } @@ -74,8 +56,49 @@ private[kafka010] class KafkaWriteTask( producer = null } } +} + +private[kafka010] abstract class KafkaRowWriter( + inputSchema: Seq[Attribute], topic: Option[String]) { + + // used to synchronize with Kafka callbacks + @volatile protected var failedWrite: Exception = _ + protected val projection = createProjection + + private val callback = new Callback() { + override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { + if (failedWrite == null && e != null) { + failedWrite = e + } + } + } - private def createProjection: UnsafeProjection = { + /** + * Send the specified row to the producer, with a callback that will save any exception + * to failedWrite. Note that send is asynchronous; subclasses must flush() their producer before + * assuming the row is in Kafka. + */ + protected def sendRow( + row: InternalRow, producer: KafkaProducer[Array[Byte], Array[Byte]]): Unit = { + val projectedRow = projection(row) + val topic = projectedRow.getUTF8String(0) + val key = projectedRow.getBinary(1) + val value = projectedRow.getBinary(2) + if (topic == null) { + throw new NullPointerException(s"null topic present in the data. Use the " + + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") + } + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) + producer.send(record, callback) + } + + protected def checkForErrors(): Unit = { + if (failedWrite != null) { + throw failedWrite + } + } + + private def createProjection = { val topicExpression = topic.map(Literal(_)).orElse { inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME) }.getOrElse { @@ -112,11 +135,5 @@ private[kafka010] class KafkaWriteTask( Seq(topicExpression, Cast(keyExpression, BinaryType), Cast(valueExpression, BinaryType)), inputSchema) } - - private def checkForErrors(): Unit = { - if (failedWrite != null) { - throw failedWrite - } - } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 5e9ae35b3f008..15cd44812cb0c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -43,10 +43,9 @@ private[kafka010] object KafkaWriter extends Logging { override def toString: String = "KafkaWriter" def validateQuery( - queryExecution: QueryExecution, + schema: Seq[Attribute], kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { - val schema = queryExecution.analyzed.output schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( if (topic.isEmpty) { throw new AnalysisException(s"topic option required when no " + @@ -84,7 +83,7 @@ private[kafka010] object KafkaWriter extends Logging { kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { val schema = queryExecution.analyzed.output - validateQuery(queryExecution, kafkaParameters, topic) + validateQuery(schema, kafkaParameters, topic) queryExecution.toRdd.foreachPartition { iter => val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) Utils.tryWithSafeFinally(block = writeTask.execute(iter))( diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala new file mode 100644 index 0000000000000..8487a69851237 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -0,0 +1,476 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.Locale +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.common.serialization.ByteArraySerializer +import org.scalatest.time.SpanSugar._ +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.types.{BinaryType, DataType} +import org.apache.spark.util.Utils + +/** + * This is a temporary port of KafkaSinkSuite, since we do not yet have a V2 memory stream. + * Once we have one, this will be changed to a specialization of KafkaSinkSuite and we won't have + * to duplicate all the code. + */ +class KafkaContinuousSinkSuite extends KafkaContinuousTest { + import testImplicits._ + + override val streamingTimeout = 30.seconds + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils( + withBrokerProps = Map("auto.create.topics.enable" -> "false")) + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } + + test("streaming - write to kafka with topic field") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = None, + withOutputMode = Some(OutputMode.Append))( + withSelectExpr = s"'$topic' as topic", "value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + } + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } + } finally { + writer.stop() + } + } + + test("streaming - write w/o topic field, with topic option") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Append()))() + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + } + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } + } finally { + writer.stop() + } + } + + test("streaming - topic field and topic option") { + /* The purpose of this test is to ensure that the topic option + * overrides the topic field. We begin by writing some data that + * includes a topic field and value (e.g., 'foo') along with a topic + * option. Then when we read from the topic specified in the option + * we should see the data i.e., the data was written to the topic + * option, and not to the topic in the data e.g., foo + */ + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Append()))( + withSelectExpr = "'foo' as topic", "CAST(value as STRING) value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .selectExpr("CAST(key AS INT)", "CAST(value AS INT)") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + } + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } + } finally { + writer.stop() + } + } + + test("null topic attribute") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + testUtils.createTopic(topic) + + /* No topic field or topic option */ + var writer: StreamingQuery = null + var ex: Exception = null + try { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = "CAST(null as STRING) as topic", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getCause.getCause.getMessage + .toLowerCase(Locale.ROOT) + .contains("null topic present in the data.")) + } + + test("streaming - write data with bad schema") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + testUtils.createTopic(topic) + + /* No topic field or topic option */ + var writer: StreamingQuery = null + var ex: Exception = null + try { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = "value as key", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage + .toLowerCase(Locale.ROOT) + .contains("topic option required when no 'topic' attribute is present")) + + try { + /* No value field */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "value as key" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "required attribute 'value' not found")) + } + + test("streaming - write data with valid schema but wrong types") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + .selectExpr("CAST(value as STRING) value") + val topic = newTopic() + testUtils.createTopic(topic) + + var writer: StreamingQuery = null + var ex: Exception = null + try { + /* topic field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"CAST('1' as INT) as topic", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string")) + + try { + /* value field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "value attribute type must be a string or binarytype")) + + try { + /* key field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "key attribute type must be a string or binarytype")) + } + + test("streaming - write to non-existing topic") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))() + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + } + throw writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) + } + + test("streaming - exception on config serializer") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + testUtils.sendMessages(inputTopic, Array("0")) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .load() + var writer: StreamingQuery = null + var ex: Exception = null + try { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.key.serializer" -> "foo"))() + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'key.serializer' is not supported")) + } finally { + writer.stop() + } + + try { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.value.serializer" -> "foo"))() + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'value.serializer' is not supported")) + } finally { + writer.stop() + } + } + + test("generic - write big data with small producer buffer") { + /* This test ensures that we understand the semantics of Kafka when + * is comes to blocking on a call to send when the send buffer is full. + * This test will configure the smallest possible producer buffer and + * indicate that we should block when it is full. Thus, no exception should + * be thrown in the case of a full buffer. + */ + val topic = newTopic() + testUtils.createTopic(topic, 1) + val options = new java.util.HashMap[String, String] + options.put("bootstrap.servers", testUtils.brokerAddress) + options.put("buffer.memory", "16384") // min buffer size + options.put("block.on.buffer.full", "true") + options.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + val inputSchema = Seq(AttributeReference("value", BinaryType)()) + val data = new Array[Byte](15000) // large value + val writeTask = new KafkaContinuousDataWriter(Some(topic), options.asScala.toMap, inputSchema) + try { + val fieldTypes: Array[DataType] = Array(BinaryType) + val converter = UnsafeProjection.create(fieldTypes) + val row = new SpecificInternalRow(fieldTypes) + row.update(0, data) + val iter = Seq.fill(1000)(converter.apply(row)).iterator + iter.foreach(writeTask.write(_)) + writeTask.commit() + } finally { + writeTask.close() + } + } + + private def createKafkaReader(topic: String): DataFrame = { + spark.read + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("startingOffsets", "earliest") + .option("endingOffsets", "latest") + .option("subscribe", topic) + .load() + } + + private def createKafkaWriter( + input: DataFrame, + withTopic: Option[String] = None, + withOutputMode: Option[OutputMode] = None, + withOptions: Map[String, String] = Map[String, String]()) + (withSelectExpr: String*): StreamingQuery = { + var stream: DataStreamWriter[Row] = null + val checkpointDir = Utils.createTempDir() + var df = input.toDF() + if (withSelectExpr.length > 0) { + df = df.selectExpr(withSelectExpr: _*) + } + stream = df.writeStream + .format("kafka") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + // We need to reduce blocking time to efficiently test non-existent partition behavior. + .option("kafka.max.block.ms", "1000") + .trigger(Trigger.Continuous(1000)) + .queryName("kafkaStream") + withTopic.foreach(stream.option("topic", _)) + withOutputMode.foreach(stream.outputMode(_)) + withOptions.foreach(opt => stream.option(opt._1, opt._2)) + stream.start() + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala new file mode 100644 index 0000000000000..b3dade414f625 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.Properties +import java.util.concurrent.atomic.AtomicInteger + +import org.scalatest.time.SpanSugar._ +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} + +// Run tests in KafkaSourceSuiteBase in continuous execution mode. +class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest + +class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { + import testImplicits._ + + override val brokerProps = Map("auto.create.topics.enable" -> "false") + + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Execute { query => + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + eventually(timeout(streamingTimeout)) { + assert( + query.lastExecution.logical.collectFirst { + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + }.exists { r => + // Ensure the new topic is present and the old topic is gone. + r.knownPartitions.exists(_.topic == topic2) + }, + s"query never reconfigured to new topic $topic2") + } + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } +} + +class KafkaContinuousSourceStressForDontFailOnDataLossSuite + extends KafkaSourceStressForDontFailOnDataLossSuite { + override protected def startStream(ds: Dataset[Int]) = { + ds.writeStream + .format("memory") + .queryName("memory") + .start() + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala new file mode 100644 index 0000000000000..5a1a14f7a307a --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.SparkContext +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.streaming.Trigger +import org.apache.spark.sql.test.TestSparkSession + +// Trait to configure StreamTest for kafka continuous execution tests. +trait KafkaContinuousTest extends KafkaSourceTest { + override val defaultTrigger = Trigger.Continuous(1000) + override val defaultUseV2Sink = true + + // We need more than the default local[2] to be able to schedule all partitions simultaneously. + override protected def createSparkSession = new TestSparkSession( + new SparkContext( + "local[10]", + "continuous-stream-test-sql-context", + sparkConf.set("spark.sql.testkey", "true"))) + + // In addition to setting the partitions in Kafka, we have to wait until the query has + // reconfigured to the new count so the test framework can hook in properly. + override protected def setTopicPartitions( + topic: String, newCount: Int, query: StreamExecution) = { + testUtils.addPartitions(topic, newCount) + eventually(timeout(streamingTimeout)) { + assert( + query.lastExecution.logical.collectFirst { + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + }.exists(_.knownPartitions.size == newCount), + s"query never reconfigured to $newCount partitions") + } + } + + // Continuous processing tasks end asynchronously, so test that they actually end. + private val tasksEndedListener = new SparkListener() { + val activeTaskIdCount = new AtomicInteger(0) + + override def onTaskStart(start: SparkListenerTaskStart): Unit = { + activeTaskIdCount.incrementAndGet() + } + + override def onTaskEnd(end: SparkListenerTaskEnd): Unit = { + activeTaskIdCount.decrementAndGet() + } + } + + override def beforeEach(): Unit = { + super.beforeEach() + spark.sparkContext.addSparkListener(tasksEndedListener) + } + + override def afterEach(): Unit = { + eventually(timeout(streamingTimeout)) { + assert(tasksEndedListener.activeTaskIdCount.get() == 0) + } + spark.sparkContext.removeSparkListener(tasksEndedListener) + super.afterEach() + } + + + test("ensure continuous stream is being used") { + val query = spark.readStream + .format("rate") + .option("numPartitions", "1") + .option("rowsPerSecond", "1") + .load() + + testStream(query)( + Execute(q => assert(q.isInstanceOf[ContinuousExecution])) + ) + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index a0f5695fc485c..1acff61e11d2a 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -34,11 +34,14 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext -import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryWriter import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ -import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} +import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest, Trigger} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} import org.apache.spark.util.Utils @@ -49,9 +52,11 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override val streamingTimeout = 30.seconds + protected val brokerProps = Map[String, Object]() + override def beforeAll(): Unit = { super.beforeAll() - testUtils = new KafkaTestUtils + testUtils = new KafkaTestUtils(brokerProps) testUtils.setup() } @@ -59,18 +64,25 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { if (testUtils != null) { testUtils.teardown() testUtils = null - super.afterAll() } + super.afterAll() } protected def makeSureGetOffsetCalled = AssertOnQuery { q => // Because KafkaSource's initialPartitionOffsets is set lazily, we need to make sure - // its "getOffset" is called before pushing any data. Otherwise, because of the race contion, + // its "getOffset" is called before pushing any data. Otherwise, because of the race condition, // we don't know which data should be fetched when `startingOffsets` is latest. - q.processAllAvailable() + q match { + case c: ContinuousExecution => c.awaitEpoch(0) + case m: MicroBatchExecution => m.processAllAvailable() + } true } + protected def setTopicPartitions(topic: String, newCount: Int, query: StreamExecution) : Unit = { + testUtils.addPartitions(topic, newCount) + } + /** * Add data to Kafka. * @@ -82,10 +94,11 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { message: String = "", topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData { - override def addData(query: Option[StreamExecution]): (Source, Offset) = { - if (query.get.isActive) { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + query match { // Make sure no Spark job is running when deleting a topic - query.get.processAllAvailable() + case Some(m: MicroBatchExecution) => m.processAllAvailable() + case _ => } val existingTopics = testUtils.getAllTopicsAndPartitionSize().toMap @@ -97,16 +110,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2)) } - // Read all topics again in case some topics are delete. - val allTopics = testUtils.getAllTopicsAndPartitionSize().toMap.keys require( query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[KafkaSource] => - source.asInstanceOf[KafkaSource] - } + case StreamingExecutionRelation(source: KafkaSource, _) => source + } ++ (query.get.lastExecution match { + case null => Seq() + case e => e.logical.collect { + case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + } + }) if (sources.isEmpty) { throw new Exception( "Could not find Kafka source in the StreamExecution logical plan to add data to") @@ -137,14 +152,158 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override def toString: String = s"AddKafkaData(topics = $topics, data = $data, message = $message)" } -} + private val topicId = new AtomicInteger(0) + protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" +} -class KafkaSourceSuite extends KafkaSourceTest { +class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { import testImplicits._ - private val topicId = new AtomicInteger(0) + test("(de)serialization of initial offsets") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + + testStream(reader.load)( + makeSureGetOffsetCalled, + StopStream, + StartStream(), + StopStream) + } + + test("maxOffsetsPerTrigger") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) + testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) + testUtils.sendMessages(topic, Array("1"), Some(2)) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 10) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + val clock = new StreamManualClock + + val waitUntilBatchProcessed = AssertOnQuery { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + } + + testStream(mapped)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // 1 from smallest, 1 from middle, 8 from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 + ), + StopStream, + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 + ), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, + 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 + ) + ) + } + + test("input row metrics") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + val mapped = kafka.map(kv => kv._2.toInt + 1) + testStream(mapped)( + StartStream(trigger = ProcessingTime(1)), + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + AssertOnQuery { query => + val recordsRead = query.recentProgress.map(_.numInputRows).sum + recordsRead == 3 + } + ) + } + + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Assert { + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + true + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } testWithUninterruptibleThread( "deserialization of initial offset with Spark 2.1.0") { @@ -237,86 +396,94 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - test("(de)serialization of initial offsets") { + test("KafkaSource with watermark") { + val now = System.currentTimeMillis() val topic = newTopic() - testUtils.createTopic(topic, partitions = 64) + testUtils.createTopic(newTopic(), partitions = 1) + testUtils.sendMessages(topic, Array(1).map(_.toString)) - val reader = spark + val kafka = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("startingOffsets", s"earliest") .option("subscribe", topic) + .load() - testStream(reader.load)( - makeSureGetOffsetCalled, - StopStream, - StartStream(), - StopStream) + val windowedAggregation = kafka + .withWatermark("timestamp", "10 seconds") + .groupBy(window($"timestamp", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start") as 'window, $"count") + + val query = windowedAggregation + .writeStream + .format("memory") + .outputMode("complete") + .queryName("kafkaWatermark") + .start() + query.processAllAvailable() + val rows = spark.table("kafkaWatermark").collect() + assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + val row = rows(0) + // We cannot check the exact window start time as it depands on the time that messages were + // inserted by the producer. So here we just use a low bound to make sure the internal + // conversion works. + assert( + row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, + s"Unexpected results: $row") + assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") + query.stop() } - test("maxOffsetsPerTrigger") { + test("delete a topic when a Spark job is running") { + KafkaSourceSuite.collectedData.clear() + val topic = newTopic() - testUtils.createTopic(topic, partitions = 3) - testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) - testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) - testUtils.sendMessages(topic, Array("1"), Some(2)) + testUtils.createTopic(topic, partitions = 1) + testUtils.sendMessages(topic, (1 to 10).map(_.toString).toArray) val reader = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") - .option("maxOffsetsPerTrigger", 10) .option("subscribe", topic) + // If a topic is deleted and we try to poll data starting from offset 0, + // the Kafka consumer will just block until timeout and return an empty result. + // So set the timeout to 1 second to make this test fast. + .option("kafkaConsumer.pollTimeoutMs", "1000") .option("startingOffsets", "earliest") + .option("failOnDataLoss", "false") val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] - val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) - - val clock = new StreamManualClock - - val waitUntilBatchProcessed = AssertOnQuery { q => - eventually(Timeout(streamingTimeout)) { - if (!q.exception.isDefined) { - assert(clock.isStreamWaitingAt(clock.getTimeMillis())) - } + KafkaSourceSuite.globalTestUtils = testUtils + // The following ForeachWriter will delete the topic before fetching data from Kafka + // in executors. + val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { + override def open(partitionId: Long, version: Long): Boolean = { + KafkaSourceSuite.globalTestUtils.deleteTopic(topic) + true } - if (q.exception.isDefined) { - throw q.exception.get + + override def process(value: Int): Unit = { + KafkaSourceSuite.collectedData.add(value) } - true - } - testStream(mapped)( - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // 1 from smallest, 1 from middle, 8 from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 - ), - StopStream, - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 - ), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, - 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 - ) - ) + override def close(errorOrNull: Throwable): Unit = {} + }).start() + query.processAllAvailable() + query.stop() + // `failOnDataLoss` is `false`, we should not fail the query + assert(query.exception.isEmpty) } +} + +class KafkaSourceSuiteBase extends KafkaSourceTest { + + import testImplicits._ test("SPARK-22956: currentPartitionOffsets should be set when no new data comes in") { def getSpecificDF(range: Range.Inclusive): org.apache.spark.sql.Dataset[Int] = { @@ -393,7 +560,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"topic-.*") + .option("subscribePattern", s"$topic.*") val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") @@ -487,65 +654,6 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - test("subscribing topic by pattern with topic deletions") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-seems" - val topic2 = topicPrefix + "-bad" - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"$topicPrefix-.*") - .option("failOnDataLoss", "false") - - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped = kafka.map(kv => kv._2.toInt + 1) - - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - Assert { - testUtils.deleteTopic(topic) - testUtils.createTopic(topic2, partitions = 5) - true - }, - AddKafkaData(Set(topic2), 4, 5, 6), - CheckAnswer(2, 3, 4, 5, 6, 7) - ) - } - - test("starting offset is latest by default") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("0")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", topic) - - val kafka = reader.load() - .selectExpr("CAST(value AS STRING)") - .as[String] - val mapped = kafka.map(_.toInt) - - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(1, 2, 3) // should not have 0 - ) - } - test("bad source options") { def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = { val ex = intercept[IllegalArgumentException] { @@ -605,77 +713,6 @@ class KafkaSourceSuite extends KafkaSourceTest { testUnsupportedConfig("kafka.auto.offset.reset", "latest") } - test("input row metrics") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val kafka = spark - .readStream - .format("kafka") - .option("subscribe", topic) - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - - val mapped = kafka.map(kv => kv._2.toInt + 1) - testStream(mapped)( - StartStream(trigger = ProcessingTime(1)), - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - AssertOnQuery { query => - val recordsRead = query.recentProgress.map(_.numInputRows).sum - recordsRead == 3 - } - ) - } - - test("delete a topic when a Spark job is running") { - KafkaSourceSuite.collectedData.clear() - - val topic = newTopic() - testUtils.createTopic(topic, partitions = 1) - testUtils.sendMessages(topic, (1 to 10).map(_.toString).toArray) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribe", topic) - // If a topic is deleted and we try to poll data starting from offset 0, - // the Kafka consumer will just block until timeout and return an empty result. - // So set the timeout to 1 second to make this test fast. - .option("kafkaConsumer.pollTimeoutMs", "1000") - .option("startingOffsets", "earliest") - .option("failOnDataLoss", "false") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - KafkaSourceSuite.globalTestUtils = testUtils - // The following ForeachWriter will delete the topic before fetching data from Kafka - // in executors. - val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { - override def open(partitionId: Long, version: Long): Boolean = { - KafkaSourceSuite.globalTestUtils.deleteTopic(topic) - true - } - - override def process(value: Int): Unit = { - KafkaSourceSuite.collectedData.add(value) - } - - override def close(errorOrNull: Throwable): Unit = {} - }).start() - query.processAllAvailable() - query.stop() - // `failOnDataLoss` is `false`, we should not fail the query - assert(query.exception.isEmpty) - } - test("get offsets from case insensitive parameters") { for ((optionKey, optionValue, answer) <- Seq( (STARTING_OFFSETS_OPTION_KEY, "earLiEst", EarliestOffsetRangeLimit), @@ -694,8 +731,6 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" - private def assignString(topic: String, partitions: Iterable[Int]): String = { JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p))) } @@ -741,6 +776,10 @@ class KafkaSourceSuite extends KafkaSourceTest { testStream(mapped)( makeSureGetOffsetCalled, + Execute { q => + // wait to reach the last offset in every partition + q.awaitOffset(0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L))) + }, CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22), StopStream, StartStream(), @@ -771,10 +810,13 @@ class KafkaSourceSuite extends KafkaSourceTest { .format("memory") .outputMode("append") .queryName("kafkaColumnTypes") + .trigger(defaultTrigger) .start() - query.processAllAvailable() - val rows = spark.table("kafkaColumnTypes").collect() - assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + var rows: Array[Row] = Array() + eventually(timeout(streamingTimeout)) { + rows = spark.table("kafkaColumnTypes").collect() + assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + } val row = rows(0) assert(row.getAs[Array[Byte]]("key") === null, s"Unexpected results: $row") assert(row.getAs[Array[Byte]]("value") === "1".getBytes(UTF_8), s"Unexpected results: $row") @@ -788,47 +830,6 @@ class KafkaSourceSuite extends KafkaSourceTest { query.stop() } - test("KafkaSource with watermark") { - val now = System.currentTimeMillis() - val topic = newTopic() - testUtils.createTopic(newTopic(), partitions = 1) - testUtils.sendMessages(topic, Array(1).map(_.toString)) - - val kafka = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("startingOffsets", s"earliest") - .option("subscribe", topic) - .load() - - val windowedAggregation = kafka - .withWatermark("timestamp", "10 seconds") - .groupBy(window($"timestamp", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"window".getField("start") as 'window, $"count") - - val query = windowedAggregation - .writeStream - .format("memory") - .outputMode("complete") - .queryName("kafkaWatermark") - .start() - query.processAllAvailable() - val rows = spark.table("kafkaWatermark").collect() - assert(rows.length === 1, s"Unexpected results: ${rows.toList}") - val row = rows(0) - // We cannot check the exact window start time as it depands on the time that messages were - // inserted by the producer. So here we just use a low bound to make sure the internal - // conversion works. - assert( - row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, - s"Unexpected results: $row") - assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") - query.stop() - } - private def testFromLatestOffsets( topic: String, addPartitions: Boolean, @@ -865,9 +866,7 @@ class KafkaSourceSuite extends KafkaSourceTest { AddKafkaData(Set(topic), 7, 8), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) { - testUtils.addPartitions(topic, 10) - } + if (addPartitions) setTopicPartitions(topic, 10, query) true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -908,9 +907,7 @@ class KafkaSourceSuite extends KafkaSourceTest { StartStream(), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) { - testUtils.addPartitions(topic, 10) - } + if (addPartitions) setTopicPartitions(topic, 10, query) true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -1042,20 +1039,8 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared } } - test("stress test for failOnDataLoss=false") { - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", "failOnDataLoss.*") - .option("startingOffsets", "earliest") - .option("failOnDataLoss", "false") - .option("fetchOffset.retryIntervalMs", "3000") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { + protected def startStream(ds: Dataset[Int]) = { + ds.writeStream.foreach(new ForeachWriter[Int] { override def open(partitionId: Long, version: Long): Boolean = { true @@ -1069,6 +1054,22 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared override def close(errorOrNull: Throwable): Unit = { } }).start() + } + + test("stress test for failOnDataLoss=false") { + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", "failOnDataLoss.*") + .option("startingOffsets", "earliest") + .option("failOnDataLoss", "false") + .option("fetchOffset.retryIntervalMs", "3000") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val query = startStream(kafka.map(kv => kv._2.toInt)) val testTime = 1.minutes val startTime = System.currentTimeMillis() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index e8d683a578f35..b714a46b5f786 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -191,6 +191,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { ds = ds.asInstanceOf[DataSourceV2], conf = sparkSession.sessionState.conf)).asJava) + // Streaming also uses the data source V2 API. So it may be that the data source implements + // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading + // the dataframe as a v1 source. val reader = (ds, userSpecifiedSchema) match { case (ds: ReadSupportWithSchema, Some(schema)) => ds.createReader(schema, options) @@ -208,23 +211,30 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } reader - case _ => - throw new AnalysisException(s"$cls does not support data reading.") + case _ => null // fall back to v1 } - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + if (reader == null) { + loadV1Source(paths: _*) + } else { + Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + } } else { - // Code path for data source v1. - sparkSession.baseRelationToDataFrame( - DataSource.apply( - sparkSession, - paths = paths, - userSpecifiedSchema = userSpecifiedSchema, - className = source, - options = extraOptions.toMap).resolveRelation()) + loadV1Source(paths: _*) } } + private def loadV1Source(paths: String*) = { + // Code path for data source v1. + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap).resolveRelation()) + } + /** * Construct a `DataFrame` representing the database table accessible via JDBC URL * url named table and connection properties. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3304f368e1050..97f12ff625c42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -255,17 +255,24 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } - case _ => throw new AnalysisException(s"$cls does not support data writing.") + // Streaming also uses the data source V2 API. So it may be that the data source implements + // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving + // as though it's a V1 source. + case _ => saveToV1Source() } } else { - // Code path for data source v1. - runCommand(df.sparkSession, "save") { - DataSource( - sparkSession = df.sparkSession, - className = source, - partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) - } + saveToV1Source() + } + } + + private def saveToV1Source(): Unit = { + // Code path for data source v1. + runCommand(df.sparkSession, "save") { + DataSource( + sparkSession = df.sparkSession, + className = source, + partitionColumns = partitioningColumns.getOrElse(Nil), + options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index f0bdf84bb7a84..a4a857f2d4d9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -81,9 +81,11 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) (index, message: WriterCommitMessage) => messages(index) = message ) - logInfo(s"Data source writer $writer is committing.") - writer.commit(messages) - logInfo(s"Data source writer $writer committed.") + if (!writer.isInstanceOf[ContinuousWriter]) { + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") + } } catch { case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] => // Interruption is how continuous queries are ended, so accept and ignore the exception. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 24a8b000df0c1..cf27e1a70650a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -142,7 +142,8 @@ abstract class StreamExecution( override val id: UUID = UUID.fromString(streamMetadata.id) - override val runId: UUID = UUID.randomUUID + override def runId: UUID = currentRunId + protected var currentRunId = UUID.randomUUID /** * Pretty identified string of printing in logs. Format is @@ -418,11 +419,17 @@ abstract class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. */ - private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = { + private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = { assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets - !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset + if (sources == null) { + // sources might not be initialized yet + false + } else { + val source = sources(sourceIndex) + !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset + } } while (notDone) { @@ -436,7 +443,7 @@ abstract class StreamExecution( awaitProgressLock.unlock() } } - logDebug(s"Unblocked at $newOffset for $source") + logDebug(s"Unblocked at $newOffset for ${sources(sourceIndex)}") } /** A flag to indicate that a batch has completed with no new data available. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index b3f1a1a1aaab3..66eb42d4658f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -77,7 +77,6 @@ class ContinuousDataSourceRDD( dataReaderThread.start() context.addTaskCompletionListener(_ => { - reader.close() dataReaderThread.interrupt() epochPollExecutor.shutdown() }) @@ -177,6 +176,7 @@ class DataReaderThread( private[continuous] var failureReason: Throwable = _ override def run(): Unit = { + TaskContext.setTaskContext(context) val baseReader = ContinuousDataSourceRDD.getBaseReader(reader) try { while (!context.isInterrupted && !context.isCompleted()) { @@ -201,6 +201,8 @@ class DataReaderThread( failedFlag.set(true) // Don't rethrow the exception in this thread. It's not needed, and the default Spark // exception handler will kill the executor. + } finally { + reader.close() } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 9657b5e26d770..667410ef9f1c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution.streaming.continuous +import java.util.UUID import java.util.concurrent.TimeUnit +import java.util.function.UnaryOperator import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} @@ -52,7 +54,7 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReader] = Seq.empty + @volatile protected var continuousSources: Seq[ContinuousReader] = _ override protected def sources: Seq[BaseStreamingSource] = continuousSources override lazy val logicalPlan: LogicalPlan = { @@ -78,15 +80,17 @@ class ContinuousExecution( } override protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { - do { - try { - runContinuous(sparkSessionForStream) - } catch { - case _: InterruptedException if state.get().equals(RECONFIGURING) => - // swallow exception and run again - state.set(ACTIVE) + val stateUpdate = new UnaryOperator[State] { + override def apply(s: State) = s match { + // If we ended the query to reconfigure, reset the state to active. + case RECONFIGURING => ACTIVE + case _ => s } - } while (state.get() == ACTIVE) + } + + do { + runContinuous(sparkSessionForStream) + } while (state.updateAndGet(stateUpdate) == ACTIVE) } /** @@ -120,12 +124,16 @@ class ContinuousExecution( } committedOffsets = nextOffsets.toStreamProgress(sources) - // Forcibly align commit and offset logs by slicing off any spurious offset logs from - // a previous run. We can't allow commits to an epoch that a previous run reached but - // this run has not. - offsetLog.purgeAfter(latestEpochId) + // Get to an epoch ID that has definitely never been sent to a sink before. Since sink + // commit happens between offset log write and commit log write, this means an epoch ID + // which is not in the offset log. + val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse { + throw new IllegalStateException( + s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" + + s"an element.") + } + currentBatchId = latestOffsetEpoch + 1 - currentBatchId = latestEpochId + 1 logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets") nextOffsets case None => @@ -141,6 +149,7 @@ class ContinuousExecution( * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with. */ private def runContinuous(sparkSessionForQuery: SparkSession): Unit = { + currentRunId = UUID.randomUUID // A list of attributes that will need to be updated. val replacements = new ArrayBuffer[(Attribute, Attribute)] // Translate from continuous relation to the underlying data source. @@ -225,13 +234,11 @@ class ContinuousExecution( triggerExecutor.execute(() => { startTrigger() - if (reader.needsReconfiguration()) { - state.set(RECONFIGURING) + if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { stopSources() if (queryExecutionThread.isAlive) { sparkSession.sparkContext.cancelJobGroup(runId.toString) queryExecutionThread.interrupt() - // No need to join - this thread is about to end anyway. } false } else if (isActive) { @@ -259,6 +266,7 @@ class ContinuousExecution( sparkSessionForQuery, lastExecution)(lastExecution.toRdd) } } finally { + epochEndpoint.askSync[Unit](StopContinuousExecutionWrites) SparkEnv.get.rpcEnv.stop(epochEndpoint) epochUpdateThread.interrupt() @@ -273,17 +281,22 @@ class ContinuousExecution( epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") - if (partitionOffsets.contains(null)) { - // If any offset is null, that means the corresponding partition hasn't seen any data yet, so - // there's nothing meaningful to add to the offset log. - } val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) - synchronized { - if (queryExecutionThread.isAlive) { - offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) - } else { - return - } + val oldOffset = synchronized { + offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) + offsetLog.get(epoch - 1) + } + + // If offset hasn't changed since last epoch, there's been no new data. + if (oldOffset.contains(OffsetSeq.fill(globalOffset))) { + noNewData = true + } + + awaitProgressLock.lock() + try { + awaitProgressLockCondition.signalAll() + } finally { + awaitProgressLock.unlock() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 98017c3ac6a33..40dcbecade814 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -39,6 +39,15 @@ private[continuous] sealed trait EpochCoordinatorMessage extends Serializable */ private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage +/** + * The RpcEndpoint stop() will wait to clear out the message queue before terminating the + * object. This can lead to a race condition where the query restarts at epoch n, a new + * EpochCoordinator starts at epoch n, and then the old epoch coordinator commits epoch n + 1. + * The framework doesn't provide a handle to wait on the message queue, so we use a synchronous + * message to stop any writes to the ContinuousExecution object. + */ +private[sql] case object StopContinuousExecutionWrites extends EpochCoordinatorMessage + // Init messages /** * Set the reader and writer partition counts. Tasks may not be started until the coordinator @@ -116,6 +125,8 @@ private[continuous] class EpochCoordinator( override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + private var queryWritesStopped: Boolean = false + private var numReaderPartitions: Int = _ private var numWriterPartitions: Int = _ @@ -147,12 +158,16 @@ private[continuous] class EpochCoordinator( partitionCommits.remove(k) } for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) { - partitionCommits.remove(k) + partitionOffsets.remove(k) } } } override def receive: PartialFunction[Any, Unit] = { + // If we just drop these messages, we won't do any writes to the query. The lame duck tasks + // won't shed errors or anything. + case _ if queryWritesStopped => () + case CommitPartitionEpoch(partitionId, epoch, message) => logDebug(s"Got commit from partition $partitionId at epoch $epoch: $message") if (!partitionCommits.isDefinedAt((epoch, partitionId))) { @@ -188,5 +203,9 @@ private[continuous] class EpochCoordinator( case SetWriterPartitions(numPartitions) => numWriterPartitions = numPartitions context.reply(()) + + case StopContinuousExecutionWrites => + queryWritesStopped = true + context.reply(()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index db588ae282f38..b5b4a05ab4973 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} +import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -279,18 +280,29 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val dataSource = - DataSource( - df.sparkSession, - className = source, - options = extraOptions.toMap, - partitionColumns = normalizedParCols.getOrElse(Nil)) + val sink = trigger match { + case _: ContinuousTrigger => + val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) + ds.newInstance() match { + case w: ContinuousWriteSupport => w + case _ => throw new AnalysisException( + s"Data source $source does not support continuous writing") + } + case _ => + val ds = DataSource( + df.sparkSession, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) + ds.createSink(outputMode) + } + df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), df, extraOptions.toMap, - dataSource.createSink(outputMode), + sink, outputMode, useTempCheckpointLocation = source == "console", recoverFromCheckpointLocation = true, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index d46461fa9bf6d..0762895fdc620 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -38,8 +38,9 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ @@ -80,6 +81,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be StateStore.stop() // stop the state store maintenance thread and unload store providers } + protected val defaultTrigger = Trigger.ProcessingTime(0) + protected val defaultUseV2Sink = false + /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds @@ -189,7 +193,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be /** Starts the stream, resuming if data has already been processed. It must not be running. */ case class StartStream( - trigger: Trigger = Trigger.ProcessingTime(0), + trigger: Trigger = defaultTrigger, triggerClock: Clock = new SystemClock, additionalConfs: Map[String, String] = Map.empty, checkpointLocation: String = null) @@ -276,7 +280,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def testStream( _stream: Dataset[_], outputMode: OutputMode = OutputMode.Append, - useV2Sink: Boolean = false)(actions: StreamAction*): Unit = synchronized { + useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = synchronized { import org.apache.spark.sql.streaming.util.StreamManualClock // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently @@ -403,18 +407,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = { verify(currentStream != null, "stream not running") - // Get the map of source index to the current source objects - val indexToSource = currentStream - .logicalPlan - .collect { case StreamingExecutionRelation(s, _) => s } - .zipWithIndex - .map(_.swap) - .toMap // Block until all data added has been processed for all the source awaiting.foreach { case (sourceIndex, offset) => failAfter(streamingTimeout) { - currentStream.awaitOffset(indexToSource(sourceIndex), offset) + currentStream.awaitOffset(sourceIndex, offset) } } @@ -473,6 +470,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be // after starting the query. try { currentStream.awaitInitialization(streamingTimeout.toMillis) + currentStream match { + case s: ContinuousExecution => eventually("IncrementalExecution was not created") { + s.lastExecution.executedPlan // will fail if lastExecution is null + } + case _ => + } } catch { case _: StreamingQueryException => // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well. @@ -600,7 +603,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def findSourceIndex(plan: LogicalPlan): Option[Int] = { plan - .collect { case StreamingExecutionRelation(s, _) => s } + .collect { + case StreamingExecutionRelation(s, _) => s + case DataSourceV2Relation(_, r) => r + } .zipWithIndex .find(_._1 == source) .map(_._2) @@ -613,9 +619,13 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be findSourceIndex(query.logicalPlan) }.orElse { findSourceIndex(stream.logicalPlan) + }.orElse { + queryToUse.flatMap { q => + findSourceIndex(q.lastExecution.logical) + } }.getOrElse { throw new IllegalArgumentException( - "Could find index of the source to which data was added") + "Could not find index of the source to which data was added") } // Store the expected offset of added data to wait for it later From 50345a2aa59741c511d555edbbad2da9611e7d16 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 16 Jan 2018 22:14:47 -0800 Subject: [PATCH 112/774] Revert "[SPARK-23020][CORE] Fix races in launcher code, test." This reverts commit 66217dac4f8952a9923625908ad3dcb030763c81. --- .../spark/launcher/SparkLauncherSuite.java | 49 +++++++------------ .../spark/launcher/AbstractAppHandle.java | 22 ++------- .../spark/launcher/ChildProcAppHandle.java | 18 +++---- .../spark/launcher/InProcessAppHandle.java | 17 +++---- .../spark/launcher/LauncherConnection.java | 14 +++--- .../apache/spark/launcher/LauncherServer.java | 46 +++-------------- .../org/apache/spark/launcher/BaseSuite.java | 42 +++------------- .../spark/launcher/LauncherServerSuite.java | 20 +++++--- 8 files changed, 72 insertions(+), 156 deletions(-) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index a042375c6ae91..9d2f563b2e367 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.launcher; -import java.time.Duration; import java.util.Arrays; import java.util.ArrayList; import java.util.HashMap; @@ -32,7 +31,6 @@ import static org.mockito.Mockito.*; import org.apache.spark.SparkContext; -import org.apache.spark.SparkContext$; import org.apache.spark.internal.config.package$; import org.apache.spark.util.Utils; @@ -139,9 +137,7 @@ public void testInProcessLauncher() throws Exception { // Here DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet. // Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM. // See SPARK-23019 and SparkContext.stop() for details. - eventually(Duration.ofSeconds(5), Duration.ofMillis(10), () -> { - assertTrue("SparkContext is still alive.", SparkContext$.MODULE$.getActive().isEmpty()); - }); + TimeUnit.MILLISECONDS.sleep(500); } } @@ -150,35 +146,26 @@ private void inProcessLauncherTestImpl() throws Exception { SparkAppHandle.Listener listener = mock(SparkAppHandle.Listener.class); doAnswer(invocation -> { SparkAppHandle h = (SparkAppHandle) invocation.getArguments()[0]; - synchronized (transitions) { - transitions.add(h.getState()); - } + transitions.add(h.getState()); return null; }).when(listener).stateChanged(any(SparkAppHandle.class)); - SparkAppHandle handle = null; - try { - handle = new InProcessLauncher() - .setMaster("local") - .setAppResource(SparkLauncher.NO_RESOURCE) - .setMainClass(InProcessTestApp.class.getName()) - .addAppArgs("hello") - .startApplication(listener); - - waitFor(handle); - assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); - - // Matches the behavior of LocalSchedulerBackend. - List expected = Arrays.asList( - SparkAppHandle.State.CONNECTED, - SparkAppHandle.State.RUNNING, - SparkAppHandle.State.FINISHED); - assertEquals(expected, transitions); - } finally { - if (handle != null) { - handle.kill(); - } - } + SparkAppHandle handle = new InProcessLauncher() + .setMaster("local") + .setAppResource(SparkLauncher.NO_RESOURCE) + .setMainClass(InProcessTestApp.class.getName()) + .addAppArgs("hello") + .startApplication(listener); + + waitFor(handle); + assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); + + // Matches the behavior of LocalSchedulerBackend. + List expected = Arrays.asList( + SparkAppHandle.State.CONNECTED, + SparkAppHandle.State.RUNNING, + SparkAppHandle.State.FINISHED); + assertEquals(expected, transitions); } public static class SparkLauncherTestApp { diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java index daf0972f824dd..df1e7316861d4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java @@ -33,7 +33,7 @@ abstract class AbstractAppHandle implements SparkAppHandle { private List listeners; private State state; private String appId; - private volatile boolean disposed; + private boolean disposed; protected AbstractAppHandle(LauncherServer server) { this.server = server; @@ -70,7 +70,8 @@ public void stop() { @Override public synchronized void disconnect() { - if (!isDisposed()) { + if (!disposed) { + disposed = true; if (connection != null) { try { connection.close(); @@ -78,7 +79,7 @@ public synchronized void disconnect() { // no-op. } } - dispose(); + server.unregister(this); } } @@ -94,21 +95,6 @@ boolean isDisposed() { return disposed; } - /** - * Mark the handle as disposed, and set it as LOST in case the current state is not final. - */ - synchronized void dispose() { - if (!isDisposed()) { - // Unregister first to make sure that the connection with the app has been really - // terminated. - server.unregister(this); - if (!getState().isFinal()) { - setState(State.LOST); - } - this.disposed = true; - } - } - void setState(State s) { setState(s, false); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index 2b99461652e1f..8b3f427b7750e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -48,16 +48,14 @@ public synchronized void disconnect() { @Override public synchronized void kill() { - if (!isDisposed()) { - setState(State.KILLED); - disconnect(); - if (childProc != null) { - if (childProc.isAlive()) { - childProc.destroyForcibly(); - } - childProc = null; + disconnect(); + if (childProc != null) { + if (childProc.isAlive()) { + childProc.destroyForcibly(); } + childProc = null; } + setState(State.KILLED); } void setChildProc(Process childProc, String loggerName, InputStream logStream) { @@ -96,6 +94,8 @@ void monitorChild() { return; } + disconnect(); + int ec; try { ec = proc.exitValue(); @@ -118,8 +118,6 @@ void monitorChild() { if (newState != null) { setState(newState, true); } - - disconnect(); } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index f04263cb74a58..acd64c962604f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -39,16 +39,15 @@ class InProcessAppHandle extends AbstractAppHandle { @Override public synchronized void kill() { - if (!isDisposed()) { - LOG.warning("kill() may leave the underlying app running in in-process mode."); - setState(State.KILLED); - disconnect(); - - // Interrupt the thread. This is not guaranteed to kill the app, though. - if (app != null) { - app.interrupt(); - } + LOG.warning("kill() may leave the underlying app running in in-process mode."); + disconnect(); + + // Interrupt the thread. This is not guaranteed to kill the app, though. + if (app != null) { + app.interrupt(); } + + setState(State.KILLED); } synchronized void start(String appName, Method main, String[] args) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java index fd6f229b2349c..b4a8719e26053 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java @@ -95,15 +95,15 @@ protected synchronized void send(Message msg) throws IOException { } @Override - public synchronized void close() throws IOException { + public void close() throws IOException { if (!closed) { - closed = true; - socket.close(); + synchronized (this) { + if (!closed) { + closed = true; + socket.close(); + } + } } } - boolean isOpen() { - return !closed; - } - } diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index 660c4443b20b9..b8999a1d7a4f4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -217,33 +217,6 @@ void unregister(AbstractAppHandle handle) { break; } } - - // If there is a live connection for this handle, we need to wait for it to finish before - // returning, otherwise there might be a race between the connection thread processing - // buffered data and the handle cleaning up after itself, leading to potentially the wrong - // state being reported for the handle. - ServerConnection conn = null; - synchronized (clients) { - for (ServerConnection c : clients) { - if (c.handle == handle) { - conn = c; - break; - } - } - } - - if (conn != null) { - synchronized (conn) { - if (conn.isOpen()) { - try { - conn.wait(); - } catch (InterruptedException ie) { - // Ignore. - } - } - } - } - unref(); } @@ -315,7 +288,7 @@ private String createSecret() { private class ServerConnection extends LauncherConnection { private TimerTask timeout; - volatile AbstractAppHandle handle; + private AbstractAppHandle handle; ServerConnection(Socket socket, TimerTask timeout) throws IOException { super(socket); @@ -365,21 +338,16 @@ protected void handle(Message msg) throws IOException { @Override public void close() throws IOException { - if (!isOpen()) { - return; - } - synchronized (clients) { clients.remove(this); } - - synchronized (this) { - super.close(); - notifyAll(); - } - + super.close(); if (handle != null) { - handle.dispose(); + if (!handle.getState().isFinal()) { + LOG.log(Level.WARNING, "Lost connection to spark application."); + handle.setState(SparkAppHandle.State.LOST); + } + handle.disconnect(); } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java index 3722a59d9438e..3e1a90eae98d4 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.launcher; -import java.time.Duration; import java.util.concurrent.TimeUnit; import org.junit.After; @@ -48,46 +47,19 @@ public void postChecks() { assertNull(server); } - protected void waitFor(final SparkAppHandle handle) throws Exception { + protected void waitFor(SparkAppHandle handle) throws Exception { + long deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); try { - eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { - assertTrue("Handle is not in final state.", handle.getState().isFinal()); - }); + while (!handle.getState().isFinal()) { + assertTrue("Timed out waiting for handle to transition to final state.", + System.nanoTime() < deadline); + TimeUnit.MILLISECONDS.sleep(10); + } } finally { if (!handle.getState().isFinal()) { handle.kill(); } } - - // Wait until the handle has been marked as disposed, to make sure all cleanup tasks - // have been performed. - AbstractAppHandle ahandle = (AbstractAppHandle) handle; - eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { - assertTrue("Handle is still not marked as disposed.", ahandle.isDisposed()); - }); - } - - /** - * Call a closure that performs a check every "period" until it succeeds, or the timeout - * elapses. - */ - protected void eventually(Duration timeout, Duration period, Runnable check) throws Exception { - assertTrue("Timeout needs to be larger than period.", timeout.compareTo(period) > 0); - long deadline = System.nanoTime() + timeout.toNanos(); - int count = 0; - while (true) { - try { - count++; - check.run(); - return; - } catch (Throwable t) { - if (System.nanoTime() >= deadline) { - String msg = String.format("Failed check after %d tries: %s.", count, t.getMessage()); - throw new IllegalStateException(msg, t); - } - Thread.sleep(period.toMillis()); - } - } } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 75c1af0c71e2a..7e2b09ce25c9b 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -23,14 +23,12 @@ import java.net.InetAddress; import java.net.Socket; import java.net.SocketException; -import java.time.Duration; import java.util.Arrays; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import static org.junit.Assert.*; @@ -199,20 +197,28 @@ private void close(Closeable c) { * server-side close immediately. */ private void waitForError(TestClient client, String secret) throws Exception { - final AtomicBoolean helloSent = new AtomicBoolean(); - eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> { + boolean helloSent = false; + int maxTries = 10; + for (int i = 0; i < maxTries; i++) { try { - if (!helloSent.get()) { + if (!helloSent) { client.send(new Hello(secret, "1.4.0")); - helloSent.set(true); + helloSent = true; } else { client.send(new SetAppId("appId")); } fail("Expected error but message went through."); } catch (IllegalStateException | IOException e) { // Expected. + break; + } catch (AssertionError e) { + if (i < maxTries - 1) { + Thread.sleep(100); + } else { + throw new AssertionError("Test failed after " + maxTries + " attempts.", e); + } } - }); + } } private static class TestClient extends LauncherConnection { From a963980a6d2b4bef2c546aa33acf0aa501d2507b Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 16 Jan 2018 22:27:28 -0800 Subject: [PATCH 113/774] Fix merge between 07ae39d0ec and 1667057851 ## What changes were proposed in this pull request? The first commit added a new test, and the second refactored the class the test was in. The automatic merge put the test in the wrong place. ## How was this patch tested? - Author: Jose Torres Closes #20289 from jose-torres/fix. --- .../apache/spark/sql/kafka010/KafkaSourceSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 1acff61e11d2a..62f6a34a6b67a 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -479,11 +479,6 @@ class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { // `failOnDataLoss` is `false`, we should not fail the query assert(query.exception.isEmpty) } -} - -class KafkaSourceSuiteBase extends KafkaSourceTest { - - import testImplicits._ test("SPARK-22956: currentPartitionOffsets should be set when no new data comes in") { def getSpecificDF(range: Range.Inclusive): org.apache.spark.sql.Dataset[Int] = { @@ -549,6 +544,11 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { CheckLastBatch(120 to 124: _*) ) } +} + +class KafkaSourceSuiteBase extends KafkaSourceTest { + + import testImplicits._ test("cannot stop Kafka stream") { val topic = newTopic() From a0aedb0ded4183cc33b27e369df1cbf862779e26 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 17 Jan 2018 14:32:18 +0800 Subject: [PATCH 114/774] [SPARK-23072][SQL][TEST] Add a Unicode schema test for file-based data sources ## What changes were proposed in this pull request? After [SPARK-20682](https://github.com/apache/spark/pull/19651), Apache Spark 2.3 is able to read ORC files with Unicode schema. Previously, it raises `org.apache.spark.sql.catalyst.parser.ParseException`. This PR adds a Unicode schema test for CSV/JSON/ORC/Parquet file-based data sources. Note that TEXT data source only has [a single column with a fixed name 'value'](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala#L71). ## How was this patch tested? Pass the newly added test case. Author: Dongjoon Hyun Closes #20266 from dongjoon-hyun/SPARK-23072. --- .../spark/sql/FileBasedDataSourceSuite.scala | 81 +++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 16 ---- .../sql/hive/MetastoreDataSourcesSuite.scala | 14 ---- .../sql/hive/execution/SQLQuerySuite.scala | 8 -- 4 files changed, 81 insertions(+), 38 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala new file mode 100644 index 0000000000000..22fb496bc838e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext + +class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + private val allFileBasedDataSources = Seq("orc", "parquet", "csv", "json", "text") + + allFileBasedDataSources.foreach { format => + test(s"Writing empty datasets should not fail - $format") { + withTempPath { dir => + Seq("str").toDS().limit(0).write.format(format).save(dir.getCanonicalPath) + } + } + } + + // `TEXT` data source always has a single column whose name is `value`. + allFileBasedDataSources.filterNot(_ == "text").foreach { format => + test(s"SPARK-23072 Write and read back unicode column names - $format") { + withTempPath { path => + val dir = path.getCanonicalPath + + // scalastyle:off nonascii + val df = Seq("a").toDF("한글") + // scalastyle:on nonascii + + df.write.format(format).option("header", "true").save(dir) + val answerDf = spark.read.format(format).option("header", "true").load(dir) + + assert(df.schema.sameType(answerDf.schema)) + checkAnswer(df, answerDf) + } + } + } + + // Only ORC/Parquet support this. `CSV` and `JSON` returns an empty schema. + // `TEXT` data source always has a single column whose name is `value`. + Seq("orc", "parquet").foreach { format => + test(s"SPARK-15474 Write and read back non-emtpy schema with empty dataframe - $format") { + withTempPath { file => + val path = file.getCanonicalPath + val emptyDf = Seq((true, 1, "str")).toDF().limit(0) + emptyDf.write.format(format).save(path) + + val df = spark.read.format(format).load(path) + assert(df.schema.sameType(emptyDf.schema)) + checkAnswer(df, emptyDf) + } + } + } + + allFileBasedDataSources.foreach { format => + test(s"SPARK-22146 read files containing special characters using $format") { + val nameWithSpecialChars = s"sp&cial%chars" + withTempDir { dir => + val tmpFile = s"$dir/$nameWithSpecialChars" + spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) + val fileContent = spark.read.format(format).load(tmpFile) + checkAnswer(fileContent, Seq(Row("a"), Row("b"))) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 96bf65fce9c4a..7c9840a34eaa3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2757,20 +2757,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } - - // Only New OrcFileFormat supports this - Seq(classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName, - "parquet").foreach { format => - test(s"SPARK-15474 Write and read back non-emtpy schema with empty dataframe - $format") { - withTempPath { file => - val path = file.getCanonicalPath - val emptyDf = Seq((true, 1, "str")).toDF.limit(0) - emptyDf.write.format(format).save(path) - - val df = spark.read.format(format).load(path) - assert(df.schema.sameType(emptyDf.schema)) - checkAnswer(df, emptyDf) - } - } - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index c8caba83bf365..fade143a1755e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -23,14 +23,12 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path -import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.execution.command.CreateTableCommand import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.hive.HiveExternalCatalog._ -import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf._ @@ -1344,18 +1342,6 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } - Seq("orc", "parquet", "csv", "json", "text").foreach { format => - test(s"SPARK-22146: read files containing special characters using $format") { - val nameWithSpecialChars = s"sp&cial%chars" - withTempDir { dir => - val tmpFile = s"$dir/$nameWithSpecialChars" - spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) - val fileContent = spark.read.format(format).load(tmpFile) - checkAnswer(fileContent, Seq(Row("a"), Row("b"))) - } - } - } - private def withDebugMode(f: => Unit): Unit = { val previousValue = sparkSession.sparkContext.conf.get(DEBUG_MODE) try { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 47adc77a52d51..33bcae91fdaf4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2159,12 +2159,4 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } - - Seq("orc", "parquet", "csv", "json", "text").foreach { format => - test(s"Writing empty datasets should not fail - $format") { - withTempDir { dir => - Seq("str").toDS.limit(0).write.format(format).save(dir.getCanonicalPath + "/tmp") - } - } - } } From 1f3d933e0bd2b1e934a233ed699ad39295376e71 Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Wed, 17 Jan 2018 16:01:41 +0800 Subject: [PATCH 115/774] [SPARK-23062][SQL] Improve EXCEPT documentation ## What changes were proposed in this pull request? Make the default behavior of EXCEPT (i.e. EXCEPT DISTINCT) more explicit in the documentation, and call out the change in behavior from 1.x. Author: Henry Robinson Closes #20254 from henryr/spark-23062. --- R/pkg/R/DataFrame.R | 2 +- python/pyspark/sql/dataframe.py | 3 ++- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 6caa125e1e14a..29f3e986eaab6 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2853,7 +2853,7 @@ setMethod("intersect", #' except #' #' Return a new SparkDataFrame containing rows in this SparkDataFrame -#' but not in another SparkDataFrame. This is equivalent to \code{EXCEPT} in SQL. +#' but not in another SparkDataFrame. This is equivalent to \code{EXCEPT DISTINCT} in SQL. #' #' @param x a SparkDataFrame. #' @param y a SparkDataFrame. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 95eca76fa9888..2d5e9b91468cf 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1364,7 +1364,8 @@ def subtract(self, other): """ Return a new :class:`DataFrame` containing rows in this frame but not in another frame. - This is equivalent to `EXCEPT` in SQL. + This is equivalent to `EXCEPT DISTINCT` in SQL. + """ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 34f0ab5aa6699..912f411fa3845 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1903,7 +1903,7 @@ class Dataset[T] private[sql]( /** * Returns a new Dataset containing rows in this Dataset but not in another Dataset. - * This is equivalent to `EXCEPT` in SQL. + * This is equivalent to `EXCEPT DISTINCT` in SQL. * * @note Equality checking is performed directly on the encoded representation of the data * and thus is not affected by a custom `equals` function defined on `T`. From 0f8a28617a0742d5a99debfbae91222c2e3b5cec Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 17 Jan 2018 21:53:36 +0800 Subject: [PATCH 116/774] [SPARK-21783][SQL] Turn on ORC filter push-down by default ## What changes were proposed in this pull request? ORC filter push-down is disabled by default from the beginning, [SPARK-2883](https://github.com/apache/spark/commit/aa31e431fc09f0477f1c2351c6275769a31aca90#diff-41ef65b9ef5b518f77e2a03559893f4dR149 ). Now, Apache Spark starts to depend on Apache ORC 1.4.1. For Apache Spark 2.3, this PR turns on ORC filter push-down by default like Parquet ([SPARK-9207](https://issues.apache.org/jira/browse/SPARK-21783)) as a part of [SPARK-20901](https://issues.apache.org/jira/browse/SPARK-20901), "Feature parity for ORC with Parquet". ## How was this patch tested? Pass the existing tests. Author: Dongjoon Hyun Closes #20265 from dongjoon-hyun/SPARK-21783. --- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../spark/sql/FilterPushdownBenchmark.scala | 243 ++++++++++++++++++ 2 files changed, 244 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6746fbcaf2483..16fbb0c3e9e21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -410,7 +410,7 @@ object SQLConf { val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val HIVE_VERIFY_PARTITION_PATH = buildConf("spark.sql.hive.verifyPartitionPath") .doc("When true, check all the partition paths under the table\'s root directory " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala new file mode 100644 index 0000000000000..c6dd7dadc9d93 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.File + +import scala.util.{Random, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.functions.monotonically_increasing_id +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.{Benchmark, Utils} + + +/** + * Benchmark to measure read performance with Filter pushdown. + */ +object FilterPushdownBenchmark { + val conf = new SparkConf() + conf.set("orc.compression", "snappy") + conf.set("spark.sql.parquet.compression.codec", "snappy") + + private val spark = SparkSession.builder() + .master("local[1]") + .appName("FilterPushdownBenchmark") + .config(conf) + .getOrCreate() + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + + private def prepareTable(dir: File, numRows: Int, width: Int): Unit = { + import spark.implicits._ + val selectExpr = (1 to width).map(i => s"CAST(value AS STRING) c$i") + val df = spark.range(numRows).map(_ => Random.nextLong).selectExpr(selectExpr: _*) + .withColumn("id", monotonically_increasing_id()) + + val dirORC = dir.getCanonicalPath + "/orc" + val dirParquet = dir.getCanonicalPath + "/parquet" + + df.write.mode("overwrite").orc(dirORC) + df.write.mode("overwrite").parquet(dirParquet) + + spark.read.orc(dirORC).createOrReplaceTempView("orcTable") + spark.read.parquet(dirParquet).createOrReplaceTempView("parquetTable") + } + + def filterPushDownBenchmark( + values: Int, + title: String, + whereExpr: String, + selectExpr: String = "*"): Unit = { + val benchmark = new Benchmark(title, values, minNumIters = 5) + + Seq(false, true).foreach { pushDownEnabled => + val name = s"Parquet Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { + spark.sql(s"SELECT $selectExpr FROM parquetTable WHERE $whereExpr").collect() + } + } + } + + Seq(false, true).foreach { pushDownEnabled => + val name = s"Native ORC Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { + spark.sql(s"SELECT $selectExpr FROM orcTable WHERE $whereExpr").collect() + } + } + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + Select 0 row (id IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7882 / 7957 2.0 501.1 1.0X + Parquet Vectorized (Pushdown) 55 / 60 285.2 3.5 142.9X + Native ORC Vectorized 5592 / 5627 2.8 355.5 1.4X + Native ORC Vectorized (Pushdown) 66 / 70 237.2 4.2 118.9X + + Select 0 row (7864320 < id < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7884 / 7909 2.0 501.2 1.0X + Parquet Vectorized (Pushdown) 739 / 752 21.3 47.0 10.7X + Native ORC Vectorized 5614 / 5646 2.8 356.9 1.4X + Native ORC Vectorized (Pushdown) 81 / 83 195.2 5.1 97.8X + + Select 1 row (id = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7905 / 8027 2.0 502.6 1.0X + Parquet Vectorized (Pushdown) 740 / 766 21.2 47.1 10.7X + Native ORC Vectorized 5684 / 5738 2.8 361.4 1.4X + Native ORC Vectorized (Pushdown) 78 / 81 202.4 4.9 101.7X + + Select 1 row (id <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7928 / 7993 2.0 504.1 1.0X + Parquet Vectorized (Pushdown) 747 / 772 21.0 47.5 10.6X + Native ORC Vectorized 5728 / 5753 2.7 364.2 1.4X + Native ORC Vectorized (Pushdown) 76 / 78 207.9 4.8 104.8X + + Select 1 row (7864320 <= id <= 7864320):Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7939 / 8021 2.0 504.8 1.0X + Parquet Vectorized (Pushdown) 746 / 770 21.1 47.4 10.6X + Native ORC Vectorized 5690 / 5734 2.8 361.7 1.4X + Native ORC Vectorized (Pushdown) 76 / 79 206.7 4.8 104.3X + + Select 1 row (7864319 < id < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7972 / 8019 2.0 506.9 1.0X + Parquet Vectorized (Pushdown) 742 / 764 21.2 47.2 10.7X + Native ORC Vectorized 5704 / 5743 2.8 362.6 1.4X + Native ORC Vectorized (Pushdown) 76 / 78 207.9 4.8 105.4X + + Select 10% rows (id < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 8733 / 8808 1.8 555.2 1.0X + Parquet Vectorized (Pushdown) 2213 / 2267 7.1 140.7 3.9X + Native ORC Vectorized 6420 / 6463 2.4 408.2 1.4X + Native ORC Vectorized (Pushdown) 1313 / 1331 12.0 83.5 6.7X + + Select 50% rows (id < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 11518 / 11591 1.4 732.3 1.0X + Parquet Vectorized (Pushdown) 7962 / 7991 2.0 506.2 1.4X + Native ORC Vectorized 8927 / 8985 1.8 567.6 1.3X + Native ORC Vectorized (Pushdown) 6102 / 6160 2.6 387.9 1.9X + + Select 90% rows (id < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 14255 / 14389 1.1 906.3 1.0X + Parquet Vectorized (Pushdown) 13564 / 13594 1.2 862.4 1.1X + Native ORC Vectorized 11442 / 11608 1.4 727.5 1.2X + Native ORC Vectorized (Pushdown) 10991 / 11029 1.4 698.8 1.3X + + Select all rows (id IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 14917 / 14938 1.1 948.4 1.0X + Parquet Vectorized (Pushdown) 14910 / 14964 1.1 948.0 1.0X + Native ORC Vectorized 11986 / 12069 1.3 762.0 1.2X + Native ORC Vectorized (Pushdown) 12037 / 12123 1.3 765.3 1.2X + + Select all rows (id > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 14951 / 14976 1.1 950.6 1.0X + Parquet Vectorized (Pushdown) 14934 / 15016 1.1 949.5 1.0X + Native ORC Vectorized 12000 / 12156 1.3 763.0 1.2X + Native ORC Vectorized (Pushdown) 12079 / 12113 1.3 767.9 1.2X + + Select all rows (id != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 14930 / 14972 1.1 949.3 1.0X + Parquet Vectorized (Pushdown) 15015 / 15047 1.0 954.6 1.0X + Native ORC Vectorized 12090 / 12259 1.3 768.7 1.2X + Native ORC Vectorized (Pushdown) 12021 / 12096 1.3 764.2 1.2X + */ + benchmark.run() + } + + def main(args: Array[String]): Unit = { + val numRows = 1024 * 1024 * 15 + val width = 5 + val mid = numRows / 2 + + withTempPath { dir => + withTempTable("orcTable", "patquetTable") { + prepareTable(dir, numRows, width) + + Seq("id IS NULL", s"$mid < id AND id < $mid").foreach { whereExpr => + val title = s"Select 0 row ($whereExpr)".replace("id AND id", "id") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + Seq( + s"id = $mid", + s"id <=> $mid", + s"$mid <= id AND id <= $mid", + s"${mid - 1} < id AND id < ${mid + 1}" + ).foreach { whereExpr => + val title = s"Select 1 row ($whereExpr)".replace("id AND id", "id") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(id)") + + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% rows (id < ${numRows * percent / 100})", + s"id < ${numRows * percent / 100}", + selectExpr + ) + } + + Seq("id IS NOT NULL", "id > -1", "id != -1").foreach { whereExpr => + filterPushDownBenchmark( + numRows, + s"Select all rows ($whereExpr)", + whereExpr, + selectExpr) + } + } + } + } +} From 8598a982b4147abe5f1aae005fea0fd5ae395ac4 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 18 Jan 2018 00:05:26 +0800 Subject: [PATCH 117/774] [SPARK-23079][SQL] Fix query constraints propagation with aliases ## What changes were proposed in this pull request? Previously, PR #19201 fix the problem of non-converging constraints. After that PR #19149 improve the loop and constraints is inferred only once. So the problem of non-converging constraints is gone. However, the case below will fail. ``` spark.range(5).write.saveAsTable("t") val t = spark.read.table("t") val left = t.withColumn("xid", $"id" + lit(1)).as("x") val right = t.withColumnRenamed("id", "xid").as("y") val df = left.join(right, "xid").filter("id = 3").toDF() checkAnswer(df, Row(4, 3)) ``` Because `aliasMap` replace all the aliased child. See the test case in PR for details. This PR is to fix this bug by removing useless code for preventing non-converging constraints. It can be also fixed with #20270, but this is much simpler and clean up the code. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #20278 from gengliangwang/FixConstraintSimple. --- .../catalyst/plans/logical/LogicalPlan.scala | 1 + .../plans/logical/QueryPlanConstraints.scala | 37 +----------- .../InferFiltersFromConstraintsSuite.scala | 59 +------------------ .../plans/ConstraintPropagationSuite.scala | 2 + .../org/apache/spark/sql/SQLQuerySuite.scala | 11 ++++ 5 files changed, 17 insertions(+), 93 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index ff2a0ec588567..c8ccd9bd03994 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -255,6 +255,7 @@ abstract class UnaryNode extends LogicalPlan { case expr: Expression if expr.semanticEquals(e) => a.toAttribute }) + allConstraints += EqualNullSafe(e, a.toAttribute) case _ => // Don't change. } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 9c0a30a47f839..5c7b8e5b97883 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -94,25 +94,16 @@ trait QueryPlanConstraints { self: LogicalPlan => case _ => Seq.empty[Attribute] } - // Collect aliases from expressions of the whole tree rooted by the current QueryPlan node, so - // we may avoid producing recursive constraints. - private lazy val aliasMap: AttributeMap[Expression] = AttributeMap( - expressions.collect { - case a: Alias if !a.child.isInstanceOf[Literal] => (a.toAttribute, a.child) - } ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints].aliasMap)) - // Note: the explicit cast is necessary, since Scala compiler fails to infer the type. - /** * Infers an additional set of constraints from a given set of equality constraints. * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an * additional constraint of the form `b = 5`. */ private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { - val aliasedConstraints = eliminateAliasedExpressionInConstraints(constraints) var inferredConstraints = Set.empty[Expression] - aliasedConstraints.foreach { + constraints.foreach { case eq @ EqualTo(l: Attribute, r: Attribute) => - val candidateConstraints = aliasedConstraints - eq + val candidateConstraints = constraints - eq inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) case _ => // No inference @@ -120,30 +111,6 @@ trait QueryPlanConstraints { self: LogicalPlan => inferredConstraints -- constraints } - /** - * Replace the aliased expression in [[Alias]] with the alias name if both exist in constraints. - * Thus non-converging inference can be prevented. - * E.g. `Alias(b, f(a)), a = b` infers `f(a) = f(f(a))` without eliminating aliased expressions. - * Also, the size of constraints is reduced without losing any information. - * When the inferred filters are pushed down the operators that generate the alias, - * the alias names used in filters are replaced by the aliased expressions. - */ - private def eliminateAliasedExpressionInConstraints(constraints: Set[Expression]) - : Set[Expression] = { - val attributesInEqualTo = constraints.flatMap { - case EqualTo(l: Attribute, r: Attribute) => l :: r :: Nil - case _ => Nil - } - var aliasedConstraints = constraints - attributesInEqualTo.foreach { a => - if (aliasMap.contains(a)) { - val child = aliasMap.get(a).get - aliasedConstraints = replaceConstraints(aliasedConstraints, child, a) - } - } - aliasedConstraints - } - private def replaceConstraints( constraints: Set[Expression], source: Expression, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index a0708bf7eee9a..178c4b8c270a0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -34,6 +34,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { PushDownPredicate, InferFiltersFromConstraints, CombineFilters, + SimplifyBinaryComparison, BooleanSimplification) :: Nil } @@ -160,64 +161,6 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("inner join with alias: don't generate constraints for recursive functions") { - val t1 = testRelation.subquery('t1) - val t2 = testRelation.subquery('t2) - - // We should prevent `Coalese(a, b)` from recursively creating complicated constraints through - // the constraint inference procedure. - val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)) - // We hide an `Alias` inside the child's child's expressions, to cover the situation reported - // in [SPARK-20700]. - .select('int_col, 'd, 'a).as("t") - .join(t2, Inner, - Some("t.a".attr === "t2.a".attr - && "t.d".attr === "t2.a".attr - && "t.int_col".attr === "t2.a".attr)) - .analyze - val correctAnswer = t1 - .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) && IsNotNull(Coalesce(Seq('b, 'a))) - && IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) && IsNotNull(Coalesce(Seq('a, 'b))) - && 'a === 'b && 'a === Coalesce(Seq('a, 'a)) && 'a === Coalesce(Seq('a, 'b)) - && 'a === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('a, 'b)) - && 'b === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('b, 'b))) - .select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)) - .select('int_col, 'd, 'a).as("t") - .join( - t2.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) && - 'a === Coalesce(Seq('a, 'a))), - Inner, - Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr - && "t.int_col".attr === "t2.a".attr)) - .analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) - } - - test("inner join with EqualTo expressions containing part of each other: don't generate " + - "constraints for recursive functions") { - val t1 = testRelation.subquery('t1) - val t2 = testRelation.subquery('t2) - - // We should prevent `c = Coalese(a, b)` and `a = Coalese(b, c)` from recursively creating - // complicated constraints through the constraint inference procedure. - val originalQuery = t1 - .select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e)) - .where('a === 'd && 'c === 'e) - .join(t2, Inner, Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr)) - .analyze - val correctAnswer = t1 - .where(IsNotNull('a) && IsNotNull('c) && 'a === Coalesce(Seq('b, 'c)) && - 'c === Coalesce(Seq('a, 'b))) - .select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e)) - .join(t2.where(IsNotNull('a) && IsNotNull('c)), - Inner, - Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr)) - .analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) - } - test("generate correct filters for alias that don't produce recursive constraints") { val t1 = testRelation.subquery('t1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 866ff0d33cbb2..a37e06d922642 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -134,6 +134,8 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { verifyConstraints(aliasedRelation.analyze.constraints, ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "x")), + resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"), + resolveColumn(aliasedRelation.analyze, "z") <=> resolveColumn(aliasedRelation.analyze, "x"), resolveColumn(aliasedRelation.analyze, "z") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "z"))))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 7c9840a34eaa3..d4d0aa4f5f5eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2717,6 +2717,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("SPARK-23079: constraints should be inferred correctly with aliases") { + withTable("t") { + spark.range(5).write.saveAsTable("t") + val t = spark.read.table("t") + val left = t.withColumn("xid", $"id" + lit(1)).as("x") + val right = t.withColumnRenamed("id", "xid").as("y") + val df = left.join(right, "xid").filter("id = 3").toDF() + checkAnswer(df, Row(4, 3)) + } + } + test("SRARK-22266: the same aggregate function was calculated multiple times") { val query = "SELECT a, max(b+1), max(b+1) + 1 FROM testData2 GROUP BY a" val df = sql(query) From c132538a164cd8b55dbd7e8ffdc0c0782a0b588c Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Wed, 17 Jan 2018 09:27:49 -0800 Subject: [PATCH 118/774] [SPARK-23020] Ignore Flaky Test: SparkLauncherSuite.testInProcessLauncher ## What changes were proposed in this pull request? Temporarily ignoring flaky test `SparkLauncherSuite.testInProcessLauncher` to de-flake the builds. This should be re-enabled when SPARK-23020 is merged. ## How was this patch tested? N/A (Test Only Change) Author: Sameer Agarwal Closes #20291 from sameeragarwal/disable-test-2. --- .../java/org/apache/spark/launcher/SparkLauncherSuite.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 9d2f563b2e367..dffa609f1cbdf 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -25,6 +25,7 @@ import java.util.Properties; import java.util.concurrent.TimeUnit; +import org.junit.Ignore; import org.junit.Test; import static org.junit.Assert.*; import static org.junit.Assume.*; @@ -120,7 +121,8 @@ public void testChildProcLauncher() throws Exception { assertEquals(0, app.waitFor()); } - @Test + // TODO: [SPARK-23020] Re-enable this + @Ignore public void testInProcessLauncher() throws Exception { // Because this test runs SparkLauncher in process and in client mode, it pollutes the system // properties, and that can cause test failures down the test pipeline. So restore the original From 86a845031824a5334db6a5299c6f5dcc982bc5b8 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 17 Jan 2018 13:52:51 -0800 Subject: [PATCH 119/774] [SPARK-23033][SS] Don't use task level retry for continuous processing ## What changes were proposed in this pull request? Continuous processing tasks will fail on any attempt number greater than 0. ContinuousExecution will catch these failures and restart globally from the last recorded checkpoints. ## How was this patch tested? unit test Author: Jose Torres Closes #20225 from jose-torres/no-retry. --- .../spark/sql/kafka010/KafkaSourceSuite.scala | 8 +-- .../ContinuousDataSourceRDDIter.scala | 5 ++ .../continuous/ContinuousExecution.scala | 2 +- .../ContinuousTaskRetryException.scala | 26 +++++++ .../spark/sql/streaming/StreamTest.scala | 9 ++- .../continuous/ContinuousSuite.scala | 71 +++++++++++-------- 6 files changed, 84 insertions(+), 37 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTaskRetryException.scala diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 62f6a34a6b67a..27dbb3f7a8f31 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -808,16 +808,14 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { val query = kafka .writeStream .format("memory") - .outputMode("append") .queryName("kafkaColumnTypes") .trigger(defaultTrigger) .start() - var rows: Array[Row] = Array() eventually(timeout(streamingTimeout)) { - rows = spark.table("kafkaColumnTypes").collect() - assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + assert(spark.table("kafkaColumnTypes").count == 1, + s"Unexpected results: ${spark.table("kafkaColumnTypes").collectAsList()}") } - val row = rows(0) + val row = spark.table("kafkaColumnTypes").head() assert(row.getAs[Array[Byte]]("key") === null, s"Unexpected results: $row") assert(row.getAs[Array[Byte]]("value") === "1".getBytes(UTF_8), s"Unexpected results: $row") assert(row.getAs[String]("topic") === topic, s"Unexpected results: $row") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index 66eb42d4658f6..dcb3b54c4e160 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -52,6 +52,11 @@ class ContinuousDataSourceRDD( } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + // If attempt number isn't 0, this is a task retry, which we don't support. + if (context.attemptNumber() != 0) { + throw new ContinuousTaskRetryException() + } + val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader() val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 667410ef9f1c6..45b794c70a50a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -24,7 +24,7 @@ import java.util.function.UnaryOperator import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTaskRetryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTaskRetryException.scala new file mode 100644 index 0000000000000..e0a6f6dd50bb3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTaskRetryException.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import org.apache.spark.SparkException + +/** + * An exception thrown when a continuous processing task runs with a nonzero attempt ID. + */ +class ContinuousTaskRetryException + extends SparkException("Continuous execution does not support task retry", null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 0762895fdc620..c75247e0f6ed8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -472,8 +472,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be currentStream.awaitInitialization(streamingTimeout.toMillis) currentStream match { case s: ContinuousExecution => eventually("IncrementalExecution was not created") { - s.lastExecution.executedPlan // will fail if lastExecution is null - } + s.lastExecution.executedPlan // will fail if lastExecution is null + } case _ => } } catch { @@ -645,7 +645,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } case CheckAnswerRowsContains(expectedAnswer, lastOnly) => - val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) + val sparkAnswer = currentStream match { + case null => fetchStreamAnswer(lastStream, lastOnly) + case s => fetchStreamAnswer(s, lastOnly) + } QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach { error => failTest(error) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 9562c10feafe9..4b4ed82dc6520 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -17,36 +17,18 @@ package org.apache.spark.sql.streaming.continuous -import java.io.{File, InterruptedIOException, IOException, UncheckedIOException} -import java.nio.channels.ClosedByInterruptException -import java.util.concurrent.{CountDownLatch, ExecutionException, TimeoutException, TimeUnit} +import java.util.UUID -import scala.reflect.ClassTag -import scala.util.control.ControlThrowable - -import com.google.common.util.concurrent.UncheckedExecutionException -import org.apache.commons.io.FileUtils -import org.apache.hadoop.conf.Configuration - -import org.apache.spark.{SparkContext, SparkEnv} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.{SparkContext, SparkEnv, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.plans.logical.Range -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes -import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 -import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.streaming.{StreamTest, Trigger} -import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.TestSparkSession -import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils class ContinuousSuiteBase extends StreamTest { // We need more than the default local[2] to be able to schedule all partitions simultaneously. @@ -219,6 +201,41 @@ class ContinuousSuite extends ContinuousSuiteBase { StopStream) } + test("task failure kills the query") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .select('value) + + // Get an arbitrary task from this query to kill. It doesn't matter which one. + var taskId: Long = -1 + val listener = new SparkListener() { + override def onTaskStart(start: SparkListenerTaskStart): Unit = { + taskId = start.taskInfo.taskId + } + } + spark.sparkContext.addSparkListener(listener) + try { + testStream(df, useV2Sink = true)( + StartStream(Trigger.Continuous(100)), + Execute(waitForRateSourceTriggers(_, 2)), + Execute { _ => + // Wait until a task is started, then kill its first attempt. + eventually(timeout(streamingTimeout)) { + assert(taskId != -1) + } + spark.sparkContext.killTaskAttempt(taskId) + }, + ExpectFailure[SparkException] { e => + e.getCause != null && e.getCause.getCause.isInstanceOf[ContinuousTaskRetryException] + }) + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } + test("query without test harness") { val df = spark.readStream .format("rate") @@ -258,13 +275,9 @@ class ContinuousStressSuite extends ContinuousSuiteBase { AwaitEpoch(0), Execute(waitForRateSourceTriggers(_, 201)), IncrementEpoch(), - Execute { query => - val data = query.sink.asInstanceOf[MemorySinkV2].allData - val vals = data.map(_.getLong(0)).toSet - assert(scala.Range(0, 25000).forall { i => - vals.contains(i) - }) - }) + StopStream, + CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_))) + ) } test("automatic epoch advancement") { @@ -280,6 +293,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase { AwaitEpoch(0), Execute(waitForRateSourceTriggers(_, 201)), IncrementEpoch(), + StopStream, CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_)))) } @@ -311,6 +325,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase { StopStream, StartStream(Trigger.Continuous(2012)), AwaitEpoch(50), + StopStream, CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_)))) } } From e946c63dd56d121cf898084ed7e9b5b0868b226e Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 17 Jan 2018 13:58:44 -0800 Subject: [PATCH 120/774] [SPARK-23093][SS] Don't change run id when reconfiguring a continuous processing query. ## What changes were proposed in this pull request? Keep the run ID static, using a different ID for the epoch coordinator to avoid cross-execution message contamination. ## How was this patch tested? new and existing unit tests Author: Jose Torres Closes #20282 from jose-torres/fix-runid. --- .../datasources/v2/DataSourceV2ScanExec.scala | 3 ++- .../datasources/v2/WriteToDataSourceV2.scala | 5 ++-- .../execution/streaming/StreamExecution.scala | 3 +-- .../ContinuousDataSourceRDDIter.scala | 10 ++++---- .../continuous/ContinuousExecution.scala | 18 ++++++++----- .../continuous/EpochCoordinator.scala | 9 ++++--- .../spark/sql/streaming/StreamTest.scala | 2 +- .../StreamingQueryListenerSuite.scala | 25 +++++++++++++++++++ 8 files changed, 54 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 8c64df080242f..beb66738732be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -58,7 +58,8 @@ case class DataSourceV2ScanExec( case _: ContinuousReader => EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + sparkContext.env) .askSync[Unit](SetReaderPartitions(readTasks.size())) new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks) .asInstanceOf[RDD[InternalRow]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index a4a857f2d4d9b..3dbdae7b4df9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -64,7 +64,8 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) val runTask = writer match { case w: ContinuousWriter => EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + sparkContext.env) .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) (context: TaskContext, iter: Iterator[InternalRow]) => @@ -135,7 +136,7 @@ object DataWritingSparkTask extends Logging { iter: Iterator[InternalRow]): WriterCommitMessage = { val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber()) val epochCoordinator = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.RUN_ID_KEY), + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) val currentMsg: WriterCommitMessage = null var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index cf27e1a70650a..e7982d7880ceb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -142,8 +142,7 @@ abstract class StreamExecution( override val id: UUID = UUID.fromString(streamMetadata.id) - override def runId: UUID = currentRunId - protected var currentRunId = UUID.randomUUID + override val runId: UUID = UUID.randomUUID /** * Pretty identified string of printing in logs. Format is diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index dcb3b54c4e160..cd7065f5e6601 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -59,7 +59,7 @@ class ContinuousDataSourceRDD( val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader() - val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY) + val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) // This queue contains two types of messages: // * (null, null) representing an epoch boundary. @@ -68,7 +68,7 @@ class ContinuousDataSourceRDD( val epochPollFailed = new AtomicBoolean(false) val epochPollExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( - s"epoch-poll--${runId}--${context.partitionId()}") + s"epoch-poll--$coordinatorId--${context.partitionId()}") val epochPollRunnable = new EpochPollRunnable(queue, context, epochPollFailed) epochPollExecutor.scheduleWithFixedDelay( epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) @@ -86,7 +86,7 @@ class ContinuousDataSourceRDD( epochPollExecutor.shutdown() }) - val epochEndpoint = EpochCoordinatorRef.get(runId, SparkEnv.get) + val epochEndpoint = EpochCoordinatorRef.get(coordinatorId, SparkEnv.get) new Iterator[UnsafeRow] { private val POLL_TIMEOUT_MS = 1000 @@ -150,7 +150,7 @@ class EpochPollRunnable( private[continuous] var failureReason: Throwable = _ private val epochEndpoint = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.RUN_ID_KEY), SparkEnv.get) + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong override def run(): Unit = { @@ -177,7 +177,7 @@ class DataReaderThread( failedFlag: AtomicBoolean) extends Thread( s"continuous-reader--${context.partitionId()}--" + - s"${context.getLocalProperty(ContinuousExecution.RUN_ID_KEY)}") { + s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") { private[continuous] var failureReason: Throwable = _ override def run(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 45b794c70a50a..c0507224f9be8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -57,6 +57,9 @@ class ContinuousExecution( @volatile protected var continuousSources: Seq[ContinuousReader] = _ override protected def sources: Seq[BaseStreamingSource] = continuousSources + // For use only in test harnesses. + private[sql] var currentEpochCoordinatorId: String = _ + override lazy val logicalPlan: LogicalPlan = { assert(queryExecutionThread eq Thread.currentThread, "logicalPlan must be initialized in StreamExecutionThread " + @@ -149,7 +152,6 @@ class ContinuousExecution( * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with. */ private def runContinuous(sparkSessionForQuery: SparkSession): Unit = { - currentRunId = UUID.randomUUID // A list of attributes that will need to be updated. val replacements = new ArrayBuffer[(Attribute, Attribute)] // Translate from continuous relation to the underlying data source. @@ -219,15 +221,19 @@ class ContinuousExecution( lastExecution.executedPlan // Force the lazy generation of execution plan } - sparkSession.sparkContext.setLocalProperty( + sparkSessionForQuery.sparkContext.setLocalProperty( ContinuousExecution.START_EPOCH_KEY, currentBatchId.toString) - sparkSession.sparkContext.setLocalProperty( - ContinuousExecution.RUN_ID_KEY, runId.toString) + // Add another random ID on top of the run ID, to distinguish epoch coordinators across + // reconfigurations. + val epochCoordinatorId = s"$runId--${UUID.randomUUID}" + currentEpochCoordinatorId = epochCoordinatorId + sparkSessionForQuery.sparkContext.setLocalProperty( + ContinuousExecution.EPOCH_COORDINATOR_ID_KEY, epochCoordinatorId) // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = EpochCoordinatorRef.create( - writer.get(), reader, this, currentBatchId, sparkSession, SparkEnv.get) + writer.get(), reader, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) val epochUpdateThread = new Thread(new Runnable { override def run: Unit = { try { @@ -359,5 +365,5 @@ class ContinuousExecution( object ContinuousExecution { val START_EPOCH_KEY = "__continuous_start_epoch" - val RUN_ID_KEY = "__run_id" + val EPOCH_COORDINATOR_ID_KEY = "__epoch_coordinator_id" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 40dcbecade814..90b3584aa0436 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -79,7 +79,7 @@ private[sql] case class ReportPartitionOffset( /** Helper object used to create reference to [[EpochCoordinator]]. */ private[sql] object EpochCoordinatorRef extends Logging { - private def endpointName(runId: String) = s"EpochCoordinator-$runId" + private def endpointName(id: String) = s"EpochCoordinator-$id" /** * Create a reference to a new [[EpochCoordinator]]. @@ -88,18 +88,19 @@ private[sql] object EpochCoordinatorRef extends Logging { writer: ContinuousWriter, reader: ContinuousReader, query: ContinuousExecution, + epochCoordinatorId: String, startEpoch: Long, session: SparkSession, env: SparkEnv): RpcEndpointRef = synchronized { val coordinator = new EpochCoordinator( writer, reader, query, startEpoch, session, env.rpcEnv) - val ref = env.rpcEnv.setupEndpoint(endpointName(query.runId.toString()), coordinator) + val ref = env.rpcEnv.setupEndpoint(endpointName(epochCoordinatorId), coordinator) logInfo("Registered EpochCoordinator endpoint") ref } - def get(runId: String, env: SparkEnv): RpcEndpointRef = synchronized { - val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName(runId), env.conf, env.rpcEnv) + def get(id: String, env: SparkEnv): RpcEndpointRef = synchronized { + val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName(id), env.conf, env.rpcEnv) logDebug("Retrieved existing EpochCoordinator endpoint") rpcEndpointRef } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index c75247e0f6ed8..efdb0e0e7cf1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -263,7 +263,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def apply(): AssertOnQuery = Execute { case s: ContinuousExecution => - val newEpoch = EpochCoordinatorRef.get(s.runId.toString, SparkEnv.get) + val newEpoch = EpochCoordinatorRef.get(s.currentEpochCoordinatorId, SparkEnv.get) .askSync[Long](IncrementAndGetEpoch) s.awaitEpoch(newEpoch - 1) case _ => throw new IllegalStateException("microbatch cannot increment epoch") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 9ff02dee288fb..79d65192a14aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -174,6 +174,31 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { } } + test("continuous processing listeners should receive QueryTerminatedEvent") { + val df = spark.readStream.format("rate").load() + val listeners = (1 to 5).map(_ => new EventCollector) + try { + listeners.foreach(listener => spark.streams.addListener(listener)) + testStream(df, OutputMode.Append, useV2Sink = true)( + StartStream(Trigger.Continuous(1000)), + StopStream, + AssertOnQuery { query => + eventually(Timeout(streamingTimeout)) { + listeners.foreach(listener => assert(listener.terminationEvent !== null)) + listeners.foreach(listener => assert(listener.terminationEvent.id === query.id)) + listeners.foreach(listener => assert(listener.terminationEvent.runId === query.runId)) + listeners.foreach(listener => assert(listener.terminationEvent.exception === None)) + } + listeners.foreach(listener => listener.checkAsyncErrors()) + listeners.foreach(listener => listener.reset()) + true + } + ) + } finally { + listeners.foreach(spark.streams.removeListener) + } + } + test("adding and removing listener") { def isListenerActive(listener: EventCollector): Boolean = { listener.reset() From 4e6f8fb150ae09c7d1de6beecb2b98e5afa5da19 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 18 Jan 2018 07:26:43 +0900 Subject: [PATCH 121/774] [SPARK-23047][PYTHON][SQL] Change MapVector to NullableMapVector in ArrowColumnVector ## What changes were proposed in this pull request? This PR changes usage of `MapVector` in Spark codebase to use `NullableMapVector`. `MapVector` is an internal Arrow class that is not supposed to be used directly. We should use `NullableMapVector` instead. ## How was this patch tested? Existing test. Author: Li Jin Closes #20239 from icexelloss/arrow-map-vector. --- .../sql/vectorized/ArrowColumnVector.java | 13 +++++-- .../vectorized/ArrowColumnVectorSuite.scala | 36 +++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 708333213f3f1..eb69001fe677e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -247,8 +247,8 @@ public ArrowColumnVector(ValueVector vector) { childColumns = new ArrowColumnVector[1]; childColumns[0] = new ArrowColumnVector(listVector.getDataVector()); - } else if (vector instanceof MapVector) { - MapVector mapVector = (MapVector) vector; + } else if (vector instanceof NullableMapVector) { + NullableMapVector mapVector = (NullableMapVector) vector; accessor = new StructAccessor(mapVector); childColumns = new ArrowColumnVector[mapVector.size()]; @@ -553,9 +553,16 @@ final int getArrayOffset(int rowId) { } } + /** + * Any call to "get" method will throw UnsupportedOperationException. + * + * Access struct values in a ArrowColumnVector doesn't use this accessor. Instead, it uses getStruct() method defined + * in the parent class. Any call to "get" method in this class is a bug in the code. + * + */ private static class StructAccessor extends ArrowVectorAccessor { - StructAccessor(MapVector vector) { + StructAccessor(NullableMapVector vector) { super(vector); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index 7304803a092c0..53432669e215d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -322,6 +322,42 @@ class ArrowColumnVectorSuite extends SparkFunSuite { allocator.close() } + test("non nullable struct") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) + val schema = new StructType().add("int", IntegerType).add("long", LongType) + val vector = ArrowUtils.toArrowField("struct", schema, nullable = false, null) + .createVector(allocator).asInstanceOf[NullableMapVector] + + vector.allocateNew() + val intVector = vector.getChildByOrdinal(0).asInstanceOf[IntVector] + val longVector = vector.getChildByOrdinal(1).asInstanceOf[BigIntVector] + + vector.setIndexDefined(0) + intVector.setSafe(0, 1) + longVector.setSafe(0, 1L) + + vector.setIndexDefined(1) + intVector.setSafe(1, 2) + longVector.setNull(1) + + vector.setValueCount(2) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === schema) + assert(columnVector.numNulls === 0) + + val row0 = columnVector.getStruct(0, 2) + assert(row0.getInt(0) === 1) + assert(row0.getLong(1) === 1L) + + val row1 = columnVector.getStruct(1, 2) + assert(row1.getInt(0) === 2) + assert(row1.isNullAt(1)) + + columnVector.close() + allocator.close() + } + test("struct") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) val schema = new StructType().add("int", IntegerType).add("long", LongType) From 45ad97df87c89cb94ce9564e5773897b6d9326f5 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 18 Jan 2018 07:30:54 +0900 Subject: [PATCH 122/774] [SPARK-23132][PYTHON][ML] Run doctests in ml.image when testing ## What changes were proposed in this pull request? This PR proposes to actually run the doctests in `ml/image.py`. ## How was this patch tested? doctests in `python/pyspark/ml/image.py`. Author: hyukjinkwon Closes #20294 from HyukjinKwon/trigger-image. --- python/pyspark/ml/image.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index c9b840276f675..2d86c7f03860c 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -194,9 +194,9 @@ def readImages(self, path, recursive=False, numPartitions=-1, :return: a :class:`DataFrame` with a single column of "images", see ImageSchema for details. - >>> df = ImageSchema.readImages('python/test_support/image/kittens', recursive=True) + >>> df = ImageSchema.readImages('data/mllib/images/kittens', recursive=True) >>> df.count() - 4 + 5 .. versionadded:: 2.3.0 """ @@ -216,3 +216,25 @@ def readImages(self, path, recursive=False, numPartitions=-1, def _disallow_instance(_): raise RuntimeError("Creating instance of _ImageSchema class is disallowed.") _ImageSchema.__init__ = _disallow_instance + + +def _test(): + import doctest + import pyspark.ml.image + globs = pyspark.ml.image.__dict__.copy() + spark = SparkSession.builder\ + .master("local[2]")\ + .appName("ml.image tests")\ + .getOrCreate() + globs['spark'] = spark + + (failure_count, test_count) = doctest.testmod( + pyspark.ml.image, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + spark.stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() From 7823d43ec0e9c4b8284bb4529b0e624c43bc9bb7 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 17 Jan 2018 17:16:57 -0600 Subject: [PATCH 123/774] [MINOR] Fix typos in ML scaladocs ## What changes were proposed in this pull request? Fixed some typos found in ML scaladocs ## How was this patch tested? NA Author: Bryan Cutler Closes #20300 from BryanCutler/ml-doc-typos-MINOR. --- .../src/main/scala/org/apache/spark/ml/stat/Summarizer.scala | 2 +- .../org/apache/spark/ml/tuning/TrainValidationSplit.scala | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index 9bed74a9f2c05..d40827edb6d64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -75,7 +75,7 @@ sealed abstract class SummaryBuilder { * val Row(meanVec) = meanDF.first() * }}} * - * Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD + * Note: Currently, the performance of this interface is about 2x~3x slower than using the RDD * interface. */ @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 8826ef3271bc1..88ff0dfd75e96 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -93,7 +93,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St def setSeed(value: Long): this.type = set(seed, value) /** - * Set the mamixum level of parallelism to evaluate models in parallel. + * Set the maximum level of parallelism to evaluate models in parallel. * Default is 1 for serial evaluation * * @group expertSetParam @@ -112,7 +112,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St * for more information. * * @group expertSetParam - */@Since("2.3.0") + */ + @Since("2.3.0") def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) @Since("2.0.0") From bac0d661af6092dd26638223156827aceb901229 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 17 Jan 2018 16:40:02 -0800 Subject: [PATCH 124/774] [SPARK-23119][SS] Minor fixes to V2 streaming APIs ## What changes were proposed in this pull request? - Added `InterfaceStability.Evolving` annotations - Improved docs. ## How was this patch tested? Existing tests. Author: Tathagata Das Closes #20286 from tdas/SPARK-23119. --- .../v2/streaming/ContinuousReadSupport.java | 2 ++ .../streaming/reader/ContinuousDataReader.java | 2 ++ .../v2/streaming/reader/ContinuousReader.java | 9 +++++++-- .../v2/streaming/reader/MicroBatchReader.java | 5 +++++ .../sources/v2/streaming/reader/Offset.java | 18 +++++++++++++----- .../v2/streaming/reader/PartitionOffset.java | 3 +++ .../sources/v2/writer/DataSourceV2Writer.java | 5 ++++- 7 files changed, 36 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java index 3136cee1f655f..9a93a806b0efc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java @@ -19,6 +19,7 @@ import java.util.Optional; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader; @@ -28,6 +29,7 @@ * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to * provide data reading ability for continuous stream processing. */ +@InterfaceStability.Evolving public interface ContinuousReadSupport extends DataSourceV2 { /** * Creates a {@link ContinuousReader} to scan the data from this data source. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java index ca9a290e97a02..3f13a4dbf5793 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java @@ -17,11 +17,13 @@ package org.apache.spark.sql.sources.v2.streaming.reader; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.DataReader; /** * A variation on {@link DataReader} for use with streaming in continuous processing mode. */ +@InterfaceStability.Evolving public interface ContinuousDataReader extends DataReader { /** * Get the offset of the current record, or the start offset if no records have been read. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java index f0b205869ed6c..745f1ce502443 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.sources.v2.streaming.reader; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; @@ -27,11 +28,15 @@ * interface to allow reading in a continuous processing mode stream. * * Implementations must ensure each read task output is a {@link ContinuousDataReader}. + * + * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with + * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. */ +@InterfaceStability.Evolving public interface ContinuousReader extends BaseStreamingSource, DataSourceV2Reader { /** - * Merge offsets coming from {@link ContinuousDataReader} instances in each partition to - * a single global offset. + * Merge partitioned offsets coming from {@link ContinuousDataReader} instances for each + * partition to a single global offset. */ Offset mergeOffsets(PartitionOffset[] offsets); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java index 70ff756806032..02f37cebc7484 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.sources.v2.streaming.reader; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; @@ -25,7 +26,11 @@ /** * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this * interface to indicate they allow micro-batch streaming reads. + * + * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with + * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. */ +@InterfaceStability.Evolving public interface MicroBatchReader extends DataSourceV2Reader, BaseStreamingSource { /** * Set the desired offset range for read tasks created from this reader. Read tasks will diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java index 60b87f2ac0756..abba3e7188b13 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java @@ -17,12 +17,20 @@ package org.apache.spark.sql.sources.v2.streaming.reader; +import org.apache.spark.annotation.InterfaceStability; + /** - * An abstract representation of progress through a [[MicroBatchReader]] or [[ContinuousReader]]. - * During execution, Offsets provided by the data source implementation will be logged and used as - * restart checkpoints. Sources should provide an Offset implementation which they can use to - * reconstruct the stream position where the offset was taken. + * An abstract representation of progress through a {@link MicroBatchReader} or + * {@link ContinuousReader}. + * During execution, offsets provided by the data source implementation will be logged and used as + * restart checkpoints. Each source should provide an offset implementation which the source can use + * to reconstruct a position in the stream up to which data has been seen/processed. + * + * Note: This class currently extends {@link org.apache.spark.sql.execution.streaming.Offset} to + * maintain compatibility with DataSource V1 APIs. This extension will be removed once we + * get rid of V1 completely. */ +@InterfaceStability.Evolving public abstract class Offset extends org.apache.spark.sql.execution.streaming.Offset { /** * A JSON-serialized representation of an Offset that is @@ -37,7 +45,7 @@ public abstract class Offset extends org.apache.spark.sql.execution.streaming.Of /** * Equality based on JSON string representation. We leverage the * JSON representation for normalization between the Offset's - * in memory and on disk representations. + * in deserialized and serialized representations. */ @Override public boolean equals(Object obj) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java index eca0085c8a8ce..4688b85f49f5f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java @@ -19,11 +19,14 @@ import java.io.Serializable; +import org.apache.spark.annotation.InterfaceStability; + /** * Used for per-partition offsets in continuous processing. ContinuousReader implementations will * provide a method to merge these into a global Offset. * * These offsets must be serializable. */ +@InterfaceStability.Evolving public interface PartitionOffset extends Serializable { } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java index fc37b9a516f82..317ac45bcfd74 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java @@ -22,11 +22,14 @@ import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.WriteSupport; +import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; /** * A data source writer that is returned by - * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}. + * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}/ + * {@link org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport#createMicroBatchWriter(String, long, StructType, OutputMode, DataSourceV2Options)}/ + * {@link org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport#createContinuousWriter(String, StructType, OutputMode, DataSourceV2Options)}. * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. * From 1002bd6b23ff78a010ca259ea76988ef4c478c6e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 17 Jan 2018 16:41:43 -0800 Subject: [PATCH 125/774] [SPARK-23064][DOCS][SS] Added documentation for stream-stream joins ## What changes were proposed in this pull request? Added documentation for stream-stream joins ![image](https://user-images.githubusercontent.com/663212/35018744-e999895a-fad7-11e7-9d6a-8c7a73e6eb9c.png) ![image](https://user-images.githubusercontent.com/663212/35018775-157eb464-fad8-11e7-879e-47a2fcbd8690.png) ![image](https://user-images.githubusercontent.com/663212/35018784-27791a24-fad8-11e7-98f4-7ff246f62a74.png) ![image](https://user-images.githubusercontent.com/663212/35018791-36a80334-fad8-11e7-9791-f85efa7c6ba2.png) ## How was this patch tested? N/a Author: Tathagata Das Closes #20255 from tdas/join-docs. --- .../structured-streaming-programming-guide.md | 338 +++++++++++++++++- 1 file changed, 326 insertions(+), 12 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index de13e281916db..1779a4215e085 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1051,7 +1051,19 @@ output mode. ### Join Operations -Streaming DataFrames can be joined with static DataFrames to create new streaming DataFrames. Here are a few examples. +Structured Streaming supports joining a streaming Dataset/DataFrame with a static Dataset/DataFrame +as well as another streaming Dataset/DataFrame. The result of the streaming join is generated +incrementally, similar to the results of streaming aggregations in the previous section. In this +section we will explore what type of joins (i.e. inner, outer, etc.) are supported in the above +cases. Note that in all the supported join types, the result of the join with a streaming +Dataset/DataFrame will be the exactly the same as if it was with a static Dataset/DataFrame +containing the same data in the stream. + + +#### Stream-static joins + +Since the introduction in Spark 2.0, Structured Streaming has supported joins (inner join and some +type of outer joins) between a streaming and a static DataFrame/Dataset. Here is a simple example.
    @@ -1089,6 +1101,300 @@ streamingDf.join(staticDf, "type", "right_join") # right outer join with a stat
    +Note that stream-static joins are not stateful, so no state management is necessary. +However, a few types of stream-static outer joins are not yet supported. +These are listed at the [end of this Join section](#support-matrix-for-joins-in-streaming-queries). + +#### Stream-stream Joins +In Spark 2.3, we have added support for stream-stream joins, that is, you can join two streaming +Datasets/DataFrames. The challenge of generating join results between two data streams is that, +at any point of time, the view of the dataset is incomplete for both sides of the join making +it much harder to find matches between inputs. Any row received from one input stream can match +with any future, yet-to-be-received row from the other input stream. Hence, for both the input +streams, we buffer past input as streaming state, so that we can match every future input with +past input and accordingly generate joined results. Furthermore, similar to streaming aggregations, +we automatically handle late, out-of-order data and can limit the state using watermarks. +Let’s discuss the different types of supported stream-stream joins and how to use them. + +##### Inner Joins with optional Watermarking +Inner joins on any kind of columns along with any kind of join conditions are supported. +However, as the stream runs, the size of streaming state will keep growing indefinitely as +*all* past input must be saved as the any new input can match with any input from the past. +To avoid unbounded state, you have to define additional join conditions such that indefinitely +old inputs cannot match with future inputs and therefore can be cleared from the state. +In other words, you will have to do the following additional steps in the join. + +1. Define watermark delays on both inputs such that the engine knows how delayed the input can be +(similar to streaming aggregations) + +1. Define a constraint on event-time across the two inputs such that the engine can figure out when +old rows of one input is not going to be required (i.e. will not satisfy the time constraint) for +matches with the other input. This constraint can be defined in one of the two ways. + + 1. Time range join conditions (e.g. `...JOIN ON leftTime BETWEN rightTime AND rightTime + INTERVAL 1 HOUR`), + + 1. Join on event-time windows (e.g. `...JOIN ON leftTimeWindow = rightTimeWindow`). + +Let’s understand this with an example. + +Let’s say we want to join a stream of advertisement impressions (when an ad was shown) with +another stream of user clicks on advertisements to correlate when impressions led to +monetizable clicks. To allow the state cleanup in this stream-stream join, you will have to +specify the watermarking delays and the time constraints as follows. + +1. Watermark delays: Say, the impressions and the corresponding clicks can be late/out-of-order +in event-time by at most 2 and 3 hours, respectively. + +1. Event-time range condition: Say, a click can occur within a time range of 0 seconds to 1 hour +after the corresponding impression. + +The code would look like this. + +
    +
    + +{% highlight scala %} +import org.apache.spark.sql.functions.expr + +val impressions = spark.readStream. ... +val clicks = spark.readStream. ... + +// Apply watermarks on event-time columns +val impressionsWithWatermark = impressions.withWatermark("impressionTime", "2 hours") +val clicksWithWatermark = clicks.withWatermark("clickTime", "3 hours") + +// Join with event-time constraints +impressionsWithWatermark.join( + clicksWithWatermark, + expr(""" + clickAdId = impressionAdId AND + clickTime >= impressionTime AND + clickTime <= impressionTime + interval 1 hour + """) +) + +{% endhighlight %} + +
    +
    + +{% highlight java %} +import static org.apache.spark.sql.functions.expr + +Dataset impressions = spark.readStream(). ... +Dataset clicks = spark.readStream(). ... + +// Apply watermarks on event-time columns +Dataset impressionsWithWatermark = impressions.withWatermark("impressionTime", "2 hours"); +Dataset clicksWithWatermark = clicks.withWatermark("clickTime", "3 hours"); + +// Join with event-time constraints +impressionsWithWatermark.join( + clicksWithWatermark, + expr( + "clickAdId = impressionAdId AND " + + "clickTime >= impressionTime AND " + + "clickTime <= impressionTime + interval 1 hour ") +); + +{% endhighlight %} + + +
    +
    + +{% highlight python %} +from pyspark.sql.functions import expr + +impressions = spark.readStream. ... +clicks = spark.readStream. ... + +# Apply watermarks on event-time columns +impressionsWithWatermark = impressions.withWatermark("impressionTime", "2 hours") +clicksWithWatermark = clicks.withWatermark("clickTime", "3 hours") + +# Join with event-time constraints +impressionsWithWatermark.join( + clicksWithWatermark, + expr(""" + clickAdId = impressionAdId AND + clickTime >= impressionTime AND + clickTime <= impressionTime + interval 1 hour + """) +) + +{% endhighlight %} + +
    +
    + +##### Outer Joins with Watermarking +While the watermark + event-time constraints is optional for inner joins, for left and right outer +joins they must be specified. This is because for generating the NULL results in outer join, the +engine must know when an input row is not going to match with anything in future. Hence, the +watermark + event-time constraints must be specified for generating correct results. Therefore, +a query with outer-join will look quite like the ad-monetization example earlier, except that +there will be an additional parameter specifying it to be an outer-join. + +
    +
    + +{% highlight scala %} + +impressionsWithWatermark.join( + clicksWithWatermark, + expr(""" + clickAdId = impressionAdId AND + clickTime >= impressionTime AND + clickTime <= impressionTime + interval 1 hour + """), + joinType = "leftOuter" // can be "inner", "leftOuter", "rightOuter" + ) + +{% endhighlight %} + +
    +
    + +{% highlight java %} +impressionsWithWatermark.join( + clicksWithWatermark, + expr( + "clickAdId = impressionAdId AND " + + "clickTime >= impressionTime AND " + + "clickTime <= impressionTime + interval 1 hour "), + "leftOuter" // can be "inner", "leftOuter", "rightOuter" +); + +{% endhighlight %} + + +
    +
    + +{% highlight python %} +impressionsWithWatermark.join( + clicksWithWatermark, + expr(""" + clickAdId = impressionAdId AND + clickTime >= impressionTime AND + clickTime <= impressionTime + interval 1 hour + """), + "leftOuter" # can be "inner", "leftOuter", "rightOuter" +) + +{% endhighlight %} + +
    +
    + +However, note that the outer NULL results will be generated with a delay (depends on the specified +watermark delay and the time range condition) because the engine has to wait for that long to ensure +there were no matches and there will be no more matches in future. + +##### Support matrix for joins in streaming queries + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Left InputRight InputJoin Type
    StaticStaticAll types + Supported, since its not on streaming data even though it + can be present in a streaming query +
    StreamStaticInnerSupported, not stateful
    Left OuterSupported, not stateful
    Right OuterNot supported
    Full OuterNot supported
    StaticStreamInnerSupported, not stateful
    Left OuterNot supported
    Right OuterSupported, not stateful
    Full OuterNot supported
    StreamStreamInner + Supported, optionally specify watermark on both sides + + time constraints for state cleanup +
    Left Outer + Conditionally supported, must specify watermark on right + time constraints for correct + results, optionally specify watermark on left for all state cleanup +
    Right Outer + Conditionally supported, must specify watermark on left + time constraints for correct + results, optionally specify watermark on right for all state cleanup +
    Full OuterNot supported
    + +Additional details on supported joins: + +- Joins can be cascaded, that is, you can do `df1.join(df2, ...).join(df3, ...).join(df4, ....)`. + +- As of Spark 2.3, you can use joins only when the query is in Append output mode. Other output modes are not yet supported. + +- As of Spark 2.3, you cannot use other non-map-like operations before joins. Here are a few examples of + what cannot be used. + + - Cannot use streaming aggregations before joins. + + - Cannot use mapGroupsWithState and flatMapGroupsWithState in Update mode before joins. + + ### Streaming Deduplication You can deduplicate records in data streams using a unique identifier in the events. This is exactly same as deduplication on static using a unique identifier column. The query will store the necessary amount of data from previous records such that it can filter duplicate records. Similar to aggregations, you can use deduplication with or without watermarking. @@ -1160,15 +1466,9 @@ Some of them are as follows. - Sorting operations are supported on streaming Datasets only after an aggregation and in Complete Output Mode. -- Outer joins between a streaming and a static Datasets are conditionally supported. - - + Full outer join with a streaming Dataset is not supported - - + Left outer join with a streaming Dataset on the right is not supported - - + Right outer join with a streaming Dataset on the left is not supported - -- Any kind of joins between two streaming Datasets is not yet supported. +- Few types of outer joins on streaming Datasets are not supported. See the + support matrix in the Join Operations section + for more details. In addition, there are some Dataset methods that will not work on streaming Datasets. They are actions that will immediately run queries and return results, which does not make sense on a streaming Dataset. Rather, those functionalities can be done by explicitly starting a streaming query (see the next section regarding that). @@ -1276,6 +1576,15 @@ Here is the compatibility matrix. Aggregations not allowed after flatMapGroupsWithState. + + Queries with joins + Append + + Update and Complete mode not supported yet. See the + support matrix in the Join Operations section + for more details on what types of joins are supported. + + Other queries Append, Update @@ -2142,6 +2451,11 @@ write.stream(aggDF, "memory", outputMode = "complete", checkpointLocation = "pat **Talks** -- Spark Summit 2017 Talk - [Easy, Scalable, Fault-tolerant Stream Processing with Structured Streaming in Apache Spark](https://spark-summit.org/2017/events/easy-scalable-fault-tolerant-stream-processing-with-structured-streaming-in-apache-spark/) -- Spark Summit 2016 Talk - [A Deep Dive into Structured Streaming](https://spark-summit.org/2016/events/a-deep-dive-into-structured-streaming/) +- Spark Summit Europe 2017 + - Easy, Scalable, Fault-tolerant Stream Processing with Structured Streaming in Apache Spark - + [Part 1 slides/video](https://databricks.com/session/easy-scalable-fault-tolerant-stream-processing-with-structured-streaming-in-apache-spark), [Part 2 slides/video](https://databricks.com/session/easy-scalable-fault-tolerant-stream-processing-with-structured-streaming-in-apache-spark-continues) + - Deep Dive into Stateful Stream Processing in Structured Streaming - [slides/video](https://databricks.com/session/deep-dive-into-stateful-stream-processing-in-structured-streaming) +- Spark Summit 2016 + - A Deep Dive into Structured Streaming - [slides/video](https://spark-summit.org/2016/events/a-deep-dive-into-structured-streaming/) + From 02194702068291b3af77486d01029fb848c36d7b Mon Sep 17 00:00:00 2001 From: Xiayun Sun Date: Wed, 17 Jan 2018 16:42:38 -0800 Subject: [PATCH 126/774] [SPARK-21996][SQL] read files with space in name for streaming ## What changes were proposed in this pull request? Structured streaming is now able to read files with space in file name (previously it would skip the file and output a warning) ## How was this patch tested? Added new unit test. Author: Xiayun Sun Closes #19247 from xysun/SPARK-21996. --- .../streaming/FileStreamSource.scala | 2 +- .../sql/streaming/FileStreamSourceSuite.scala | 50 ++++++++++++++++++- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 0debd7db84757..8c016abc5b643 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -166,7 +166,7 @@ class FileStreamSource( val newDataSource = DataSource( sparkSession, - paths = files.map(_.path), + paths = files.map(f => new Path(new URI(f.path)).toString), userSpecifiedSchema = Some(schema), partitionColumns = partitionColumns, className = fileFormatClassName, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 39bb572740617..5bb0f4d643bbe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -74,11 +74,11 @@ abstract class FileStreamSourceTest protected def addData(source: FileStreamSource): Unit } - case class AddTextFileData(content: String, src: File, tmp: File) + case class AddTextFileData(content: String, src: File, tmp: File, tmpFilePrefix: String = "text") extends AddFileData { override def addData(source: FileStreamSource): Unit = { - val tempFile = Utils.tempFileWith(new File(tmp, "text")) + val tempFile = Utils.tempFileWith(new File(tmp, tmpFilePrefix)) val finalFile = new File(src, tempFile.getName) src.mkdirs() require(stringToFile(tempFile, content).renameTo(finalFile)) @@ -408,6 +408,52 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } + test("SPARK-21996 read from text files -- file name has space") { + withTempDirs { case (src, tmp) => + val textStream = createFileStream("text", src.getCanonicalPath) + val filtered = textStream.filter($"value" contains "keep") + + testStream(filtered)( + AddTextFileData("drop1\nkeep2\nkeep3", src, tmp, "text text"), + CheckAnswer("keep2", "keep3") + ) + } + } + + test("SPARK-21996 read from text files generated by file sink -- file name has space") { + val testTableName = "FileStreamSourceTest" + withTable(testTableName) { + withTempDirs { case (src, checkpoint) => + val output = new File(src, "text text") + val inputData = MemoryStream[String] + val ds = inputData.toDS() + + val query = ds.writeStream + .option("checkpointLocation", checkpoint.getCanonicalPath) + .format("text") + .start(output.getCanonicalPath) + + try { + inputData.addData("foo") + failAfter(streamingTimeout) { + query.processAllAvailable() + } + } finally { + query.stop() + } + + val df2 = spark.readStream.format("text").load(output.getCanonicalPath) + val query2 = df2.writeStream.format("memory").queryName(testTableName).start() + try { + query2.processAllAvailable() + checkDatasetUnorderly(spark.table(testTableName).as[String], "foo") + } finally { + query2.stop() + } + } + } + } + test("read from textfile") { withTempDirs { case (src, tmp) => val textStream = spark.readStream.textFile(src.getCanonicalPath) From 39d244d921d8d2d3ed741e8e8f1175515a74bdbd Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 18 Jan 2018 14:51:05 +0900 Subject: [PATCH 127/774] [SPARK-23122][PYTHON][SQL] Deprecate register* for UDFs in SQLContext and Catalog in PySpark ## What changes were proposed in this pull request? This PR proposes to deprecate `register*` for UDFs in `SQLContext` and `Catalog` in Spark 2.3.0. These are inconsistent with Scala / Java APIs and also these basically do the same things with `spark.udf.register*`. Also, this PR moves the logcis from `[sqlContext|spark.catalog].register*` to `spark.udf.register*` and reuse the docstring. This PR also handles minor doc corrections. It also includes https://github.com/apache/spark/pull/20158 ## How was this patch tested? Manually tested, manually checked the API documentation and tests added to check if deprecated APIs call the aliases correctly. Author: hyukjinkwon Closes #20288 from HyukjinKwon/deprecate-udf. --- dev/sparktestsupport/modules.py | 1 + python/pyspark/sql/catalog.py | 91 ++-------------- python/pyspark/sql/context.py | 137 ++++-------------------- python/pyspark/sql/functions.py | 4 +- python/pyspark/sql/group.py | 3 +- python/pyspark/sql/session.py | 6 +- python/pyspark/sql/tests.py | 20 ++++ python/pyspark/sql/udf.py | 182 +++++++++++++++++++++++++++++++- 8 files changed, 234 insertions(+), 210 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 7164180a6a7b0..b900f0bd913c3 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -400,6 +400,7 @@ def __hash__(self): "pyspark.sql.functions", "pyspark.sql.readwriter", "pyspark.sql.streaming", + "pyspark.sql.udf", "pyspark.sql.window", "pyspark.sql.tests", ] diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 35fbe9e669adb..6aef0f22340be 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -224,92 +224,17 @@ def dropGlobalTempView(self, viewName): """ self._jcatalog.dropGlobalTempView(viewName) - @ignore_unicode_prefix @since(2.0) def registerFunction(self, name, f, returnType=None): - """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` - as a UDF. The registered UDF can be used in SQL statements. - - :func:`spark.udf.register` is an alias for :func:`spark.catalog.registerFunction`. - - In addition to a name and the function itself, `returnType` can be optionally specified. - 1) When f is a Python function, `returnType` defaults to a string. The produced object must - match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return - type of the given UDF as the return type of the registered UDF. The input parameter - `returnType` is None by default. If given by users, the value must be None. - - :param name: name of the UDF in SQL statements. - :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either - row-at-a-time or vectorized. - :param returnType: the return type of the registered UDF. - :return: a wrapped/native :class:`UserDefinedFunction` - - >>> strlen = spark.catalog.registerFunction("stringLengthString", len) - >>> spark.sql("SELECT stringLengthString('test')").collect() - [Row(stringLengthString(test)=u'4')] - - >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() - [Row(stringLengthString(text)=u'3')] - - >>> from pyspark.sql.types import IntegerType - >>> _ = spark.catalog.registerFunction("stringLengthInt", len, IntegerType()) - >>> spark.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> _ = spark.udf.register("stringLengthInt", len, IntegerType()) - >>> spark.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> from pyspark.sql.functions import udf - >>> slen = udf(lambda s: len(s), IntegerType()) - >>> _ = spark.udf.register("slen", slen) - >>> spark.sql("SELECT slen('test')").collect() - [Row(slen(test)=4)] - - >>> import random - >>> from pyspark.sql.functions import udf - >>> from pyspark.sql.types import IntegerType - >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() - >>> new_random_udf = spark.catalog.registerFunction("random_udf", random_udf) - >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP - [Row(random_udf()=82)] - >>> spark.range(1).select(new_random_udf()).collect() # doctest: +SKIP - [Row(()=26)] - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP - ... def add_one(x): - ... return x + 1 - ... - >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP - >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP - [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] - """ + """An alias for :func:`spark.udf.register`. + See :meth:`pyspark.sql.UDFRegistration.register`. - # This is to check whether the input function is a wrapped/native UserDefinedFunction - if hasattr(f, 'asNondeterministic'): - if returnType is not None: - raise TypeError( - "Invalid returnType: None is expected when f is a UserDefinedFunction, " - "but got %s." % returnType) - if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, - PythonEvalType.SQL_PANDAS_SCALAR_UDF]: - raise ValueError( - "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF") - register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, - evalType=f.evalType, - deterministic=f.deterministic) - return_udf = f - else: - if returnType is None: - returnType = StringType() - register_udf = UserDefinedFunction(f, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF) - return_udf = register_udf._wrapped() - self._jsparkSession.udf().registerPython(name, register_udf._judf) - return return_udf + .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead. + """ + warnings.warn( + "Deprecated in 2.3.0. Use spark.udf.register instead.", + DeprecationWarning) + return self._sparkSession.udf.register(name, f, returnType) @since(2.0) def isCached(self, tableName): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 85479095af594..cc1cd1a5842d9 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -29,9 +29,10 @@ from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader from pyspark.sql.types import IntegerType, Row, StringType +from pyspark.sql.udf import UDFRegistration from pyspark.sql.utils import install_exception_handler -__all__ = ["SQLContext", "HiveContext", "UDFRegistration"] +__all__ = ["SQLContext", "HiveContext"] class SQLContext(object): @@ -147,7 +148,7 @@ def udf(self): :return: :class:`UDFRegistration` """ - return UDFRegistration(self) + return self.sparkSession.udf @since(1.4) def range(self, start, end=None, step=1, numPartitions=None): @@ -172,113 +173,29 @@ def range(self, start, end=None, step=1, numPartitions=None): """ return self.sparkSession.range(start, end, step, numPartitions) - @ignore_unicode_prefix @since(1.2) def registerFunction(self, name, f, returnType=None): - """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` - as a UDF. The registered UDF can be used in SQL statements. - - :func:`spark.udf.register` is an alias for :func:`sqlContext.registerFunction`. - - In addition to a name and the function itself, `returnType` can be optionally specified. - 1) When f is a Python function, `returnType` defaults to a string. The produced object must - match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return - type of the given UDF as the return type of the registered UDF. The input parameter - `returnType` is None by default. If given by users, the value must be None. - - :param name: name of the UDF in SQL statements. - :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either - row-at-a-time or vectorized. - :param returnType: the return type of the registered UDF. - :return: a wrapped/native :class:`UserDefinedFunction` - - >>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x)) - >>> sqlContext.sql("SELECT stringLengthString('test')").collect() - [Row(stringLengthString(test)=u'4')] - - >>> sqlContext.sql("SELECT 'foo' AS text").select(strlen("text")).collect() - [Row(stringLengthString(text)=u'3')] - - >>> from pyspark.sql.types import IntegerType - >>> _ = sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> from pyspark.sql.functions import udf - >>> slen = udf(lambda s: len(s), IntegerType()) - >>> _ = sqlContext.udf.register("slen", slen) - >>> sqlContext.sql("SELECT slen('test')").collect() - [Row(slen(test)=4)] - - >>> import random - >>> from pyspark.sql.functions import udf - >>> from pyspark.sql.types import IntegerType - >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() - >>> new_random_udf = sqlContext.registerFunction("random_udf", random_udf) - >>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP - [Row(random_udf()=82)] - >>> sqlContext.range(1).select(new_random_udf()).collect() # doctest: +SKIP - [Row(()=26)] - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP - ... def add_one(x): - ... return x + 1 - ... - >>> _ = sqlContext.udf.register("add_one", add_one) # doctest: +SKIP - >>> sqlContext.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP - [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] + """An alias for :func:`spark.udf.register`. + See :meth:`pyspark.sql.UDFRegistration.register`. + + .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead. """ - return self.sparkSession.catalog.registerFunction(name, f, returnType) + warnings.warn( + "Deprecated in 2.3.0. Use spark.udf.register instead.", + DeprecationWarning) + return self.sparkSession.udf.register(name, f, returnType) - @ignore_unicode_prefix @since(2.1) def registerJavaFunction(self, name, javaClassName, returnType=None): - """Register a java UDF so it can be used in SQL statements. - - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not specified we would infer it via reflection. - :param name: name of the UDF - :param javaClassName: fully qualified name of java class - :param returnType: a :class:`pyspark.sql.types.DataType` object - - >>> sqlContext.registerJavaFunction("javaStringLength", - ... "test.org.apache.spark.sql.JavaStringLength", IntegerType()) - >>> sqlContext.sql("SELECT javaStringLength('test')").collect() - [Row(UDF:javaStringLength(test)=4)] - >>> sqlContext.registerJavaFunction("javaStringLength2", - ... "test.org.apache.spark.sql.JavaStringLength") - >>> sqlContext.sql("SELECT javaStringLength2('test')").collect() - [Row(UDF:javaStringLength2(test)=4)] + """An alias for :func:`spark.udf.registerJavaFunction`. + See :meth:`pyspark.sql.UDFRegistration.registerJavaFunction`. + .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.registerJavaFunction` instead. """ - jdt = None - if returnType is not None: - jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) - self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) - - @ignore_unicode_prefix - @since(2.3) - def registerJavaUDAF(self, name, javaClassName): - """Register a java UDAF so it can be used in SQL statements. - - :param name: name of the UDAF - :param javaClassName: fully qualified name of java class - - >>> sqlContext.registerJavaUDAF("javaUDAF", - ... "test.org.apache.spark.sql.MyDoubleAvg") - >>> df = sqlContext.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) - >>> df.registerTempTable("df") - >>> sqlContext.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() - [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] - """ - self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) + warnings.warn( + "Deprecated in 2.3.0. Use spark.udf.registerJavaFunction instead.", + DeprecationWarning) + return self.sparkSession.udf.registerJavaFunction(name, javaClassName, returnType) # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): @@ -590,24 +507,6 @@ def refreshTable(self, tableName): self._ssql_ctx.refreshTable(tableName) -class UDFRegistration(object): - """Wrapper for user-defined function registration.""" - - def __init__(self, sqlContext): - self.sqlContext = sqlContext - - def register(self, name, f, returnType=None): - return self.sqlContext.registerFunction(name, f, returnType) - - def registerJavaFunction(self, name, javaClassName, returnType=None): - self.sqlContext.registerJavaFunction(name, javaClassName, returnType) - - def registerJavaUDAF(self, name, javaClassName): - self.sqlContext.registerJavaUDAF(name, javaClassName) - - register.__doc__ = SQLContext.registerFunction.__doc__ - - def _test(): import os import doctest diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f7b3f29764040..988c1d25259bc 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2103,7 +2103,7 @@ def udf(f=None, returnType=StringType()): >>> import random >>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic() - .. note:: The user-defined functions do not support conditional expressions or short curcuiting + .. note:: The user-defined functions do not support conditional expressions or short circuiting in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions. @@ -2231,7 +2231,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): ... return pd.Series(np.random.randn(len(v)) >>> random = random.asNondeterministic() # doctest: +SKIP - .. note:: The user-defined functions do not support conditional expressions or short curcuiting + .. note:: The user-defined functions do not support conditional expressions or short circuiting in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions. """ diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 09fae46adf014..22061b83eb78c 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -212,7 +212,8 @@ def apply(self, udf): This function does not support partial aggregation, and requires shuffling all the data in the :class:`DataFrame`. - :param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf` + :param udf: a group map user-defined function returned by + :meth:`pyspark.sql.functions.pandas_udf`. >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 604021c1f45cc..6c84023c43fb6 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -29,7 +29,6 @@ from pyspark import since from pyspark.rdd import RDD, ignore_unicode_prefix -from pyspark.sql.catalog import Catalog from pyspark.sql.conf import RuntimeConfig from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader @@ -280,6 +279,7 @@ def catalog(self): :return: :class:`Catalog` """ + from pyspark.sql.catalog import Catalog if not hasattr(self, "_catalog"): self._catalog = Catalog(self) return self._catalog @@ -291,8 +291,8 @@ def udf(self): :return: :class:`UDFRegistration` """ - from pyspark.sql.context import UDFRegistration - return UDFRegistration(self._wrapped) + from pyspark.sql.udf import UDFRegistration + return UDFRegistration(self) @since(2.0) def range(self, start, end=None, step=1, numPartitions=None): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 8906618666b14..f84aa3d68b808 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -372,6 +372,12 @@ def test_udf(self): [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() self.assertEqual(row[0], 5) + # This is to check if a deprecated 'SQLContext.registerFunction' can call its alias. + sqlContext = self.spark._wrapped + sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType()) + [row] = sqlContext.sql("SELECT oneArg('test')").collect() + self.assertEqual(row[0], 4) + def test_udf2(self): self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType()) self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\ @@ -577,11 +583,25 @@ def test_udf_registration_returns_udf(self): df.select(add_three("id").alias("plus_three")).collect() ) + # This is to check if a 'SQLContext.udf' can call its alias. + sqlContext = self.spark._wrapped + add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType()) + + self.assertListEqual( + df.selectExpr("add_four(id) AS plus_four").collect(), + df.select(add_four("id").alias("plus_four")).collect() + ) + def test_non_existed_udf(self): spark = self.spark self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf")) + # This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias. + sqlContext = spark._wrapped + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", + lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf")) + def test_non_existed_udaf(self): spark = self.spark self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 5e80ab9165867..1943bb73f9ac2 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -19,11 +19,13 @@ """ import functools -from pyspark import SparkContext -from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType +from pyspark import SparkContext, since +from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string +__all__ = ["UDFRegistration"] + def _wrap_function(sc, func, returnType): command = (func, returnType) @@ -181,3 +183,179 @@ def asNondeterministic(self): """ self.deterministic = False return self + + +class UDFRegistration(object): + """ + Wrapper for user-defined function registration. This instance can be accessed by + :attr:`spark.udf` or :attr:`sqlContext.udf`. + + .. versionadded:: 1.3.1 + """ + + def __init__(self, sparkSession): + self.sparkSession = sparkSession + + @ignore_unicode_prefix + @since("1.3.1") + def register(self, name, f, returnType=None): + """Registers a Python function (including lambda function) or a user-defined function + in SQL statements. + + :param name: name of the user-defined function in SQL statements. + :param f: a Python function, or a user-defined function. The user-defined function can + be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and + :meth:`pyspark.sql.functions.pandas_udf`. + :param returnType: the return type of the registered user-defined function. + :return: a user-defined function. + + `returnType` can be optionally specified when `f` is a Python function but not + when `f` is a user-defined function. Please see below. + + 1. When `f` is a Python function: + + `returnType` defaults to string type and can be optionally specified. The produced + object must match the specified type. In this case, this API works as if + `register(name, f, returnType=StringType())`. + + >>> strlen = spark.udf.register("stringLengthString", lambda x: len(x)) + >>> spark.sql("SELECT stringLengthString('test')").collect() + [Row(stringLengthString(test)=u'4')] + + >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() + [Row(stringLengthString(text)=u'3')] + + >>> from pyspark.sql.types import IntegerType + >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + + >>> from pyspark.sql.types import IntegerType + >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + + 2. When `f` is a user-defined function: + + Spark uses the return type of the given user-defined function as the return type of + the registered user-defined function. `returnType` should not be specified. + In this case, this API works as if `register(name, f)`. + + >>> from pyspark.sql.types import IntegerType + >>> from pyspark.sql.functions import udf + >>> slen = udf(lambda s: len(s), IntegerType()) + >>> _ = spark.udf.register("slen", slen) + >>> spark.sql("SELECT slen('test')").collect() + [Row(slen(test)=4)] + + >>> import random + >>> from pyspark.sql.functions import udf + >>> from pyspark.sql.types import IntegerType + >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() + >>> new_random_udf = spark.udf.register("random_udf", random_udf) + >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP + [Row(random_udf()=82)] + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP + ... def add_one(x): + ... return x + 1 + ... + >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP + >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP + [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] + + .. note:: Registration for a user-defined function (case 2.) was added from + Spark 2.3.0. + """ + + # This is to check whether the input function is from a user-defined function or + # Python function. + if hasattr(f, 'asNondeterministic'): + if returnType is not None: + raise TypeError( + "Invalid returnType: data type can not be specified when f is" + "a user-defined function, but got %s." % returnType) + if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_PANDAS_SCALAR_UDF]: + raise ValueError( + "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF") + register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, + evalType=f.evalType, + deterministic=f.deterministic) + return_udf = f + else: + if returnType is None: + returnType = StringType() + register_udf = UserDefinedFunction(f, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF) + return_udf = register_udf._wrapped() + self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf) + return return_udf + + @ignore_unicode_prefix + @since(2.3) + def registerJavaFunction(self, name, javaClassName, returnType=None): + """Register a Java user-defined function so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not specified we would infer it via reflection. + + :param name: name of the user-defined function + :param javaClassName: fully qualified name of java class + :param returnType: a :class:`pyspark.sql.types.DataType` object + + >>> from pyspark.sql.types import IntegerType + >>> spark.udf.registerJavaFunction( + ... "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType()) + >>> spark.sql("SELECT javaStringLength('test')").collect() + [Row(UDF:javaStringLength(test)=4)] + >>> spark.udf.registerJavaFunction( + ... "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength") + >>> spark.sql("SELECT javaStringLength2('test')").collect() + [Row(UDF:javaStringLength2(test)=4)] + """ + + jdt = None + if returnType is not None: + jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) + self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) + + @ignore_unicode_prefix + @since(2.3) + def registerJavaUDAF(self, name, javaClassName): + """Register a Java user-defined aggregate function so it can be used in SQL statements. + + :param name: name of the user-defined aggregate function + :param javaClassName: fully qualified name of java class + + >>> spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg") + >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) + >>> df.registerTempTable("df") + >>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() + [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] + """ + + self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) + + +def _test(): + import doctest + from pyspark.sql import SparkSession + import pyspark.sql.udf + globs = pyspark.sql.udf.__dict__.copy() + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("sql.udf tests")\ + .getOrCreate() + globs['spark'] = spark + (failure_count, test_count) = doctest.testmod( + pyspark.sql.udf, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + spark.stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() From 1c76a91e5fae11dcb66c453889e587b48039fdc9 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 17 Jan 2018 22:36:29 -0800 Subject: [PATCH 128/774] [SPARK-23052][SS] Migrate ConsoleSink to data source V2 api. ## What changes were proposed in this pull request? Migrate ConsoleSink to data source V2 api. Note that this includes a missing piece in DataStreamWriter required to specify a data source V2 writer. Note also that I've removed the "Rerun batch" part of the sink, because as far as I can tell this would never have actually happened. A MicroBatchExecution object will only commit each batch once for its lifetime, and a new MicroBatchExecution object would have a new ConsoleSink object which doesn't know it's retrying a batch. So I think this represents an anti-feature rather than a weakness in the V2 API. ## How was this patch tested? new unit test Author: Jose Torres Closes #20243 from jose-torres/console-sink. --- .../streaming/MicroBatchExecution.scala | 7 +- .../sql/execution/streaming/console.scala | 62 ++--- .../continuous/ContinuousExecution.scala | 9 +- .../streaming/sources/ConsoleWriter.scala | 64 +++++ .../sources/PackedRowWriterFactory.scala | 60 +++++ .../sql/streaming/DataStreamWriter.scala | 16 +- ...pache.spark.sql.sources.DataSourceRegister | 8 + .../sources/ConsoleWriterSuite.scala | 135 ++++++++++ .../sources/StreamingDataSourceV2Suite.scala | 249 ++++++++++++++++++ .../test/DataStreamReaderWriterSuite.scala | 25 -- 10 files changed, 551 insertions(+), 84 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 70407f0580f97..7c3804547b736 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -91,11 +91,14 @@ class MicroBatchExecution( nextSourceId += 1 StreamingExecutionRelation(reader, output)(sparkSession) }) - case s @ StreamingRelationV2(_, _, _, output, v1Relation) => + case s @ StreamingRelationV2(_, sourceName, _, output, v1Relation) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - assert(v1Relation.isDefined, "v2 execution didn't match but v1 was unavailable") + if (v1Relation.isEmpty) { + throw new UnsupportedOperationException( + s"Data source $sourceName does not support microbatch processing.") + } val source = v1Relation.get.dataSource.createSource(metadataPath) nextSourceId += 1 StreamingExecutionRelation(source, output)(sparkSession) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 71eaabe273fea..94820376ff7e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -17,58 +17,36 @@ package org.apache.spark.sql.execution.streaming -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} -import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, StreamSinkProvider} -import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.StructType - -class ConsoleSink(options: Map[String, String]) extends Sink with Logging { - // Number of rows to display, by default 20 rows - private val numRowsToShow = options.get("numRows").map(_.toInt).getOrElse(20) - - // Truncate the displayed data if it is too long, by default it is true - private val isTruncated = options.get("truncate").map(_.toBoolean).getOrElse(true) +import java.util.Optional - // Track the batch id - private var lastBatchId = -1L - - override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized { - val batchIdStr = if (batchId <= lastBatchId) { - s"Rerun batch: $batchId" - } else { - lastBatchId = batchId - s"Batch: $batchId" - } - - // scalastyle:off println - println("-------------------------------------------") - println(batchIdStr) - println("-------------------------------------------") - // scalastyle:off println - data.sparkSession.createDataFrame( - data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) - .show(numRowsToShow, isTruncated) - } +import scala.collection.JavaConverters._ - override def toString(): String = s"ConsoleSink[numRows=$numRowsToShow, truncate=$isTruncated]" -} +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter +import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport +import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) extends BaseRelation { override def schema: StructType = data.schema } -class ConsoleSinkProvider extends StreamSinkProvider +class ConsoleSinkProvider extends DataSourceV2 + with MicroBatchWriteSupport with DataSourceRegister with CreatableRelationProvider { - def createSink( - sqlContext: SQLContext, - parameters: Map[String, String], - partitionColumns: Seq[String], - outputMode: OutputMode): Sink = { - new ConsoleSink(parameters) + + override def createMicroBatchWriter( + queryId: String, + epochId: Long, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[DataSourceV2Writer] = { + Optional.of(new ConsoleWriter(epochId, schema, options)) } def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index c0507224f9be8..462e7d9721d28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -54,16 +54,13 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReader] = _ + @volatile protected var continuousSources: Seq[ContinuousReader] = Seq() override protected def sources: Seq[BaseStreamingSource] = continuousSources // For use only in test harnesses. private[sql] var currentEpochCoordinatorId: String = _ - override lazy val logicalPlan: LogicalPlan = { - assert(queryExecutionThread eq Thread.currentThread, - "logicalPlan must be initialized in StreamExecutionThread " + - s"but the current thread was ${Thread.currentThread}") + override val logicalPlan: LogicalPlan = { val toExecutionRelationMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]() analyzedPlan.transform { case r @ StreamingRelationV2( @@ -72,7 +69,7 @@ class ContinuousExecution( ContinuousExecutionRelation(source, extraReaderOptions, output)(sparkSession) }) case StreamingRelationV2(_, sourceName, _, _, _) => - throw new AnalysisException( + throw new UnsupportedOperationException( s"Data source $sourceName does not support continuous processing.") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala new file mode 100644 index 0000000000000..361979984bbec --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.writer.{DataSourceV2Writer, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.types.StructType + +/** + * A [[DataSourceV2Writer]] that collects results to the driver and prints them in the console. + * Generated by [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. + * + * This sink should not be used for production, as it requires sending all rows to the driver + * and does not support recovery. + */ +class ConsoleWriter(batchId: Long, schema: StructType, options: DataSourceV2Options) + extends DataSourceV2Writer with Logging { + // Number of rows to display, by default 20 rows + private val numRowsToShow = options.getInt("numRows", 20) + + // Truncate the displayed data if it is too long, by default it is true + private val isTruncated = options.getBoolean("truncate", true) + + assert(SparkSession.getActiveSession.isDefined) + private val spark = SparkSession.getActiveSession.get + + override def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory + + override def commit(messages: Array[WriterCommitMessage]): Unit = synchronized { + val batch = messages.collect { + case PackedRowCommitMessage(rows) => rows + }.flatten + + // scalastyle:off println + println("-------------------------------------------") + println(s"Batch: $batchId") + println("-------------------------------------------") + // scalastyle:off println + spark.createDataFrame( + spark.sparkContext.parallelize(batch), schema) + .show(numRowsToShow, isTruncated) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = {} + + override def toString(): String = s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala new file mode 100644 index 0000000000000..9282ba05bdb7b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Row +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} + +/** + * A simple [[DataWriterFactory]] whose tasks just pack rows into the commit message for delivery + * to a [[org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer]] on the driver. + * + * Note that, because it sends all rows to the driver, this factory will generally be unsuitable + * for production-quality sinks. It's intended for use in tests. + */ +case object PackedRowWriterFactory extends DataWriterFactory[Row] { + def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + new PackedRowDataWriter() + } +} + +/** + * Commit message for a [[PackedRowDataWriter]], containing all the rows written in the most + * recent interval. + */ +case class PackedRowCommitMessage(rows: Array[Row]) extends WriterCommitMessage + +/** + * A simple [[DataWriter]] that just sends all the rows it's received as a commit message. + */ +class PackedRowDataWriter() extends DataWriter[Row] with Logging { + private val data = mutable.Buffer[Row]() + + override def write(row: Row): Unit = data.append(row) + + override def commit(): PackedRowCommitMessage = { + val msg = PackedRowCommitMessage(data.toArray) + data.clear() + msg + } + + override def abort(): Unit = data.clear() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index b5b4a05ab4973..d24f0ddeab4de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport +import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -280,14 +280,12 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val sink = trigger match { - case _: ContinuousTrigger => - val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) - ds.newInstance() match { - case w: ContinuousWriteSupport => w - case _ => throw new AnalysisException( - s"Data source $source does not support continuous writing") - } + val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) + val sink = (ds.newInstance(), trigger) match { + case (w: ContinuousWriteSupport, _: ContinuousTrigger) => w + case (_, _: ContinuousTrigger) => throw new UnsupportedOperationException( + s"Data source $source does not support continuous writing") + case (w: MicroBatchWriteSupport, _) => w case _ => val ds = DataSource( df.sparkSession, diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index c6973bf41d34b..a0b25b4e82364 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,3 +5,11 @@ org.apache.spark.sql.sources.FakeSourceFour org.apache.fakesource.FakeExternalSourceOne org.apache.fakesource.FakeExternalSourceTwo org.apache.fakesource.FakeExternalSourceThree +org.apache.spark.sql.streaming.sources.FakeReadMicroBatchOnly +org.apache.spark.sql.streaming.sources.FakeReadContinuousOnly +org.apache.spark.sql.streaming.sources.FakeReadBothModes +org.apache.spark.sql.streaming.sources.FakeReadNeitherMode +org.apache.spark.sql.streaming.sources.FakeWriteMicroBatchOnly +org.apache.spark.sql.streaming.sources.FakeWriteContinuousOnly +org.apache.spark.sql.streaming.sources.FakeWriteBothModes +org.apache.spark.sql.streaming.sources.FakeWriteNeitherMode diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala new file mode 100644 index 0000000000000..60ffee9b9b42c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io.ByteArrayOutputStream + +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming.StreamTest + +class ConsoleWriterSuite extends StreamTest { + import testImplicits._ + + test("console") { + val input = MemoryStream[Int] + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val query = input.toDF().writeStream.format("console").start() + try { + input.addData(1, 2, 3) + query.processAllAvailable() + input.addData(4, 5, 6) + query.processAllAvailable() + input.addData() + query.processAllAvailable() + } finally { + query.stop() + } + } + + assert(captured.toString() == + """------------------------------------------- + |Batch: 0 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + || 1| + || 2| + || 3| + |+-----+ + | + |------------------------------------------- + |Batch: 1 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + || 4| + || 5| + || 6| + |+-----+ + | + |------------------------------------------- + |Batch: 2 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + |+-----+ + | + |""".stripMargin) + } + + test("console with numRows") { + val input = MemoryStream[Int] + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val query = input.toDF().writeStream.format("console").option("NUMROWS", 2).start() + try { + input.addData(1, 2, 3) + query.processAllAvailable() + } finally { + query.stop() + } + } + + assert(captured.toString() == + """------------------------------------------- + |Batch: 0 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + || 1| + || 2| + |+-----+ + |only showing top 2 rows + | + |""".stripMargin) + } + + test("console with truncation") { + val input = MemoryStream[String] + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val query = input.toDF().writeStream.format("console").option("TRUNCATE", true).start() + try { + input.addData("123456789012345678901234567890") + query.processAllAvailable() + } finally { + query.stop() + } + } + + assert(captured.toString() == + """------------------------------------------- + |Batch: 0 + |------------------------------------------- + |+--------------------+ + || value| + |+--------------------+ + ||12345678901234567...| + |+--------------------+ + | + |""".stripMargin) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala new file mode 100644 index 0000000000000..f152174b0a7f0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.sources + +import java.util.Optional + +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming.{LongOffset, RateStreamOffset} +import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.reader.ReadTask +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport, MicroBatchReadSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer +import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest, Trigger} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + +case class FakeReader() extends MicroBatchReader with ContinuousReader { + def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = {} + def getStartOffset: Offset = RateStreamOffset(Map()) + def getEndOffset: Offset = RateStreamOffset(Map()) + def deserializeOffset(json: String): Offset = RateStreamOffset(Map()) + def commit(end: Offset): Unit = {} + def readSchema(): StructType = StructType(Seq()) + def stop(): Unit = {} + def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) + def setOffset(start: Optional[Offset]): Unit = {} + + def createReadTasks(): java.util.ArrayList[ReadTask[Row]] = { + throw new IllegalStateException("fake source - cannot actually read") + } +} + +trait FakeMicroBatchReadSupport extends MicroBatchReadSupport { + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceV2Options): MicroBatchReader = FakeReader() +} + +trait FakeContinuousReadSupport extends ContinuousReadSupport { + override def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceV2Options): ContinuousReader = FakeReader() +} + +trait FakeMicroBatchWriteSupport extends MicroBatchWriteSupport { + def createMicroBatchWriter( + queryId: String, + epochId: Long, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[DataSourceV2Writer] = { + throw new IllegalStateException("fake sink - cannot actually write") + } +} + +trait FakeContinuousWriteSupport extends ContinuousWriteSupport { + def createContinuousWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[ContinuousWriter] = { + throw new IllegalStateException("fake sink - cannot actually write") + } +} + +class FakeReadMicroBatchOnly extends DataSourceRegister with FakeMicroBatchReadSupport { + override def shortName(): String = "fake-read-microbatch-only" +} + +class FakeReadContinuousOnly extends DataSourceRegister with FakeContinuousReadSupport { + override def shortName(): String = "fake-read-continuous-only" +} + +class FakeReadBothModes extends DataSourceRegister + with FakeMicroBatchReadSupport with FakeContinuousReadSupport { + override def shortName(): String = "fake-read-microbatch-continuous" +} + +class FakeReadNeitherMode extends DataSourceRegister { + override def shortName(): String = "fake-read-neither-mode" +} + +class FakeWriteMicroBatchOnly extends DataSourceRegister with FakeMicroBatchWriteSupport { + override def shortName(): String = "fake-write-microbatch-only" +} + +class FakeWriteContinuousOnly extends DataSourceRegister with FakeContinuousWriteSupport { + override def shortName(): String = "fake-write-continuous-only" +} + +class FakeWriteBothModes extends DataSourceRegister + with FakeMicroBatchWriteSupport with FakeContinuousWriteSupport { + override def shortName(): String = "fake-write-microbatch-continuous" +} + +class FakeWriteNeitherMode extends DataSourceRegister { + override def shortName(): String = "fake-write-neither-mode" +} + +class StreamingDataSourceV2Suite extends StreamTest { + + override def beforeAll(): Unit = { + super.beforeAll() + val fakeCheckpoint = Utils.createTempDir() + spark.conf.set("spark.sql.streaming.checkpointLocation", fakeCheckpoint.getCanonicalPath) + } + + val readFormats = Seq( + "fake-read-microbatch-only", + "fake-read-continuous-only", + "fake-read-microbatch-continuous", + "fake-read-neither-mode") + val writeFormats = Seq( + "fake-write-microbatch-only", + "fake-write-continuous-only", + "fake-write-microbatch-continuous", + "fake-write-neither-mode") + val triggers = Seq( + Trigger.Once(), + Trigger.ProcessingTime(1000), + Trigger.Continuous(1000)) + + private def testPositiveCase(readFormat: String, writeFormat: String, trigger: Trigger) = { + val query = spark.readStream + .format(readFormat) + .load() + .writeStream + .format(writeFormat) + .trigger(trigger) + .start() + query.stop() + } + + private def testNegativeCase( + readFormat: String, + writeFormat: String, + trigger: Trigger, + errorMsg: String) = { + val ex = intercept[UnsupportedOperationException] { + testPositiveCase(readFormat, writeFormat, trigger) + } + assert(ex.getMessage.contains(errorMsg)) + } + + private def testPostCreationNegativeCase( + readFormat: String, + writeFormat: String, + trigger: Trigger, + errorMsg: String) = { + val query = spark.readStream + .format(readFormat) + .load() + .writeStream + .format(writeFormat) + .trigger(trigger) + .start() + + eventually(timeout(streamingTimeout)) { + assert(query.exception.isDefined) + assert(query.exception.get.cause != null) + assert(query.exception.get.cause.getMessage.contains(errorMsg)) + } + } + + // Get a list of (read, write, trigger) tuples for test cases. + val cases = readFormats.flatMap { read => + writeFormats.flatMap { write => + triggers.map(t => (write, t)) + }.map { + case (write, t) => (read, write, t) + } + } + + for ((read, write, trigger) <- cases) { + testQuietly(s"stream with read format $read, write format $write, trigger $trigger") { + val readSource = DataSource.lookupDataSource(read, spark.sqlContext.conf).newInstance() + val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf).newInstance() + (readSource, writeSource, trigger) match { + // Valid microbatch queries. + case (_: MicroBatchReadSupport, _: MicroBatchWriteSupport, t) + if !t.isInstanceOf[ContinuousTrigger] => + testPositiveCase(read, write, trigger) + + // Valid continuous queries. + case (_: ContinuousReadSupport, _: ContinuousWriteSupport, _: ContinuousTrigger) => + testPositiveCase(read, write, trigger) + + // Invalid - can't read at all + case (r, _, _) + if !r.isInstanceOf[MicroBatchReadSupport] + && !r.isInstanceOf[ContinuousReadSupport] => + testNegativeCase(read, write, trigger, + s"Data source $read does not support streamed reading") + + // Invalid - trigger is continuous but writer is not + case (_, w, _: ContinuousTrigger) if !w.isInstanceOf[ContinuousWriteSupport] => + testNegativeCase(read, write, trigger, + s"Data source $write does not support continuous writing") + + // Invalid - can't write at all + case (_, w, _) + if !w.isInstanceOf[MicroBatchWriteSupport] + && !w.isInstanceOf[ContinuousWriteSupport] => + testNegativeCase(read, write, trigger, + s"Data source $write does not support streamed writing") + + // Invalid - trigger and writer are continuous but reader is not + case (r, _: ContinuousWriteSupport, _: ContinuousTrigger) + if !r.isInstanceOf[ContinuousReadSupport] => + testNegativeCase(read, write, trigger, + s"Data source $read does not support continuous processing") + + // Invalid - trigger is microbatch but writer is not + case (_, w, t) + if !w.isInstanceOf[MicroBatchWriteSupport] && !t.isInstanceOf[ContinuousTrigger] => + testNegativeCase(read, write, trigger, + s"Data source $write does not support streamed writing") + + // Invalid - trigger and writer are microbatch but reader is not + case (r, _, t) + if !r.isInstanceOf[MicroBatchReadSupport] && !t.isInstanceOf[ContinuousTrigger] => + testPostCreationNegativeCase(read, write, trigger, + s"Data source $read does not support microbatch processing") + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index aa163d2211c38..8212fb912ec57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -422,21 +422,6 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { } } - test("ConsoleSink can be correctly loaded") { - LastOptions.clear() - val df = spark.readStream - .format("org.apache.spark.sql.streaming.test") - .load() - - val sq = df.writeStream - .format("console") - .option("checkpointLocation", newMetadataDir) - .trigger(ProcessingTime(2.seconds)) - .start() - - sq.awaitTermination(2000L) - } - test("prevent all column partitioning") { withTempDir { dir => val path = dir.getCanonicalPath @@ -450,16 +435,6 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { } } - test("ConsoleSink should not require checkpointLocation") { - LastOptions.clear() - val df = spark.readStream - .format("org.apache.spark.sql.streaming.test") - .load() - - val sq = df.writeStream.format("console").start() - sq.stop() - } - private def testMemorySinkCheckpointRecovery(chkLoc: String, provideInWriter: Boolean): Unit = { import testImplicits._ val ms = new MemoryStream[Int](0, sqlContext) From 7a2248341396840628eef398aa512cac3e3bd55f Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 18 Jan 2018 19:18:55 +0800 Subject: [PATCH 129/774] [SPARK-23140][SQL] Add DataSourceV2Strategy to Hive Session state's planner ## What changes were proposed in this pull request? `DataSourceV2Strategy` is missing in `HiveSessionStateBuilder`'s planner, which will throw exception as described in [SPARK-23140](https://issues.apache.org/jira/browse/SPARK-23140). ## How was this patch tested? Manual test. Author: jerryshao Closes #20305 from jerryshao/SPARK-23140. --- .../sql/hive/HiveSessionStateBuilder.scala | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index dc92ad3b0c1ac..12c74368dd184 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -96,22 +96,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session override val sparkSession: SparkSession = session override def extraPlanningStrategies: Seq[Strategy] = - super.extraPlanningStrategies ++ customPlanningStrategies - - override def strategies: Seq[Strategy] = { - experimentalMethods.extraStrategies ++ - extraPlanningStrategies ++ Seq( - FileSourceStrategy, - DataSourceStrategy(conf), - SpecialLimits, - InMemoryScans, - HiveTableScans, - Scripts, - Aggregation, - JoinSelection, - BasicOperators - ) - } + super.extraPlanningStrategies ++ customPlanningStrategies ++ Seq(HiveTableScans, Scripts) } } From e28eb431146bcdcaf02a6f6c406ca30920592a6a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 18 Jan 2018 21:24:39 +0800 Subject: [PATCH 130/774] [SPARK-22036][SQL] Decimal multiplication with high precision/scale often returns NULL ## What changes were proposed in this pull request? When there is an operation between Decimals and the result is a number which is not representable exactly with the result's precision and scale, Spark is returning `NULL`. This was done to reflect Hive's behavior, but it is against SQL ANSI 2011, which states that "If the result cannot be represented exactly in the result type, then whether it is rounded or truncated is implementation-defined". Moreover, Hive now changed its behavior in order to respect the standard, thanks to HIVE-15331. Therefore, the PR propose to: - update the rules to determine the result precision and scale according to the new Hive's ones introduces in HIVE-15331; - round the result of the operations, when it is not representable exactly with the result's precision and scale, instead of returning `NULL` - introduce a new config `spark.sql.decimalOperations.allowPrecisionLoss` which default to `true` (ie. the new behavior) in order to allow users to switch back to the previous one. Hive behavior reflects SQLServer's one. The only difference is that the precision and scale are adjusted for all the arithmetic operations in Hive, while SQL Server is said to do so only for multiplications and divisions in the documentation. This PR follows Hive's behavior. A more detailed explanation is available here: https://mail-archives.apache.org/mod_mbox/spark-dev/201712.mbox/%3CCAEorWNAJ4TxJR9NBcgSFMD_VxTg8qVxusjP%2BAJP-x%2BJV9zH-yA%40mail.gmail.com%3E. ## How was this patch tested? modified and added UTs. Comparisons with results of Hive and SQLServer. Author: Marco Gaido Closes #20023 from mgaido91/SPARK-22036. --- docs/sql-programming-guide.md | 5 + .../catalyst/analysis/DecimalPrecision.scala | 114 +++++--- .../sql/catalyst/expressions/literals.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 12 + .../apache/spark/sql/types/DecimalType.scala | 45 +++- .../sql/catalyst/analysis/AnalysisSuite.scala | 4 +- .../analysis/DecimalPrecisionSuite.scala | 20 +- .../native/decimalArithmeticOperations.sql | 47 ++++ .../decimalArithmeticOperations.sql.out | 245 ++++++++++++++++-- .../native/decimalPrecision.sql.out | 4 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 18 -- 11 files changed, 434 insertions(+), 82 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 258c769ff593b..3e2e48a0ef249 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1793,6 +1793,11 @@ options. - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. + - Since Spark 2.3, by default arithmetic operations between decimals return a rounded value if an exact representation is not possible (instead of returning NULL). This is compliant to SQL ANSI 2011 specification and Hive's new behavior introduced in Hive 2.2 (HIVE-15331). This involves the following changes + - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. All the arithmetic operations are affected by the change, ie. addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`) and positive module (`pmod`). + - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. + - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. + ## Upgrading From Spark SQL 2.1 to 2.2 - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index a8100b9b24aac..ab63131b07573 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -42,8 +43,10 @@ import org.apache.spark.sql.types._ * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) - * sum(e1) p1 + 10 s1 - * avg(e1) p1 + 4 s1 + 4 + * + * When `spark.sql.decimalOperations.allowPrecisionLoss` is set to true, if the precision / scale + * needed are out of the range of available values, the scale is reduced up to 6, in order to + * prevent the truncation of the integer part of the decimals. * * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited * precision, do the math on unlimited-precision numbers, then introduce casts back to the @@ -56,6 +59,7 @@ import org.apache.spark.sql.types._ * - INT gets turned into DECIMAL(10, 0) * - LONG gets turned into DECIMAL(20, 0) * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE + * - Literals INT and LONG get turned into DECIMAL with the precision strictly needed by the value */ // scalastyle:on object DecimalPrecision extends TypeCoercionRule { @@ -93,41 +97,76 @@ object DecimalPrecision extends TypeCoercionRule { case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - CheckOverflow(Add(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) + val resultScale = max(s1, s2) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, + resultScale) + } else { + DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) + } + CheckOverflow(Add(promotePrecision(e1, resultType), promotePrecision(e2, resultType)), + resultType) case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - CheckOverflow(Subtract(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) + val resultScale = max(s1, s2) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, + resultScale) + } else { + DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) + } + CheckOverflow(Subtract(promotePrecision(e1, resultType), promotePrecision(e2, resultType)), + resultType) case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(p1 + p2 + 1, s1 + s2) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2) + } else { + DecimalType.bounded(p1 + p2 + 1, s1 + s2) + } val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) - var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) - val diff = (intDig + decDig) - DecimalType.MAX_SCALE - if (diff > 0) { - decDig -= diff / 2 + 1 - intDig = DecimalType.MAX_SCALE - decDig + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) + // Scale: max(6, s1 + p2 + 1) + val intDig = p1 - s1 + s2 + val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1) + val prec = intDig + scale + DecimalType.adjustPrecisionScale(prec, scale) + } else { + var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) + var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) + val diff = (intDig + decDig) - DecimalType.MAX_SCALE + if (diff > 0) { + decDig -= diff / 2 + 1 + intDig = DecimalType.MAX_SCALE - decDig + } + DecimalType.bounded(intDig + decDig, decDig) } - val resultType = DecimalType.bounded(intDig + decDig, decDig) val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } else { + DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } else { + DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), @@ -137,9 +176,6 @@ object DecimalPrecision extends TypeCoercionRule { e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => val resultType = widerDecimalType(p1, s1, p2, s2) b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) - - // TODO: MaxOf, MinOf, etc might want other rules - // SUM and AVERAGE are handled by the implementations of those expressions } /** @@ -243,17 +279,35 @@ object DecimalPrecision extends TypeCoercionRule { // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - (left.dataType, right.dataType) match { - case (t: IntegralType, DecimalType.Fixed(p, s)) => - b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right)) - case (DecimalType.Fixed(p, s), t: IntegralType) => - b.makeCopy(Array(left, Cast(right, DecimalType.forType(t)))) - case (t, DecimalType.Fixed(p, s)) if isFloat(t) => - b.makeCopy(Array(left, Cast(right, DoubleType))) - case (DecimalType.Fixed(p, s), t) if isFloat(t) => - b.makeCopy(Array(Cast(left, DoubleType), right)) - case _ => - b + (left, right) match { + // Promote literal integers inside a binary expression with fixed-precision decimals to + // decimals. The precision and scale are the ones strictly needed by the integer value. + // Requiring more precision than necessary may lead to a useless loss of precision. + // Consider the following example: multiplying a column which is DECIMAL(38, 18) by 2. + // If we use the default precision and scale for the integer type, 2 is considered a + // DECIMAL(10, 0). According to the rules, the result would be DECIMAL(38 + 10 + 1, 18), + // which is out of range and therefore it will becomes DECIMAL(38, 7), leading to + // potentially loosing 11 digits of the fractional part. Using only the precision needed + // by the Literal, instead, the result would be DECIMAL(38 + 1 + 1, 18), which would + // become DECIMAL(38, 16), safely having a much lower precision loss. + case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] + && l.dataType.isInstanceOf[IntegralType] => + b.makeCopy(Array(Cast(l, DecimalType.fromLiteral(l)), r)) + case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] + && r.dataType.isInstanceOf[IntegralType] => + b.makeCopy(Array(l, Cast(r, DecimalType.fromLiteral(r)))) + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case (l @ IntegralType(), r @ DecimalType.Expression(_, _)) => + b.makeCopy(Array(Cast(l, DecimalType.forType(l.dataType)), r)) + case (l @ DecimalType.Expression(_, _), r @ IntegralType()) => + b.makeCopy(Array(l, Cast(r, DecimalType.forType(r.dataType)))) + case (l, r @ DecimalType.Expression(_, _)) if isFloat(l.dataType) => + b.makeCopy(Array(l, Cast(r, DoubleType))) + case (l @ DecimalType.Expression(_, _), r) if isFloat(r.dataType) => + b.makeCopy(Array(Cast(l, DoubleType), r)) + case _ => b } } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 383203a209833..cd176d941819f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -58,7 +58,7 @@ object Literal { case s: Short => Literal(s, ShortType) case s: String => Literal(UTF8String.fromString(s), StringType) case b: Boolean => Literal(b, BooleanType) - case d: BigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale)) + case d: BigDecimal => Literal(Decimal(d), DecimalType.fromBigDecimal(d)) case d: JavaBigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale())) case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 16fbb0c3e9e21..cc4f4bf332459 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1064,6 +1064,16 @@ object SQLConf { .booleanConf .createWithDefault(true) + val DECIMAL_OPERATIONS_ALLOW_PREC_LOSS = + buildConf("spark.sql.decimalOperations.allowPrecisionLoss") + .internal() + .doc("When true (default), establishing the result type of an arithmetic operation " + + "happens according to Hive behavior and SQL ANSI 2011 specification, ie. rounding the " + + "decimal part of the result if an exact representation is not possible. Otherwise, NULL " + + "is returned in those cases, as previously.") + .booleanConf + .createWithDefault(true) + val SQL_STRING_REDACTION_PATTERN = ConfigBuilder("spark.sql.redaction.string.regex") .doc("Regex to decide which parts of strings produced by Spark contain sensitive " + @@ -1441,6 +1451,8 @@ class SQLConf extends Serializable with Logging { def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) + def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) + def continuousStreamingExecutorQueueSize: Int = getConf(CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) def continuousStreamingExecutorPollIntervalMs: Long = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 6e050c18b8acb..ef3b67c0d48d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -23,7 +23,7 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} /** @@ -117,6 +117,7 @@ object DecimalType extends AbstractDataType { val MAX_SCALE = 38 val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18) val USER_DEFAULT: DecimalType = DecimalType(10, 0) + val MINIMUM_ADJUSTED_SCALE = 6 // The decimal types compatible with other numeric types private[sql] val ByteDecimal = DecimalType(3, 0) @@ -136,10 +137,52 @@ object DecimalType extends AbstractDataType { case DoubleType => DoubleDecimal } + private[sql] def fromLiteral(literal: Literal): DecimalType = literal.value match { + case v: Short => fromBigDecimal(BigDecimal(v)) + case v: Int => fromBigDecimal(BigDecimal(v)) + case v: Long => fromBigDecimal(BigDecimal(v)) + case _ => forType(literal.dataType) + } + + private[sql] def fromBigDecimal(d: BigDecimal): DecimalType = { + DecimalType(Math.max(d.precision, d.scale), d.scale) + } + private[sql] def bounded(precision: Int, scale: Int): DecimalType = { DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE)) } + /** + * Scale adjustment implementation is based on Hive's one, which is itself inspired to + * SQLServer's one. In particular, when a result precision is greater than + * {@link #MAX_PRECISION}, the corresponding scale is reduced to prevent the integral part of a + * result from being truncated. + * + * This method is used only when `spark.sql.decimalOperations.allowPrecisionLoss` is set to true. + */ + private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = { + // Assumptions: + assert(precision >= scale) + assert(scale >= 0) + + if (precision <= MAX_PRECISION) { + // Adjustment only needed when we exceed max precision + DecimalType(precision, scale) + } else { + // Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION. + val intDigits = precision - scale + // If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise + // preserve at least MINIMUM_ADJUSTED_SCALE fractional digits + val minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE) + // The resulting scale is the maximum between what is available without causing a loss of + // digits for the integer part of the decimal and the minimum guaranteed scale, which is + // computed above + val adjustedScale = Math.max(MAX_PRECISION - intDigits, minScaleValue) + + DecimalType(MAX_PRECISION, adjustedScale) + } + } + override private[sql] def defaultConcreteType: DataType = SYSTEM_DEFAULT override private[sql] def acceptsType(other: DataType): Boolean = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index f4514205d3ae0..cd8579584eada 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -408,8 +408,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { assertExpressionType(sum(Divide(1.0, 2.0)), DoubleType) assertExpressionType(sum(Divide(1, 2.0f)), DoubleType) assertExpressionType(sum(Divide(1.0f, 2)), DoubleType) - assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(31, 11)) - assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(31, 11)) + assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(22, 11)) + assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(26, 6)) assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType) assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 60e46a9910a8b..c86dc18dfa680 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -136,19 +136,19 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter { test("maximum decimals") { for (expr <- Seq(d1, d2, i, u)) { - checkType(Add(expr, u), DecimalType.SYSTEM_DEFAULT) - checkType(Subtract(expr, u), DecimalType.SYSTEM_DEFAULT) + checkType(Add(expr, u), DecimalType(38, 17)) + checkType(Subtract(expr, u), DecimalType(38, 17)) } - checkType(Multiply(d1, u), DecimalType(38, 19)) - checkType(Multiply(d2, u), DecimalType(38, 20)) - checkType(Multiply(i, u), DecimalType(38, 18)) - checkType(Multiply(u, u), DecimalType(38, 36)) + checkType(Multiply(d1, u), DecimalType(38, 16)) + checkType(Multiply(d2, u), DecimalType(38, 14)) + checkType(Multiply(i, u), DecimalType(38, 7)) + checkType(Multiply(u, u), DecimalType(38, 6)) - checkType(Divide(u, d1), DecimalType(38, 18)) - checkType(Divide(u, d2), DecimalType(38, 19)) - checkType(Divide(u, i), DecimalType(38, 23)) - checkType(Divide(u, u), DecimalType(38, 18)) + checkType(Divide(u, d1), DecimalType(38, 17)) + checkType(Divide(u, d2), DecimalType(38, 16)) + checkType(Divide(u, i), DecimalType(38, 18)) + checkType(Divide(u, u), DecimalType(38, 6)) checkType(Remainder(d1, u), DecimalType(19, 18)) checkType(Remainder(d2, u), DecimalType(21, 18)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql index c8e108ac2c45e..c6d8a49d4b93a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql @@ -22,6 +22,51 @@ select a / b from t; select a % b from t; select pmod(a, b) from t; +-- tests for decimals handling in operations +create table decimals_test(id int, a decimal(38,18), b decimal(38,18)) using parquet; + +insert into decimals_test values(1, 100.0, 999.0), (2, 12345.123, 12345.123), + (3, 0.1234567891011, 1234.1), (4, 123456789123456789.0, 1.123456789123456789); + +-- test decimal operations +select id, a+b, a-b, a*b, a/b from decimals_test order by id; + +-- test operations between decimals and constants +select id, a*10, b/10 from decimals_test order by id; + +-- test operations on constants +select 10.3 * 3.0; +select 10.3000 * 3.0; +select 10.30000 * 30.0; +select 10.300000000000000000 * 3.000000000000000000; +select 10.300000000000000000 * 3.0000000000000000000; + +-- arithmetic operations causing an overflow return NULL +select (5e36 + 0.1) + 5e36; +select (-4e36 - 0.1) - 7e36; +select 12345678901234567890.0 * 12345678901234567890.0; +select 1e35 / 0.1; + +-- arithmetic operations causing a precision loss are truncated +select 123456789123456789.1234567890 * 1.123456789123456789; +select 0.001 / 9876543210987654321098765432109876543.2 + +-- return NULL instead of rounding, according to old Spark versions' behavior +set spark.sql.decimalOperations.allowPrecisionLoss=false; + +-- test decimal operations +select id, a+b, a-b, a*b, a/b from decimals_test order by id; + +-- test operations between decimals and constants +select id, a*10, b/10 from decimals_test order by id; + +-- test operations on constants +select 10.3 * 3.0; +select 10.3000 * 3.0; +select 10.30000 * 30.0; +select 10.300000000000000000 * 3.000000000000000000; +select 10.300000000000000000 * 3.0000000000000000000; + -- arithmetic operations causing an overflow return NULL select (5e36 + 0.1) + 5e36; select (-4e36 - 0.1) - 7e36; @@ -31,3 +76,5 @@ select 1e35 / 0.1; -- arithmetic operations causing a precision loss return NULL select 123456789123456789.1234567890 * 1.123456789123456789; select 0.001 / 9876543210987654321098765432109876543.2 + +drop table decimals_test; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out index ce02f6adc456c..4d70fe19d539f 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 32 -- !query 0 @@ -35,48 +35,257 @@ NULL -- !query 4 -select (5e36 + 0.1) + 5e36 +create table decimals_test(id int, a decimal(38,18), b decimal(38,18)) using parquet -- !query 4 schema -struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<> -- !query 4 output -NULL + -- !query 5 -select (-4e36 - 0.1) - 7e36 +insert into decimals_test values(1, 100.0, 999.0), (2, 12345.123, 12345.123), + (3, 0.1234567891011, 1234.1), (4, 123456789123456789.0, 1.123456789123456789) -- !query 5 schema -struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<> -- !query 5 output -NULL + -- !query 6 -select 12345678901234567890.0 * 12345678901234567890.0 +select id, a+b, a-b, a*b, a/b from decimals_test order by id -- !query 6 schema -struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +struct -- !query 6 output -NULL +1 1099 -899 99900 0.1001 +2 24690.246 0 152402061.885129 1 +3 1234.2234567891011 -1233.9765432108989 152.358023 0.0001 +4 123456789123456790.12345678912345679 123456789123456787.87654321087654321 138698367904130467.515623 109890109097814272.043109 -- !query 7 -select 1e35 / 0.1 +select id, a*10, b/10 from decimals_test order by id -- !query 7 schema -struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)> +struct -- !query 7 output -NULL +1 1000 99.9 +2 123451.23 1234.5123 +3 1.234567891011 123.41 +4 1234567891234567890 0.112345678912345679 -- !query 8 -select 123456789123456789.1234567890 * 1.123456789123456789 +select 10.3 * 3.0 -- !query 8 schema -struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)> +struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> -- !query 8 output -NULL +30.9 -- !query 9 -select 0.001 / 9876543210987654321098765432109876543.2 +select 10.3000 * 3.0 -- !query 9 schema -struct<(CAST(0.001 AS DECIMAL(38,3)) / CAST(9876543210987654321098765432109876543.2 AS DECIMAL(38,3))):decimal(38,37)> +struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> -- !query 9 output +30.9 + + +-- !query 10 +select 10.30000 * 30.0 +-- !query 10 schema +struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> +-- !query 10 output +309 + + +-- !query 11 +select 10.300000000000000000 * 3.000000000000000000 +-- !query 11 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,34)> +-- !query 11 output +30.9 + + +-- !query 12 +select 10.300000000000000000 * 3.0000000000000000000 +-- !query 12 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,34)> +-- !query 12 output +30.9 + + +-- !query 13 +select (5e36 + 0.1) + 5e36 +-- !query 13 schema +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 13 output +NULL + + +-- !query 14 +select (-4e36 - 0.1) - 7e36 +-- !query 14 schema +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 14 output +NULL + + +-- !query 15 +select 12345678901234567890.0 * 12345678901234567890.0 +-- !query 15 schema +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +-- !query 15 output NULL + + +-- !query 16 +select 1e35 / 0.1 +-- !query 16 schema +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> +-- !query 16 output +NULL + + +-- !query 17 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 17 schema +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> +-- !query 17 output +138698367904130467.654320988515622621 + + +-- !query 18 +select 0.001 / 9876543210987654321098765432109876543.2 + +set spark.sql.decimalOperations.allowPrecisionLoss=false +-- !query 18 schema +struct<> +-- !query 18 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'spark' expecting (line 3, pos 4) + +== SQL == +select 0.001 / 9876543210987654321098765432109876543.2 + +set spark.sql.decimalOperations.allowPrecisionLoss=false +----^^^ + + +-- !query 19 +select id, a+b, a-b, a*b, a/b from decimals_test order by id +-- !query 19 schema +struct +-- !query 19 output +1 1099 -899 99900 0.1001 +2 24690.246 0 152402061.885129 1 +3 1234.2234567891011 -1233.9765432108989 152.358023 0.0001 +4 123456789123456790.12345678912345679 123456789123456787.87654321087654321 138698367904130467.515623 109890109097814272.043109 + + +-- !query 20 +select id, a*10, b/10 from decimals_test order by id +-- !query 20 schema +struct +-- !query 20 output +1 1000 99.9 +2 123451.23 1234.5123 +3 1.234567891011 123.41 +4 1234567891234567890 0.112345678912345679 + + +-- !query 21 +select 10.3 * 3.0 +-- !query 21 schema +struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> +-- !query 21 output +30.9 + + +-- !query 22 +select 10.3000 * 3.0 +-- !query 22 schema +struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> +-- !query 22 output +30.9 + + +-- !query 23 +select 10.30000 * 30.0 +-- !query 23 schema +struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> +-- !query 23 output +309 + + +-- !query 24 +select 10.300000000000000000 * 3.000000000000000000 +-- !query 24 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,34)> +-- !query 24 output +30.9 + + +-- !query 25 +select 10.300000000000000000 * 3.0000000000000000000 +-- !query 25 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,34)> +-- !query 25 output +30.9 + + +-- !query 26 +select (5e36 + 0.1) + 5e36 +-- !query 26 schema +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 26 output +NULL + + +-- !query 27 +select (-4e36 - 0.1) - 7e36 +-- !query 27 schema +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 27 output +NULL + + +-- !query 28 +select 12345678901234567890.0 * 12345678901234567890.0 +-- !query 28 schema +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +-- !query 28 output +NULL + + +-- !query 29 +select 1e35 / 0.1 +-- !query 29 schema +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> +-- !query 29 output +NULL + + +-- !query 30 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 30 schema +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> +-- !query 30 output +138698367904130467.654320988515622621 + + +-- !query 31 +select 0.001 / 9876543210987654321098765432109876543.2 + +drop table decimals_test +-- !query 31 schema +struct<> +-- !query 31 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'table' expecting (line 3, pos 5) + +== SQL == +select 0.001 / 9876543210987654321098765432109876543.2 + +drop table decimals_test +-----^^^ diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out index ebc8201ed5a1d..6ee7f59d69877 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out @@ -2329,7 +2329,7 @@ struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) / CAST(C -- !query 280 SELECT cast(1 as bigint) / cast(1 as decimal(20, 0)) FROM t -- !query 280 schema -struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) / CAST(1 AS DECIMAL(20,0))):decimal(38,19)> +struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) / CAST(1 AS DECIMAL(20,0))):decimal(38,18)> -- !query 280 output 1 @@ -2661,7 +2661,7 @@ struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) / CAST(CAST(CAST(1 AS BI -- !query 320 SELECT cast(1 as decimal(20, 0)) / cast(1 as bigint) FROM t -- !query 320 schema -struct<(CAST(1 AS DECIMAL(20,0)) / CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):decimal(38,19)> +struct<(CAST(1 AS DECIMAL(20,0)) / CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):decimal(38,18)> -- !query 320 output 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d4d0aa4f5f5eb..083a0c0b1b9a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1517,24 +1517,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - test("decimal precision with multiply/division") { - checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90"))) - checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000"))) - checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000"))) - checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"), - Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) - checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"), - Row(null)) - - checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333"))) - checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333"))) - checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333"))) - checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"), - Row(BigDecimal("3.433333333333333333333333333", new MathContext(38)))) - checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), - Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38)))) - } - test("SPARK-10215 Div of Decimal returns null") { val d = Decimal(1.12321).toBigDecimal val df = Seq((d, 1)).toDF("a", "b") From 5063b7481173ad72bd0dc941b5cf3c9b26a591e4 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 18 Jan 2018 22:33:04 +0900 Subject: [PATCH 131/774] [SPARK-23141][SQL][PYSPARK] Support data type string as a returnType for registerJavaFunction. ## What changes were proposed in this pull request? Currently `UDFRegistration.registerJavaFunction` doesn't support data type string as a `returnType` whereas `UDFRegistration.register`, `udf`, or `pandas_udf` does. We can support it for `UDFRegistration.registerJavaFunction` as well. ## How was this patch tested? Added a doctest and existing tests. Author: Takuya UESHIN Closes #20307 from ueshin/issues/SPARK-23141. --- python/pyspark/sql/functions.py | 6 ++++-- python/pyspark/sql/udf.py | 14 ++++++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 988c1d25259bc..961b3267b44cf 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2108,7 +2108,8 @@ def udf(f=None, returnType=StringType()): can fail on special rows, the workaround is to incorporate the condition into the functions. :param f: python function if used as a standalone function - :param returnType: a :class:`pyspark.sql.types.DataType` object + :param returnType: the return type of the user-defined function. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. >>> from pyspark.sql.types import IntegerType >>> slen = udf(lambda s: len(s), IntegerType()) @@ -2148,7 +2149,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): Creates a vectorized user defined function (UDF). :param f: user-defined function. A python function if used as a standalone function - :param returnType: a :class:`pyspark.sql.types.DataType` object + :param returnType: the return type of the user-defined function. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. :param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`. Default: SCALAR. diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 1943bb73f9ac2..c77f19f89a442 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -206,7 +206,8 @@ def register(self, name, f, returnType=None): :param f: a Python function, or a user-defined function. The user-defined function can be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and :meth:`pyspark.sql.functions.pandas_udf`. - :param returnType: the return type of the registered user-defined function. + :param returnType: the return type of the registered user-defined function. The value can + be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. :return: a user-defined function. `returnType` can be optionally specified when `f` is a Python function but not @@ -303,21 +304,30 @@ def registerJavaFunction(self, name, javaClassName, returnType=None): :param name: name of the user-defined function :param javaClassName: fully qualified name of java class - :param returnType: a :class:`pyspark.sql.types.DataType` object + :param returnType: the return type of the registered Java function. The value can be either + a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. >>> from pyspark.sql.types import IntegerType >>> spark.udf.registerJavaFunction( ... "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType()) >>> spark.sql("SELECT javaStringLength('test')").collect() [Row(UDF:javaStringLength(test)=4)] + >>> spark.udf.registerJavaFunction( ... "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength") >>> spark.sql("SELECT javaStringLength2('test')").collect() [Row(UDF:javaStringLength2(test)=4)] + + >>> spark.udf.registerJavaFunction( + ... "javaStringLength3", "test.org.apache.spark.sql.JavaStringLength", "integer") + >>> spark.sql("SELECT javaStringLength3('test')").collect() + [Row(UDF:javaStringLength3(test)=4)] """ jdt = None if returnType is not None: + if not isinstance(returnType, DataType): + returnType = _parse_datatype_string(returnType) jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) From cf7ee1767ddadce08dce050fc3b40c77cdd187da Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 18 Jan 2018 10:19:36 -0800 Subject: [PATCH 132/774] [SPARK-23147][UI] Fix task page table IndexOutOfBound Exception ## What changes were proposed in this pull request? Stage's task page table will throw an exception when there's no complete tasks. Furthermore, because the `dataSize` doesn't take running tasks into account, so sometimes UI cannot show the running tasks. Besides table will only be displayed when first task is finished according to the default sortColumn("index"). ![screen shot 2018-01-18 at 8 50 08 pm](https://user-images.githubusercontent.com/850797/35100052-470b4cae-fc95-11e7-96a2-ad9636e732b3.png) To reproduce this issue, user could try `sc.parallelize(1 to 20, 20).map { i => Thread.sleep(10000); i }.collect()` or `sc.parallelize(1 to 20, 20).map { i => Thread.sleep((20 - i) * 1000); i }.collect` to reproduce the above issue. Here propose a solution to fix it. Not sure if it is a right fix, please help to review. ## How was this patch tested? Manual test. Author: jerryshao Closes #20315 from jerryshao/SPARK-23147. --- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 7c6e06cf183ba..af78373ddb4b2 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -676,7 +676,7 @@ private[ui] class TaskDataSource( private var _tasksToShow: Seq[TaskData] = null - override def dataSize: Int = stage.numCompleteTasks + stage.numFailedTasks + stage.numKilledTasks + override def dataSize: Int = stage.numTasks override def sliceData(from: Int, to: Int): Seq[TaskData] = { if (_tasksToShow == null) { From 9678941f54ebc5db935ed8d694e502086e2a31c0 Mon Sep 17 00:00:00 2001 From: Fernando Pereira Date: Thu, 18 Jan 2018 13:02:03 -0600 Subject: [PATCH 133/774] [SPARK-23029][DOCS] Specifying default units of configuration entries ## What changes were proposed in this pull request? This PR completes the docs, specifying the default units assumed in configuration entries of type size. This is crucial since unit-less values are accepted and the user might assume the base unit is bytes, which in most cases it is not, leading to hard-to-debug problems. ## How was this patch tested? This patch updates only documentation only. Author: Fernando Pereira Closes #20269 from ferdonline/docs_units. --- .../scala/org/apache/spark/SparkConf.scala | 6 +- .../spark/internal/config/package.scala | 47 ++++---- docs/configuration.md | 100 ++++++++++-------- 3 files changed, 85 insertions(+), 68 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index d77303e6fdf8b..f53b2bed74c6e 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -640,9 +640,9 @@ private[spark] object SparkConf extends Logging { translation = s => s"${s.toLong * 10}s")), "spark.reducer.maxSizeInFlight" -> Seq( AlternateConfig("spark.reducer.maxMbInFlight", "1.4")), - "spark.kryoserializer.buffer" -> - Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", - translation = s => s"${(s.toDouble * 1000).toInt}k")), + "spark.kryoserializer.buffer" -> Seq( + AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", + translation = s => s"${(s.toDouble * 1000).toInt}k")), "spark.kryoserializer.buffer.max" -> Seq( AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")), "spark.shuffle.file.buffer" -> Seq( diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index eb12ddf961314..bbfcfbaa7363c 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -38,10 +38,13 @@ package object config { ConfigBuilder("spark.driver.userClassPathFirst").booleanConf.createWithDefault(false) private[spark] val DRIVER_MEMORY = ConfigBuilder("spark.driver.memory") + .doc("Amount of memory to use for the driver process, in MiB unless otherwise specified.") .bytesConf(ByteUnit.MiB) .createWithDefaultString("1g") private[spark] val DRIVER_MEMORY_OVERHEAD = ConfigBuilder("spark.driver.memoryOverhead") + .doc("The amount of off-heap memory to be allocated per driver in cluster mode, " + + "in MiB unless otherwise specified.") .bytesConf(ByteUnit.MiB) .createOptional @@ -62,6 +65,7 @@ package object config { .createWithDefault(false) private[spark] val EVENT_LOG_OUTPUT_BUFFER_SIZE = ConfigBuilder("spark.eventLog.buffer.kb") + .doc("Buffer size to use when writing to output streams, in KiB unless otherwise specified.") .bytesConf(ByteUnit.KiB) .createWithDefaultString("100k") @@ -81,10 +85,13 @@ package object config { ConfigBuilder("spark.executor.userClassPathFirst").booleanConf.createWithDefault(false) private[spark] val EXECUTOR_MEMORY = ConfigBuilder("spark.executor.memory") + .doc("Amount of memory to use per executor process, in MiB unless otherwise specified.") .bytesConf(ByteUnit.MiB) .createWithDefaultString("1g") private[spark] val EXECUTOR_MEMORY_OVERHEAD = ConfigBuilder("spark.executor.memoryOverhead") + .doc("The amount of off-heap memory to be allocated per executor in cluster mode, " + + "in MiB unless otherwise specified.") .bytesConf(ByteUnit.MiB) .createOptional @@ -353,7 +360,7 @@ package object config { private[spark] val BUFFER_WRITE_CHUNK_SIZE = ConfigBuilder("spark.buffer.write.chunkSize") .internal() - .doc("The chunk size during writing out the bytes of ChunkedByteBuffer.") + .doc("The chunk size in bytes during writing out the bytes of ChunkedByteBuffer.") .bytesConf(ByteUnit.BYTE) .checkValue(_ <= Int.MaxValue, "The chunk size during writing out the bytes of" + " ChunkedByteBuffer should not larger than Int.MaxValue.") @@ -368,9 +375,9 @@ package object config { private[spark] val SHUFFLE_ACCURATE_BLOCK_THRESHOLD = ConfigBuilder("spark.shuffle.accurateBlockThreshold") - .doc("When we compress the size of shuffle blocks in HighlyCompressedMapStatus, we will " + - "record the size accurately if it's above this config. This helps to prevent OOM by " + - "avoiding underestimating shuffle block size when fetch shuffle blocks.") + .doc("Threshold in bytes above which the size of shuffle blocks in " + + "HighlyCompressedMapStatus is accurately recorded. This helps to prevent OOM " + + "by avoiding underestimating shuffle block size when fetch shuffle blocks.") .bytesConf(ByteUnit.BYTE) .createWithDefault(100 * 1024 * 1024) @@ -389,23 +396,23 @@ package object config { private[spark] val REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS = ConfigBuilder("spark.reducer.maxBlocksInFlightPerAddress") - .doc("This configuration limits the number of remote blocks being fetched per reduce task" + - " from a given host port. When a large number of blocks are being requested from a given" + - " address in a single fetch or simultaneously, this could crash the serving executor or" + - " Node Manager. This is especially useful to reduce the load on the Node Manager when" + - " external shuffle is enabled. You can mitigate the issue by setting it to a lower value.") + .doc("This configuration limits the number of remote blocks being fetched per reduce task " + + "from a given host port. When a large number of blocks are being requested from a given " + + "address in a single fetch or simultaneously, this could crash the serving executor or " + + "Node Manager. This is especially useful to reduce the load on the Node Manager when " + + "external shuffle is enabled. You can mitigate the issue by setting it to a lower value.") .intConf .checkValue(_ > 0, "The max no. of blocks in flight cannot be non-positive.") .createWithDefault(Int.MaxValue) private[spark] val MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM = ConfigBuilder("spark.maxRemoteBlockSizeFetchToMem") - .doc("Remote block will be fetched to disk when size of the block is " + - "above this threshold. This is to avoid a giant request takes too much memory. We can " + - "enable this config by setting a specific value(e.g. 200m). Note this configuration will " + - "affect both shuffle fetch and block manager remote block fetch. For users who " + - "enabled external shuffle service, this feature can only be worked when external shuffle" + - " service is newer than Spark 2.2.") + .doc("Remote block will be fetched to disk when size of the block is above this threshold " + + "in bytes. This is to avoid a giant request takes too much memory. We can enable this " + + "config by setting a specific value(e.g. 200m). Note this configuration will affect " + + "both shuffle fetch and block manager remote block fetch. For users who enabled " + + "external shuffle service, this feature can only be worked when external shuffle" + + "service is newer than Spark 2.2.") .bytesConf(ByteUnit.BYTE) .createWithDefault(Long.MaxValue) @@ -419,9 +426,9 @@ package object config { private[spark] val SHUFFLE_FILE_BUFFER_SIZE = ConfigBuilder("spark.shuffle.file.buffer") - .doc("Size of the in-memory buffer for each shuffle file output stream. " + - "These buffers reduce the number of disk seeks and system calls made " + - "in creating intermediate shuffle files.") + .doc("Size of the in-memory buffer for each shuffle file output stream, in KiB unless " + + "otherwise specified. These buffers reduce the number of disk seeks and system calls " + + "made in creating intermediate shuffle files.") .bytesConf(ByteUnit.KiB) .checkValue(v => v > 0 && v <= Int.MaxValue / 1024, s"The file buffer size must be greater than 0 and less than ${Int.MaxValue / 1024}.") @@ -430,7 +437,7 @@ package object config { private[spark] val SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE = ConfigBuilder("spark.shuffle.unsafe.file.output.buffer") .doc("The file system for this buffer size after each partition " + - "is written in unsafe shuffle writer.") + "is written in unsafe shuffle writer. In KiB unless otherwise specified.") .bytesConf(ByteUnit.KiB) .checkValue(v => v > 0 && v <= Int.MaxValue / 1024, s"The buffer size must be greater than 0 and less than ${Int.MaxValue / 1024}.") @@ -438,7 +445,7 @@ package object config { private[spark] val SHUFFLE_DISK_WRITE_BUFFER_SIZE = ConfigBuilder("spark.shuffle.spill.diskWriteBufferSize") - .doc("The buffer size to use when writing the sorted records to an on-disk file.") + .doc("The buffer size, in bytes, to use when writing the sorted records to an on-disk file.") .bytesConf(ByteUnit.BYTE) .checkValue(v => v > 0 && v <= Int.MaxValue, s"The buffer size must be greater than 0 and less than ${Int.MaxValue}.") diff --git a/docs/configuration.md b/docs/configuration.md index 1189aea2aa71f..eecb39dcafc9e 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -58,6 +58,10 @@ The following format is accepted: 1t or 1tb (tebibytes = 1024 gibibytes) 1p or 1pb (pebibytes = 1024 tebibytes) +While numbers without units are generally interpreted as bytes, a few are interpreted as KiB or MiB. +See documentation of individual configuration properties. Specifying units is desirable where +possible. + ## Dynamically Loading Spark Properties In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For @@ -136,9 +140,9 @@ of the most common options to set are: spark.driver.maxResultSize 1g - Limit of total size of serialized results of all partitions for each Spark action (e.g. collect). - Should be at least 1M, or 0 for unlimited. Jobs will be aborted if the total size - is above this limit. + Limit of total size of serialized results of all partitions for each Spark action (e.g. + collect) in bytes. Should be at least 1M, or 0 for unlimited. Jobs will be aborted if the total + size is above this limit. Having a high limit may cause out-of-memory errors in driver (depends on spark.driver.memory and memory overhead of objects in JVM). Setting a proper limit can protect the driver from out-of-memory errors. @@ -148,10 +152,10 @@ of the most common options to set are: spark.driver.memory 1g - Amount of memory to use for the driver process, i.e. where SparkContext is initialized. - (e.g. 1g, 2g). - -
    Note: In client mode, this config must not be set through the SparkConf + Amount of memory to use for the driver process, i.e. where SparkContext is initialized, in MiB + unless otherwise specified (e.g. 1g, 2g). +
    + Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. Instead, please set this through the --driver-memory command line option or in your default properties file. @@ -161,27 +165,28 @@ of the most common options to set are: spark.driver.memoryOverhead driverMemory * 0.10, with minimum of 384 - The amount of off-heap memory (in megabytes) to be allocated per driver in cluster mode. This is - memory that accounts for things like VM overheads, interned strings, other native overheads, etc. - This tends to grow with the container size (typically 6-10%). This option is currently supported - on YARN and Kubernetes. + The amount of off-heap memory to be allocated per driver in cluster mode, in MiB unless + otherwise specified. This is memory that accounts for things like VM overheads, interned strings, + other native overheads, etc. This tends to grow with the container size (typically 6-10%). + This option is currently supported on YARN and Kubernetes. spark.executor.memory 1g - Amount of memory to use per executor process (e.g. 2g, 8g). + Amount of memory to use per executor process, in MiB unless otherwise specified. + (e.g. 2g, 8g). spark.executor.memoryOverhead executorMemory * 0.10, with minimum of 384 - The amount of off-heap memory (in megabytes) to be allocated per executor. This is memory that - accounts for things like VM overheads, interned strings, other native overheads, etc. This tends - to grow with the executor size (typically 6-10%). This option is currently supported on YARN and - Kubernetes. + The amount of off-heap memory to be allocated per executor, in MiB unless otherwise specified. + This is memory that accounts for things like VM overheads, interned strings, other native + overheads, etc. This tends to grow with the executor size (typically 6-10%). + This option is currently supported on YARN and Kubernetes. @@ -431,8 +436,9 @@ Apart from these, the following properties are also available, and may be useful 512m Amount of memory to use per python worker process during aggregation, in the same - format as JVM memory strings (e.g. 512m, 2g). If the memory - used during aggregation goes above this amount, it will spill the data into disks. + format as JVM memory strings with a size unit suffix ("k", "m", "g" or "t") + (e.g. 512m, 2g). + If the memory used during aggregation goes above this amount, it will spill the data into disks. @@ -540,9 +546,10 @@ Apart from these, the following properties are also available, and may be useful spark.reducer.maxSizeInFlight 48m - Maximum size of map outputs to fetch simultaneously from each reduce task. Since - each output requires us to create a buffer to receive it, this represents a fixed memory - overhead per reduce task, so keep it small unless you have a large amount of memory. + Maximum size of map outputs to fetch simultaneously from each reduce task, in MiB unless + otherwise specified. Since each output requires us to create a buffer to receive it, this + represents a fixed memory overhead per reduce task, so keep it small unless you have a + large amount of memory. @@ -570,9 +577,9 @@ Apart from these, the following properties are also available, and may be useful spark.maxRemoteBlockSizeFetchToMem Long.MaxValue - The remote block will be fetched to disk when size of the block is above this threshold. + The remote block will be fetched to disk when size of the block is above this threshold in bytes. This is to avoid a giant request takes too much memory. We can enable this config by setting - a specific value(e.g. 200m). Note this configuration will affect both shuffle fetch + a specific value(e.g. 200m). Note this configuration will affect both shuffle fetch and block manager remote block fetch. For users who enabled external shuffle service, this feature can only be worked when external shuffle service is newer than Spark 2.2. @@ -589,8 +596,9 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.file.buffer 32k - Size of the in-memory buffer for each shuffle file output stream. These buffers - reduce the number of disk seeks and system calls made in creating intermediate shuffle files. + Size of the in-memory buffer for each shuffle file output stream, in KiB unless otherwise + specified. These buffers reduce the number of disk seeks and system calls made in creating + intermediate shuffle files. @@ -651,7 +659,7 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.service.index.cache.size 100m - Cache entries limited to the specified memory footprint. + Cache entries limited to the specified memory footprint in bytes. @@ -685,9 +693,9 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.accurateBlockThreshold 100 * 1024 * 1024 - When we compress the size of shuffle blocks in HighlyCompressedMapStatus, we will record the - size accurately if it's above this config. This helps to prevent OOM by avoiding - underestimating shuffle block size when fetch shuffle blocks. + Threshold in bytes above which the size of shuffle blocks in HighlyCompressedMapStatus is + accurately recorded. This helps to prevent OOM by avoiding underestimating shuffle + block size when fetch shuffle blocks. @@ -779,7 +787,7 @@ Apart from these, the following properties are also available, and may be useful spark.eventLog.buffer.kb 100k - Buffer size in KB to use when writing to output streams. + Buffer size to use when writing to output streams, in KiB unless otherwise specified. @@ -917,7 +925,7 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.lz4.blockSize 32k - Block size used in LZ4 compression, in the case when LZ4 compression codec + Block size in bytes used in LZ4 compression, in the case when LZ4 compression codec is used. Lowering this block size will also lower shuffle memory usage when LZ4 is used. @@ -925,7 +933,7 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.snappy.blockSize 32k - Block size used in Snappy compression, in the case when Snappy compression codec + Block size in bytes used in Snappy compression, in the case when Snappy compression codec is used. Lowering this block size will also lower shuffle memory usage when Snappy is used. @@ -941,7 +949,7 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.zstd.bufferSize 32k - Buffer size used in Zstd compression, in the case when Zstd compression codec + Buffer size in bytes used in Zstd compression, in the case when Zstd compression codec is used. Lowering this size will lower the shuffle memory usage when Zstd is used, but it might increase the compression cost because of excessive JNI call overhead. @@ -1001,8 +1009,8 @@ Apart from these, the following properties are also available, and may be useful spark.kryoserializer.buffer.max 64m - Maximum allowable size of Kryo serialization buffer. This must be larger than any - object you attempt to serialize and must be less than 2048m. + Maximum allowable size of Kryo serialization buffer, in MiB unless otherwise specified. + This must be larger than any object you attempt to serialize and must be less than 2048m. Increase this if you get a "buffer limit exceeded" exception inside Kryo. @@ -1010,9 +1018,9 @@ Apart from these, the following properties are also available, and may be useful spark.kryoserializer.buffer 64k - Initial size of Kryo's serialization buffer. Note that there will be one buffer - per core on each worker. This buffer will grow up to - spark.kryoserializer.buffer.max if needed. + Initial size of Kryo's serialization buffer, in KiB unless otherwise specified. + Note that there will be one buffer per core on each worker. This buffer will grow up to + spark.kryoserializer.buffer.max if needed. @@ -1086,7 +1094,8 @@ Apart from these, the following properties are also available, and may be useful spark.memory.offHeap.enabled false - If true, Spark will attempt to use off-heap memory for certain operations. If off-heap memory use is enabled, then spark.memory.offHeap.size must be positive. + If true, Spark will attempt to use off-heap memory for certain operations. If off-heap memory + use is enabled, then spark.memory.offHeap.size must be positive. @@ -1094,7 +1103,8 @@ Apart from these, the following properties are also available, and may be useful 0 The absolute amount of memory in bytes which can be used for off-heap allocation. - This setting has no impact on heap memory usage, so if your executors' total memory consumption must fit within some hard limit then be sure to shrink your JVM heap size accordingly. + This setting has no impact on heap memory usage, so if your executors' total memory consumption + must fit within some hard limit then be sure to shrink your JVM heap size accordingly. This must be set to a positive value when spark.memory.offHeap.enabled=true. @@ -1202,9 +1212,9 @@ Apart from these, the following properties are also available, and may be useful spark.broadcast.blockSize 4m - Size of each piece of a block for TorrentBroadcastFactory. - Too large a value decreases parallelism during broadcast (makes it slower); however, if it is - too small, BlockManager might take a performance hit. + Size of each piece of a block for TorrentBroadcastFactory, in KiB unless otherwise + specified. Too large a value decreases parallelism during broadcast (makes it slower); however, + if it is too small, BlockManager might take a performance hit. @@ -1312,7 +1322,7 @@ Apart from these, the following properties are also available, and may be useful spark.storage.memoryMapThreshold 2m - Size of a block above which Spark memory maps when reading a block from disk. + Size in bytes of a block above which Spark memory maps when reading a block from disk. This prevents Spark from memory mapping very small blocks. In general, memory mapping has high overhead for blocks close to or below the page size of the operating system. @@ -2490,4 +2500,4 @@ Also, you can modify or add configurations at runtime: --conf "spark.executor.extraJavaOptions=-XX:+PrintGCDetails -XX:+PrintGCTimeStamps" \ --conf spark.hadoop.abc.def=xyz \ myApp.jar -{% endhighlight %} \ No newline at end of file +{% endhighlight %} From 2d41f040a34d6483919fd5d491cf90eee5429290 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 18 Jan 2018 12:25:52 -0800 Subject: [PATCH 134/774] [SPARK-23143][SS][PYTHON] Added python API for setting continuous trigger ## What changes were proposed in this pull request? Self-explanatory. ## How was this patch tested? New python tests. Author: Tathagata Das Closes #20309 from tdas/SPARK-23143. --- python/pyspark/sql/streaming.py | 23 +++++++++++++++++++---- python/pyspark/sql/tests.py | 6 ++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 24ae3776a217b..e2a97acb5e2a7 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -786,7 +786,7 @@ def queryName(self, queryName): @keyword_only @since(2.0) - def trigger(self, processingTime=None, once=None): + def trigger(self, processingTime=None, once=None, continuous=None): """Set the trigger for the stream query. If this is not set it will run the query as fast as possible, which is equivalent to setting the trigger to ``processingTime='0 seconds'``. @@ -802,23 +802,38 @@ def trigger(self, processingTime=None, once=None): >>> writer = sdf.writeStream.trigger(processingTime='5 seconds') >>> # trigger the query for just once batch of data >>> writer = sdf.writeStream.trigger(once=True) + >>> # trigger the query for execution every 5 seconds + >>> writer = sdf.writeStream.trigger(continuous='5 seconds') """ + params = [processingTime, once, continuous] + + if params.count(None) == 3: + raise ValueError('No trigger provided') + elif params.count(None) < 2: + raise ValueError('Multiple triggers not allowed.') + jTrigger = None if processingTime is not None: - if once is not None: - raise ValueError('Multiple triggers not allowed.') if type(processingTime) != str or len(processingTime.strip()) == 0: raise ValueError('Value for processingTime must be a non empty string. Got: %s' % processingTime) interval = processingTime.strip() jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.ProcessingTime( interval) + elif once is not None: if once is not True: raise ValueError('Value for once must be True. Got: %s' % once) jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.Once() + else: - raise ValueError('No trigger provided') + if type(continuous) != str or len(continuous.strip()) == 0: + raise ValueError('Value for continuous must be a non empty string. Got: %s' % + continuous) + interval = continuous.strip() + jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.Continuous( + interval) + self._jwrite = self._jwrite.trigger(jTrigger) return self diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f84aa3d68b808..25483594f2725 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1538,6 +1538,12 @@ def test_stream_trigger(self): except ValueError: pass + # Should not take multiple args + try: + df.writeStream.trigger(processingTime='5 seconds', continuous='1 second') + except ValueError: + pass + # Should take only keyword args try: df.writeStream.trigger('5 seconds') From bf34d665b9c865e00fac7001500bf6d521c2dff9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 18 Jan 2018 12:33:39 -0800 Subject: [PATCH 135/774] [SPARK-23144][SS] Added console sink for continuous processing ## What changes were proposed in this pull request? Refactored ConsoleWriter into ConsoleMicrobatchWriter and ConsoleContinuousWriter. ## How was this patch tested? new unit test Author: Tathagata Das Closes #20311 from tdas/SPARK-23144. --- .../sql/execution/streaming/console.scala | 20 +++-- .../streaming/sources/ConsoleWriter.scala | 80 ++++++++++++++----- .../sources/ConsoleWriterSuite.scala | 26 +++++- 3 files changed, 96 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 94820376ff7e7..f2aa3259731d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql.execution.streaming import java.util.Optional -import scala.collection.JavaConverters._ - import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter +import org.apache.spark.sql.execution.streaming.sources.{ConsoleContinuousWriter, ConsoleMicroBatchWriter, ConsoleWriter} import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} -import org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport +import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -37,16 +36,25 @@ case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) class ConsoleSinkProvider extends DataSourceV2 with MicroBatchWriteSupport + with ContinuousWriteSupport with DataSourceRegister with CreatableRelationProvider { override def createMicroBatchWriter( queryId: String, - epochId: Long, + batchId: Long, schema: StructType, mode: OutputMode, options: DataSourceV2Options): Optional[DataSourceV2Writer] = { - Optional.of(new ConsoleWriter(epochId, schema, options)) + Optional.of(new ConsoleMicroBatchWriter(batchId, schema, options)) + } + + override def createContinuousWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[ContinuousWriter] = { + Optional.of(new ConsoleContinuousWriter(schema, options)) } def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index 361979984bbec..6fb61dff60045 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -20,45 +20,85 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.sources.v2.writer.{DataSourceV2Writer, DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.types.StructType -/** - * A [[DataSourceV2Writer]] that collects results to the driver and prints them in the console. - * Generated by [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. - * - * This sink should not be used for production, as it requires sending all rows to the driver - * and does not support recovery. - */ -class ConsoleWriter(batchId: Long, schema: StructType, options: DataSourceV2Options) - extends DataSourceV2Writer with Logging { +/** Common methods used to create writes for the the console sink */ +trait ConsoleWriter extends Logging { + + def options: DataSourceV2Options + // Number of rows to display, by default 20 rows - private val numRowsToShow = options.getInt("numRows", 20) + protected val numRowsToShow = options.getInt("numRows", 20) // Truncate the displayed data if it is too long, by default it is true - private val isTruncated = options.getBoolean("truncate", true) + protected val isTruncated = options.getBoolean("truncate", true) assert(SparkSession.getActiveSession.isDefined) - private val spark = SparkSession.getActiveSession.get + protected val spark = SparkSession.getActiveSession.get + + def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory - override def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory + def abort(messages: Array[WriterCommitMessage]): Unit = {} - override def commit(messages: Array[WriterCommitMessage]): Unit = synchronized { - val batch = messages.collect { + protected def printRows( + commitMessages: Array[WriterCommitMessage], + schema: StructType, + printMessage: String): Unit = { + val rows = commitMessages.collect { case PackedRowCommitMessage(rows) => rows }.flatten // scalastyle:off println println("-------------------------------------------") - println(s"Batch: $batchId") + println(printMessage) println("-------------------------------------------") // scalastyle:off println - spark.createDataFrame( - spark.sparkContext.parallelize(batch), schema) + spark + .createDataFrame(spark.sparkContext.parallelize(rows), schema) .show(numRowsToShow, isTruncated) } +} + + +/** + * A [[DataSourceV2Writer]] that collects results from a micro-batch query to the driver and + * prints them in the console. Created by + * [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. + * + * This sink should not be used for production, as it requires sending all rows to the driver + * and does not support recovery. + */ +class ConsoleMicroBatchWriter(batchId: Long, schema: StructType, val options: DataSourceV2Options) + extends DataSourceV2Writer with ConsoleWriter { + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + printRows(messages, schema, s"Batch: $batchId") + } + + override def toString(): String = { + s"ConsoleMicroBatchWriter[numRows=$numRowsToShow, truncate=$isTruncated]" + } +} - override def abort(messages: Array[WriterCommitMessage]): Unit = {} - override def toString(): String = s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]" +/** + * A [[DataSourceV2Writer]] that collects results from a continuous query to the driver and + * prints them in the console. Created by + * [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. + * + * This sink should not be used for production, as it requires sending all rows to the driver + * and does not support recovery. + */ +class ConsoleContinuousWriter(schema: StructType, val options: DataSourceV2Options) + extends ContinuousWriter with ConsoleWriter { + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + printRows(messages, schema, s"Continuous processing epoch $epochId") + } + + override def toString(): String = { + s"ConsoleContinuousWriter[numRows=$numRowsToShow, truncate=$isTruncated]" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala index 60ffee9b9b42c..55acf2ba28d2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.execution.streaming.sources import java.io.ByteArrayOutputStream +import org.scalatest.time.SpanSugar._ + import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.streaming.{StreamTest, Trigger} class ConsoleWriterSuite extends StreamTest { import testImplicits._ - test("console") { + test("microbatch - default") { val input = MemoryStream[Int] val captured = new ByteArrayOutputStream() @@ -77,7 +79,7 @@ class ConsoleWriterSuite extends StreamTest { |""".stripMargin) } - test("console with numRows") { + test("microbatch - with numRows") { val input = MemoryStream[Int] val captured = new ByteArrayOutputStream() @@ -106,7 +108,7 @@ class ConsoleWriterSuite extends StreamTest { |""".stripMargin) } - test("console with truncation") { + test("microbatch - truncation") { val input = MemoryStream[String] val captured = new ByteArrayOutputStream() @@ -132,4 +134,20 @@ class ConsoleWriterSuite extends StreamTest { | |""".stripMargin) } + + test("continuous - default") { + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val input = spark.readStream + .format("rate") + .option("numPartitions", "1") + .option("rowsPerSecond", "5") + .load() + .select('value) + + val query = input.writeStream.format("console").trigger(Trigger.Continuous(200)).start() + assert(query.isActive) + query.stop() + } + } } From f568e9cf76f657d094f1d036ab5a95f2531f5761 Mon Sep 17 00:00:00 2001 From: Andrew Korzhuev Date: Thu, 18 Jan 2018 14:00:12 -0800 Subject: [PATCH 136/774] [SPARK-23133][K8S] Fix passing java options to Executor Pass through spark java options to the executor in context of docker image. Closes #20296 andrusha: Deployed two version of containers to local k8s, checked that java options were present in the updated image on the running executor. Manual test Author: Andrew Korzhuev Closes #20322 from foxish/patch-1. --- .../docker/src/main/dockerfiles/spark/entrypoint.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 0c28c75857871..b9090dc2852a5 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -42,7 +42,7 @@ shift 1 SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt -readarray -t SPARK_DRIVER_JAVA_OPTS < /tmp/java_opts.txt +readarray -t SPARK_JAVA_OPTS < /tmp/java_opts.txt if [ -n "$SPARK_MOUNTED_CLASSPATH" ]; then SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_MOUNTED_CLASSPATH" fi @@ -54,7 +54,7 @@ case "$SPARK_K8S_CMD" in driver) CMD=( ${JAVA_HOME}/bin/java - "${SPARK_DRIVER_JAVA_OPTS[@]}" + "${SPARK_JAVA_OPTS[@]}" -cp "$SPARK_CLASSPATH" -Xms$SPARK_DRIVER_MEMORY -Xmx$SPARK_DRIVER_MEMORY @@ -67,7 +67,7 @@ case "$SPARK_K8S_CMD" in executor) CMD=( ${JAVA_HOME}/bin/java - "${SPARK_EXECUTOR_JAVA_OPTS[@]}" + "${SPARK_JAVA_OPTS[@]}" -Xms$SPARK_EXECUTOR_MEMORY -Xmx$SPARK_EXECUTOR_MEMORY -cp "$SPARK_CLASSPATH" From e01919e834d301e13adc8919932796ebae900576 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 19 Jan 2018 07:36:06 +0900 Subject: [PATCH 137/774] [SPARK-23094] Fix invalid character handling in JsonDataSource ## What changes were proposed in this pull request? There were two related fixes regarding `from_json`, `get_json_object` and `json_tuple` ([Fix #1](https://github.com/apache/spark/commit/c8803c06854683c8761fdb3c0e4c55d5a9e22a95), [Fix #2](https://github.com/apache/spark/commit/86174ea89b39a300caaba6baffac70f3dc702788)), but they weren't comprehensive it seems. I wanted to extend those fixes to all the parsers, and add tests for each case. ## How was this patch tested? Regression tests Author: Burak Yavuz Closes #20302 from brkyvz/json-invfix. --- .../catalyst/json/CreateJacksonParser.scala | 5 +-- .../sources/JsonHadoopFsRelationSuite.scala | 34 +++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala index 025a388aacaa5..b1672e7e2fca2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala @@ -40,10 +40,11 @@ private[sql] object CreateJacksonParser extends Serializable { } def text(jsonFactory: JsonFactory, record: Text): JsonParser = { - jsonFactory.createParser(record.getBytes, 0, record.getLength) + val bain = new ByteArrayInputStream(record.getBytes, 0, record.getLength) + jsonFactory.createParser(new InputStreamReader(bain, "UTF-8")) } def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = { - jsonFactory.createParser(record) + jsonFactory.createParser(new InputStreamReader(record, "UTF-8")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala index 49be30435ad2f..27f398ebf301a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.types._ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = "json" + private val badJson = "\u0000\u0000\u0000A\u0001AAA" + // JSON does not write data of NullType and does not play well with BinaryType. override protected def supportsDataType(dataType: DataType): Boolean = dataType match { case _: NullType => false @@ -105,4 +107,36 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { ) } } + + test("invalid json with leading nulls - from file (multiLine=true)") { + import testImplicits._ + withTempDir { tempDir => + val path = tempDir.getAbsolutePath + Seq(badJson, """{"a":1}""").toDS().write.mode("overwrite").text(path) + val expected = s"""$badJson\n{"a":1}\n""" + val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) + val df = + spark.read.format(dataSourceName).option("multiLine", true).schema(schema).load(path) + checkAnswer(df, Row(null, expected)) + } + } + + test("invalid json with leading nulls - from file (multiLine=false)") { + import testImplicits._ + withTempDir { tempDir => + val path = tempDir.getAbsolutePath + Seq(badJson, """{"a":1}""").toDS().write.mode("overwrite").text(path) + val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) + val df = + spark.read.format(dataSourceName).option("multiLine", false).schema(schema).load(path) + checkAnswer(df, Seq(Row(1, null), Row(null, badJson))) + } + } + + test("invalid json with leading nulls - from dataset") { + import testImplicits._ + checkAnswer( + spark.read.json(Seq(badJson).toDS()), + Row(badJson)) + } } From 5d7c4ba4d73a72f26d591108db3c20b4a6c84f3f Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Thu, 18 Jan 2018 14:44:22 -0800 Subject: [PATCH 138/774] [SPARK-22962][K8S] Fail fast if submission client local files are used ## What changes were proposed in this pull request? In the Kubernetes mode, fails fast in the submission process if any submission client local dependencies are used as the use case is not supported yet. ## How was this patch tested? Unit tests, integration tests, and manual tests. vanzin foxish Author: Yinan Li Closes #20320 from liyinan926/master. --- docs/running-on-kubernetes.md | 5 ++- .../k8s/submit/DriverConfigOrchestrator.scala | 14 ++++++++- .../DriverConfigOrchestratorSuite.scala | 31 ++++++++++++++++++- 3 files changed, 47 insertions(+), 3 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 08ec34c63ba3f..d6b1735ce5550 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -117,7 +117,10 @@ This URI is the location of the example jar that is already in the Docker image. If your application's dependencies are all hosted in remote locations like HDFS or HTTP servers, they may be referred to by their appropriate remote URIs. Also, application dependencies can be pre-mounted into custom-built Docker images. Those dependencies can be added to the classpath by referencing them with `local://` URIs and/or setting the -`SPARK_EXTRA_CLASSPATH` environment variable in your Dockerfiles. +`SPARK_EXTRA_CLASSPATH` environment variable in your Dockerfiles. The `local://` scheme is also required when referring to +dependencies in custom-built Docker images in `spark-submit`. Note that using application dependencies from the submission +client's local file system is currently not yet supported. + ### Using Remote Dependencies When there are application dependencies hosted in remote locations like HDFS or HTTP servers, the driver and executor pods diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala index c9cc300d65569..ae70904621184 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala @@ -20,7 +20,7 @@ import java.util.UUID import com.google.common.primitives.Longs -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ @@ -117,6 +117,12 @@ private[spark] class DriverConfigOrchestrator( .map(_.split(",")) .getOrElse(Array.empty[String]) + // TODO(SPARK-23153): remove once submission client local dependencies are supported. + if (existSubmissionLocalFiles(sparkJars) || existSubmissionLocalFiles(sparkFiles)) { + throw new SparkException("The Kubernetes mode does not yet support referencing application " + + "dependencies in the local file system.") + } + val dependencyResolutionStep = if (sparkJars.nonEmpty || sparkFiles.nonEmpty) { Seq(new DependencyResolutionStep( sparkJars, @@ -162,6 +168,12 @@ private[spark] class DriverConfigOrchestrator( initContainerBootstrapStep } + private def existSubmissionLocalFiles(files: Seq[String]): Boolean = { + files.exists { uri => + Utils.resolveURI(uri).getScheme == "file" + } + } + private def existNonContainerLocalFiles(files: Seq[String]): Boolean = { files.exists { uri => Utils.resolveURI(uri).getScheme != "local" diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala index 65274c6f50e01..033d303e946fd 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.deploy.k8s.submit -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.submit.steps._ @@ -117,6 +117,35 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { classOf[DriverMountSecretsStep]) } + test("Submission using client local dependencies") { + val sparkConf = new SparkConf(false) + .set(CONTAINER_IMAGE, DRIVER_IMAGE) + var orchestrator = new DriverConfigOrchestrator( + APP_ID, + LAUNCH_TIME, + Some(JavaMainAppResource("file:///var/apps/jars/main.jar")), + APP_NAME, + MAIN_CLASS, + APP_ARGS, + sparkConf) + assertThrows[SparkException] { + orchestrator.getAllConfigurationSteps + } + + sparkConf.set("spark.files", "/path/to/file1,/path/to/file2") + orchestrator = new DriverConfigOrchestrator( + APP_ID, + LAUNCH_TIME, + Some(JavaMainAppResource("local:///var/apps/jars/main.jar")), + APP_NAME, + MAIN_CLASS, + APP_ARGS, + sparkConf) + assertThrows[SparkException] { + orchestrator.getAllConfigurationSteps + } + } + private def validateStepTypes( orchestrator: DriverConfigOrchestrator, types: Class[_ <: DriverConfigurationStep]*): Unit = { From 4cd2ecc0c7222fef1337e04f1948333296c3be86 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 18 Jan 2018 16:29:45 -0800 Subject: [PATCH 139/774] [SPARK-23142][SS][DOCS] Added docs for continuous processing ## What changes were proposed in this pull request? Added documentation for continuous processing. Modified two locations. - Modified the overview to have a mention of Continuous Processing. - Added a new section on Continuous Processing at the end. ![image](https://user-images.githubusercontent.com/663212/35083551-a3dd23f6-fbd4-11e7-9e7e-90866f131ca9.png) ![image](https://user-images.githubusercontent.com/663212/35083618-d844027c-fbd4-11e7-9fde-75992cc517bd.png) ## How was this patch tested? N/A Author: Tathagata Das Closes #20308 from tdas/SPARK-23142. --- .../structured-streaming-programming-guide.md | 98 ++++++++++++++++++- 1 file changed, 97 insertions(+), 1 deletion(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 1779a4215e085..2ddba2f0d942e 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -10,7 +10,9 @@ title: Structured Streaming Programming Guide # Overview Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* -In this guide, we are going to walk you through the programming model and the APIs. First, let's start with a simple example - a streaming word count. +Internally, by default, Structured Streaming queries are processed using a *micro-batch processing* engine, which processes data streams as a series of small batch jobs thereby achieving end-to-end latencies as low as 100 milliseconds and exactly-once fault-tolerance guarantees. However, since Spark 2.3, we have introduced a new low-latency processing mode called **Continuous Processing**, which can achieve end-to-end latencies as low as 1 millisecond with at-least-once guarantees. Without changing the Dataset/DataFrame operations in your queries, you will be able choose the mode based on your application requirements. + +In this guide, we are going to walk you through the programming model and the APIs. We are going to explain the concepts mostly using the default micro-batch processing model, and then [later](#continuous-processing-experimental) discuss Continuous Processing model. First, let's start with a simple example of a Structured Streaming query - a streaming word count. # Quick Example Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in @@ -2434,6 +2436,100 @@ write.stream(aggDF, "memory", outputMode = "complete", checkpointLocation = "pat
+# Continuous Processing [Experimental] +**Continuous processing** is a new, experimental streaming execution mode introduced in Spark 2.3 that enables low (~1 ms) end-to-end latency with at-least-once fault-tolerance guarantees. Compare this with the default *micro-batch processing* engine which can achieve exactly-once guarantees but achieve latencies of ~100ms at best. For some types of queries (discussed below), you can choose which mode to execute them in without modifying the application logic (i.e. without changing the DataFrame/Dataset operations). + +To run a supported query in continuous processing mode, all you need to do is specify a **continuous trigger** with the desired checkpoint interval as a parameter. For example, + +
+
+{% highlight scala %} +import org.apache.spark.sql.streaming.Trigger + +spark + .readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("") + +spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .trigger(Trigger.Continuous("1 second")) // only change in query + .start() +{% endhighlight %} +
+
+{% highlight java %} +import org.apache.spark.sql.streaming.Trigger; + +spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .trigger(Trigger.Continuous("1 second")) // only change in query + .start(); +{% endhighlight %} +
+
+{% highlight python %} +spark \ + .readStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("subscribe", "topic1") \ + .load() \ + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") \ + .writeStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("topic", "topic1") \ + .trigger(continuous="1 second") \ # only change in query + .start() + +{% endhighlight %} +
+
+ +A checkpoint interval of 1 second means that the continuous processing engine will records the progress of the query every second. The resulting checkpoints are in a format compatible with the micro-batch engine, hence any query can be restarted with any trigger. For example, a supported query started with the micro-batch mode can be restarted in continuous mode, and vice versa. Note that any time you switch to continuous mode, you will get at-least-once fault-tolerance guarantees. + +## Supported Queries +As of Spark 2.3, only the following type of queries are supported in the continuous processing mode. + +- *Operations*: Only map-like Dataset/DataFrame operations are supported in continuous mode, that is, only projections (`select`, `map`, `flatMap`, `mapPartitions`, etc.) and selections (`where`, `filter`, etc.). + + All SQL functions are supported except aggregation functions (since aggregations are not yet supported), `current_timestamp()` and `current_date()` (deterministic computations using time is challenging). + +- *Sources*: + + Kafka source: All options are supported. + + Rate source: Good for testing. Only options that are supported in the continuous mode are `numPartitions` and `rowsPerSecond`. + +- *Sinks*: + + Kafka sink: All options are supported. + + Memory sink: Good for debugging. + + Console sink: Good for debugging. All options are supported. Note that the console will print every checkpoint interval that you have specified in the continuous trigger. + +See [Input Sources](#input-sources) and [Output Sinks](#output-sinks) sections for more details on them. While the console sink is good for testing, the end-to-end low-latency processing can be best observed with Kafka as the source and sink, as this allows the engine to process the data and make the results available in the output topic within milliseconds of the input data being available in the input topic. + +## Caveats +- Continuous processing engine launches multiple long-running tasks that continuously read data from sources, process it and continuously write to sinks. The number of tasks required by the query depends on how many partitions the query can read from the sources in parallel. Therefore, before starting a continuous processing query, you must ensure there are enough cores in the cluster to all the tasks in parallel. For example, if you are reading from a Kafka topic that has 10 partitions, then the cluster must have at least 10 cores for the query to make progress. +- Stopping a continuous processing stream may produce spurious task termination warnings. These can be safely ignored. +- There are currently no automatic retries of failed tasks. Any failure will lead to the query being stopped and it needs to be manually restarted from the checkpoint. + # Additional Information **Further Reading** From 6121e91b7f5c9513d68674e4d5edbc3a4a5fd5fd Mon Sep 17 00:00:00 2001 From: brandonJY Date: Thu, 18 Jan 2018 18:57:49 -0600 Subject: [PATCH 140/774] [DOCS] change to dataset for java code in structured-streaming-kafka-integration document ## What changes were proposed in this pull request? In latest structured-streaming-kafka-integration document, Java code example for Kafka integration is using `DataFrame`, shouldn't it be changed to `DataSet`? ## How was this patch tested? manual test has been performed to test the updated example Java code in Spark 2.2.1 with Kafka 1.0 Author: brandonJY Closes #20312 from brandonJY/patch-2. --- docs/structured-streaming-kafka-integration.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index bab0be8ddeb9f..461c29ce1ba89 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -61,7 +61,7 @@ df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% highlight java %} // Subscribe to 1 topic -DataFrame df = spark +Dataset df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -70,7 +70,7 @@ DataFrame df = spark df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") // Subscribe to multiple topics -DataFrame df = spark +Dataset df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -79,7 +79,7 @@ DataFrame df = spark df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") // Subscribe to a pattern -DataFrame df = spark +Dataset df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -171,7 +171,7 @@ df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% highlight java %} // Subscribe to 1 topic defaults to the earliest and latest offsets -DataFrame df = spark +Dataset df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -180,7 +180,7 @@ DataFrame df = spark df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); // Subscribe to multiple topics, specifying explicit Kafka offsets -DataFrame df = spark +Dataset df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -191,7 +191,7 @@ DataFrame df = spark df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); // Subscribe to a pattern, at the earliest and latest offsets -DataFrame df = spark +Dataset df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") From 568055da93049c207bb830f244ff9b60c638837c Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 19 Jan 2018 11:37:08 +0800 Subject: [PATCH 141/774] [SPARK-23054][SQL][PYSPARK][FOLLOWUP] Use sqlType casting when casting PythonUserDefinedType to String. ## What changes were proposed in this pull request? This is a follow-up of #20246. If a UDT in Python doesn't have its corresponding Scala UDT, cast to string will be the raw string of the internal value, e.g. `"org.apache.spark.sql.catalyst.expressions.UnsafeArrayDataxxxxxxxx"` if the internal type is `ArrayType`. This pr fixes it by using its `sqlType` casting. ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN Closes #20306 from ueshin/issues/SPARK-23054/fup1. --- python/pyspark/sql/tests.py | 11 +++++++++++ .../apache/spark/sql/catalyst/expressions/Cast.scala | 2 ++ .../org/apache/spark/sql/test/ExamplePointUDT.scala | 2 ++ 3 files changed, 15 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 25483594f2725..4fee2ecde391b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1189,6 +1189,17 @@ def test_union_with_udt(self): ] ) + def test_cast_to_string_with_udt(self): + from pyspark.sql.tests import ExamplePointUDT, ExamplePoint + from pyspark.sql.functions import col + row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0)) + schema = StructType([StructField("point", ExamplePointUDT(), False), + StructField("pypoint", PythonOnlyUDT(), False)]) + df = self.spark.createDataFrame([row], schema) + + result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head() + self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]')) + def test_column_operators(self): ci = self.df.key cs = self.df.value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a95ebe301b9d1..79b051670e9e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -282,6 +282,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String builder.append("]") builder.build() }) + case pudt: PythonUserDefinedType => castToString(pudt.sqlType) case udt: UserDefinedType[_] => buildCast[Any](_, o => UTF8String.fromString(udt.deserialize(o).toString)) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) @@ -838,6 +839,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$evPrim = $buffer.build(); """.stripMargin } + case pudt: PythonUserDefinedType => castToStringCode(pudt.sqlType, ctx) case udt: UserDefinedType[_] => val udtRef = ctx.addReferenceObj("udt", udt) (c, evPrim, evNull) => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index a73e4272950a4..8bab7e1c58762 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -34,6 +34,8 @@ private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializab case that: ExamplePoint => this.x == that.x && this.y == that.y case _ => false } + + override def toString(): String = s"($x, $y)" } /** From 9c4b99861cda3f9ec44ca8c1adc81a293508190c Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Fri, 19 Jan 2018 01:38:08 -0800 Subject: [PATCH 142/774] [BUILD][MINOR] Fix java style check issues ## What changes were proposed in this pull request? This patch fixes a few recently introduced java style check errors in master and release branch. As an aside, given that [java linting currently fails](https://github.com/apache/spark/pull/10763 ) on machines with a clean maven cache, it'd be great to find another workaround to [re-enable the java style checks](https://github.com/apache/spark/blob/3a07eff5af601511e97a05e6fea0e3d48f74c4f0/dev/run-tests.py#L577) as part of Spark PRB. /cc zsxwing JoshRosen srowen for any suggestions ## How was this patch tested? Manual Check Author: Sameer Agarwal Closes #20323 from sameeragarwal/java. --- .../spark/sql/sources/v2/writer/DataSourceV2Writer.java | 6 ++++-- .../org/apache/spark/sql/vectorized/ArrowColumnVector.java | 5 +++-- .../apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java | 3 ++- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java index 317ac45bcfd74..f1ef411423162 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java @@ -28,8 +28,10 @@ /** * A data source writer that is returned by * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}/ - * {@link org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport#createMicroBatchWriter(String, long, StructType, OutputMode, DataSourceV2Options)}/ - * {@link org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport#createContinuousWriter(String, StructType, OutputMode, DataSourceV2Options)}. + * {@link org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport#createMicroBatchWriter( + * String, long, StructType, OutputMode, DataSourceV2Options)}/ + * {@link org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport#createContinuousWriter( + * String, StructType, OutputMode, DataSourceV2Options)}. * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. * diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index eb69001fe677e..bfd1b4cb0ef12 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -556,8 +556,9 @@ final int getArrayOffset(int rowId) { /** * Any call to "get" method will throw UnsupportedOperationException. * - * Access struct values in a ArrowColumnVector doesn't use this accessor. Instead, it uses getStruct() method defined - * in the parent class. Any call to "get" method in this class is a bug in the code. + * Access struct values in a ArrowColumnVector doesn't use this accessor. Instead, it uses + * getStruct() method defined in the parent class. Any call to "get" method in this class is a + * bug in the code. * */ private static class StructAccessor extends ArrowVectorAccessor { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java index 44e5146d7c553..98d6a53b54d28 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -69,7 +69,8 @@ public DataReader createDataReader() { ColumnVector[] vectors = new ColumnVector[2]; vectors[0] = i; vectors[1] = j; - this.batch = new ColumnarBatch(new StructType().add("i", "int").add("j", "int"), vectors, BATCH_SIZE); + this.batch = + new ColumnarBatch(new StructType().add("i", "int").add("j", "int"), vectors, BATCH_SIZE); return this; } From 60203fca6a605ad158184e1e0ce5187e144a3ea7 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Fri, 19 Jan 2018 12:43:23 +0200 Subject: [PATCH 143/774] [SPARK-23127][DOC] Update FeatureHasher guide for categoricalCols parameter Update user guide entry for `FeatureHasher` to match the Scala / Python doc, to describe the `categoricalCols` parameter. ## How was this patch tested? Doc only Author: Nick Pentreath Closes #20293 from MLnick/SPARK-23127-catCol-userguide. --- docs/ml-features.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 72643137d96b1..10183c3e78c76 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -222,9 +222,9 @@ The `FeatureHasher` transformer operates on multiple columns. Each column may co numeric or categorical features. Behavior and handling of column data types is as follows: - Numeric columns: For numeric features, the hash value of the column name is used to map the -feature value to its index in the feature vector. Numeric features are never treated as -categorical, even when they are integers. You must explicitly convert numeric columns containing -categorical features to strings first. +feature value to its index in the feature vector. By default, numeric features are not treated +as categorical (even when they are integers). To treat them as categorical, specify the relevant +columns using the `categoricalCols` parameter. - String columns: For categorical features, the hash value of the string "column_name=value" is used to map to the vector index, with an indicator value of `1.0`. Thus, categorical features are "one-hot" encoded (similarly to using [OneHotEncoder](ml-features.html#onehotencoder) with From b74366481cc87490adf4e69d26389ec737548c15 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 19 Jan 2018 12:48:42 +0200 Subject: [PATCH 144/774] [SPARK-23048][ML] Add OneHotEncoderEstimator document and examples ## What changes were proposed in this pull request? We have `OneHotEncoderEstimator` now and `OneHotEncoder` will be deprecated since 2.3.0. We should add `OneHotEncoderEstimator` into mllib document. We also need to provide corresponding examples for `OneHotEncoderEstimator` which are used in the document too. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #20257 from viirya/SPARK-23048. --- docs/ml-features.md | 28 ++++++++----- ...=> JavaOneHotEncoderEstimatorExample.java} | 41 ++++++++----------- ...py => onehot_encoder_estimator_example.py} | 29 +++++++------ ...la => OneHotEncoderEstimatorExample.scala} | 40 ++++++++---------- 4 files changed, 68 insertions(+), 70 deletions(-) rename examples/src/main/java/org/apache/spark/examples/ml/{JavaOneHotEncoderExample.java => JavaOneHotEncoderEstimatorExample.java} (62%) rename examples/src/main/python/ml/{onehot_encoder_example.py => onehot_encoder_estimator_example.py} (65%) rename examples/src/main/scala/org/apache/spark/examples/ml/{OneHotEncoderExample.scala => OneHotEncoderEstimatorExample.scala} (65%) diff --git a/docs/ml-features.md b/docs/ml-features.md index 10183c3e78c76..466a8fbe99cf6 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -775,35 +775,43 @@ for more details on the API.
-## OneHotEncoder +## OneHotEncoder (Deprecated since 2.3.0) -[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features. +Because this existing `OneHotEncoder` is a stateless transformer, it is not usable on new data where the number of categories may differ from the training data. In order to fix this, a new `OneHotEncoderEstimator` was created that produces an `OneHotEncoderModel` when fitting. For more detail, please see [SPARK-13030](https://issues.apache.org/jira/browse/SPARK-13030). + +`OneHotEncoder` has been deprecated in 2.3.0 and will be removed in 3.0.0. Please use [OneHotEncoderEstimator](ml-features.html#onehotencoderestimator) instead. + +## OneHotEncoderEstimator + +[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a categorical feature, represented as a label index, to a binary vector with at most a single one-value indicating the presence of a specific feature value from among the set of all feature values. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features. For string type input data, it is common to encode categorical features using [StringIndexer](ml-features.html#stringindexer) first. + +`OneHotEncoderEstimator` can transform multiple columns, returning an one-hot-encoded output vector column for each input column. It is common to merge these vectors into a single feature vector using [VectorAssembler](ml-features.html#vectorassembler). + +`OneHotEncoderEstimator` supports the `handleInvalid` parameter to choose how to handle invalid input during transforming data. Available options include 'keep' (any invalid inputs are assigned to an extra categorical index) and 'error' (throw an error). **Examples**
-Refer to the [OneHotEncoder Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoder) -for more details on the API. +Refer to the [OneHotEncoderEstimator Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoderEstimator) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala %} +{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala %}
-Refer to the [OneHotEncoder Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoder.html) +Refer to the [OneHotEncoderEstimator Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoderEstimator.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java %} +{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java %}
-Refer to the [OneHotEncoder Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoder) -for more details on the API. +Refer to the [OneHotEncoderEstimator Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoderEstimator) for more details on the API. -{% include_example python/ml/onehot_encoder_example.py %} +{% include_example python/ml/onehot_encoder_estimator_example.py %}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java similarity index 62% rename from examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java rename to examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java index 99af37676ba98..6f93cff94b725 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java @@ -23,9 +23,8 @@ import java.util.Arrays; import java.util.List; -import org.apache.spark.ml.feature.OneHotEncoder; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.ml.feature.StringIndexerModel; +import org.apache.spark.ml.feature.OneHotEncoderEstimator; +import org.apache.spark.ml.feature.OneHotEncoderModel; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; @@ -35,41 +34,37 @@ import org.apache.spark.sql.types.StructType; // $example off$ -public class JavaOneHotEncoderExample { +public class JavaOneHotEncoderEstimatorExample { public static void main(String[] args) { SparkSession spark = SparkSession .builder() - .appName("JavaOneHotEncoderExample") + .appName("JavaOneHotEncoderEstimatorExample") .getOrCreate(); + // Note: categorical features are usually first encoded with StringIndexer // $example on$ List data = Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") + RowFactory.create(0.0, 1.0), + RowFactory.create(1.0, 0.0), + RowFactory.create(2.0, 1.0), + RowFactory.create(0.0, 2.0), + RowFactory.create(0.0, 1.0), + RowFactory.create(2.0, 0.0) ); StructType schema = new StructType(new StructField[]{ - new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), - new StructField("category", DataTypes.StringType, false, Metadata.empty()) + new StructField("categoryIndex1", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("categoryIndex2", DataTypes.DoubleType, false, Metadata.empty()) }); Dataset df = spark.createDataFrame(data, schema); - StringIndexerModel indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df); - Dataset indexed = indexer.transform(df); + OneHotEncoderEstimator encoder = new OneHotEncoderEstimator() + .setInputCols(new String[] {"categoryIndex1", "categoryIndex2"}) + .setOutputCols(new String[] {"categoryVec1", "categoryVec2"}); - OneHotEncoder encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec"); - - Dataset encoded = encoder.transform(indexed); + OneHotEncoderModel model = encoder.fit(df); + Dataset encoded = model.transform(df); encoded.show(); // $example off$ diff --git a/examples/src/main/python/ml/onehot_encoder_example.py b/examples/src/main/python/ml/onehot_encoder_estimator_example.py similarity index 65% rename from examples/src/main/python/ml/onehot_encoder_example.py rename to examples/src/main/python/ml/onehot_encoder_estimator_example.py index e1996c7f0a55b..2723e681cea7c 100644 --- a/examples/src/main/python/ml/onehot_encoder_example.py +++ b/examples/src/main/python/ml/onehot_encoder_estimator_example.py @@ -18,32 +18,31 @@ from __future__ import print_function # $example on$ -from pyspark.ml.feature import OneHotEncoder, StringIndexer +from pyspark.ml.feature import OneHotEncoderEstimator # $example off$ from pyspark.sql import SparkSession if __name__ == "__main__": spark = SparkSession\ .builder\ - .appName("OneHotEncoderExample")\ + .appName("OneHotEncoderEstimatorExample")\ .getOrCreate() + # Note: categorical features are usually first encoded with StringIndexer # $example on$ df = spark.createDataFrame([ - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") - ], ["id", "category"]) + (0.0, 1.0), + (1.0, 0.0), + (2.0, 1.0), + (0.0, 2.0), + (0.0, 1.0), + (2.0, 0.0) + ], ["categoryIndex1", "categoryIndex2"]) - stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") - model = stringIndexer.fit(df) - indexed = model.transform(df) - - encoder = OneHotEncoder(inputCol="categoryIndex", outputCol="categoryVec") - encoded = encoder.transform(indexed) + encoder = OneHotEncoderEstimator(inputCols=["categoryIndex1", "categoryIndex2"], + outputCols=["categoryVec1", "categoryVec2"]) + model = encoder.fit(df) + encoded = model.transform(df) encoded.show() # $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala similarity index 65% rename from examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala rename to examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala index 274cc1268f4d1..45d816808ed8e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala @@ -19,38 +19,34 @@ package org.apache.spark.examples.ml // $example on$ -import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} +import org.apache.spark.ml.feature.OneHotEncoderEstimator // $example off$ import org.apache.spark.sql.SparkSession -object OneHotEncoderExample { +object OneHotEncoderEstimatorExample { def main(args: Array[String]): Unit = { val spark = SparkSession .builder - .appName("OneHotEncoderExample") + .appName("OneHotEncoderEstimatorExample") .getOrCreate() + // Note: categorical features are usually first encoded with StringIndexer // $example on$ val df = spark.createDataFrame(Seq( - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") - )).toDF("id", "category") - - val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df) - val indexed = indexer.transform(df) - - val encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec") - - val encoded = encoder.transform(indexed) + (0.0, 1.0), + (1.0, 0.0), + (2.0, 1.0), + (0.0, 2.0), + (0.0, 1.0), + (2.0, 0.0) + )).toDF("categoryIndex1", "categoryIndex2") + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("categoryIndex1", "categoryIndex2")) + .setOutputCols(Array("categoryVec1", "categoryVec2")) + val model = encoder.fit(df) + + val encoded = model.transform(df) encoded.show() // $example off$ From e41400c3c8aace9eb72e6134173f222627fb0faf Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 19 Jan 2018 19:46:48 +0800 Subject: [PATCH 145/774] [SPARK-23089][STS] Recreate session log directory if it doesn't exist ## What changes were proposed in this pull request? When creating a session directory, Thrift should create the parent directory (i.e. /tmp/base_session_log_dir) if it is not present. It is common that many tools delete empty directories, so the directory may be deleted. This can cause the session log to be disabled. This was fixed in HIVE-12262: this PR brings it in Spark too. ## How was this patch tested? manual tests Author: Marco Gaido Closes #20281 from mgaido91/SPARK-23089. --- .../hive/service/cli/session/HiveSessionImpl.java | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java index 47bfaa86021d6..108074cce3d6d 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java @@ -223,6 +223,18 @@ private void configureSession(Map sessionConfMap) throws HiveSQL @Override public void setOperationLogSessionDir(File operationLogRootDir) { + if (!operationLogRootDir.exists()) { + LOG.warn("The operation log root directory is removed, recreating: " + + operationLogRootDir.getAbsolutePath()); + if (!operationLogRootDir.mkdirs()) { + LOG.warn("Unable to create operation log root directory: " + + operationLogRootDir.getAbsolutePath()); + } + } + if (!operationLogRootDir.canWrite()) { + LOG.warn("The operation log root directory is not writable: " + + operationLogRootDir.getAbsolutePath()); + } sessionLogDir = new File(operationLogRootDir, sessionHandle.getHandleIdentifier().toString()); isOperationLogEnabled = true; if (!sessionLogDir.exists()) { From e1c33b6cd14e4e1123814f4d040e3520db7d1ec9 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Fri, 19 Jan 2018 08:22:24 -0600 Subject: [PATCH 146/774] [SPARK-23024][WEB-UI] Spark ui about the contents of the form need to have hidden and show features, when the table records very much. ## What changes were proposed in this pull request? Spark ui about the contents of the form need to have hidden and show features, when the table records very much. Because sometimes you do not care about the record of the table, you just want to see the contents of the next table, but you have to scroll the scroll bar for a long time to see the contents of the next table. Currently we have about 500 workers, but I just wanted to see the logs for the running applications table. I had to scroll through the scroll bars for a long time to see the logs for the running applications table. In order to ensure functional consistency, I modified the Master Page, Worker Page, Job Page, Stage Page, Task Page, Configuration Page, Storage Page, Pool Page. fix before: ![1](https://user-images.githubusercontent.com/26266482/34805936-601ed628-f6bb-11e7-8dd3-d8413573a076.png) fix after: ![2](https://user-images.githubusercontent.com/26266482/34805949-6af8afba-f6bb-11e7-89f4-ab16584916fb.png) ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #20216 from guoxiaolongzte/SPARK-23024. --- .../org/apache/spark/ui/static/webui.js | 30 ++++++++ .../deploy/master/ui/ApplicationPage.scala | 25 +++++-- .../spark/deploy/master/ui/MasterPage.scala | 63 ++++++++++++++--- .../spark/deploy/worker/ui/WorkerPage.scala | 52 +++++++++++--- .../apache/spark/ui/env/EnvironmentPage.scala | 48 +++++++++++-- .../apache/spark/ui/jobs/AllJobsPage.scala | 39 +++++++++-- .../apache/spark/ui/jobs/AllStagesPage.scala | 67 +++++++++++++++--- .../org/apache/spark/ui/jobs/JobPage.scala | 68 ++++++++++++++++--- .../org/apache/spark/ui/jobs/PoolPage.scala | 13 +++- .../org/apache/spark/ui/jobs/StagePage.scala | 12 +++- .../apache/spark/ui/storage/StoragePage.scala | 12 +++- 11 files changed, 373 insertions(+), 56 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index 0fa1fcf25f8b9..e575c4c78970d 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -50,4 +50,34 @@ function collapseTable(thisName, table){ // to remember if it's collapsed on each page reload $(function() { collapseTablePageLoad('collapse-aggregated-metrics','aggregated-metrics'); + collapseTablePageLoad('collapse-aggregated-executors','aggregated-executors'); + collapseTablePageLoad('collapse-aggregated-removedExecutors','aggregated-removedExecutors'); + collapseTablePageLoad('collapse-aggregated-workers','aggregated-workers'); + collapseTablePageLoad('collapse-aggregated-activeApps','aggregated-activeApps'); + collapseTablePageLoad('collapse-aggregated-activeDrivers','aggregated-activeDrivers'); + collapseTablePageLoad('collapse-aggregated-completedApps','aggregated-completedApps'); + collapseTablePageLoad('collapse-aggregated-completedDrivers','aggregated-completedDrivers'); + collapseTablePageLoad('collapse-aggregated-runningExecutors','aggregated-runningExecutors'); + collapseTablePageLoad('collapse-aggregated-runningDrivers','aggregated-runningDrivers'); + collapseTablePageLoad('collapse-aggregated-finishedExecutors','aggregated-finishedExecutors'); + collapseTablePageLoad('collapse-aggregated-finishedDrivers','aggregated-finishedDrivers'); + collapseTablePageLoad('collapse-aggregated-runtimeInformation','aggregated-runtimeInformation'); + collapseTablePageLoad('collapse-aggregated-sparkProperties','aggregated-sparkProperties'); + collapseTablePageLoad('collapse-aggregated-systemProperties','aggregated-systemProperties'); + collapseTablePageLoad('collapse-aggregated-classpathEntries','aggregated-classpathEntries'); + collapseTablePageLoad('collapse-aggregated-activeJobs','aggregated-activeJobs'); + collapseTablePageLoad('collapse-aggregated-completedJobs','aggregated-completedJobs'); + collapseTablePageLoad('collapse-aggregated-failedJobs','aggregated-failedJobs'); + collapseTablePageLoad('collapse-aggregated-poolTable','aggregated-poolTable'); + collapseTablePageLoad('collapse-aggregated-allActiveStages','aggregated-allActiveStages'); + collapseTablePageLoad('collapse-aggregated-allPendingStages','aggregated-allPendingStages'); + collapseTablePageLoad('collapse-aggregated-allCompletedStages','aggregated-allCompletedStages'); + collapseTablePageLoad('collapse-aggregated-allFailedStages','aggregated-allFailedStages'); + collapseTablePageLoad('collapse-aggregated-activeStages','aggregated-activeStages'); + collapseTablePageLoad('collapse-aggregated-pendingOrSkippedStages','aggregated-pendingOrSkippedStages'); + collapseTablePageLoad('collapse-aggregated-completedStages','aggregated-completedStages'); + collapseTablePageLoad('collapse-aggregated-failedStages','aggregated-failedStages'); + collapseTablePageLoad('collapse-aggregated-poolActiveStages','aggregated-poolActiveStages'); + collapseTablePageLoad('collapse-aggregated-tasks','aggregated-tasks'); + collapseTablePageLoad('collapse-aggregated-rdds','aggregated-rdds'); }); \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 68e57b7564ad1..f699c75085fe1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -100,12 +100,29 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
-

Executor Summary ({allExecutors.length})

- {executorsTable} + +

+ + Executor Summary ({allExecutors.length}) +

+
+
+ {executorsTable} +
{ if (removedExecutors.nonEmpty) { -

Removed Executors ({removedExecutors.length})

++ - removedExecutorsTable + +

+ + Removed Executors ({removedExecutors.length}) +

+
++ +
+ {removedExecutorsTable} +
} }
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index bc0bf6a1d9700..c629937606b51 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -128,15 +128,31 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
-

Workers ({workers.length})

- {workerTable} + +

+ + Workers ({workers.length}) +

+
+
+ {workerTable} +
-

Running Applications ({activeApps.length})

- {activeAppsTable} + +

+ + Running Applications ({activeApps.length}) +

+
+
+ {activeAppsTable} +
@@ -144,8 +160,17 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {if (hasDrivers) {
-

Running Drivers ({activeDrivers.length})

- {activeDriversTable} + +

+ + Running Drivers ({activeDrivers.length}) +

+
+
+ {activeDriversTable} +
} @@ -154,8 +179,17 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
-

Completed Applications ({completedApps.length})

- {completedAppsTable} + +

+ + Completed Applications ({completedApps.length}) +

+
+
+ {completedAppsTable} +
@@ -164,8 +198,17 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { if (hasDrivers) {
-

Completed Drivers ({completedDrivers.length})

- {completedDriversTable} + +

+ + Completed Drivers ({completedDrivers.length}) +

+
+
+ {completedDriversTable} +
} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index ce84bc4dae32c..8b98ae56fc108 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -77,24 +77,60 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") {
-

Running Executors ({runningExecutors.size})

- {runningExecutorTable} + +

+ + Running Executors ({runningExecutors.size}) +

+
+
+ {runningExecutorTable} +
{ if (runningDrivers.nonEmpty) { -

Running Drivers ({runningDrivers.size})

++ - runningDriverTable + +

+ + Running Drivers ({runningDrivers.size}) +

+
++ +
+ {runningDriverTable} +
} } { if (finishedExecutors.nonEmpty) { -

Finished Executors ({finishedExecutors.size})

++ - finishedExecutorTable + +

+ + Finished Executors ({finishedExecutors.size}) +

+
++ +
+ {finishedExecutorTable} +
} } { if (finishedDrivers.nonEmpty) { -

Finished Drivers ({finishedDrivers.size})

++ - finishedDriverTable + +

+ + Finished Drivers ({finishedDrivers.size}) +

+
++ +
+ {finishedDriverTable} +
} }
diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala index 43adab7a35d65..902eb92b854f2 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala @@ -48,10 +48,50 @@ private[ui] class EnvironmentPage( classPathHeaders, classPathRow, appEnv.classpathEntries, fixedWidth = true) val content = -

Runtime Information

{runtimeInformationTable} -

Spark Properties

{sparkPropertiesTable} -

System Properties

{systemPropertiesTable} -

Classpath Entries

{classpathEntriesTable} + +

+ + Runtime Information +

+
+
+ {runtimeInformationTable} +
+ +

+ + Spark Properties +

+
+
+ {sparkPropertiesTable} +
+ +

+ + System Properties +

+
+
+ {systemPropertiesTable} +
+ +

+ + Classpath Entries +

+
+
+ {classpathEntriesTable} +
UIUtils.headerSparkPage("Environment", content, parent) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index ff916bb6a5759..e3b72f1f34859 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -363,16 +363,43 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We store.executorList(false), startTime) if (shouldShowActiveJobs) { - content ++=

Active Jobs ({activeJobs.size})

++ - activeJobsTable + content ++= + +

+ + Active Jobs ({activeJobs.size}) +

+
++ +
+ {activeJobsTable} +
} if (shouldShowCompletedJobs) { - content ++=

Completed Jobs ({completedJobNumStr})

++ - completedJobsTable + content ++= + +

+ + Completed Jobs ({completedJobNumStr}) +

+
++ +
+ {completedJobsTable} +
} if (shouldShowFailedJobs) { - content ++=

Failed Jobs ({failedJobs.size})

++ - failedJobsTable + content ++= + +

+ + Failed Jobs ({failedJobs.size}) +

+
++ +
+ {failedJobsTable} +
} val helpText = """A job is triggered by an action, like count() or saveAsTextFile().""" + diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index b1e343451e28e..606dc1e180e5b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -116,26 +116,75 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { var content = summary ++ { if (sc.isDefined && isFairScheduler) { -

Fair Scheduler Pools ({pools.size})

++ poolTable.toNodeSeq + +

+ + Fair Scheduler Pools ({pools.size}) +

+
++ +
+ {poolTable.toNodeSeq} +
} else { Seq.empty[Node] } } if (shouldShowActiveStages) { - content ++=

Active Stages ({activeStages.size})

++ - activeStagesTable.toNodeSeq + content ++= + +

+ + Active Stages ({activeStages.size}) +

+
++ +
+ {activeStagesTable.toNodeSeq} +
} if (shouldShowPendingStages) { - content ++=

Pending Stages ({pendingStages.size})

++ - pendingStagesTable.toNodeSeq + content ++= + +

+ + Pending Stages ({pendingStages.size}) +

+
++ +
+ {pendingStagesTable.toNodeSeq} +
} if (shouldShowCompletedStages) { - content ++=

Completed Stages ({completedStageNumStr})

++ - completedStagesTable.toNodeSeq + content ++= + +

+ + Completed Stages ({completedStageNumStr}) +

+
++ +
+ {completedStagesTable.toNodeSeq} +
} if (shouldShowFailedStages) { - content ++=

Failed Stages ({numFailedStages})

++ - failedStagesTable.toNodeSeq + content ++= + +

+ + Failed Stages ({numFailedStages}) +

+
++ +
+ {failedStagesTable.toNodeSeq} +
} UIUtils.headerSparkPage("Stages for All Jobs", content, parent) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index bf59152c8c0cd..c27f30c21a843 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -340,24 +340,72 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP jobId, store.operationGraphForJob(jobId)) if (shouldShowActiveStages) { - content ++=

Active Stages ({activeStages.size})

++ - activeStagesTable.toNodeSeq + content ++= + +

+ + Active Stages ({activeStages.size}) +

+
++ +
+ {activeStagesTable.toNodeSeq} +
} if (shouldShowPendingStages) { - content ++=

Pending Stages ({pendingOrSkippedStages.size})

++ - pendingOrSkippedStagesTable.toNodeSeq + content ++= + +

+ + Pending Stages ({pendingOrSkippedStages.size}) +

+
++ +
+ {pendingOrSkippedStagesTable.toNodeSeq} +
} if (shouldShowCompletedStages) { - content ++=

Completed Stages ({completedStages.size})

++ - completedStagesTable.toNodeSeq + content ++= + +

+ + Completed Stages ({completedStages.size}) +

+
++ +
+ {completedStagesTable.toNodeSeq} +
} if (shouldShowSkippedStages) { - content ++=

Skipped Stages ({pendingOrSkippedStages.size})

++ - pendingOrSkippedStagesTable.toNodeSeq + content ++= + +

+ + Skipped Stages ({pendingOrSkippedStages.size}) +

+
++ +
+ {pendingOrSkippedStagesTable.toNodeSeq} +
} if (shouldShowFailedStages) { - content ++=

Failed Stages ({failedStages.size})

++ - failedStagesTable.toNodeSeq + content ++= + +

+ + Failed Stages ({failedStages.size}) +

+
++ +
+ {failedStagesTable.toNodeSeq} +
} UIUtils.headerSparkPage(s"Details for Job $jobId", content, parent, showVisualization = true) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 98fbd7aceaa11..a3e1f13782e30 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -51,7 +51,18 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { val poolTable = new PoolTable(Map(pool -> uiPool), parent) var content =

Summary

++ poolTable.toNodeSeq if (activeStages.nonEmpty) { - content ++=

Active Stages ({activeStages.size})

++ activeStagesTable.toNodeSeq + content ++= + +

+ + Active Stages ({activeStages.size}) +

+
++ +
+ {activeStagesTable.toNodeSeq} +
} UIUtils.headerSparkPage("Fair Scheduler Pool: " + poolName, content, parent) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index af78373ddb4b2..25bee33028393 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -486,8 +486,16 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
{summaryTable.getOrElse("No tasks have reported metrics yet.")}
++ aggMetrics ++ maybeAccumulableTable ++ -

Tasks ({totalTasksNumStr})

++ - taskTableHTML ++ jsForScrollingDownToTaskTable + +

+ + Tasks ({totalTasksNumStr}) +

+
++ +
+ {taskTableHTML ++ jsForScrollingDownToTaskTable} +
UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true) } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index b8aec9890247a..68d946574a37b 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -41,8 +41,16 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends Nil } else {
-

RDDs

- {UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table"))} + +

+ + RDDs +

+
+
+ {UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table"))} +
} } From 6c39654efcb2aa8cb4d082ab7277a6fa38fb48e4 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 19 Jan 2018 22:47:18 +0800 Subject: [PATCH 147/774] [SPARK-23000][TEST] Keep Derby DB Location Unchanged After Session Cloning ## What changes were proposed in this pull request? After session cloning in `TestHive`, the conf of the singleton SparkContext for derby DB location is changed to a new directory. The new directory is created in `HiveUtils.newTemporaryConfiguration(useInMemoryDerby = false)`. This PR is to keep the conf value of `ConfVars.METASTORECONNECTURLKEY.varname` unchanged during the session clone. ## How was this patch tested? The issue can be reproduced by the command: > build/sbt -Phive "hive/test-only org.apache.spark.sql.hive.HiveSessionStateSuite org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite" Also added a test case. Author: gatorsmile Closes #20328 from gatorsmile/fixTestFailure. --- .../org/apache/spark/sql/SessionStateSuite.scala | 5 +---- .../apache/spark/sql/hive/test/TestHive.scala | 8 +++++++- .../spark/sql/hive/HiveSessionStateSuite.scala | 16 +++++++++++++--- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index 5d75f5835bf9e..4efae4c46c2e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import org.scalatest.BeforeAndAfterAll -import org.scalatest.BeforeAndAfterEach import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite @@ -28,8 +26,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.util.QueryExecutionListener -class SessionStateSuite extends SparkFunSuite - with BeforeAndAfterEach with BeforeAndAfterAll { +class SessionStateSuite extends SparkFunSuite { /** * A shared SparkSession for all tests in this suite. Make sure you reset any changes to this diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index b6be00dbb3a73..c84131fc3212a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -180,7 +180,13 @@ private[hive] class TestHiveSparkSession( ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true", // scratch directory used by Hive's metastore client ConfVars.SCRATCHDIR.varname -> TestHiveContext.makeScratchDir().toURI.toString, - ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY.varname -> "1") + ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY.varname -> "1") ++ + // After session cloning, the JDBC connect string for a JDBC metastore should not be changed. + existingSharedState.map { state => + val connKey = + state.sparkContext.hadoopConfiguration.get(ConfVars.METASTORECONNECTURLKEY.varname) + ConfVars.METASTORECONNECTURLKEY.varname -> connKey + } metastoreTempConf.foreach { case (k, v) => sc.hadoopConfiguration.set(k, v) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala index f7da3c4cbb0aa..ecc09cdcdbeaf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import org.scalatest.BeforeAndAfterEach +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -25,8 +25,7 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton /** * Run all tests from `SessionStateSuite` with a Hive based `SessionState`. */ -class HiveSessionStateSuite extends SessionStateSuite - with TestHiveSingleton with BeforeAndAfterEach { +class HiveSessionStateSuite extends SessionStateSuite with TestHiveSingleton { override def beforeAll(): Unit = { // Reuse the singleton session @@ -39,4 +38,15 @@ class HiveSessionStateSuite extends SessionStateSuite activeSession = null super.afterAll() } + + test("Clone then newSession") { + val sparkSession = hiveContext.sparkSession + val conf = sparkSession.sparkContext.hadoopConfiguration + val oldValue = conf.get(ConfVars.METASTORECONNECTURLKEY.varname) + sparkSession.cloneSession() + sparkSession.sharedState.externalCatalog.client.newSession() + val newValue = conf.get(ConfVars.METASTORECONNECTURLKEY.varname) + assert(oldValue == newValue, + "cloneSession and then newSession should not affect the Derby directory") + } } From 606a7485f12c5d5377c50258006c353ba5e49c3f Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Fri, 19 Jan 2018 09:28:35 -0600 Subject: [PATCH 148/774] [SPARK-23085][ML] API parity for mllib.linalg.Vectors.sparse ## What changes were proposed in this pull request? `ML.Vectors#sparse(size: Int, elements: Seq[(Int, Double)])` support zero-length ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #20275 from zhengruifeng/SparseVector_size. --- .../scala/org/apache/spark/ml/linalg/Vectors.scala | 2 +- .../org/apache/spark/ml/linalg/VectorsSuite.scala | 14 ++++++++++++++ .../org/apache/spark/mllib/linalg/Vectors.scala | 3 +-- .../apache/spark/mllib/linalg/VectorsSuite.scala | 14 ++++++++++++++ 4 files changed, 30 insertions(+), 3 deletions(-) diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index 941b6eca568d3..5824e463ca1aa 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -565,7 +565,7 @@ class SparseVector @Since("2.0.0") ( // validate the data { - require(size >= 0, "The size of the requested sparse vector must be greater than 0.") + require(size >= 0, "The size of the requested sparse vector must be no less than 0.") require(indices.length == values.length, "Sparse vectors require that the dimension of the" + s" indices match the dimension of the values. You provided ${indices.length} indices and " + s" ${values.length} values.") diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala index 79acef8214d88..0a316f57f811b 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala @@ -366,4 +366,18 @@ class VectorsSuite extends SparkMLFunSuite { assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2))) assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4))) } + + test("sparse vector only support non-negative length") { + val v1 = Vectors.sparse(0, Array.emptyIntArray, Array.emptyDoubleArray) + val v2 = Vectors.sparse(0, Array.empty[(Int, Double)]) + assert(v1.size === 0) + assert(v2.size === 0) + + intercept[IllegalArgumentException] { + Vectors.sparse(-1, Array(1), Array(2.0)) + } + intercept[IllegalArgumentException] { + Vectors.sparse(-1, Array((1, 2.0))) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index fd9605c013625..6e68d9684a672 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -326,8 +326,6 @@ object Vectors { */ @Since("1.0.0") def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = { - require(size > 0, "The size of the requested sparse vector must be greater than 0.") - val (indices, values) = elements.sortBy(_._1).unzip var prev = -1 indices.foreach { i => @@ -758,6 +756,7 @@ class SparseVector @Since("1.0.0") ( @Since("1.0.0") val indices: Array[Int], @Since("1.0.0") val values: Array[Double]) extends Vector { + require(size >= 0, "The size of the requested sparse vector must be no less than 0.") require(indices.length == values.length, "Sparse vectors require that the dimension of the" + s" indices match the dimension of the values. You provided ${indices.length} indices and " + s" ${values.length} values.") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 4074bead421e6..217b4a35438fd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -495,4 +495,18 @@ class VectorsSuite extends SparkFunSuite with Logging { assert(mlDenseVectorToArray(dv) === mlDenseVectorToArray(newDV)) assert(mlSparseVectorToArray(sv) === mlSparseVectorToArray(newSV)) } + + test("sparse vector only support non-negative length") { + val v1 = Vectors.sparse(0, Array.emptyIntArray, Array.emptyDoubleArray) + val v2 = Vectors.sparse(0, Array.empty[(Int, Double)]) + assert(v1.size === 0) + assert(v2.size === 0) + + intercept[IllegalArgumentException] { + Vectors.sparse(-1, Array(1), Array(2.0)) + } + intercept[IllegalArgumentException] { + Vectors.sparse(-1, Array((1, 2.0))) + } + } } From d8aaa771e249b3f54b57ce24763e53fd65a0dbf7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 19 Jan 2018 08:58:21 -0800 Subject: [PATCH 149/774] [SPARK-23149][SQL] polish ColumnarBatch ## What changes were proposed in this pull request? Several cleanups in `ColumnarBatch` * remove `schema`. The `ColumnVector`s inside `ColumnarBatch` already have the data type information, we don't need this `schema`. * remove `capacity`. `ColumnarBatch` is just a wrapper of `ColumnVector`s, not builders, it doesn't need a capacity property. * remove `DEFAULT_BATCH_SIZE`. As a wrapper, `ColumnarBatch` can't decide the batch size, it should be decided by the reader, e.g. parquet reader, orc reader, cached table reader. The default batch size should also be defined by the reader. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #20316 from cloud-fan/columnar-batch. --- .../orc/OrcColumnarBatchReader.java | 49 +++++++------------ .../SpecificParquetRecordReaderBase.java | 12 ++--- .../VectorizedParquetRecordReader.java | 24 ++++----- .../vectorized/ColumnVectorUtils.java | 18 +++---- .../spark/sql/vectorized/ColumnarBatch.java | 20 +------- .../VectorizedHashMapGenerator.scala | 2 +- .../sql/execution/arrow/ArrowConverters.scala | 2 +- .../columnar/InMemoryTableScanExec.scala | 5 +- .../python/ArrowEvalPythonExec.scala | 8 +-- .../execution/python/ArrowPythonRunner.scala | 2 +- .../sql/sources/v2/JavaBatchDataSourceV2.java | 3 +- .../vectorized/ColumnarBatchSuite.scala | 7 ++- .../sql/sources/v2/DataSourceV2Suite.scala | 3 +- 13 files changed, 61 insertions(+), 94 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index 36fdf2bdf84d2..89bae4326e93b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -49,18 +49,8 @@ * After creating, `initialize` and `initBatch` should be called sequentially. */ public class OrcColumnarBatchReader extends RecordReader { - - /** - * The default size of batch. We use this value for ORC reader to make it consistent with Spark's - * columnar batch, because their default batch sizes are different like the following: - * - * - ORC's VectorizedRowBatch.DEFAULT_SIZE = 1024 - * - Spark's ColumnarBatch.DEFAULT_BATCH_SIZE = 4 * 1024 - */ - private static final int DEFAULT_SIZE = 4 * 1024; - - // ORC File Reader - private Reader reader; + // TODO: make this configurable. + private static final int CAPACITY = 4 * 1024; // Vectorized ORC Row Batch private VectorizedRowBatch batch; @@ -98,22 +88,22 @@ public OrcColumnarBatchReader(boolean useOffHeap, boolean copyToSpark) { @Override - public Void getCurrentKey() throws IOException, InterruptedException { + public Void getCurrentKey() { return null; } @Override - public ColumnarBatch getCurrentValue() throws IOException, InterruptedException { + public ColumnarBatch getCurrentValue() { return columnarBatch; } @Override - public float getProgress() throws IOException, InterruptedException { + public float getProgress() throws IOException { return recordReader.getProgress(); } @Override - public boolean nextKeyValue() throws IOException, InterruptedException { + public boolean nextKeyValue() throws IOException { return nextBatch(); } @@ -134,16 +124,15 @@ public void close() throws IOException { * Please note that `initBatch` is needed to be called after this. */ @Override - public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) - throws IOException, InterruptedException { + public void initialize( + InputSplit inputSplit, TaskAttemptContext taskAttemptContext) throws IOException { FileSplit fileSplit = (FileSplit)inputSplit; Configuration conf = taskAttemptContext.getConfiguration(); - reader = OrcFile.createReader( + Reader reader = OrcFile.createReader( fileSplit.getPath(), OrcFile.readerOptions(conf) .maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf)) .filesystem(fileSplit.getPath().getFileSystem(conf))); - Reader.Options options = OrcInputFormat.buildOptions(conf, reader, fileSplit.getStart(), fileSplit.getLength()); recordReader = reader.rows(options); @@ -159,7 +148,7 @@ public void initBatch( StructField[] requiredFields, StructType partitionSchema, InternalRow partitionValues) { - batch = orcSchema.createRowBatch(DEFAULT_SIZE); + batch = orcSchema.createRowBatch(CAPACITY); assert(!batch.selectedInUse); // `selectedInUse` should be initialized with `false`. this.requiredFields = requiredFields; @@ -171,19 +160,17 @@ public void initBatch( resultSchema = resultSchema.add(f); } - int capacity = DEFAULT_SIZE; - if (copyToSpark) { if (MEMORY_MODE == MemoryMode.OFF_HEAP) { - columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); + columnVectors = OffHeapColumnVector.allocateColumns(CAPACITY, resultSchema); } else { - columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); + columnVectors = OnHeapColumnVector.allocateColumns(CAPACITY, resultSchema); } // Initialize the missing columns once. for (int i = 0; i < requiredFields.length; i++) { if (requestedColIds[i] == -1) { - columnVectors[i].putNulls(0, capacity); + columnVectors[i].putNulls(0, CAPACITY); columnVectors[i].setIsConstant(); } } @@ -196,7 +183,7 @@ public void initBatch( } } - columnarBatch = new ColumnarBatch(resultSchema, columnVectors, capacity); + columnarBatch = new ColumnarBatch(columnVectors); } else { // Just wrap the ORC column vector instead of copying it to Spark column vector. orcVectorWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()]; @@ -206,8 +193,8 @@ public void initBatch( int colId = requestedColIds[i]; // Initialize the missing columns once. if (colId == -1) { - OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); - missingCol.putNulls(0, capacity); + OnHeapColumnVector missingCol = new OnHeapColumnVector(CAPACITY, dt); + missingCol.putNulls(0, CAPACITY); missingCol.setIsConstant(); orcVectorWrappers[i] = missingCol; } else { @@ -219,14 +206,14 @@ public void initBatch( int partitionIdx = requiredFields.length; for (int i = 0; i < partitionValues.numFields(); i++) { DataType dt = partitionSchema.fields()[i].dataType(); - OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt); + OnHeapColumnVector partitionCol = new OnHeapColumnVector(CAPACITY, dt); ColumnVectorUtils.populate(partitionCol, partitionValues, i); partitionCol.setIsConstant(); orcVectorWrappers[partitionIdx + i] = partitionCol; } } - columnarBatch = new ColumnarBatch(resultSchema, orcVectorWrappers, capacity); + columnarBatch = new ColumnarBatch(orcVectorWrappers); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 80c2f491b48ce..e65cd252c3ddf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -170,7 +170,7 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont * Returns the list of files at 'path' recursively. This skips files that are ignored normally * by MapReduce. */ - public static List listDirectory(File path) throws IOException { + public static List listDirectory(File path) { List result = new ArrayList<>(); if (path.isDirectory()) { for (File f: path.listFiles()) { @@ -231,7 +231,7 @@ protected void initialize(String path, List columns) throws IOException } @Override - public Void getCurrentKey() throws IOException, InterruptedException { + public Void getCurrentKey() { return null; } @@ -259,7 +259,7 @@ public ValuesReaderIntIterator(ValuesReader delegate) { } @Override - int nextInt() throws IOException { + int nextInt() { return delegate.readInteger(); } } @@ -279,15 +279,15 @@ int nextInt() throws IOException { protected static final class NullIntIterator extends IntIterator { @Override - int nextInt() throws IOException { return 0; } + int nextInt() { return 0; } } /** * Creates a reader for definition and repetition levels, returning an optimized one if * the levels are not needed. */ - protected static IntIterator createRLEIterator(int maxLevel, BytesInput bytes, - ColumnDescriptor descriptor) throws IOException { + protected static IntIterator createRLEIterator( + int maxLevel, BytesInput bytes, ColumnDescriptor descriptor) throws IOException { try { if (maxLevel == 0) return new NullIntIterator(); return new RLEIntIterator( diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index cd745b1f0e4e3..bb1b23611a7d7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -50,6 +50,9 @@ * TODO: make this always return ColumnarBatches. */ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBase { + // TODO: make this configurable. + private static final int CAPACITY = 4 * 1024; + /** * Batch of rows that we assemble and the current index we've returned. Every time this * batch is used up (batchIdx == numBatched), we populated the batch. @@ -152,7 +155,7 @@ public void close() throws IOException { } @Override - public boolean nextKeyValue() throws IOException, InterruptedException { + public boolean nextKeyValue() throws IOException { resultBatch(); if (returnColumnarBatch) return nextBatch(); @@ -165,13 +168,13 @@ public boolean nextKeyValue() throws IOException, InterruptedException { } @Override - public Object getCurrentValue() throws IOException, InterruptedException { + public Object getCurrentValue() { if (returnColumnarBatch) return columnarBatch; return columnarBatch.getRow(batchIdx - 1); } @Override - public float getProgress() throws IOException, InterruptedException { + public float getProgress() { return (float) rowsReturned / totalRowCount; } @@ -181,7 +184,7 @@ public float getProgress() throws IOException, InterruptedException { // Columns 0,1: data columns // Column 2: partitionValues[0] // Column 3: partitionValues[1] - public void initBatch( + private void initBatch( MemoryMode memMode, StructType partitionColumns, InternalRow partitionValues) { @@ -195,13 +198,12 @@ public void initBatch( } } - int capacity = ColumnarBatch.DEFAULT_BATCH_SIZE; if (memMode == MemoryMode.OFF_HEAP) { - columnVectors = OffHeapColumnVector.allocateColumns(capacity, batchSchema); + columnVectors = OffHeapColumnVector.allocateColumns(CAPACITY, batchSchema); } else { - columnVectors = OnHeapColumnVector.allocateColumns(capacity, batchSchema); + columnVectors = OnHeapColumnVector.allocateColumns(CAPACITY, batchSchema); } - columnarBatch = new ColumnarBatch(batchSchema, columnVectors, capacity); + columnarBatch = new ColumnarBatch(columnVectors); if (partitionColumns != null) { int partitionIdx = sparkSchema.fields().length; for (int i = 0; i < partitionColumns.fields().length; i++) { @@ -213,13 +215,13 @@ public void initBatch( // Initialize missing columns with nulls. for (int i = 0; i < missingColumns.length; i++) { if (missingColumns[i]) { - columnVectors[i].putNulls(0, columnarBatch.capacity()); + columnVectors[i].putNulls(0, CAPACITY); columnVectors[i].setIsConstant(); } } } - public void initBatch() { + private void initBatch() { initBatch(MEMORY_MODE, null, null); } @@ -255,7 +257,7 @@ public boolean nextBatch() throws IOException { if (rowsReturned >= totalRowCount) return false; checkEndOfRowGroup(); - int num = (int) Math.min((long) columnarBatch.capacity(), totalCountLoadedSoFar - rowsReturned); + int num = (int) Math.min((long) CAPACITY, totalCountLoadedSoFar - rowsReturned); for (int i = 0; i < columnReaders.length; ++i) { if (columnReaders[i] == null) continue; columnReaders[i].readBatch(num, columnVectors[i]); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index b5cbe8e2839ba..5ee8cc8da2309 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -118,19 +118,19 @@ private static void appendValue(WritableColumnVector dst, DataType t, Object o) } } else { if (t == DataTypes.BooleanType) { - dst.appendBoolean(((Boolean)o).booleanValue()); + dst.appendBoolean((Boolean) o); } else if (t == DataTypes.ByteType) { - dst.appendByte(((Byte) o).byteValue()); + dst.appendByte((Byte) o); } else if (t == DataTypes.ShortType) { - dst.appendShort(((Short)o).shortValue()); + dst.appendShort((Short) o); } else if (t == DataTypes.IntegerType) { - dst.appendInt(((Integer)o).intValue()); + dst.appendInt((Integer) o); } else if (t == DataTypes.LongType) { - dst.appendLong(((Long)o).longValue()); + dst.appendLong((Long) o); } else if (t == DataTypes.FloatType) { - dst.appendFloat(((Float)o).floatValue()); + dst.appendFloat((Float) o); } else if (t == DataTypes.DoubleType) { - dst.appendDouble(((Double)o).doubleValue()); + dst.appendDouble((Double) o); } else if (t == DataTypes.StringType) { byte[] b =((String)o).getBytes(StandardCharsets.UTF_8); dst.appendByteArray(b, 0, b.length); @@ -192,7 +192,7 @@ private static void appendValue(WritableColumnVector dst, DataType t, Row src, i */ public static ColumnarBatch toBatch( StructType schema, MemoryMode memMode, Iterator row) { - int capacity = ColumnarBatch.DEFAULT_BATCH_SIZE; + int capacity = 4 * 1024; WritableColumnVector[] columnVectors; if (memMode == MemoryMode.OFF_HEAP) { columnVectors = OffHeapColumnVector.allocateColumns(capacity, schema); @@ -208,7 +208,7 @@ public static ColumnarBatch toBatch( } n++; } - ColumnarBatch batch = new ColumnarBatch(schema, columnVectors, capacity); + ColumnarBatch batch = new ColumnarBatch(columnVectors); batch.setNumRows(n); return batch; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java index 9ae1c6d9993f0..4dc826cf60c15 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java @@ -20,7 +20,6 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.MutableColumnarRow; -import org.apache.spark.sql.types.StructType; /** * This class wraps multiple ColumnVectors as a row-wise table. It provides a row view of this @@ -28,10 +27,6 @@ * the entire data loading process. */ public final class ColumnarBatch { - public static final int DEFAULT_BATCH_SIZE = 4 * 1024; - - private final StructType schema; - private final int capacity; private int numRows; private final ColumnVector[] columns; @@ -82,7 +77,6 @@ public void remove() { * Sets the number of rows in this batch. */ public void setNumRows(int numRows) { - assert(numRows <= this.capacity); this.numRows = numRows; } @@ -96,16 +90,6 @@ public void setNumRows(int numRows) { */ public int numRows() { return numRows; } - /** - * Returns the schema that makes up this batch. - */ - public StructType schema() { return schema; } - - /** - * Returns the max capacity (in number of rows) for this batch. - */ - public int capacity() { return capacity; } - /** * Returns the column at `ordinal`. */ @@ -120,10 +104,8 @@ public InternalRow getRow(int rowId) { return row; } - public ColumnarBatch(StructType schema, ColumnVector[] columns, int capacity) { - this.schema = schema; + public ColumnarBatch(ColumnVector[] columns) { this.columns = columns; - this.capacity = capacity; this.row = new MutableColumnarRow(columns); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 0cf9b53ce1d5d..eb48584d0c1ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -94,7 +94,7 @@ class VectorizedHashMapGenerator( | | public $generatedClassName() { | vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, schema); - | batch = new ${classOf[ColumnarBatch].getName}(schema, vectors, capacity); + | batch = new ${classOf[ColumnarBatch].getName}(vectors); | | // Generates a projection to return the aggregate buffer only. | ${classOf[OnHeapColumnVector].getName}[] aggBufferVectors = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index bcd1aa0890ba3..7487564ed64da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -175,7 +175,7 @@ private[sql] object ArrowConverters { new ArrowColumnVector(vector).asInstanceOf[ColumnVector] }.toArray - val batch = new ColumnarBatch(schemaRead, columns, root.getRowCount) + val batch = new ColumnarBatch(columns) batch.setNumRows(root.getRowCount) batch.rowIterator().asScala } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 3565ee3af1b9f..28b3875505cd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -78,11 +78,10 @@ case class InMemoryTableScanExec( } else { OffHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) } - val columnarBatch = new ColumnarBatch( - columnarBatchSchema, columnVectors.asInstanceOf[Array[ColumnVector]], rowCount) + val columnarBatch = new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]]) columnarBatch.setNumRows(rowCount) - for (i <- 0 until attributes.length) { + for (i <- attributes.indices) { ColumnAccessor.decompress( cachedColumnarBatch.buffers(columnIndices(i)), columnarBatch.column(i).asInstanceOf[WritableColumnVector], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index c06bc7b66ff39..47b146f076b62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -74,8 +74,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi schema: StructType, context: TaskContext): Iterator[InternalRow] = { - val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex - .map { case (attr, i) => attr.withName(s"_$i") }) + val outputTypes = output.drop(child.output.length).map(_.dataType) // DO NOT use iter.grouped(). See BatchIterator. val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) @@ -90,8 +89,9 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi private var currentIter = if (columnarBatchIter.hasNext) { val batch = columnarBatchIter.next() - assert(schemaOut.equals(batch.schema), - s"Invalid schema from pandas_udf: expected $schemaOut, got ${batch.schema}") + val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType()) + assert(outputTypes == actualDataTypes, "Invalid schema from pandas_udf: " + + s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}") batch.rowIterator.asScala } else { Iterator.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index dc5ba96e69aec..5fcdcddca7d51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -138,7 +138,7 @@ class ArrowPythonRunner( if (reader != null && batchLoaded) { batchLoaded = reader.loadNextBatch() if (batchLoaded) { - val batch = new ColumnarBatch(schema, vectors, root.getRowCount) + val batch = new ColumnarBatch(vectors) batch.setNumRows(root.getRowCount) batch } else { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java index 98d6a53b54d28..a5d77a90ece42 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -69,8 +69,7 @@ public DataReader createDataReader() { ColumnVector[] vectors = new ColumnVector[2]; vectors[0] = i; vectors[1] = j; - this.batch = - new ColumnarBatch(new StructType().add("i", "int").add("j", "int"), vectors, BATCH_SIZE); + this.batch = new ColumnarBatch(vectors); return this; } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 675f06b31b970..cd90681ecabc6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -875,14 +875,13 @@ class ColumnarBatchSuite extends SparkFunSuite { .add("intCol2", IntegerType) .add("string", BinaryType) - val capacity = ColumnarBatch.DEFAULT_BATCH_SIZE + val capacity = 4 * 1024 val columns = schema.fields.map { field => allocate(capacity, field.dataType, memMode) } - val batch = new ColumnarBatch(schema, columns.toArray, ColumnarBatch.DEFAULT_BATCH_SIZE) + val batch = new ColumnarBatch(columns.toArray) assert(batch.numCols() == 4) assert(batch.numRows() == 0) - assert(batch.capacity() > 0) assert(batch.rowIterator().hasNext == false) // Add a row [1, 1.1, NULL] @@ -1153,7 +1152,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val columnVectors = Seq(new ArrowColumnVector(vector1), new ArrowColumnVector(vector2)) val schema = StructType(Seq(StructField("int1", IntegerType), StructField("int2", IntegerType))) - val batch = new ColumnarBatch(schema, columnVectors.toArray[ColumnVector], 11) + val batch = new ColumnarBatch(columnVectors.toArray) batch.setNumRows(11) assert(batch.numCols() == 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index a89f7c55bf4f7..0ca29524c6d05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -333,8 +333,7 @@ class BatchReadTask(start: Int, end: Int) private final val BATCH_SIZE = 20 private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) - private lazy val batch = new ColumnarBatch( - new StructType().add("i", "int").add("j", "int"), Array(i, j), BATCH_SIZE) + private lazy val batch = new ColumnarBatch(Array(i, j)) private var current = start From 73d3b230f3816a854a181c0912d87b180e347271 Mon Sep 17 00:00:00 2001 From: foxish Date: Fri, 19 Jan 2018 10:23:13 -0800 Subject: [PATCH 150/774] [SPARK-23104][K8S][DOCS] Changes to Kubernetes scheduler documentation ## What changes were proposed in this pull request? Docs changes: - Adding a warning that the backend is experimental. - Removing a defunct internal-only option from documentation - Clarifying that node selectors can be used right away, and other minor cosmetic changes ## How was this patch tested? Docs only change Author: foxish Closes #20314 from foxish/ambiguous-docs. --- docs/cluster-overview.md | 4 ++-- docs/running-on-kubernetes.md | 43 ++++++++++++++++------------------- 2 files changed, 22 insertions(+), 25 deletions(-) diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index 658e67f99dd71..7277e2fb2731d 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -52,8 +52,8 @@ The system currently supports three cluster managers: * [Apache Mesos](running-on-mesos.html) -- a general cluster manager that can also run Hadoop MapReduce and service applications. * [Hadoop YARN](running-on-yarn.html) -- the resource manager in Hadoop 2. -* [Kubernetes](running-on-kubernetes.html) -- [Kubernetes](https://kubernetes.io/docs/concepts/overview/what-is-kubernetes/) -is an open-source platform that provides container-centric infrastructure. +* [Kubernetes](running-on-kubernetes.html) -- an open-source system for automating deployment, scaling, + and management of containerized applications. A third-party project (not supported by the Spark project) exists to add support for [Nomad](https://github.com/hashicorp/nomad-spark) as a cluster manager. diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index d6b1735ce5550..3c7586e8544ba 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -8,6 +8,10 @@ title: Running Spark on Kubernetes Spark can run on clusters managed by [Kubernetes](https://kubernetes.io). This feature makes use of native Kubernetes scheduler that has been added to Spark. +**The Kubernetes scheduler is currently experimental. +In future versions, there may be behavioral changes around configuration, +container images and entrypoints.** + # Prerequisites * A runnable distribution of Spark 2.3 or above. @@ -41,11 +45,10 @@ logs and remains in "completed" state in the Kubernetes API until it's eventuall Note that in the completed state, the driver pod does *not* use any computational or memory resources. -The driver and executor pod scheduling is handled by Kubernetes. It will be possible to affect Kubernetes scheduling -decisions for driver and executor pods using advanced primitives like -[node selectors](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#nodeselector) -and [node/pod affinities](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#affinity-and-anti-affinity) -in a future release. +The driver and executor pod scheduling is handled by Kubernetes. It is possible to schedule the +driver and executor pods on a subset of available nodes through a [node selector](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#nodeselector) +using the configuration property for it. It will be possible to use more advanced +scheduling hints like [node/pod affinities](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#affinity-and-anti-affinity) in a future release. # Submitting Applications to Kubernetes @@ -62,8 +65,10 @@ use with the Kubernetes backend. Example usage is: - ./bin/docker-image-tool.sh -r -t my-tag build - ./bin/docker-image-tool.sh -r -t my-tag push +```bash +$ ./bin/docker-image-tool.sh -r -t my-tag build +$ ./bin/docker-image-tool.sh -r -t my-tag push +``` ## Cluster Mode @@ -94,7 +99,7 @@ must consist of lower case alphanumeric characters, `-`, and `.` and must start If you have a Kubernetes cluster setup, one way to discover the apiserver URL is by executing `kubectl cluster-info`. ```bash -kubectl cluster-info +$ kubectl cluster-info Kubernetes master is running at http://127.0.0.1:6443 ``` @@ -105,7 +110,7 @@ authenticating proxy, `kubectl proxy` to communicate to the Kubernetes API. The local proxy can be started by: ```bash -kubectl proxy +$ kubectl proxy ``` If the local proxy is running at localhost:8001, `--master k8s://http://127.0.0.1:8001` can be used as the argument to @@ -173,7 +178,7 @@ Logs can be accessed using the Kubernetes API and the `kubectl` CLI. When a Spar to stream logs from the application using: ```bash -kubectl -n= logs -f +$ kubectl -n= logs -f ``` The same logs can also be accessed through the @@ -186,7 +191,7 @@ The UI associated with any application can be accessed locally using [`kubectl port-forward`](https://kubernetes.io/docs/tasks/access-application-cluster/port-forward-access-application-cluster/#forward-a-local-port-to-a-port-on-the-pod). ```bash -kubectl port-forward 4040:4040 +$ kubectl port-forward 4040:4040 ``` Then, the Spark driver UI can be accessed on `http://localhost:4040`. @@ -200,13 +205,13 @@ are errors during the running of the application, often, the best way to investi To get some basic information about the scheduling decisions made around the driver pod, you can run: ```bash -kubectl describe pod +$ kubectl describe pod ``` If the pod has encountered a runtime error, the status can be probed further using: ```bash -kubectl logs +$ kubectl logs ``` Status and logs of failed executor pods can be checked in similar ways. Finally, deleting the driver pod will clean up the entire spark @@ -254,7 +259,7 @@ To create a custom service account, a user can use the `kubectl create serviceac following command creates a service account named `spark`: ```bash -kubectl create serviceaccount spark +$ kubectl create serviceaccount spark ``` To grant a service account a `Role` or `ClusterRole`, a `RoleBinding` or `ClusterRoleBinding` is needed. To create @@ -263,7 +268,7 @@ for `ClusterRoleBinding`) command. For example, the following command creates an namespace and grants it to the `spark` service account created above: ```bash -kubectl create clusterrolebinding spark-role --clusterrole=edit --serviceaccount=default:spark --namespace=default +$ kubectl create clusterrolebinding spark-role --clusterrole=edit --serviceaccount=default:spark --namespace=default ``` Note that a `Role` can only be used to grant access to resources (like pods) within a single namespace, whereas a @@ -543,14 +548,6 @@ specific to Spark on Kubernetes. to avoid name conflicts. - - spark.kubernetes.executor.podNamePrefix - (none) - - Prefix for naming the executor pods. - If not set, the executor pod name is set to driver pod name suffixed by an integer. - - spark.kubernetes.executor.lostCheck.maxAttempts 10 From 07296a61c29eb074553956b6c0f92810ecf7bab2 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 19 Jan 2018 10:25:18 -0800 Subject: [PATCH 151/774] [INFRA] Close stale PR. Closes #20185. From fed2139f053fac4a9a6952ff0ab1cc2a5f657bd0 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 19 Jan 2018 13:26:37 -0600 Subject: [PATCH 152/774] [SPARK-20664][CORE] Delete stale application data from SHS. Detect the deletion of event log files from storage, and remove data about the related application attempt in the SHS. Also contains code to fix SPARK-21571 based on code by ericvandenbergfb. Author: Marcelo Vanzin Closes #20138 from vanzin/SPARK-20664. --- .../deploy/history/FsHistoryProvider.scala | 297 +++++++++++------- .../history/FsHistoryProviderSuite.scala | 117 ++++++- .../deploy/history/HistoryServerSuite.scala | 4 +- 3 files changed, 306 insertions(+), 112 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 94c80ebd55e74..f9d0b5ee4e23e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.history import java.io.{File, FileNotFoundException, IOException} import java.util.{Date, ServiceLoader, UUID} -import java.util.concurrent.{Executors, ExecutorService, Future, TimeUnit} +import java.util.concurrent.{ExecutorService, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ @@ -29,7 +29,7 @@ import scala.xml.Node import com.fasterxml.jackson.annotation.JsonIgnore import com.google.common.io.ByteStreams -import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} +import com.google.common.util.concurrent.MoreExecutors import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.hdfs.DistributedFileSystem @@ -116,8 +116,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // Used by check event thread and clean log thread. // Scheduled thread pool size must be one, otherwise it will have concurrent issues about fs // and applications between check task and clean task. - private val pool = Executors.newScheduledThreadPool(1, new ThreadFactoryBuilder() - .setNameFormat("spark-history-task-%d").setDaemon(true).build()) + private val pool = ThreadUtils.newDaemonSingleThreadScheduledExecutor("spark-history-task-%d") // The modification time of the newest log detected during the last scan. Currently only // used for logging msgs (logs are re-scanned based on file size, rather than modtime) @@ -174,7 +173,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * Fixed size thread pool to fetch and parse log files. */ private val replayExecutor: ExecutorService = { - if (!conf.contains("spark.testing")) { + if (Utils.isTesting) { ThreadUtils.newDaemonFixedThreadPool(NUM_PROCESSING_THREADS, "log-replay-executor") } else { MoreExecutors.sameThreadExecutor() @@ -275,7 +274,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) try { Some(load(appId).toApplicationInfo()) } catch { - case e: NoSuchElementException => + case _: NoSuchElementException => None } } @@ -405,49 +404,70 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) try { val newLastScanTime = getNewLastScanTime() logDebug(s"Scanning $logDir with lastScanTime==$lastScanTime") - // scan for modified applications, replay and merge them - val logInfos = Option(fs.listStatus(new Path(logDir))).map(_.toSeq).getOrElse(Nil) + + val updated = Option(fs.listStatus(new Path(logDir))).map(_.toSeq).getOrElse(Nil) .filter { entry => !entry.isDirectory() && // FsHistoryProvider generates a hidden file which can't be read. Accidentally // reading a garbage file is safe, but we would log an error which can be scary to // the end-user. !entry.getPath().getName().startsWith(".") && - SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) && - recordedFileSize(entry.getPath()) < entry.getLen() + SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) + } + .filter { entry => + try { + val info = listing.read(classOf[LogInfo], entry.getPath().toString()) + if (info.fileSize < entry.getLen()) { + // Log size has changed, it should be parsed. + true + } else { + // If the SHS view has a valid application, update the time the file was last seen so + // that the entry is not deleted from the SHS listing. + if (info.appId.isDefined) { + listing.write(info.copy(lastProcessed = newLastScanTime)) + } + false + } + } catch { + case _: NoSuchElementException => + // If the file is currently not being tracked by the SHS, add an entry for it and try + // to parse it. This will allow the cleaner code to detect the file as stale later on + // if it was not possible to parse it. + listing.write(LogInfo(entry.getPath().toString(), newLastScanTime, None, None, + entry.getLen())) + entry.getLen() > 0 + } } .sortWith { case (entry1, entry2) => entry1.getModificationTime() > entry2.getModificationTime() } - if (logInfos.nonEmpty) { - logDebug(s"New/updated attempts found: ${logInfos.size} ${logInfos.map(_.getPath)}") + if (updated.nonEmpty) { + logDebug(s"New/updated attempts found: ${updated.size} ${updated.map(_.getPath)}") } - var tasks = mutable.ListBuffer[Future[_]]() - - try { - for (file <- logInfos) { - tasks += replayExecutor.submit(new Runnable { - override def run(): Unit = mergeApplicationListing(file) + val tasks = updated.map { entry => + try { + replayExecutor.submit(new Runnable { + override def run(): Unit = mergeApplicationListing(entry, newLastScanTime) }) + } catch { + // let the iteration over the updated entries break, since an exception on + // replayExecutor.submit (..) indicates the ExecutorService is unable + // to take any more submissions at this time + case e: Exception => + logError(s"Exception while submitting event log for replay", e) + null } - } catch { - // let the iteration over logInfos break, since an exception on - // replayExecutor.submit (..) indicates the ExecutorService is unable - // to take any more submissions at this time - - case e: Exception => - logError(s"Exception while submitting event log for replay", e) - } + }.filter(_ != null) pendingReplayTasksCount.addAndGet(tasks.size) + // Wait for all tasks to finish. This makes sure that checkForLogs + // is not scheduled again while some tasks are already running in + // the replayExecutor. tasks.foreach { task => try { - // Wait for all tasks to finish. This makes sure that checkForLogs - // is not scheduled again while some tasks are already running in - // the replayExecutor. task.get() } catch { case e: InterruptedException => @@ -459,13 +479,70 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + // Delete all information about applications whose log files disappeared from storage. + // This is done by identifying the event logs which were not touched by the current + // directory scan. + // + // Only entries with valid applications are cleaned up here. Cleaning up invalid log + // files is done by the periodic cleaner task. + val stale = listing.view(classOf[LogInfo]) + .index("lastProcessed") + .last(newLastScanTime - 1) + .asScala + .toList + stale.foreach { log => + log.appId.foreach { appId => + cleanAppData(appId, log.attemptId, log.logPath) + listing.delete(classOf[LogInfo], log.logPath) + } + } + lastScanTime.set(newLastScanTime) } catch { case e: Exception => logError("Exception in checking for event log updates", e) } } - private def getNewLastScanTime(): Long = { + private def cleanAppData(appId: String, attemptId: Option[String], logPath: String): Unit = { + try { + val app = load(appId) + val (attempt, others) = app.attempts.partition(_.info.attemptId == attemptId) + + assert(attempt.isEmpty || attempt.size == 1) + val isStale = attempt.headOption.exists { a => + if (a.logPath != new Path(logPath).getName()) { + // If the log file name does not match, then probably the old log file was from an + // in progress application. Just return that the app should be left alone. + false + } else { + val maybeUI = synchronized { + activeUIs.remove(appId -> attemptId) + } + + maybeUI.foreach { ui => + ui.invalidate() + ui.ui.store.close() + } + + diskManager.foreach(_.release(appId, attemptId, delete = true)) + true + } + } + + if (isStale) { + if (others.nonEmpty) { + val newAppInfo = new ApplicationInfoWrapper(app.info, others) + listing.write(newAppInfo) + } else { + listing.delete(classOf[ApplicationInfoWrapper], appId) + } + } + } catch { + case _: NoSuchElementException => + } + } + + private[history] def getNewLastScanTime(): Long = { val fileName = "." + UUID.randomUUID().toString val path = new Path(logDir, fileName) val fos = fs.create(path) @@ -530,7 +607,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Replay the given log file, saving the application in the listing db. */ - protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { + protected def mergeApplicationListing(fileStatus: FileStatus, scanTime: Long): Unit = { val eventsFilter: ReplayEventsFilter = { eventString => eventString.startsWith(APPL_START_EVENT_PREFIX) || eventString.startsWith(APPL_END_EVENT_PREFIX) || @@ -544,73 +621,78 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) bus.addListener(listener) replay(fileStatus, bus, eventsFilter = eventsFilter) - listener.applicationInfo.foreach { app => - // Invalidate the existing UI for the reloaded app attempt, if any. See LoadedAppUI for a - // discussion on the UI lifecycle. - synchronized { - activeUIs.get((app.info.id, app.attempts.head.info.attemptId)).foreach { ui => - ui.invalidate() - ui.ui.store.close() + val (appId, attemptId) = listener.applicationInfo match { + case Some(app) => + // Invalidate the existing UI for the reloaded app attempt, if any. See LoadedAppUI for a + // discussion on the UI lifecycle. + synchronized { + activeUIs.get((app.info.id, app.attempts.head.info.attemptId)).foreach { ui => + ui.invalidate() + ui.ui.store.close() + } } - } - addListing(app) + addListing(app) + (Some(app.info.id), app.attempts.head.info.attemptId) + + case _ => + // If the app hasn't written down its app ID to the logs, still record the entry in the + // listing db, with an empty ID. This will make the log eligible for deletion if the app + // does not make progress after the configured max log age. + (None, None) } - listing.write(new LogInfo(logPath.toString(), fileStatus.getLen())) + listing.write(LogInfo(logPath.toString(), scanTime, appId, attemptId, fileStatus.getLen())) } /** * Delete event logs from the log directory according to the clean policy defined by the user. */ - private[history] def cleanLogs(): Unit = { - var iterator: Option[KVStoreIterator[ApplicationInfoWrapper]] = None - try { - val maxTime = clock.getTimeMillis() - conf.get(MAX_LOG_AGE_S) * 1000 - - // Iterate descending over all applications whose oldest attempt happened before maxTime. - iterator = Some(listing.view(classOf[ApplicationInfoWrapper]) - .index("oldestAttempt") - .reverse() - .first(maxTime) - .closeableIterator()) - - iterator.get.asScala.foreach { app => - // Applications may have multiple attempts, some of which may not need to be deleted yet. - val (remaining, toDelete) = app.attempts.partition { attempt => - attempt.info.lastUpdated.getTime() >= maxTime - } + private[history] def cleanLogs(): Unit = Utils.tryLog { + val maxTime = clock.getTimeMillis() - conf.get(MAX_LOG_AGE_S) * 1000 - if (remaining.nonEmpty) { - val newApp = new ApplicationInfoWrapper(app.info, remaining) - listing.write(newApp) - } + val expired = listing.view(classOf[ApplicationInfoWrapper]) + .index("oldestAttempt") + .reverse() + .first(maxTime) + .asScala + .toList + expired.foreach { app => + // Applications may have multiple attempts, some of which may not need to be deleted yet. + val (remaining, toDelete) = app.attempts.partition { attempt => + attempt.info.lastUpdated.getTime() >= maxTime + } - toDelete.foreach { attempt => - val logPath = new Path(logDir, attempt.logPath) - try { - listing.delete(classOf[LogInfo], logPath.toString()) - } catch { - case _: NoSuchElementException => - logDebug(s"Log info entry for $logPath not found.") - } - try { - fs.delete(logPath, true) - } catch { - case e: AccessControlException => - logInfo(s"No permission to delete ${attempt.logPath}, ignoring.") - case t: IOException => - logError(s"IOException in cleaning ${attempt.logPath}", t) - } - } + if (remaining.nonEmpty) { + val newApp = new ApplicationInfoWrapper(app.info, remaining) + listing.write(newApp) + } - if (remaining.isEmpty) { - listing.delete(app.getClass(), app.id) - } + toDelete.foreach { attempt => + logInfo(s"Deleting expired event log for ${attempt.logPath}") + val logPath = new Path(logDir, attempt.logPath) + listing.delete(classOf[LogInfo], logPath.toString()) + cleanAppData(app.id, attempt.info.attemptId, logPath.toString()) + deleteLog(logPath) + } + + if (remaining.isEmpty) { + listing.delete(app.getClass(), app.id) + } + } + + // Delete log files that don't have a valid application and exceed the configured max age. + val stale = listing.view(classOf[LogInfo]) + .index("lastProcessed") + .reverse() + .first(maxTime) + .asScala + .toList + stale.foreach { log => + if (log.appId.isEmpty) { + logInfo(s"Deleting invalid / corrupt event log ${log.logPath}") + deleteLog(new Path(log.logPath)) + listing.delete(classOf[LogInfo], log.logPath) } - } catch { - case t: Exception => logError("Exception while cleaning logs", t) - } finally { - iterator.foreach(_.close()) } } @@ -631,12 +713,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // an error the other way -- if we report a size bigger (ie later) than the file that is // actually read, we may never refresh the app. FileStatus is guaranteed to be static // after it's created, so we get a file size that is no bigger than what is actually read. - val logInput = EventLoggingListener.openEventLog(logPath, fs) - try { - bus.replay(logInput, logPath.toString, !isCompleted, eventsFilter) + Utils.tryWithResource(EventLoggingListener.openEventLog(logPath, fs)) { in => + bus.replay(in, logPath.toString, !isCompleted, eventsFilter) logInfo(s"Finished parsing $logPath") - } finally { - logInput.close() } } @@ -703,18 +782,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) | application count=$count}""".stripMargin } - /** - * Return the last known size of the given event log, recorded the last time the file - * system scanner detected a change in the file. - */ - private def recordedFileSize(log: Path): Long = { - try { - listing.read(classOf[LogInfo], log.toString()).fileSize - } catch { - case _: NoSuchElementException => 0L - } - } - private def load(appId: String): ApplicationInfoWrapper = { listing.read(classOf[ApplicationInfoWrapper], appId) } @@ -773,11 +840,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) logInfo(s"Leasing disk manager space for app $appId / ${attempt.info.attemptId}...") val lease = dm.lease(status.getLen(), isCompressed) val newStorePath = try { - val store = KVUtils.open(lease.tmpPath, metadata) - try { + Utils.tryWithResource(KVUtils.open(lease.tmpPath, metadata)) { store => rebuildAppStore(store, status, attempt.info.lastUpdated.getTime()) - } finally { - store.close() } lease.commit(appId, attempt.info.attemptId) } catch { @@ -806,6 +870,17 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) throw new NoSuchElementException(s"Cannot find attempt $attemptId of $appId.")) } + private def deleteLog(log: Path): Unit = { + try { + fs.delete(log, true) + } catch { + case _: AccessControlException => + logInfo(s"No permission to delete $log, ignoring.") + case ioe: IOException => + logError(s"IOException in cleaning $log", ioe) + } + } + } private[history] object FsHistoryProvider { @@ -832,8 +907,16 @@ private[history] case class FsHistoryProviderMetadata( uiVersion: Long, logDir: String) +/** + * Tracking info for event logs detected in the configured log directory. Tracks both valid and + * invalid logs (e.g. unparseable logs, recorded as logs with no app ID) so that the cleaner + * can know what log files are safe to delete. + */ private[history] case class LogInfo( @KVIndexParam logPath: String, + @KVIndexParam("lastProcessed") lastProcessed: Long, + appId: Option[String], + attemptId: Option[String], fileSize: Long) private[history] class AttemptInfoWrapper( diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 84ee01c7f5aaf..787de59edf465 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hdfs.DistributedFileSystem import org.json4s.jackson.JsonMethods._ import org.mockito.Matchers.any -import org.mockito.Mockito.{mock, spy, verify} +import org.mockito.Mockito.{doReturn, mock, spy, verify} import org.scalatest.BeforeAndAfter import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ @@ -149,8 +149,10 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc class TestFsHistoryProvider extends FsHistoryProvider(createTestConf()) { var mergeApplicationListingCall = 0 - override protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { - super.mergeApplicationListing(fileStatus) + override protected def mergeApplicationListing( + fileStatus: FileStatus, + lastSeen: Long): Unit = { + super.mergeApplicationListing(fileStatus, lastSeen) mergeApplicationListingCall += 1 } } @@ -663,6 +665,115 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc freshUI.get.ui.store.job(0) } + test("clean up stale app information") { + val storeDir = Utils.createTempDir() + val conf = createTestConf().set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) + val provider = spy(new FsHistoryProvider(conf)) + val appId = "new1" + + // Write logs for two app attempts. + doReturn(1L).when(provider).getNewLastScanTime() + val attempt1 = newLogFile(appId, Some("1"), inProgress = false) + writeFile(attempt1, true, None, + SparkListenerApplicationStart(appId, Some(appId), 1L, "test", Some("1")), + SparkListenerJobStart(0, 1L, Nil, null), + SparkListenerApplicationEnd(5L) + ) + val attempt2 = newLogFile(appId, Some("2"), inProgress = false) + writeFile(attempt2, true, None, + SparkListenerApplicationStart(appId, Some(appId), 1L, "test", Some("2")), + SparkListenerJobStart(0, 1L, Nil, null), + SparkListenerApplicationEnd(5L) + ) + updateAndCheck(provider) { list => + assert(list.size === 1) + assert(list(0).id === appId) + assert(list(0).attempts.size === 2) + } + + // Load the app's UI. + val ui = provider.getAppUI(appId, Some("1")) + assert(ui.isDefined) + + // Delete the underlying log file for attempt 1 and rescan. The UI should go away, but since + // attempt 2 still exists, listing data should be there. + doReturn(2L).when(provider).getNewLastScanTime() + attempt1.delete() + updateAndCheck(provider) { list => + assert(list.size === 1) + assert(list(0).id === appId) + assert(list(0).attempts.size === 1) + } + assert(!ui.get.valid) + assert(provider.getAppUI(appId, None) === None) + + // Delete the second attempt's log file. Now everything should go away. + doReturn(3L).when(provider).getNewLastScanTime() + attempt2.delete() + updateAndCheck(provider) { list => + assert(list.isEmpty) + } + } + + test("SPARK-21571: clean up removes invalid history files") { + // TODO: "maxTime" becoming negative in cleanLogs() causes this test to fail, so avoid that + // until we figure out what's causing the problem. + val clock = new ManualClock(TimeUnit.DAYS.toMillis(120)) + val conf = createTestConf().set(MAX_LOG_AGE_S.key, s"2d") + val provider = new FsHistoryProvider(conf, clock) { + override def getNewLastScanTime(): Long = clock.getTimeMillis() + } + + // Create 0-byte size inprogress and complete files + var logCount = 0 + var validLogCount = 0 + + val emptyInProgress = newLogFile("emptyInprogressLogFile", None, inProgress = true) + emptyInProgress.createNewFile() + emptyInProgress.setLastModified(clock.getTimeMillis()) + logCount += 1 + + val slowApp = newLogFile("slowApp", None, inProgress = true) + slowApp.createNewFile() + slowApp.setLastModified(clock.getTimeMillis()) + logCount += 1 + + val emptyFinished = newLogFile("emptyFinishedLogFile", None, inProgress = false) + emptyFinished.createNewFile() + emptyFinished.setLastModified(clock.getTimeMillis()) + logCount += 1 + + // Create an incomplete log file, has an end record but no start record. + val corrupt = newLogFile("nonEmptyCorruptLogFile", None, inProgress = false) + writeFile(corrupt, true, None, SparkListenerApplicationEnd(0)) + corrupt.setLastModified(clock.getTimeMillis()) + logCount += 1 + + provider.checkForLogs() + provider.cleanLogs() + assert(new File(testDir.toURI).listFiles().size === logCount) + + // Move the clock forward 1 day and scan the files again. They should still be there. + clock.advance(TimeUnit.DAYS.toMillis(1)) + provider.checkForLogs() + provider.cleanLogs() + assert(new File(testDir.toURI).listFiles().size === logCount) + + // Update the slow app to contain valid info. Code should detect the change and not clean + // it up. + writeFile(slowApp, true, None, + SparkListenerApplicationStart(slowApp.getName(), Some(slowApp.getName()), 1L, "test", None)) + slowApp.setLastModified(clock.getTimeMillis()) + validLogCount += 1 + + // Move the clock forward another 2 days and scan the files again. This time the cleaner should + // pick up the invalid files and get rid of them. + clock.advance(TimeUnit.DAYS.toMillis(2)) + provider.checkForLogs() + provider.cleanLogs() + assert(new File(testDir.toURI).listFiles().size === validLogCount) + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 87778dda0e2c8..7aa60f2b60796 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -48,7 +48,7 @@ import org.apache.spark.deploy.history.config._ import org.apache.spark.status.api.v1.ApplicationInfo import org.apache.spark.status.api.v1.JobData import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ResetSystemProperties, Utils} +import org.apache.spark.util.{ResetSystemProperties, ShutdownHookManager, Utils} /** * A collection of tests against the historyserver, including comparing responses from the json @@ -564,7 +564,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers assert(jobcount === getNumJobs("/jobs")) // no need to retain the test dir now the tests complete - logDir.deleteOnExit() + ShutdownHookManager.registerShutdownDeleteDir(logDir) } test("ui and api authorization checks") { From aa3a1276f9e23ffbb093d00743e63cd4369f9f57 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 19 Jan 2018 13:32:20 -0600 Subject: [PATCH 153/774] [SPARK-23103][CORE] Ensure correct sort order for negative values in LevelDB. The code was sorting "0" as "less than" negative values, which is a little wrong. Fix is simple, most of the changes are the added test and related cleanup. Author: Marcelo Vanzin Closes #20284 from vanzin/SPARK-23103. --- .../spark/util/kvstore/LevelDBTypeInfo.java | 2 +- .../spark/util/kvstore/DBIteratorSuite.java | 7 +- .../spark/util/kvstore/LevelDBSuite.java | 77 ++++++++++--------- .../spark/status/AppStatusListenerSuite.scala | 8 +- 4 files changed, 52 insertions(+), 42 deletions(-) diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java index 232ee41dd0b1f..f4d359234cb9e 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java @@ -493,7 +493,7 @@ byte[] toKey(Object value, byte prefix) { byte[] key = new byte[bytes * 2 + 2]; long longValue = ((Number) value).longValue(); key[0] = prefix; - key[1] = longValue > 0 ? POSITIVE_MARKER : NEGATIVE_MARKER; + key[1] = longValue >= 0 ? POSITIVE_MARKER : NEGATIVE_MARKER; for (int i = 0; i < key.length - 2; i++) { int masked = (int) ((longValue >>> (4 * i)) & 0xF); diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java index 9a81f86812cde..1e062437d1803 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java @@ -73,7 +73,9 @@ default BaseComparator reverse() { private static final BaseComparator NATURAL_ORDER = (t1, t2) -> t1.key.compareTo(t2.key); private static final BaseComparator REF_INDEX_ORDER = (t1, t2) -> t1.id.compareTo(t2.id); private static final BaseComparator COPY_INDEX_ORDER = (t1, t2) -> t1.name.compareTo(t2.name); - private static final BaseComparator NUMERIC_INDEX_ORDER = (t1, t2) -> t1.num - t2.num; + private static final BaseComparator NUMERIC_INDEX_ORDER = (t1, t2) -> { + return Integer.valueOf(t1.num).compareTo(t2.num); + }; private static final BaseComparator CHILD_INDEX_ORDER = (t1, t2) -> t1.child.compareTo(t2.child); /** @@ -112,7 +114,8 @@ public void setup() throws Exception { t.key = "key" + i; t.id = "id" + i; t.name = "name" + RND.nextInt(MAX_ENTRIES); - t.num = RND.nextInt(MAX_ENTRIES); + // Force one item to have an integer value of zero to test the fix for SPARK-23103. + t.num = (i != 0) ? (int) RND.nextLong() : 0; t.child = "child" + (i % MIN_ENTRIES); allEntries.add(t); } diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java index 2b07d249d2022..b8123ac81d29a 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java @@ -21,6 +21,8 @@ import java.util.Arrays; import java.util.List; import java.util.NoSuchElementException; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; import org.apache.commons.io.FileUtils; import org.iq80.leveldb.DBIterator; @@ -74,11 +76,7 @@ public void testReopenAndVersionCheckDb() throws Exception { @Test public void testObjectWriteReadDelete() throws Exception { - CustomType1 t = new CustomType1(); - t.key = "key"; - t.id = "id"; - t.name = "name"; - t.child = "child"; + CustomType1 t = createCustomType1(1); try { db.read(CustomType1.class, t.key); @@ -106,17 +104,9 @@ public void testObjectWriteReadDelete() throws Exception { @Test public void testMultipleObjectWriteReadDelete() throws Exception { - CustomType1 t1 = new CustomType1(); - t1.key = "key1"; - t1.id = "id"; - t1.name = "name1"; - t1.child = "child1"; - - CustomType1 t2 = new CustomType1(); - t2.key = "key2"; - t2.id = "id"; - t2.name = "name2"; - t2.child = "child2"; + CustomType1 t1 = createCustomType1(1); + CustomType1 t2 = createCustomType1(2); + t2.id = t1.id; db.write(t1); db.write(t2); @@ -142,11 +132,7 @@ public void testMultipleObjectWriteReadDelete() throws Exception { @Test public void testMultipleTypesWriteReadDelete() throws Exception { - CustomType1 t1 = new CustomType1(); - t1.key = "1"; - t1.id = "id"; - t1.name = "name1"; - t1.child = "child1"; + CustomType1 t1 = createCustomType1(1); IntKeyType t2 = new IntKeyType(); t2.key = 2; @@ -188,10 +174,7 @@ public void testMultipleTypesWriteReadDelete() throws Exception { public void testMetadata() throws Exception { assertNull(db.getMetadata(CustomType1.class)); - CustomType1 t = new CustomType1(); - t.id = "id"; - t.name = "name"; - t.child = "child"; + CustomType1 t = createCustomType1(1); db.setMetadata(t); assertEquals(t, db.getMetadata(CustomType1.class)); @@ -202,11 +185,7 @@ public void testMetadata() throws Exception { @Test public void testUpdate() throws Exception { - CustomType1 t = new CustomType1(); - t.key = "key"; - t.id = "id"; - t.name = "name"; - t.child = "child"; + CustomType1 t = createCustomType1(1); db.write(t); @@ -222,13 +201,7 @@ public void testUpdate() throws Exception { @Test public void testSkip() throws Exception { for (int i = 0; i < 10; i++) { - CustomType1 t = new CustomType1(); - t.key = "key" + i; - t.id = "id" + i; - t.name = "name" + i; - t.child = "child" + i; - - db.write(t); + db.write(createCustomType1(i)); } KVStoreIterator it = db.view(CustomType1.class).closeableIterator(); @@ -240,6 +213,36 @@ public void testSkip() throws Exception { assertFalse(it.hasNext()); } + @Test + public void testNegativeIndexValues() throws Exception { + List expected = Arrays.asList(-100, -50, 0, 50, 100); + + expected.stream().forEach(i -> { + try { + db.write(createCustomType1(i)); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + + List results = StreamSupport + .stream(db.view(CustomType1.class).index("int").spliterator(), false) + .map(e -> e.num) + .collect(Collectors.toList()); + + assertEquals(expected, results); + } + + private CustomType1 createCustomType1(int i) { + CustomType1 t = new CustomType1(); + t.key = "key" + i; + t.id = "id" + i; + t.name = "name" + i; + t.num = i; + t.child = "child" + i; + return t; + } + private int countKeys(Class type) throws Exception { byte[] prefix = db.getTypeInfo(type).keyPrefix(); int count = 0; diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index ca66b6b9db890..e7981bec6d64b 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -894,15 +894,19 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val dropped = stages.drop(1).head // Cache some quantiles by calling AppStatusStore.taskSummary(). For quantiles to be - // calculcated, we need at least one finished task. + // calculated, we need at least one finished task. The code in AppStatusStore uses + // `executorRunTime` to detect valid tasks, so that metric needs to be updated in the + // task end event. time += 1 val task = createTasks(1, Array("1")).head listener.onTaskStart(SparkListenerTaskStart(dropped.stageId, dropped.attemptId, task)) time += 1 task.markFinished(TaskState.FINISHED, time) + val metrics = TaskMetrics.empty + metrics.setExecutorRunTime(42L) listener.onTaskEnd(SparkListenerTaskEnd(dropped.stageId, dropped.attemptId, - "taskType", Success, task, null)) + "taskType", Success, task, metrics)) new AppStatusStore(store) .taskSummary(dropped.stageId, dropped.attemptId, Array(0.25d, 0.50d, 0.75d)) From f6da41b0150725fe96ccb2ee3b48840b207f47eb Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 19 Jan 2018 13:14:24 -0800 Subject: [PATCH 154/774] [SPARK-23135][UI] Fix rendering of accumulators in the stage page. This follows the behavior of 2.2: only named accumulators with a value are rendered. Screenshot: ![accs](https://user-images.githubusercontent.com/1694083/35065700-df409114-fb82-11e7-87c1-550c3f674371.png) Author: Marcelo Vanzin Closes #20299 from vanzin/SPARK-23135. --- .../org/apache/spark/ui/jobs/StagePage.scala | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 25bee33028393..0eb3190205c3e 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -260,7 +260,11 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") def accumulableRow(acc: AccumulableInfo): Seq[Node] = { - {acc.name}{acc.value} + if (acc.name != null && acc.value != null) { + {acc.name}{acc.value} + } else { + Nil + } } val accumulableTable = UIUtils.listingTable( accumulableHeaders, @@ -864,7 +868,7 @@ private[ui] class TaskPagedTable( {formatBytes(task.taskMetrics.map(_.peakExecutionMemory))} {if (hasAccumulators(stage)) { - accumulatorsInfo(task) + {accumulatorsInfo(task)} }} {if (hasInput(stage)) { metricInfo(task) { m => @@ -920,8 +924,12 @@ private[ui] class TaskPagedTable( } private def accumulatorsInfo(task: TaskData): Seq[Node] = { - task.accumulatorUpdates.map { acc => - Unparsed(StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update}")) + task.accumulatorUpdates.flatMap { acc => + if (acc.name != null && acc.update.isDefined) { + Unparsed(StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}")) ++
+ } else { + Nil + } } } @@ -985,7 +993,9 @@ private object ApiHelper { "Shuffle Spill (Disk)" -> TaskIndexNames.DISK_SPILL, "Errors" -> TaskIndexNames.ERROR) - def hasAccumulators(stageData: StageData): Boolean = stageData.accumulatorUpdates.size > 0 + def hasAccumulators(stageData: StageData): Boolean = { + stageData.accumulatorUpdates.exists { acc => acc.name != null && acc.value != null } + } def hasInput(stageData: StageData): Boolean = stageData.inputBytes > 0 From 793841c6b8b98b918dcf241e29f60ef125914db9 Mon Sep 17 00:00:00 2001 From: Kent Yao <11215016@zju.edu.cn> Date: Fri, 19 Jan 2018 15:49:29 -0800 Subject: [PATCH 155/774] [SPARK-21771][SQL] remove useless hive client in SparkSQLEnv ## What changes were proposed in this pull request? Once a meta hive client is created, it generates its SessionState which creates a lot of session related directories, some deleteOnExit, some does not. if a hive client is useless we may not create it at the very start. ## How was this patch tested? N/A cc hvanhovell cloud-fan Author: Kent Yao <11215016@zju.edu.cn> Closes #18983 from yaooqinn/patch-1. --- .../org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 6b19f971b73bb..cbd75ad12d430 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -50,8 +50,7 @@ private[hive] object SparkSQLEnv extends Logging { sqlContext = sparkSession.sqlContext val metadataHive = sparkSession - .sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] - .client.newSession() + .sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) From 396cdfbea45232bacbc03bfaf8be4ea85d47d3fd Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 19 Jan 2018 22:46:34 -0800 Subject: [PATCH 156/774] [SPARK-23091][ML] Incorrect unit test for approxQuantile ## What changes were proposed in this pull request? Narrow bound on approx quantile test to epsilon from 2*epsilon to match paper ## How was this patch tested? Existing tests. Author: Sean Owen Closes #20324 from srowen/SPARK-23091. --- .../apache/spark/sql/DataFrameStatSuite.scala | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 5169d2b5fc6b2..8eae35325faea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -154,24 +154,24 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { val Array(d1, d2) = df.stat.approxQuantile("doubles", Array(q1, q2), epsilon) val Array(s1, s2) = df.stat.approxQuantile("singles", Array(q1, q2), epsilon) - val error_single = 2 * 1000 * epsilon - val error_double = 2 * 2000 * epsilon + val errorSingle = 1000 * epsilon + val errorDouble = 2.0 * errorSingle - assert(math.abs(single1 - q1 * n) < error_single) - assert(math.abs(double2 - 2 * q2 * n) < error_double) - assert(math.abs(s1 - q1 * n) < error_single) - assert(math.abs(s2 - q2 * n) < error_single) - assert(math.abs(d1 - 2 * q1 * n) < error_double) - assert(math.abs(d2 - 2 * q2 * n) < error_double) + assert(math.abs(single1 - q1 * n) <= errorSingle) + assert(math.abs(double2 - 2 * q2 * n) <= errorDouble) + assert(math.abs(s1 - q1 * n) <= errorSingle) + assert(math.abs(s2 - q2 * n) <= errorSingle) + assert(math.abs(d1 - 2 * q1 * n) <= errorDouble) + assert(math.abs(d2 - 2 * q2 * n) <= errorDouble) // Multiple columns val Array(Array(ms1, ms2), Array(md1, md2)) = df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), epsilon) - assert(math.abs(ms1 - q1 * n) < error_single) - assert(math.abs(ms2 - q2 * n) < error_single) - assert(math.abs(md1 - 2 * q1 * n) < error_double) - assert(math.abs(md2 - 2 * q2 * n) < error_double) + assert(math.abs(ms1 - q1 * n) <= errorSingle) + assert(math.abs(ms2 - q2 * n) <= errorSingle) + assert(math.abs(md1 - 2 * q1 * n) <= errorDouble) + assert(math.abs(md2 - 2 * q2 * n) <= errorDouble) } // quantile should be in the range [0.0, 1.0] From 84a076e0e9a38a26edf7b702c24fdbbcf1e697b9 Mon Sep 17 00:00:00 2001 From: Shashwat Anand Date: Sat, 20 Jan 2018 14:34:37 -0800 Subject: [PATCH 157/774] [SPARK-23165][DOC] Spelling mistake fix in quick-start doc. ## What changes were proposed in this pull request? Fix spelling in quick-start doc. ## How was this patch tested? Doc only. Author: Shashwat Anand Closes #20336 from ashashwat/SPARK-23165. --- docs/cloud-integration.md | 4 ++-- docs/configuration.md | 14 +++++++------- docs/graphx-programming-guide.md | 4 ++-- docs/monitoring.md | 8 ++++---- docs/quick-start.md | 6 +++--- docs/running-on-mesos.md | 2 +- docs/running-on-yarn.md | 2 +- docs/security.md | 2 +- docs/sql-programming-guide.md | 8 ++++---- docs/storage-openstack-swift.md | 2 +- docs/streaming-programming-guide.md | 4 ++-- docs/structured-streaming-kafka-integration.md | 4 ++-- docs/structured-streaming-programming-guide.md | 6 +++--- docs/submitting-applications.md | 8 ++++---- 14 files changed, 37 insertions(+), 37 deletions(-) diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md index 751a192da4ffd..c150d9efc06ff 100644 --- a/docs/cloud-integration.md +++ b/docs/cloud-integration.md @@ -180,10 +180,10 @@ under the path, not the number of *new* files, so it can become a slow operation The size of the window needs to be set to handle this. 1. Files only appear in an object store once they are completely written; there -is no need for a worklow of write-then-rename to ensure that files aren't picked up +is no need for a workflow of write-then-rename to ensure that files aren't picked up while they are still being written. Applications can write straight to the monitored directory. -1. Streams should only be checkpointed to an store implementing a fast and +1. Streams should only be checkpointed to a store implementing a fast and atomic `rename()` operation Otherwise the checkpointing may be slow and potentially unreliable. ## Further Reading diff --git a/docs/configuration.md b/docs/configuration.md index eecb39dcafc9e..e7f2419cc2fa4 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -79,7 +79,7 @@ Then, you can supply configuration values at runtime: {% endhighlight %} The Spark shell and [`spark-submit`](submitting-applications.html) -tool support two ways to load configurations dynamically. The first are command line options, +tool support two ways to load configurations dynamically. The first is command line options, such as `--master`, as shown above. `spark-submit` can accept any Spark property using the `--conf` flag, but uses special flags for properties that play a part in launching the Spark application. Running `./bin/spark-submit --help` will show the entire list of these options. @@ -413,7 +413,7 @@ Apart from these, the following properties are also available, and may be useful false Enable profiling in Python worker, the profile result will show up by sc.show_profiles(), - or it will be displayed before the driver exiting. It also can be dumped into disk by + or it will be displayed before the driver exits. It also can be dumped into disk by sc.dump_profiles(path). If some of the profile results had been displayed manually, they will not be displayed automatically before driver exiting. @@ -446,7 +446,7 @@ Apart from these, the following properties are also available, and may be useful true Reuse Python worker or not. If yes, it will use a fixed number of Python workers, - does not need to fork() a Python process for every tasks. It will be very useful + does not need to fork() a Python process for every task. It will be very useful if there is large broadcast, then the broadcast will not be needed to transferred from JVM to Python worker for every task. @@ -1294,7 +1294,7 @@ Apart from these, the following properties are also available, and may be useful spark.files.openCostInBytes 4194304 (4 MB) - The estimated cost to open a file, measured by the number of bytes could be scanned in the same + The estimated cost to open a file, measured by the number of bytes could be scanned at the same time. This is used when putting multiple files into a partition. It is better to over estimate, then the partitions with small files will be faster than partitions with bigger files. @@ -1855,8 +1855,8 @@ Apart from these, the following properties are also available, and may be useful spark.user.groups.mapping org.apache.spark.security.ShellBasedGroupsMappingProvider - The list of groups for a user are determined by a group mapping service defined by the trait - org.apache.spark.security.GroupMappingServiceProvider which can configured by this property. + The list of groups for a user is determined by a group mapping service defined by the trait + org.apache.spark.security.GroupMappingServiceProvider which can be configured by this property. A default unix shell based implementation is provided org.apache.spark.security.ShellBasedGroupsMappingProvider which can be specified to resolve a list of groups for a user. Note: This implementation supports only a Unix/Linux based environment. Windows environment is @@ -2465,7 +2465,7 @@ should be included on Spark's classpath: The location of these configuration files varies across Hadoop versions, but a common location is inside of `/etc/hadoop/conf`. Some tools create -configurations on-the-fly, but offer a mechanisms to download copies of them. +configurations on-the-fly, but offer a mechanism to download copies of them. To make these files visible to Spark, set `HADOOP_CONF_DIR` in `$SPARK_HOME/conf/spark-env.sh` to a location containing the configuration files. diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 46225dc598da8..5c97a248df4bc 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -708,7 +708,7 @@ messages remaining. > messaging function. These constraints allow additional optimization within GraphX. The following is the type signature of the [Pregel operator][GraphOps.pregel] as well as a *sketch* -of its implementation (note: to avoid stackOverflowError due to long lineage chains, pregel support periodcally +of its implementation (note: to avoid stackOverflowError due to long lineage chains, pregel support periodically checkpoint graph and messages by setting "spark.graphx.pregel.checkpointInterval" to a positive number, say 10. And set checkpoint directory as well using SparkContext.setCheckpointDir(directory: String)): @@ -928,7 +928,7 @@ switch to 2D-partitioning or other heuristics included in GraphX.

-Once the edges have be partitioned the key challenge to efficient graph-parallel computation is +Once the edges have been partitioned the key challenge to efficient graph-parallel computation is efficiently joining vertex attributes with the edges. Because real-world graphs typically have more edges than vertices, we move vertex attributes to the edges. Because not all partitions will contain edges adjacent to all vertices we internally maintain a routing table which identifies where diff --git a/docs/monitoring.md b/docs/monitoring.md index f8d3ce91a0691..6f6cfc1288d73 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -118,7 +118,7 @@ The history server can be configured as follows: The number of applications to retain UI data for in the cache. If this cap is exceeded, then the oldest applications will be removed from the cache. If an application is not in the cache, - it will have to be loaded from disk if its accessed from the UI. + it will have to be loaded from disk if it is accessed from the UI. @@ -407,7 +407,7 @@ can be identified by their `[attempt-id]`. In the API listed below, when running -The number of jobs and stages which can retrieved is constrained by the same retention +The number of jobs and stages which can be retrieved is constrained by the same retention mechanism of the standalone Spark UI; `"spark.ui.retainedJobs"` defines the threshold value triggering garbage collection on jobs, and `spark.ui.retainedStages` that for stages. Note that the garbage collection takes place on playback: it is possible to retrieve @@ -422,10 +422,10 @@ These endpoints have been strongly versioned to make it easier to develop applic * Individual fields will never be removed for any given endpoint * New endpoints may be added * New fields may be added to existing endpoints -* New versions of the api may be added in the future at a separate endpoint (eg., `api/v2`). New versions are *not* required to be backwards compatible. +* New versions of the api may be added in the future as a separate endpoint (eg., `api/v2`). New versions are *not* required to be backwards compatible. * Api versions may be dropped, but only after at least one minor release of co-existing with a new api version. -Note that even when examining the UI of a running applications, the `applications/[app-id]` portion is +Note that even when examining the UI of running applications, the `applications/[app-id]` portion is still required, though there is only one application available. Eg. to see the list of jobs for the running app, you would go to `http://localhost:4040/api/v1/applications/[app-id]/jobs`. This is to keep the paths consistent in both modes. diff --git a/docs/quick-start.md b/docs/quick-start.md index 200b97230e866..07c520cbee6be 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -67,7 +67,7 @@ res3: Long = 15 ./bin/pyspark -Or if PySpark is installed with pip in your current enviroment: +Or if PySpark is installed with pip in your current environment: pyspark @@ -156,7 +156,7 @@ One common data flow pattern is MapReduce, as popularized by Hadoop. Spark can i >>> wordCounts = textFile.select(explode(split(textFile.value, "\s+")).alias("word")).groupBy("word").count() {% endhighlight %} -Here, we use the `explode` function in `select`, to transfrom a Dataset of lines to a Dataset of words, and then combine `groupBy` and `count` to compute the per-word counts in the file as a DataFrame of 2 columns: "word" and "count". To collect the word counts in our shell, we can call `collect`: +Here, we use the `explode` function in `select`, to transform a Dataset of lines to a Dataset of words, and then combine `groupBy` and `count` to compute the per-word counts in the file as a DataFrame of 2 columns: "word" and "count". To collect the word counts in our shell, we can call `collect`: {% highlight python %} >>> wordCounts.collect() @@ -422,7 +422,7 @@ $ YOUR_SPARK_HOME/bin/spark-submit \ Lines with a: 46, Lines with b: 23 {% endhighlight %} -If you have PySpark pip installed into your enviroment (e.g., `pip install pyspark`), you can run your application with the regular Python interpreter or use the provided 'spark-submit' as you prefer. +If you have PySpark pip installed into your environment (e.g., `pip install pyspark`), you can run your application with the regular Python interpreter or use the provided 'spark-submit' as you prefer. {% highlight bash %} # Use the Python interpreter to run your application diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 382cbfd5301b0..2bb5ecf1b8509 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -154,7 +154,7 @@ can find the results of the driver from the Mesos Web UI. To use cluster mode, you must start the `MesosClusterDispatcher` in your cluster via the `sbin/start-mesos-dispatcher.sh` script, passing in the Mesos master URL (e.g: mesos://host:5050). This starts the `MesosClusterDispatcher` as a daemon running on the host. -By setting the Mesos proxy config property (requires mesos version >= 1.4), `--conf spark.mesos.proxy.baseURL=http://localhost:5050` when launching the dispacther, the mesos sandbox URI for each driver is added to the mesos dispatcher UI. +By setting the Mesos proxy config property (requires mesos version >= 1.4), `--conf spark.mesos.proxy.baseURL=http://localhost:5050` when launching the dispatcher, the mesos sandbox URI for each driver is added to the mesos dispatcher UI. If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`). Note that the `MesosClusterDispatcher` not yet supports multiple instances for HA. diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index e7edec5990363..e4f5a0c659e66 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -445,7 +445,7 @@ To use a custom metrics.properties for the application master and executors, upd yarn.nodemanager.log-aggregation.roll-monitoring-interval-seconds should be configured in yarn-site.xml. This feature can only be used with Hadoop 2.6.4+. The Spark log4j appender needs be changed to use - FileAppender or another appender that can handle the files being removed while its running. Based + FileAppender or another appender that can handle the files being removed while it is running. Based on the file name configured in the log4j configuration (like spark.log), the user should set the regex (spark*) to include all the log files that need to be aggregated. diff --git a/docs/security.md b/docs/security.md index 15aadf07cf873..bebc28ddbfb0e 100644 --- a/docs/security.md +++ b/docs/security.md @@ -62,7 +62,7 @@ component-specific configuration namespaces used to override the default setting -The full breakdown of available SSL options can be found on the [configuration page](configuration.html). +The full breakdown of available SSL options can be found on the [configuration page](configuration.html). SSL must be configured on each node and configured for each component involved in communication using the particular protocol. ### YARN mode diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3e2e48a0ef249..502c0a8c37e01 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1253,7 +1253,7 @@ provide a ClassTag. (Note that this is different than the Spark SQL JDBC server, which allows other applications to run queries using Spark SQL). -To get started you will need to include the JDBC driver for you particular database on the +To get started you will need to include the JDBC driver for your particular database on the spark classpath. For example, to connect to postgres from the Spark Shell you would run the following command: @@ -1793,7 +1793,7 @@ options. - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. - - Since Spark 2.3, by default arithmetic operations between decimals return a rounded value if an exact representation is not possible (instead of returning NULL). This is compliant to SQL ANSI 2011 specification and Hive's new behavior introduced in Hive 2.2 (HIVE-15331). This involves the following changes + - Since Spark 2.3, by default arithmetic operations between decimals return a rounded value if an exact representation is not possible (instead of returning NULL). This is compliant with SQL ANSI 2011 specification and Hive's new behavior introduced in Hive 2.2 (HIVE-15331). This involves the following changes - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. All the arithmetic operations are affected by the change, ie. addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`) and positive module (`pmod`). - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. @@ -1821,7 +1821,7 @@ options. transformations (e.g., `map`, `filter`, and `groupByKey`) and untyped transformations (e.g., `select` and `groupBy`) are available on the Dataset class. Since compile-time type-safety in Python and R is not a language feature, the concept of Dataset does not apply to these languages’ - APIs. Instead, `DataFrame` remains the primary programing abstraction, which is analogous to the + APIs. Instead, `DataFrame` remains the primary programming abstraction, which is analogous to the single-node data frame notion in these languages. - Dataset and DataFrame API `unionAll` has been deprecated and replaced by `union` @@ -1997,7 +1997,7 @@ Java and Python users will need to update their code. Prior to Spark 1.3 there were separate Java compatible classes (`JavaSQLContext` and `JavaSchemaRDD`) that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users -of either language should use `SQLContext` and `DataFrame`. In general theses classes try to +of either language should use `SQLContext` and `DataFrame`. In general these classes try to use types that are usable from both languages (i.e. `Array` instead of language specific collections). In some cases where no common type exists (e.g., for passing in closures or Maps) function overloading is used instead. diff --git a/docs/storage-openstack-swift.md b/docs/storage-openstack-swift.md index f4bb2353e3c49..1dd54719b21aa 100644 --- a/docs/storage-openstack-swift.md +++ b/docs/storage-openstack-swift.md @@ -42,7 +42,7 @@ Create core-site.xml and place it inside Spark's conf The main category of parameters that should be configured are the authentication parameters required by Keystone. -The following table contains a list of Keystone mandatory parameters. PROVIDER can be +The following table contains a list of Keystone mandatory parameters. PROVIDER can be any (alphanumeric) name. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 868acc41226dc..ffda36d64a770 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -74,7 +74,7 @@ import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ // not necessary since Spark 1.3 // Create a local StreamingContext with two working thread and batch interval of 1 second. -// The master requires 2 cores to prevent from a starvation scenario. +// The master requires 2 cores to prevent a starvation scenario. val conf = new SparkConf().setMaster("local[2]").setAppName("NetworkWordCount") val ssc = new StreamingContext(conf, Seconds(1)) @@ -172,7 +172,7 @@ each line will be split into multiple words and the stream of words is represent `words` DStream. Note that we defined the transformation using a [FlatMapFunction](api/scala/index.html#org.apache.spark.api.java.function.FlatMapFunction) object. As we will discover along the way, there are a number of such convenience classes in the Java API -that help define DStream transformations. +that help defines DStream transformations. Next, we want to count these words. diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index 461c29ce1ba89..5647ec6bc5797 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -125,7 +125,7 @@ df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") ### Creating a Kafka Source for Batch Queries If you have a use case that is better suited to batch processing, -you can create an Dataset/DataFrame for a defined range of offsets. +you can create a Dataset/DataFrame for a defined range of offsets.
@@ -597,7 +597,7 @@ Note that the following Kafka params cannot be set and the Kafka source or sink - **key.serializer**: Keys are always serialized with ByteArraySerializer or StringSerializer. Use DataFrame operations to explicitly serialize the keys into either strings or byte arrays. - **value.serializer**: values are always serialized with ByteArraySerializer or StringSerializer. Use -DataFrame oeprations to explicitly serialize the values into either strings or byte arrays. +DataFrame operations to explicitly serialize the values into either strings or byte arrays. - **enable.auto.commit**: Kafka source doesn't commit any offset. - **interceptor.classes**: Kafka source always read keys and values as byte arrays. It's not safe to use ConsumerInterceptor as it may break the query. diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 2ddba2f0d942e..2ef5d3168a87b 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -10,7 +10,7 @@ title: Structured Streaming Programming Guide # Overview Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* -Internally, by default, Structured Streaming queries are processed using a *micro-batch processing* engine, which processes data streams as a series of small batch jobs thereby achieving end-to-end latencies as low as 100 milliseconds and exactly-once fault-tolerance guarantees. However, since Spark 2.3, we have introduced a new low-latency processing mode called **Continuous Processing**, which can achieve end-to-end latencies as low as 1 millisecond with at-least-once guarantees. Without changing the Dataset/DataFrame operations in your queries, you will be able choose the mode based on your application requirements. +Internally, by default, Structured Streaming queries are processed using a *micro-batch processing* engine, which processes data streams as a series of small batch jobs thereby achieving end-to-end latencies as low as 100 milliseconds and exactly-once fault-tolerance guarantees. However, since Spark 2.3, we have introduced a new low-latency processing mode called **Continuous Processing**, which can achieve end-to-end latencies as low as 1 millisecond with at-least-once guarantees. Without changing the Dataset/DataFrame operations in your queries, you will be able to choose the mode based on your application requirements. In this guide, we are going to walk you through the programming model and the APIs. We are going to explain the concepts mostly using the default micro-batch processing model, and then [later](#continuous-processing-experimental) discuss Continuous Processing model. First, let's start with a simple example of a Structured Streaming query - a streaming word count. @@ -1121,7 +1121,7 @@ Let’s discuss the different types of supported stream-stream joins and how to ##### Inner Joins with optional Watermarking Inner joins on any kind of columns along with any kind of join conditions are supported. However, as the stream runs, the size of streaming state will keep growing indefinitely as -*all* past input must be saved as the any new input can match with any input from the past. +*all* past input must be saved as any new input can match with any input from the past. To avoid unbounded state, you have to define additional join conditions such that indefinitely old inputs cannot match with future inputs and therefore can be cleared from the state. In other words, you will have to do the following additional steps in the join. @@ -1839,7 +1839,7 @@ aggDF \ .format("console") \ .start() -# Have all the aggregates in an in memory table. The query name will be the table name +# Have all the aggregates in an in-memory table. The query name will be the table name aggDF \ .writeStream \ .queryName("aggregates") \ diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 0473ab73a5e6c..a3643bf0838a1 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -5,7 +5,7 @@ title: Submitting Applications The `spark-submit` script in Spark's `bin` directory is used to launch applications on a cluster. It can use all of Spark's supported [cluster managers](cluster-overview.html#cluster-manager-types) -through a uniform interface so you don't have to configure your application specially for each one. +through a uniform interface so you don't have to configure your application especially for each one. # Bundling Your Application's Dependencies If your code depends on other projects, you will need to package them alongside @@ -58,7 +58,7 @@ for applications that involve the REPL (e.g. Spark shell). Alternatively, if your application is submitted from a machine far from the worker machines (e.g. locally on your laptop), it is common to use `cluster` mode to minimize network latency between -the drivers and the executors. Currently, standalone mode does not support cluster mode for Python +the drivers and the executors. Currently, the standalone mode does not support cluster mode for Python applications. For Python applications, simply pass a `.py` file in the place of `` instead of a JAR, @@ -68,7 +68,7 @@ There are a few options available that are specific to the [cluster manager](cluster-overview.html#cluster-manager-types) that is being used. For example, with a [Spark standalone cluster](spark-standalone.html) with `cluster` deploy mode, you can also specify `--supervise` to make sure that the driver is automatically restarted if it -fails with non-zero exit code. To enumerate all such options available to `spark-submit`, +fails with a non-zero exit code. To enumerate all such options available to `spark-submit`, run it with `--help`. Here are a few examples of common options: {% highlight bash %} @@ -192,7 +192,7 @@ debugging information by running `spark-submit` with the `--verbose` option. # Advanced Dependency Management When using `spark-submit`, the application jar along with any jars included with the `--jars` option -will be automatically transferred to the cluster. URLs supplied after `--jars` must be separated by commas. That list is included on the driver and executor classpaths. Directory expansion does not work with `--jars`. +will be automatically transferred to the cluster. URLs supplied after `--jars` must be separated by commas. That list is included in the driver and executor classpaths. Directory expansion does not work with `--jars`. Spark uses the following URL scheme to allow different strategies for disseminating jars: From 00d169156d4b1c91d2bcfd788b254b03c509dc41 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Sat, 20 Jan 2018 14:49:49 -0800 Subject: [PATCH 158/774] [SPARK-21786][SQL] The 'spark.sql.parquet.compression.codec' and 'spark.sql.orc.compression.codec' configuration doesn't take effect on hive table writing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [SPARK-21786][SQL] The 'spark.sql.parquet.compression.codec' and 'spark.sql.orc.compression.codec' configuration doesn't take effect on hive table writing What changes were proposed in this pull request? Pass ‘spark.sql.parquet.compression.codec’ value to ‘parquet.compression’. Pass ‘spark.sql.orc.compression.codec’ value to ‘orc.compress’. How was this patch tested? Add test. Note: This is the same issue mentioned in #19218 . That branch was deleted mistakenly, so make a new pr instead. gatorsmile maropu dongjoon-hyun discipleforteen Author: fjh100456 Author: Takeshi Yamamuro Author: Wenchen Fan Author: gatorsmile Author: Yinan Li Author: Marcelo Vanzin Author: Juliusz Sompolski Author: Felix Cheung Author: jerryshao Author: Li Jin Author: Gera Shegalov Author: chetkhatri Author: Joseph K. Bradley Author: Bago Amirbekian Author: Xianjin YE Author: Bruce Robbins Author: zuotingbing Author: Kent Yao Author: hyukjinkwon Author: Adrian Ionescu Closes #20087 from fjh100456/HiveTableWriting. --- .../datasources/orc/OrcOptions.scala | 2 + .../datasources/parquet/ParquetOptions.scala | 6 +- .../sql/hive/execution/HiveOptions.scala | 22 ++ .../sql/hive/execution/SaveAsHiveFile.scala | 20 +- .../sql/hive/CompressionCodecSuite.scala | 353 ++++++++++++++++++ 5 files changed, 397 insertions(+), 6 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala index c866dd834a525..0ad3862f6cf01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala @@ -67,4 +67,6 @@ object OrcOptions { "snappy" -> "SNAPPY", "zlib" -> "ZLIB", "lzo" -> "LZO") + + def getORCCompressionCodecName(name: String): String = shortOrcCompressionCodecNames(name) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index ef67ea7d17cea..f36a89a4c3c5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf /** * Options for the Parquet data source. */ -private[parquet] class ParquetOptions( +class ParquetOptions( @transient private val parameters: CaseInsensitiveMap[String], @transient private val sqlConf: SQLConf) extends Serializable { @@ -82,4 +82,8 @@ object ParquetOptions { "snappy" -> CompressionCodecName.SNAPPY, "gzip" -> CompressionCodecName.GZIP, "lzo" -> CompressionCodecName.LZO) + + def getParquetCompressionCodecName(name: String): String = { + shortParquetCompressionCodecNames(name).name() + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala index 5c515515b9b9c..802ddafdbee4d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala @@ -19,7 +19,16 @@ package org.apache.spark.sql.hive.execution import java.util.Locale +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.orc.OrcConf.COMPRESS +import org.apache.parquet.hadoop.ParquetOutputFormat + import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.datasources.orc.OrcOptions +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.internal.SQLConf /** * Options for the Hive data source. Note that rule `DetermineHiveSerde` will extract Hive @@ -102,4 +111,17 @@ object HiveOptions { "collectionDelim" -> "colelction.delim", "mapkeyDelim" -> "mapkey.delim", "lineDelim" -> "line.delim").map { case (k, v) => k.toLowerCase(Locale.ROOT) -> v } + + def getHiveWriteCompression(tableInfo: TableDesc, sqlConf: SQLConf): Option[(String, String)] = { + val tableProps = tableInfo.getProperties.asScala.toMap + tableInfo.getOutputFileFormatClassName.toLowerCase(Locale.ROOT) match { + case formatName if formatName.endsWith("parquetoutputformat") => + val compressionCodec = new ParquetOptions(tableProps, sqlConf).compressionCodecClassName + Option((ParquetOutputFormat.COMPRESSION, compressionCodec)) + case formatName if formatName.endsWith("orcoutputformat") => + val compressionCodec = new OrcOptions(tableProps, sqlConf).compressionCodec + Option((COMPRESS.getAttribute, compressionCodec)) + case _ => None + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 9a6607f2f2c6c..e484356906e87 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -55,18 +55,28 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty, partitionAttributes: Seq[Attribute] = Nil): Set[String] = { - val isCompressed = hadoopConf.get("hive.exec.compress.output", "false").toBoolean + val isCompressed = + fileSinkConf.getTableInfo.getOutputFileFormatClassName.toLowerCase(Locale.ROOT) match { + case formatName if formatName.endsWith("orcoutputformat") => + // For ORC,"mapreduce.output.fileoutputformat.compress", + // "mapreduce.output.fileoutputformat.compress.codec", and + // "mapreduce.output.fileoutputformat.compress.type" + // have no impact because it uses table properties to store compression information. + false + case _ => hadoopConf.get("hive.exec.compress.output", "false").toBoolean + } + if (isCompressed) { - // Please note that isCompressed, "mapreduce.output.fileoutputformat.compress", - // "mapreduce.output.fileoutputformat.compress.codec", and - // "mapreduce.output.fileoutputformat.compress.type" - // have no impact on ORC because it uses table properties to store compression information. hadoopConf.set("mapreduce.output.fileoutputformat.compress", "true") fileSinkConf.setCompressed(true) fileSinkConf.setCompressCodec(hadoopConf .get("mapreduce.output.fileoutputformat.compress.codec")) fileSinkConf.setCompressType(hadoopConf .get("mapreduce.output.fileoutputformat.compress.type")) + } else { + // Set compression by priority + HiveOptions.getHiveWriteCompression(fileSinkConf.getTableInfo, sparkSession.sessionState.conf) + .foreach { case (compression, codec) => hadoopConf.set(compression, codec) } } val committer = FileCommitProtocol.instantiate( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala new file mode 100644 index 0000000000000..d10a6f25c64fc --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path +import org.apache.orc.OrcConf.COMPRESS +import org.apache.parquet.hadoop.ParquetOutputFormat +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.execution.datasources.orc.OrcOptions +import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetTest} +import org.apache.spark.sql.hive.orc.OrcFileOperator +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf + +class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with BeforeAndAfterAll { + import spark.implicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + (0 until maxRecordNum).toDF("a").createOrReplaceTempView("table_source") + } + + override def afterAll(): Unit = { + try { + spark.catalog.dropTempView("table_source") + } finally { + super.afterAll() + } + } + + private val maxRecordNum = 50 + + private def getConvertMetastoreConfName(format: String): String = format.toLowerCase match { + case "parquet" => HiveUtils.CONVERT_METASTORE_PARQUET.key + case "orc" => HiveUtils.CONVERT_METASTORE_ORC.key + } + + private def getSparkCompressionConfName(format: String): String = format.toLowerCase match { + case "parquet" => SQLConf.PARQUET_COMPRESSION.key + case "orc" => SQLConf.ORC_COMPRESSION.key + } + + private def getHiveCompressPropName(format: String): String = format.toLowerCase match { + case "parquet" => ParquetOutputFormat.COMPRESSION + case "orc" => COMPRESS.getAttribute + } + + private def normalizeCodecName(format: String, name: String): String = { + format.toLowerCase match { + case "parquet" => ParquetOptions.getParquetCompressionCodecName(name) + case "orc" => OrcOptions.getORCCompressionCodecName(name) + } + } + + private def getTableCompressionCodec(path: String, format: String): Seq[String] = { + val hadoopConf = spark.sessionState.newHadoopConf() + val codecs = format.toLowerCase match { + case "parquet" => for { + footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConf) + block <- footer.getParquetMetadata.getBlocks.asScala + column <- block.getColumns.asScala + } yield column.getCodec.name() + case "orc" => new File(path).listFiles().filter { file => + file.isFile && !file.getName.endsWith(".crc") && file.getName != "_SUCCESS" + }.map { orcFile => + OrcFileOperator.getFileReader(orcFile.toPath.toString).get.getCompression.toString + }.toSeq + } + codecs.distinct + } + + private def createTable( + rootDir: File, + tableName: String, + isPartitioned: Boolean, + format: String, + compressionCodec: Option[String]): Unit = { + val tblProperties = compressionCodec match { + case Some(prop) => s"TBLPROPERTIES('${getHiveCompressPropName(format)}'='$prop')" + case _ => "" + } + val partitionCreate = if (isPartitioned) "PARTITIONED BY (p string)" else "" + sql( + s""" + |CREATE TABLE $tableName(a int) + |$partitionCreate + |STORED AS $format + |LOCATION '${rootDir.toURI.toString.stripSuffix("/")}/$tableName' + |$tblProperties + """.stripMargin) + } + + private def writeDataToTable( + tableName: String, + partitionValue: Option[String]): Unit = { + val partitionInsert = partitionValue.map(p => s"partition (p='$p')").mkString + sql( + s""" + |INSERT INTO TABLE $tableName + |$partitionInsert + |SELECT * FROM table_source + """.stripMargin) + } + + private def writeDateToTableUsingCTAS( + rootDir: File, + tableName: String, + partitionValue: Option[String], + format: String, + compressionCodec: Option[String]): Unit = { + val partitionCreate = partitionValue.map(p => s"PARTITIONED BY (p)").mkString + val compressionOption = compressionCodec.map { codec => + s",'${getHiveCompressPropName(format)}'='$codec'" + }.mkString + val partitionSelect = partitionValue.map(p => s",'$p' AS p").mkString + sql( + s""" + |CREATE TABLE $tableName + |USING $format + |OPTIONS('path'='${rootDir.toURI.toString.stripSuffix("/")}/$tableName' $compressionOption) + |$partitionCreate + |AS SELECT * $partitionSelect FROM table_source + """.stripMargin) + } + + private def getPreparedTablePath( + tmpDir: File, + tableName: String, + isPartitioned: Boolean, + format: String, + compressionCodec: Option[String], + usingCTAS: Boolean): String = { + val partitionValue = if (isPartitioned) Some("test") else None + if (usingCTAS) { + writeDateToTableUsingCTAS(tmpDir, tableName, partitionValue, format, compressionCodec) + } else { + createTable(tmpDir, tableName, isPartitioned, format, compressionCodec) + writeDataToTable(tableName, partitionValue) + } + getTablePartitionPath(tmpDir, tableName, partitionValue) + } + + private def getTableSize(path: String): Long = { + val dir = new File(path) + val files = dir.listFiles().filter(_.getName.startsWith("part-")) + files.map(_.length()).sum + } + + private def getTablePartitionPath( + dir: File, + tableName: String, + partitionValue: Option[String]) = { + val partitionPath = partitionValue.map(p => s"p=$p").mkString + s"${dir.getPath.stripSuffix("/")}/$tableName/$partitionPath" + } + + private def getUncompressedDataSizeByFormat( + format: String, isPartitioned: Boolean, usingCTAS: Boolean): Long = { + var totalSize = 0L + val tableName = s"tbl_$format" + val codecName = normalizeCodecName(format, "uncompressed") + withSQLConf(getSparkCompressionConfName(format) -> codecName) { + withTempDir { tmpDir => + withTable(tableName) { + val compressionCodec = Option(codecName) + val path = getPreparedTablePath( + tmpDir, tableName, isPartitioned, format, compressionCodec, usingCTAS) + totalSize = getTableSize(path) + } + } + } + assert(totalSize > 0L) + totalSize + } + + private def checkCompressionCodecForTable( + format: String, + isPartitioned: Boolean, + compressionCodec: Option[String], + usingCTAS: Boolean) + (assertion: (String, Long) => Unit): Unit = { + val tableName = + if (usingCTAS) s"tbl_$format$isPartitioned" else s"tbl_$format${isPartitioned}_CAST" + withTempDir { tmpDir => + withTable(tableName) { + val path = getPreparedTablePath( + tmpDir, tableName, isPartitioned, format, compressionCodec, usingCTAS) + val relCompressionCodecs = getTableCompressionCodec(path, format) + assert(relCompressionCodecs.length == 1) + val tableSize = getTableSize(path) + assertion(relCompressionCodecs.head, tableSize) + } + } + } + + private def checkTableCompressionCodecForCodecs( + format: String, + isPartitioned: Boolean, + convertMetastore: Boolean, + usingCTAS: Boolean, + compressionCodecs: List[String], + tableCompressionCodecs: List[String]) + (assertionCompressionCodec: (Option[String], String, String, Long) => Unit): Unit = { + withSQLConf(getConvertMetastoreConfName(format) -> convertMetastore.toString) { + tableCompressionCodecs.foreach { tableCompression => + compressionCodecs.foreach { sessionCompressionCodec => + withSQLConf(getSparkCompressionConfName(format) -> sessionCompressionCodec) { + // 'tableCompression = null' means no table-level compression + val compression = Option(tableCompression) + checkCompressionCodecForTable(format, isPartitioned, compression, usingCTAS) { + case (realCompressionCodec, tableSize) => + assertionCompressionCodec( + compression, sessionCompressionCodec, realCompressionCodec, tableSize) + } + } + } + } + } + } + + // When the amount of data is small, compressed data size may be larger than uncompressed one, + // so we just check the difference when compressionCodec is not NONE or UNCOMPRESSED. + private def checkTableSize( + format: String, + compressionCodec: String, + isPartitioned: Boolean, + convertMetastore: Boolean, + usingCTAS: Boolean, + tableSize: Long): Boolean = { + val uncompressedSize = getUncompressedDataSizeByFormat(format, isPartitioned, usingCTAS) + compressionCodec match { + case "UNCOMPRESSED" if format == "parquet" => tableSize == uncompressedSize + case "NONE" if format == "orc" => tableSize == uncompressedSize + case _ => tableSize != uncompressedSize + } + } + + def checkForTableWithCompressProp(format: String, compressCodecs: List[String]): Unit = { + Seq(true, false).foreach { isPartitioned => + Seq(true, false).foreach { convertMetastore => + // TODO: Also verify CTAS(usingCTAS=true) cases when the bug(SPARK-22926) is fixed. + Seq(false).foreach { usingCTAS => + checkTableCompressionCodecForCodecs( + format, + isPartitioned, + convertMetastore, + usingCTAS, + compressionCodecs = compressCodecs, + tableCompressionCodecs = compressCodecs) { + case (tableCodec, sessionCodec, realCodec, tableSize) => + // For non-partitioned table and when convertMetastore is true, Expect session-level + // take effect, and in other cases expect table-level take effect + // TODO: It should always be table-level taking effect when the bug(SPARK-22926) + // is fixed + val expectCodec = + if (convertMetastore && !isPartitioned) sessionCodec else tableCodec.get + assert(expectCodec == realCodec) + assert(checkTableSize( + format, expectCodec, isPartitioned, convertMetastore, usingCTAS, tableSize)) + } + } + } + } + } + + def checkForTableWithoutCompressProp(format: String, compressCodecs: List[String]): Unit = { + Seq(true, false).foreach { isPartitioned => + Seq(true, false).foreach { convertMetastore => + // TODO: Also verify CTAS(usingCTAS=true) cases when the bug(SPARK-22926) is fixed. + Seq(false).foreach { usingCTAS => + checkTableCompressionCodecForCodecs( + format, + isPartitioned, + convertMetastore, + usingCTAS, + compressionCodecs = compressCodecs, + tableCompressionCodecs = List(null)) { + case (tableCodec, sessionCodec, realCodec, tableSize) => + // Always expect session-level take effect + assert(sessionCodec == realCodec) + assert(checkTableSize( + format, sessionCodec, isPartitioned, convertMetastore, usingCTAS, tableSize)) + } + } + } + } + } + + test("both table-level and session-level compression are set") { + checkForTableWithCompressProp("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkForTableWithCompressProp("orc", List("NONE", "SNAPPY", "ZLIB")) + } + + test("table-level compression is not set but session-level compressions is set ") { + checkForTableWithoutCompressProp("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkForTableWithoutCompressProp("orc", List("NONE", "SNAPPY", "ZLIB")) + } + + def checkTableWriteWithCompressionCodecs(format: String, compressCodecs: List[String]): Unit = { + Seq(true, false).foreach { isPartitioned => + Seq(true, false).foreach { convertMetastore => + withTempDir { tmpDir => + val tableName = s"tbl_$format$isPartitioned" + createTable(tmpDir, tableName, isPartitioned, format, None) + withTable(tableName) { + compressCodecs.foreach { compressionCodec => + val partitionValue = if (isPartitioned) Some(compressionCodec) else None + withSQLConf(getConvertMetastoreConfName(format) -> convertMetastore.toString, + getSparkCompressionConfName(format) -> compressionCodec + ) { writeDataToTable(tableName, partitionValue) } + } + val tablePath = getTablePartitionPath(tmpDir, tableName, None) + val realCompressionCodecs = + if (isPartitioned) compressCodecs.flatMap { codec => + getTableCompressionCodec(s"$tablePath/p=$codec", format) + } else { + getTableCompressionCodec(tablePath, format) + } + + assert(realCompressionCodecs.distinct.sorted == compressCodecs.sorted) + val recordsNum = sql(s"SELECT * from $tableName").count() + assert(recordsNum == maxRecordNum * compressCodecs.length) + } + } + } + } + } + + test("test table containing mixed compression codec") { + checkTableWriteWithCompressionCodecs("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkTableWriteWithCompressionCodecs("orc", List("NONE", "SNAPPY", "ZLIB")) + } +} From 121dc96f088a7b157d5b2cffb626b0e22d1fc052 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 20 Jan 2018 22:39:49 -0800 Subject: [PATCH 159/774] [SPARK-23087][SQL] CheckCartesianProduct too restrictive when condition is false/null ## What changes were proposed in this pull request? CheckCartesianProduct raises an AnalysisException also when the join condition is always false/null. In this case, we shouldn't raise it, since the result will not be a cartesian product. ## How was this patch tested? added UT Author: Marco Gaido Closes #20333 from mgaido91/SPARK-23087. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 10 +++++++--- .../org/apache/spark/sql/DataFrameJoinSuite.scala | 14 ++++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c794ba8619322..0f9daa5f04c76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1108,15 +1108,19 @@ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper { */ def isCartesianProduct(join: Join): Boolean = { val conditions = join.condition.map(splitConjunctivePredicates).getOrElse(Nil) - !conditions.map(_.references).exists(refs => refs.exists(join.left.outputSet.contains) - && refs.exists(join.right.outputSet.contains)) + + conditions match { + case Seq(Literal.FalseLiteral) | Seq(Literal(null, BooleanType)) => false + case _ => !conditions.map(_.references).exists(refs => + refs.exists(join.left.outputSet.contains) && refs.exists(join.right.outputSet.contains)) + } } def apply(plan: LogicalPlan): LogicalPlan = if (SQLConf.get.crossJoinEnabled) { plan } else plan transform { - case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, condition) + case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _) if isCartesianProduct(j) => throw new AnalysisException( s"""Detected cartesian product for ${j.joinType.sql} join between logical plans diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index aef0d7f3e425b..1656f290ee19c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -274,4 +274,18 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { checkAnswer(innerJoin, Row(1) :: Nil) } + test("SPARK-23087: don't throw Analysis Exception in CheckCartesianProduct when join condition " + + "is false or null") { + val df = spark.range(10) + val dfNull = spark.range(10).select(lit(null).as("b")) + val planNull = df.join(dfNull, $"id" === $"b", "left").queryExecution.analyzed + + spark.sessionState.executePlan(planNull).optimizedPlan + + val dfOne = df.select(lit(1).as("a")) + val dfTwo = spark.range(10).select(lit(2).as("b")) + val planFalse = dfOne.join(dfTwo, $"a" === $"b", "left").queryExecution.analyzed + + spark.sessionState.executePlan(planFalse).optimizedPlan + } } From 4f43d27c9e97be8605b120b3d7c11c7c61e3ca6f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 21 Jan 2018 08:51:12 -0600 Subject: [PATCH 160/774] [SPARK-22119][ML] Add cosine distance to KMeans ## What changes were proposed in this pull request? Currently, KMeans assumes the only possible distance measure to be used is the Euclidean. This PR aims to add the cosine distance support to the KMeans algorithm. ## How was this patch tested? existing and added UTs. Author: Marco Gaido Author: Marco Gaido Closes #19340 from mgaido91/SPARK-22119. --- .../apache/spark/ml/clustering/KMeans.scala | 22 +- .../mllib/clustering/BisectingKMeans.scala | 11 +- .../spark/mllib/clustering/KMeans.scala | 216 ++++++++++++++---- .../spark/mllib/clustering/KMeansModel.scala | 74 +++++- .../spark/mllib/clustering/LocalKMeans.scala | 10 +- .../spark/ml/clustering/KMeansSuite.scala | 42 +++- .../spark/mllib/clustering/KMeansSuite.scala | 6 +- 7 files changed, 315 insertions(+), 66 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index f2af7fe082b41..c8145de564cbe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -26,7 +26,7 @@ import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} +import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD @@ -71,6 +71,15 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") def getInitMode: String = $(initMode) + @Since("2.4.0") + final val distanceMeasure = new Param[String](this, "distanceMeasure", "The distance measure. " + + "Supported options: 'euclidean' and 'cosine'.", + (value: String) => MLlibKMeans.validateDistanceMeasure(value)) + + /** @group expertGetParam */ + @Since("2.4.0") + def getDistanceMeasure: String = $(distanceMeasure) + /** * Param for the number of steps for the k-means|| initialization mode. This is an advanced * setting -- the default of 2 is almost always enough. Must be > 0. Default: 2. @@ -260,7 +269,8 @@ class KMeans @Since("1.5.0") ( maxIter -> 20, initMode -> MLlibKMeans.K_MEANS_PARALLEL, initSteps -> 2, - tol -> 1e-4) + tol -> 1e-4, + distanceMeasure -> DistanceMeasure.EUCLIDEAN) @Since("1.5.0") override def copy(extra: ParamMap): KMeans = defaultCopy(extra) @@ -284,6 +294,10 @@ class KMeans @Since("1.5.0") ( @Since("1.5.0") def setInitMode(value: String): this.type = set(initMode, value) + /** @group expertSetParam */ + @Since("2.4.0") + def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) + /** @group expertSetParam */ @Since("1.5.0") def setInitSteps(value: Int): this.type = set(initSteps, value) @@ -314,7 +328,8 @@ class KMeans @Since("1.5.0") ( } val instr = Instrumentation.create(this, instances) - instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol) + instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure, + maxIter, seed, tol) val algo = new MLlibKMeans() .setK($(k)) .setInitializationMode($(initMode)) @@ -322,6 +337,7 @@ class KMeans @Since("1.5.0") ( .setMaxIterations($(maxIter)) .setSeed($(seed)) .setEpsilon($(tol)) + .setDistanceMeasure($(distanceMeasure)) val parentModel = algo.run(instances, Option(instr)) val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index 9b9c70cfe5109..2221f4c0edc17 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -350,7 +350,7 @@ private object BisectingKMeans extends Serializable { val newClusterChildren = children.filter(newClusterCenters.contains(_)) if (newClusterChildren.nonEmpty) { val selected = newClusterChildren.minBy { child => - KMeans.fastSquaredDistance(newClusterCenters(child), v) + EuclideanDistanceMeasure.fastSquaredDistance(newClusterCenters(child), v) } (selected, v) } else { @@ -387,7 +387,7 @@ private object BisectingKMeans extends Serializable { val rightIndex = rightChildIndex(rawIndex) val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_)) val height = math.sqrt(indexes.map { childIndex => - KMeans.fastSquaredDistance(center, clusters(childIndex).center) + EuclideanDistanceMeasure.fastSquaredDistance(center, clusters(childIndex).center) }.max) val children = indexes.map(buildSubTree(_)).toArray new ClusteringTreeNode(index, size, center, cost, height, children) @@ -457,7 +457,7 @@ private[clustering] class ClusteringTreeNode private[clustering] ( this :: Nil } else { val selected = children.minBy { child => - KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm) + EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm) } selected :: selected.predictPath(pointWithNorm) } @@ -475,7 +475,8 @@ private[clustering] class ClusteringTreeNode private[clustering] ( * Predicts the cluster index and the cost of the input point. */ private def predict(pointWithNorm: VectorWithNorm): (Int, Double) = { - predict(pointWithNorm, KMeans.fastSquaredDistance(centerWithNorm, pointWithNorm)) + predict(pointWithNorm, + EuclideanDistanceMeasure.fastSquaredDistance(centerWithNorm, pointWithNorm)) } /** @@ -490,7 +491,7 @@ private[clustering] class ClusteringTreeNode private[clustering] ( (index, cost) } else { val (selectedChild, minCost) = children.map { child => - (child, KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm)) + (child, EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm)) }.minBy(_._2) selectedChild.predict(pointWithNorm, minCost) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 49043b5acb807..607145cb59fba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -25,7 +25,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.clustering.{KMeans => NewKMeans} import org.apache.spark.ml.util.Instrumentation import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.linalg.BLAS.{axpy, scal} +import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -46,14 +46,23 @@ class KMeans private ( private var initializationMode: String, private var initializationSteps: Int, private var epsilon: Double, - private var seed: Long) extends Serializable with Logging { + private var seed: Long, + private var distanceMeasure: String) extends Serializable with Logging { + + @Since("0.8.0") + private def this(k: Int, maxIterations: Int, initializationMode: String, initializationSteps: Int, + epsilon: Double, seed: Long) = + this(k, maxIterations, initializationMode, initializationSteps, + epsilon, seed, DistanceMeasure.EUCLIDEAN) /** * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, - * initializationMode: "k-means||", initializationSteps: 2, epsilon: 1e-4, seed: random}. + * initializationMode: "k-means||", initializationSteps: 2, epsilon: 1e-4, seed: random, + * distanceMeasure: "euclidean"}. */ @Since("0.8.0") - def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong()) + def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong(), + DistanceMeasure.EUCLIDEAN) /** * Number of clusters to create (k). @@ -184,6 +193,22 @@ class KMeans private ( this } + /** + * The distance suite used by the algorithm. + */ + @Since("2.4.0") + def getDistanceMeasure: String = distanceMeasure + + /** + * Set the distance suite used by the algorithm. + */ + @Since("2.4.0") + def setDistanceMeasure(distanceMeasure: String): this.type = { + KMeans.validateDistanceMeasure(distanceMeasure) + this.distanceMeasure = distanceMeasure + this + } + // Initial cluster centers can be provided as a KMeansModel object rather than using the // random or k-means|| initializationMode private var initialModel: Option[KMeansModel] = None @@ -246,6 +271,8 @@ class KMeans private ( val initStartTime = System.nanoTime() + val distanceMeasureInstance = DistanceMeasure.decodeFromString(this.distanceMeasure) + val centers = initialModel match { case Some(kMeansCenters) => kMeansCenters.clusterCenters.map(new VectorWithNorm(_)) @@ -253,7 +280,7 @@ class KMeans private ( if (initializationMode == KMeans.RANDOM) { initRandom(data) } else { - initKMeansParallel(data) + initKMeansParallel(data, distanceMeasureInstance) } } val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 @@ -281,7 +308,7 @@ class KMeans private ( val counts = Array.fill(thisCenters.length)(0L) points.foreach { point => - val (bestCenter, cost) = KMeans.findClosest(thisCenters, point) + val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point) costAccum.add(cost) val sum = sums(bestCenter) axpy(1.0, point.vector, sum) @@ -302,7 +329,8 @@ class KMeans private ( // Update the cluster centers and costs converged = true newCenters.foreach { case (j, newCenter) => - if (converged && KMeans.fastSquaredDistance(newCenter, centers(j)) > epsilon * epsilon) { + if (converged && + !distanceMeasureInstance.isCenterConverged(centers(j), newCenter, epsilon)) { converged = false } centers(j) = newCenter @@ -323,7 +351,7 @@ class KMeans private ( logInfo(s"The cost is $cost.") - new KMeansModel(centers.map(_.vector)) + new KMeansModel(centers.map(_.vector), distanceMeasure) } /** @@ -345,7 +373,8 @@ class KMeans private ( * * The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf. */ - private[clustering] def initKMeansParallel(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { + private[clustering] def initKMeansParallel(data: RDD[VectorWithNorm], + distanceMeasureInstance: DistanceMeasure): Array[VectorWithNorm] = { // Initialize empty centers and point costs. var costs = data.map(_ => Double.PositiveInfinity) @@ -369,7 +398,7 @@ class KMeans private ( bcNewCentersList += bcNewCenters val preCosts = costs costs = data.zip(preCosts).map { case (point, cost) => - math.min(KMeans.pointCost(bcNewCenters.value, point), cost) + math.min(distanceMeasureInstance.pointCost(bcNewCenters.value, point), cost) }.persist(StorageLevel.MEMORY_AND_DISK) val sumCosts = costs.sum() @@ -397,7 +426,9 @@ class KMeans private ( // candidate by the number of points in the dataset mapping to it and run a local k-means++ // on the weighted centers to pick k of them val bcCenters = data.context.broadcast(distinctCenters) - val countMap = data.map(KMeans.findClosest(bcCenters.value, _)._1).countByValue() + val countMap = data + .map(distanceMeasureInstance.findClosest(bcCenters.value, _)._1) + .countByValue() bcCenters.destroy(blocking = false) @@ -546,10 +577,110 @@ object KMeans { .run(data) } + private[spark] def validateInitMode(initMode: String): Boolean = { + initMode match { + case KMeans.RANDOM => true + case KMeans.K_MEANS_PARALLEL => true + case _ => false + } + } + + private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = { + distanceMeasure match { + case DistanceMeasure.EUCLIDEAN => true + case DistanceMeasure.COSINE => true + case _ => false + } + } +} + +/** + * A vector with its norm for fast distance computation. + */ +private[clustering] class VectorWithNorm(val vector: Vector, val norm: Double) + extends Serializable { + + def this(vector: Vector) = this(vector, Vectors.norm(vector, 2.0)) + + def this(array: Array[Double]) = this(Vectors.dense(array)) + + /** Converts the vector to a dense vector. */ + def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm) +} + + +private[spark] abstract class DistanceMeasure extends Serializable { + + /** + * @return the index of the closest center to the given point, as well as the cost. + */ + def findClosest( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): (Int, Double) = { + var bestDistance = Double.PositiveInfinity + var bestIndex = 0 + var i = 0 + centers.foreach { center => + val currentDistance = distance(center, point) + if (currentDistance < bestDistance) { + bestDistance = currentDistance + bestIndex = i + } + i += 1 + } + (bestIndex, bestDistance) + } + /** - * Returns the index of the closest center to the given point, as well as the squared distance. + * @return the K-means cost of a given point against the given cluster centers. */ - private[mllib] def findClosest( + def pointCost( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): Double = { + findClosest(centers, point)._2 + } + + /** + * @return whether a center converged or not, given the epsilon parameter. + */ + def isCenterConverged( + oldCenter: VectorWithNorm, + newCenter: VectorWithNorm, + epsilon: Double): Boolean = { + distance(oldCenter, newCenter) <= epsilon + } + + /** + * @return the cosine distance between two points. + */ + def distance( + v1: VectorWithNorm, + v2: VectorWithNorm): Double + +} + +@Since("2.4.0") +object DistanceMeasure { + + @Since("2.4.0") + val EUCLIDEAN = "euclidean" + @Since("2.4.0") + val COSINE = "cosine" + + private[spark] def decodeFromString(distanceMeasure: String): DistanceMeasure = + distanceMeasure match { + case EUCLIDEAN => new EuclideanDistanceMeasure + case COSINE => new CosineDistanceMeasure + case _ => throw new IllegalArgumentException(s"distanceMeasure must be one of: " + + s"$EUCLIDEAN, $COSINE. $distanceMeasure provided.") + } +} + +private[spark] class EuclideanDistanceMeasure extends DistanceMeasure { + /** + * @return the index of the closest center to the given point, as well as the squared distance. + */ + override def findClosest( centers: TraversableOnce[VectorWithNorm], point: VectorWithNorm): (Int, Double) = { var bestDistance = Double.PositiveInfinity @@ -561,7 +692,7 @@ object KMeans { var lowerBoundOfSqDist = center.norm - point.norm lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist if (lowerBoundOfSqDist < bestDistance) { - val distance: Double = fastSquaredDistance(center, point) + val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point) if (distance < bestDistance) { bestDistance = distance bestIndex = i @@ -573,15 +704,29 @@ object KMeans { } /** - * Returns the K-means cost of a given point against the given cluster centers. + * @return whether a center converged or not, given the epsilon parameter. */ - private[mllib] def pointCost( - centers: TraversableOnce[VectorWithNorm], - point: VectorWithNorm): Double = - findClosest(centers, point)._2 + override def isCenterConverged( + oldCenter: VectorWithNorm, + newCenter: VectorWithNorm, + epsilon: Double): Boolean = { + EuclideanDistanceMeasure.fastSquaredDistance(newCenter, oldCenter) <= epsilon * epsilon + } + + /** + * @param v1: first vector + * @param v2: second vector + * @return the Euclidean distance between the two input vectors + */ + override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { + Math.sqrt(EuclideanDistanceMeasure.fastSquaredDistance(v1, v2)) + } +} + +private[spark] object EuclideanDistanceMeasure { /** - * Returns the squared Euclidean distance between two vectors computed by + * @return the squared Euclidean distance between two vectors computed by * [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]]. */ private[clustering] def fastSquaredDistance( @@ -589,28 +734,15 @@ object KMeans { v2: VectorWithNorm): Double = { MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm) } - - private[spark] def validateInitMode(initMode: String): Boolean = { - initMode match { - case KMeans.RANDOM => true - case KMeans.K_MEANS_PARALLEL => true - case _ => false - } - } } -/** - * A vector with its norm for fast distance computation. - * - * @see [[org.apache.spark.mllib.clustering.KMeans#fastSquaredDistance]] - */ -private[clustering] -class VectorWithNorm(val vector: Vector, val norm: Double) extends Serializable { - - def this(vector: Vector) = this(vector, Vectors.norm(vector, 2.0)) - - def this(array: Array[Double]) = this(Vectors.dense(array)) - - /** Converts the vector to a dense vector. */ - def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm) +private[spark] class CosineDistanceMeasure extends DistanceMeasure { + /** + * @param v1: first vector + * @param v2: second vector + * @return the cosine distance between the two input vectors + */ + override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { + 1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 3ad08c46d204d..a78c21e838e44 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -36,12 +36,20 @@ import org.apache.spark.sql.{Row, SparkSession} * A clustering model for K-means. Each point belongs to the cluster with the closest center. */ @Since("0.8.0") -class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vector]) +class KMeansModel @Since("2.4.0") (@Since("1.0.0") val clusterCenters: Array[Vector], + @Since("2.4.0") val distanceMeasure: String) extends Saveable with Serializable with PMMLExportable { + private val distanceMeasureInstance: DistanceMeasure = + DistanceMeasure.decodeFromString(distanceMeasure) + private val clusterCentersWithNorm = if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_)) + @Since("1.1.0") + def this(clusterCenters: Array[Vector]) = + this(clusterCenters: Array[Vector], DistanceMeasure.EUCLIDEAN) + /** * A Java-friendly constructor that takes an Iterable of Vectors. */ @@ -59,7 +67,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec */ @Since("0.8.0") def predict(point: Vector): Int = { - KMeans.findClosest(clusterCentersWithNorm, new VectorWithNorm(point))._1 + distanceMeasureInstance.findClosest(clusterCentersWithNorm, new VectorWithNorm(point))._1 } /** @@ -68,7 +76,8 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec @Since("1.0.0") def predict(points: RDD[Vector]): RDD[Int] = { val bcCentersWithNorm = points.context.broadcast(clusterCentersWithNorm) - points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1) + points.map(p => + distanceMeasureInstance.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1) } /** @@ -85,8 +94,9 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec @Since("0.8.0") def computeCost(data: RDD[Vector]): Double = { val bcCentersWithNorm = data.context.broadcast(clusterCentersWithNorm) - val cost = data - .map(p => KMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))).sum() + val cost = data.map(p => + distanceMeasureInstance.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))) + .sum() bcCentersWithNorm.destroy(blocking = false) cost } @@ -94,7 +104,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec @Since("1.4.0") override def save(sc: SparkContext, path: String): Unit = { - KMeansModel.SaveLoadV1_0.save(sc, this, path) + KMeansModel.SaveLoadV2_0.save(sc, this, path) } override protected def formatVersion: String = "1.0" @@ -105,7 +115,20 @@ object KMeansModel extends Loader[KMeansModel] { @Since("1.4.0") override def load(sc: SparkContext, path: String): KMeansModel = { - KMeansModel.SaveLoadV1_0.load(sc, path) + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + val classNameV2_0 = SaveLoadV2_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + SaveLoadV1_0.load(sc, path) + case (className, "2.0") if className == classNameV2_0 => + SaveLoadV2_0.load(sc, path) + case _ => throw new Exception( + s"KMeansModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)\n" + + s" ($classNameV2_0, 2.0)") + } } private case class Cluster(id: Int, point: Vector) @@ -116,8 +139,7 @@ object KMeansModel extends Loader[KMeansModel] { } } - private[clustering] - object SaveLoadV1_0 { + private[clustering] object SaveLoadV1_0 { private val thisFormatVersion = "1.0" @@ -149,4 +171,38 @@ object KMeansModel extends Loader[KMeansModel] { new KMeansModel(localCentroids.sortBy(_.id).map(_.point)) } } + + private[clustering] object SaveLoadV2_0 { + + private val thisFormatVersion = "2.0" + + private[clustering] val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel" + + def save(sc: SparkContext, model: KMeansModel, path: String): Unit = { + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) + ~ ("k" -> model.k) ~ ("distanceMeasure" -> model.distanceMeasure))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex).map { case (p, id) => + Cluster(id, p.vector) + } + spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): KMeansModel = { + implicit val formats = DefaultFormats + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val k = (metadata \ "k").extract[Int] + val centroids = spark.read.parquet(Loader.dataPath(path)) + Loader.checkSchema[Cluster](centroids.schema) + val localCentroids = centroids.rdd.map(Cluster.apply).collect() + assert(k == localCentroids.length) + val distanceMeasure = (metadata \ "distanceMeasure").extract[String] + new KMeansModel(localCentroids.sortBy(_.id).map(_.point), distanceMeasure) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala index 53587670a5db0..4a08c0a55e68f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala @@ -46,7 +46,7 @@ private[mllib] object LocalKMeans extends Logging { // Initialize centers by sampling using the k-means++ procedure. centers(0) = pickWeighted(rand, points, weights).toDense - val costArray = points.map(KMeans.fastSquaredDistance(_, centers(0))) + val costArray = points.map(EuclideanDistanceMeasure.fastSquaredDistance(_, centers(0))) for (i <- 1 until k) { val sum = costArray.zip(weights).map(p => p._1 * p._2).sum @@ -67,11 +67,15 @@ private[mllib] object LocalKMeans extends Logging { // update costArray for (p <- points.indices) { - costArray(p) = math.min(KMeans.fastSquaredDistance(points(p), centers(i)), costArray(p)) + costArray(p) = math.min( + EuclideanDistanceMeasure.fastSquaredDistance(points(p), centers(i)), + costArray(p)) } } + val distanceMeasureInstance = new EuclideanDistanceMeasure + // Run up to maxIterations iterations of Lloyd's algorithm val oldClosest = Array.fill(points.length)(-1) var iteration = 0 @@ -83,7 +87,7 @@ private[mllib] object LocalKMeans extends Logging { var i = 0 while (i < points.length) { val p = points(i) - val index = KMeans.findClosest(centers, p)._1 + val index = distanceMeasureInstance.findClosest(centers, p)._1 axpy(weights(i), p.vector, sums(index)) counts(index) += weights(i) if (index != oldClosest(i)) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 119fe1dead9a9..e4506f23feb31 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} +import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} @@ -50,6 +50,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL) assert(kmeans.getInitSteps === 2) assert(kmeans.getTol === 1e-4) + assert(kmeans.getDistanceMeasure === DistanceMeasure.EUCLIDEAN) val model = kmeans.setMaxIter(1).fit(dataset) MLTestingUtils.checkCopyAndUids(kmeans, model) @@ -68,6 +69,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR .setInitSteps(3) .setSeed(123) .setTol(1e-3) + .setDistanceMeasure(DistanceMeasure.COSINE) assert(kmeans.getK === 9) assert(kmeans.getFeaturesCol === "test_feature") @@ -77,6 +79,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(kmeans.getInitSteps === 3) assert(kmeans.getSeed === 123) assert(kmeans.getTol === 1e-3) + assert(kmeans.getDistanceMeasure === DistanceMeasure.COSINE) } test("parameters validation") { @@ -89,6 +92,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR intercept[IllegalArgumentException] { new KMeans().setInitSteps(0) } + intercept[IllegalArgumentException] { + new KMeans().setDistanceMeasure("no_such_a_measure") + } } test("fit, transform and summary") { @@ -144,6 +150,37 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(model.getPredictionCol == predictionColName) } + test("KMeans using cosine distance") { + val df = spark.createDataFrame(spark.sparkContext.parallelize(Array( + Vectors.dense(1.0, 1.0), + Vectors.dense(10.0, 10.0), + Vectors.dense(1.0, 0.5), + Vectors.dense(10.0, 4.4), + Vectors.dense(-1.0, 1.0), + Vectors.dense(-100.0, 90.0) + )).map(v => TestRow(v))) + + val model = new KMeans() + .setK(3) + .setSeed(1) + .setInitMode(MLlibKMeans.RANDOM) + .setTol(1e-6) + .setDistanceMeasure(DistanceMeasure.COSINE) + .fit(df) + + val predictionDf = model.transform(df) + assert(predictionDf.select("prediction").distinct().count() == 3) + val predictionsMap = predictionDf.collect().map(row => + row.getAs[Vector]("features") -> row.getAs[Int]("prediction")).toMap + assert(predictionsMap(Vectors.dense(1.0, 1.0)) == + predictionsMap(Vectors.dense(10.0, 10.0))) + assert(predictionsMap(Vectors.dense(1.0, 0.5)) == + predictionsMap(Vectors.dense(10.0, 4.4))) + assert(predictionsMap(Vectors.dense(-1.0, 1.0)) == + predictionsMap(Vectors.dense(-100.0, 90.0))) + + } + test("read/write") { def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = { assert(model.clusterCenters === model2.clusterCenters) @@ -182,6 +219,7 @@ object KMeansSuite { "predictionCol" -> "myPrediction", "k" -> 3, "maxIter" -> 2, - "tol" -> 0.01 + "tol" -> 0.01, + "distanceMeasure" -> DistanceMeasure.EUCLIDEAN ) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 00d7e2f2d3864..1b98250061c7a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -89,7 +89,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { .setInitializationMode("k-means||") .setInitializationSteps(10) .setSeed(seed) - val initialCenters = km.initKMeansParallel(normedData).map(_.vector) + + val distanceMeasureInstance = new EuclideanDistanceMeasure + val initialCenters = km.initKMeansParallel(normedData, distanceMeasureInstance).map(_.vector) assert(initialCenters.length === initialCenters.distinct.length) assert(initialCenters.length <= numDistinctPoints) @@ -104,7 +106,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { .setInitializationMode("k-means||") .setInitializationSteps(10) .setSeed(seed) - val initialCenters2 = km2.initKMeansParallel(normedData).map(_.vector) + val initialCenters2 = km2.initKMeansParallel(normedData, distanceMeasureInstance).map(_.vector) assert(initialCenters2.length === initialCenters2.distinct.length) assert(initialCenters2.length === k) From 2239d7a410e906ccd40aa8e84d637e9d06cd7b8a Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 21 Jan 2018 11:23:51 -0800 Subject: [PATCH 161/774] [SPARK-21293][SS][SPARKR] Add doc example for streaming join, dedup ## What changes were proposed in this pull request? streaming programming guide changes ## How was this patch tested? manually Author: Felix Cheung Closes #20340 from felixcheung/rstreamdoc. --- .../structured-streaming-programming-guide.md | 74 ++++++++++++++++++- 1 file changed, 72 insertions(+), 2 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 2ef5d3168a87b..62589a62ac4c4 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1100,6 +1100,21 @@ streamingDf.join(staticDf, "type") # inner equi-join with a static DF streamingDf.join(staticDf, "type", "right_join") # right outer join with a static DF {% endhighlight %} +
+ +
+ +{% highlight r %} +staticDf <- read.df(...) +streamingDf <- read.stream(...) +joined <- merge(streamingDf, staticDf, sort = FALSE) # inner equi-join with a static DF +joined <- join( + staticDf, + streamingDf, + streamingDf$value == staticDf$value, + "right_outer") # right outer join with a static DF +{% endhighlight %} +
@@ -1227,6 +1242,30 @@ impressionsWithWatermark.join( {% endhighlight %} + +
+ +{% highlight r %} +impressions <- read.stream(...) +clicks <- read.stream(...) + +# Apply watermarks on event-time columns +impressionsWithWatermark <- withWatermark(impressions, "impressionTime", "2 hours") +clicksWithWatermark <- withWatermark(clicks, "clickTime", "3 hours") + +# Join with event-time constraints +joined <- join( + impressionsWithWatermark, + clicksWithWatermark, + expr( + paste( + "clickAdId = impressionAdId AND", + "clickTime >= impressionTime AND", + "clickTime <= impressionTime + interval 1 hour" +))) + +{% endhighlight %} +
@@ -1287,6 +1326,23 @@ impressionsWithWatermark.join( {% endhighlight %} + +
+ +{% highlight r %} +joined <- join( + impressionsWithWatermark, + clicksWithWatermark, + expr( + paste( + "clickAdId = impressionAdId AND", + "clickTime >= impressionTime AND", + "clickTime <= impressionTime + interval 1 hour"), + "left_outer" # can be "inner", "left_outer", "right_outer" +)) + +{% endhighlight %} +
@@ -1441,15 +1497,29 @@ streamingDf {% highlight python %} streamingDf = spark.readStream. ... -// Without watermark using guid column +# Without watermark using guid column streamingDf.dropDuplicates("guid") -// With watermark using guid and eventTime columns +# With watermark using guid and eventTime columns streamingDf \ .withWatermark("eventTime", "10 seconds") \ .dropDuplicates("guid", "eventTime") {% endhighlight %} + +
+ +{% highlight r %} +streamingDf <- read.stream(...) + +# Without watermark using guid column +streamingDf <- dropDuplicates(streamingDf, "guid") + +# With watermark using guid and eventTime columns +streamingDf <- withWatermark(streamingDf, "eventTime", "10 seconds") +streamingDf <- dropDuplicates(streamingDf, "guid", "eventTime") +{% endhighlight %} +
From 12faae295e42820b99a695ba49826051944244e1 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 22 Jan 2018 09:45:27 +0900 Subject: [PATCH 162/774] [SPARK-23169][INFRA][R] Run lintr on the changes of lint-r script and .lintr configuration ## What changes were proposed in this pull request? When running the `run-tests` script, seems we don't run lintr on the changes of `lint-r` script and `.lintr` configuration. ## How was this patch tested? Jenkins builds Author: hyukjinkwon Closes #20339 from HyukjinKwon/check-r-changed. --- dev/run-tests.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 7e6f7ff060351..fb270c4ee0508 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -578,7 +578,10 @@ def main(): pass if not changed_files or any(f.endswith(".py") for f in changed_files): run_python_style_checks() - if not changed_files or any(f.endswith(".R") for f in changed_files): + if not changed_files or any(f.endswith(".R") + or f.endswith("lint-r") + or f.endswith(".lintr") + for f in changed_files): run_sparkr_style_checks() # determine if docs were changed and if we're inside the amplab environment From 602c6d82d893a7f34b37d674642669048eb59b03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=99=93=E5=93=B2?= Date: Mon, 22 Jan 2018 10:43:12 +0900 Subject: [PATCH 163/774] [SPARK-20947][PYTHON] Fix encoding/decoding error in pipe action MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Pipe action convert objects into strings using a way that was affected by the default encoding setting of Python environment. This patch fixed the problem. The detailed description is added here: https://issues.apache.org/jira/browse/SPARK-20947 ## How was this patch tested? Run the following statement in pyspark-shell, and it will NOT raise exception if this patch is applied: ```python sc.parallelize([u'\u6d4b\u8bd5']).pipe('cat').collect() ``` Author: 王晓哲 Closes #18277 from chaoslawful/fix_pipe_encoding_error. --- python/pyspark/rdd.py | 2 +- python/pyspark/tests.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 340bc3a6b7470..1b3915548fb14 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -766,7 +766,7 @@ def func(iterator): def pipe_objs(out): for obj in iterator: - s = str(obj).rstrip('\n') + '\n' + s = unicode(obj).rstrip('\n') + '\n' out.write(s.encode('utf-8')) out.close() Thread(target=pipe_objs, args=[pipe.stdin]).start() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index da99872da2f0e..511585763cb01 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1239,6 +1239,13 @@ def test_pipe_functions(self): self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect) self.assertEqual([], rdd.pipe('grep 4').collect()) + def test_pipe_unicode(self): + # Regression test for SPARK-20947 + data = [u'\u6d4b\u8bd5', '1'] + rdd = self.sc.parallelize(data) + result = rdd.pipe('cat').collect() + self.assertEqual(data, result) + class ProfilerTests(PySparkTestCase): From 11daeb833222b1cd349fb1410307d64ab33981db Mon Sep 17 00:00:00 2001 From: Russell Spitzer Date: Mon, 22 Jan 2018 12:27:51 +0800 Subject: [PATCH 164/774] [SPARK-22976][CORE] Cluster mode driver dir removed while running ## What changes were proposed in this pull request? The clean up logic on the worker perviously determined the liveness of a particular applicaiton based on whether or not it had running executors. This would fail in the case that a directory was made for a driver running in cluster mode if that driver had no running executors on the same machine. To preserve driver directories we consider both executors and running drivers when checking directory liveness. ## How was this patch tested? Manually started up two node cluster with a single core on each node. Turned on worker directory cleanup and set the interval to 1 second and liveness to one second. Without the patch the driver directory is removed immediately after the app is launched. With the patch it is not ### Without Patch ``` INFO 2018-01-05 23:48:24,693 Logging.scala:54 - Asked to launch driver driver-20180105234824-0000 INFO 2018-01-05 23:48:25,293 Logging.scala:54 - Changing view acls to: cassandra INFO 2018-01-05 23:48:25,293 Logging.scala:54 - Changing modify acls to: cassandra INFO 2018-01-05 23:48:25,294 Logging.scala:54 - Changing view acls groups to: INFO 2018-01-05 23:48:25,294 Logging.scala:54 - Changing modify acls groups to: INFO 2018-01-05 23:48:25,294 Logging.scala:54 - SecurityManager: authentication disabled; ui acls disabled; users with view permissions: Set(cassandra); groups with view permissions: Set(); users with modify permissions: Set(cassandra); groups with modify permissions: Set() INFO 2018-01-05 23:48:25,330 Logging.scala:54 - Copying user jar file:/home/automaton/writeRead-0.1.jar to /var/lib/spark/worker/driver-20180105234824-0000/writeRead-0.1.jar INFO 2018-01-05 23:48:25,332 Logging.scala:54 - Copying /home/automaton/writeRead-0.1.jar to /var/lib/spark/worker/driver-20180105234824-0000/writeRead-0.1.jar INFO 2018-01-05 23:48:25,361 Logging.scala:54 - Launch Command: "/usr/lib/jvm/jdk1.8.0_40//bin/java" .... **** INFO 2018-01-05 23:48:56,577 Logging.scala:54 - Removing directory: /var/lib/spark/worker/driver-20180105234824-0000 ### << Cleaned up **** -- One minute passes while app runs (app has 1 minute sleep built in) -- WARN 2018-01-05 23:49:58,080 ShuffleSecretManager.java:73 - Attempted to unregister application app-20180105234831-0000 when it is not registered INFO 2018-01-05 23:49:58,081 ExternalShuffleBlockResolver.java:163 - Application app-20180105234831-0000 removed, cleanupLocalDirs = false INFO 2018-01-05 23:49:58,081 ExternalShuffleBlockResolver.java:163 - Application app-20180105234831-0000 removed, cleanupLocalDirs = false INFO 2018-01-05 23:49:58,082 ExternalShuffleBlockResolver.java:163 - Application app-20180105234831-0000 removed, cleanupLocalDirs = true INFO 2018-01-05 23:50:00,999 Logging.scala:54 - Driver driver-20180105234824-0000 exited successfully ``` With Patch ``` INFO 2018-01-08 23:19:54,603 Logging.scala:54 - Asked to launch driver driver-20180108231954-0002 INFO 2018-01-08 23:19:54,975 Logging.scala:54 - Changing view acls to: automaton INFO 2018-01-08 23:19:54,976 Logging.scala:54 - Changing modify acls to: automaton INFO 2018-01-08 23:19:54,976 Logging.scala:54 - Changing view acls groups to: INFO 2018-01-08 23:19:54,976 Logging.scala:54 - Changing modify acls groups to: INFO 2018-01-08 23:19:54,976 Logging.scala:54 - SecurityManager: authentication disabled; ui acls disabled; users with view permissions: Set(automaton); groups with view permissions: Set(); users with modify permissions: Set(automaton); groups with modify permissions: Set() INFO 2018-01-08 23:19:55,029 Logging.scala:54 - Copying user jar file:/home/automaton/writeRead-0.1.jar to /var/lib/spark/worker/driver-20180108231954-0002/writeRead-0.1.jar INFO 2018-01-08 23:19:55,031 Logging.scala:54 - Copying /home/automaton/writeRead-0.1.jar to /var/lib/spark/worker/driver-20180108231954-0002/writeRead-0.1.jar INFO 2018-01-08 23:19:55,038 Logging.scala:54 - Launch Command: ...... INFO 2018-01-08 23:21:28,674 ShuffleSecretManager.java:69 - Unregistered shuffle secret for application app-20180108232000-0000 INFO 2018-01-08 23:21:28,675 ExternalShuffleBlockResolver.java:163 - Application app-20180108232000-0000 removed, cleanupLocalDirs = false INFO 2018-01-08 23:21:28,675 ExternalShuffleBlockResolver.java:163 - Application app-20180108232000-0000 removed, cleanupLocalDirs = false INFO 2018-01-08 23:21:28,681 ExternalShuffleBlockResolver.java:163 - Application app-20180108232000-0000 removed, cleanupLocalDirs = true INFO 2018-01-08 23:21:31,703 Logging.scala:54 - Driver driver-20180108231954-0002 exited successfully ***** INFO 2018-01-08 23:21:32,346 Logging.scala:54 - Removing directory: /var/lib/spark/worker/driver-20180108231954-0002 ### < Happening AFTER the Run completes rather than during it ***** ``` Author: Russell Spitzer Closes #20298 from RussellSpitzer/SPARK-22976-master. --- core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 3962d422f81d3..563b84934f264 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -441,7 +441,7 @@ private[deploy] class Worker( // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker // rpcEndpoint. // Copy ids so that it can be used in the cleanup thread. - val appIds = executors.values.map(_.appId).toSet + val appIds = (executors.values.map(_.appId) ++ drivers.values.map(_.driverId)).toSet val cleanupFuture = concurrent.Future { val appDirs = workDir.listFiles() if (appDirs == null) { From 8142a3b883a5fe6fc620a2c5b25b6bde4fda32e5 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 22 Jan 2018 15:18:57 +0900 Subject: [PATCH 165/774] [MINOR][SQL] Fix wrong comments on org.apache.spark.sql.parquet.row.attributes ## What changes were proposed in this pull request? This PR fixes the wrong comment on `org.apache.spark.sql.parquet.row.attributes` which is useful for UDTs like Vector/Matrix. Please see [SPARK-22320](https://issues.apache.org/jira/browse/SPARK-22320) for the usage. Originally, [SPARK-19411](https://github.com/apache/spark/commit/bf493686eb17006727b3ec81849b22f3df68fdef#diff-ee26d4c4be21e92e92a02e9f16dbc285L314) left this behind during removing optional column metadatas. In the same PR, the same comment was removed at line 310-311. ## How was this patch tested? N/A (This is about comments). Author: Dongjoon Hyun Closes #20346 from dongjoon-hyun/minor_comment_parquet. --- .../sql/execution/datasources/parquet/ParquetFileFormat.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 45bedf70f975c..f53a97ba45a26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -108,8 +108,7 @@ class ParquetFileFormat ParquetOutputFormat.setWriteSupportClass(job, classOf[ParquetWriteSupport]) - // We want to clear this temporary metadata from saving into Parquet file. - // This metadata is only useful for detecting optional columns when pushdowning filters. + // This metadata is useful for keeping UDTs like Vector/Matrix. ParquetWriteSupport.setSchema(dataSchema, conf) // Sets flags for `ParquetWriteSupport`, which converts Catalyst schema to Parquet From ec228976156619ed8df21a85bceb5fd3bdeb5855 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 22 Jan 2018 14:49:12 +0800 Subject: [PATCH 166/774] [SPARK-23020][CORE] Fix races in launcher code, test. The race in the code is because the handle might update its state to the wrong state if the connection handling thread is still processing incoming data; so the handle needs to wait for the connection to finish up before checking the final state. The race in the test is because when waiting for a handle to reach a final state, the waitFor() method needs to wait until all handle state is updated (which also includes waiting for the connection thread above to finish). Otherwise, waitFor() may return too early, which would cause a bunch of different races (like the listener not being yet notified of the state change, or being in the middle of being notified, or the handle not being properly disposed and causing postChecks() to assert). On top of that I found, by code inspection, a couple of potential races that could make a handle end up in the wrong state when being killed. The original version of this fix introduced the flipped version of the first race described above; the connection closing might override the handle state before the handle might have a chance to do cleanup. The fix there is to only dispose of the handle from the connection when there is an error, and let the handle dispose itself in the normal case. The fix also caused a bug in YarnClusterSuite to be surfaced; the code was checking for a file in the classpath that was not expected to be there in client mode. Because of the above issues, the error was not propagating correctly and the (buggy) test was incorrectly passing. Tested by running the existing unit tests a lot (and not seeing the errors I was seeing before). Author: Marcelo Vanzin Closes #20297 from vanzin/SPARK-23020. --- .../spark/launcher/SparkLauncherSuite.java | 53 +++++++++++-------- .../spark/launcher/AbstractAppHandle.java | 22 ++++++-- .../spark/launcher/ChildProcAppHandle.java | 18 ++++--- .../spark/launcher/InProcessAppHandle.java | 17 +++--- .../spark/launcher/LauncherConnection.java | 18 +++---- .../apache/spark/launcher/LauncherServer.java | 49 +++++++++++++---- .../org/apache/spark/launcher/BaseSuite.java | 42 ++++++++++++--- .../spark/launcher/LauncherServerSuite.java | 23 ++++---- .../spark/deploy/yarn/YarnClusterSuite.scala | 4 +- 9 files changed, 165 insertions(+), 81 deletions(-) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index dffa609f1cbdf..a042375c6ae91 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.launcher; +import java.time.Duration; import java.util.Arrays; import java.util.ArrayList; import java.util.HashMap; @@ -25,13 +26,13 @@ import java.util.Properties; import java.util.concurrent.TimeUnit; -import org.junit.Ignore; import org.junit.Test; import static org.junit.Assert.*; import static org.junit.Assume.*; import static org.mockito.Mockito.*; import org.apache.spark.SparkContext; +import org.apache.spark.SparkContext$; import org.apache.spark.internal.config.package$; import org.apache.spark.util.Utils; @@ -121,8 +122,7 @@ public void testChildProcLauncher() throws Exception { assertEquals(0, app.waitFor()); } - // TODO: [SPARK-23020] Re-enable this - @Ignore + @Test public void testInProcessLauncher() throws Exception { // Because this test runs SparkLauncher in process and in client mode, it pollutes the system // properties, and that can cause test failures down the test pipeline. So restore the original @@ -139,7 +139,9 @@ public void testInProcessLauncher() throws Exception { // Here DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet. // Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM. // See SPARK-23019 and SparkContext.stop() for details. - TimeUnit.MILLISECONDS.sleep(500); + eventually(Duration.ofSeconds(5), Duration.ofMillis(10), () -> { + assertTrue("SparkContext is still alive.", SparkContext$.MODULE$.getActive().isEmpty()); + }); } } @@ -148,26 +150,35 @@ private void inProcessLauncherTestImpl() throws Exception { SparkAppHandle.Listener listener = mock(SparkAppHandle.Listener.class); doAnswer(invocation -> { SparkAppHandle h = (SparkAppHandle) invocation.getArguments()[0]; - transitions.add(h.getState()); + synchronized (transitions) { + transitions.add(h.getState()); + } return null; }).when(listener).stateChanged(any(SparkAppHandle.class)); - SparkAppHandle handle = new InProcessLauncher() - .setMaster("local") - .setAppResource(SparkLauncher.NO_RESOURCE) - .setMainClass(InProcessTestApp.class.getName()) - .addAppArgs("hello") - .startApplication(listener); - - waitFor(handle); - assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); - - // Matches the behavior of LocalSchedulerBackend. - List expected = Arrays.asList( - SparkAppHandle.State.CONNECTED, - SparkAppHandle.State.RUNNING, - SparkAppHandle.State.FINISHED); - assertEquals(expected, transitions); + SparkAppHandle handle = null; + try { + handle = new InProcessLauncher() + .setMaster("local") + .setAppResource(SparkLauncher.NO_RESOURCE) + .setMainClass(InProcessTestApp.class.getName()) + .addAppArgs("hello") + .startApplication(listener); + + waitFor(handle); + assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); + + // Matches the behavior of LocalSchedulerBackend. + List expected = Arrays.asList( + SparkAppHandle.State.CONNECTED, + SparkAppHandle.State.RUNNING, + SparkAppHandle.State.FINISHED); + assertEquals(expected, transitions); + } finally { + if (handle != null) { + handle.kill(); + } + } } public static class SparkLauncherTestApp { diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java index df1e7316861d4..daf0972f824dd 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java @@ -33,7 +33,7 @@ abstract class AbstractAppHandle implements SparkAppHandle { private List listeners; private State state; private String appId; - private boolean disposed; + private volatile boolean disposed; protected AbstractAppHandle(LauncherServer server) { this.server = server; @@ -70,8 +70,7 @@ public void stop() { @Override public synchronized void disconnect() { - if (!disposed) { - disposed = true; + if (!isDisposed()) { if (connection != null) { try { connection.close(); @@ -79,7 +78,7 @@ public synchronized void disconnect() { // no-op. } } - server.unregister(this); + dispose(); } } @@ -95,6 +94,21 @@ boolean isDisposed() { return disposed; } + /** + * Mark the handle as disposed, and set it as LOST in case the current state is not final. + */ + synchronized void dispose() { + if (!isDisposed()) { + // Unregister first to make sure that the connection with the app has been really + // terminated. + server.unregister(this); + if (!getState().isFinal()) { + setState(State.LOST); + } + this.disposed = true; + } + } + void setState(State s) { setState(s, false); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index 8b3f427b7750e..2b99461652e1f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -48,14 +48,16 @@ public synchronized void disconnect() { @Override public synchronized void kill() { - disconnect(); - if (childProc != null) { - if (childProc.isAlive()) { - childProc.destroyForcibly(); + if (!isDisposed()) { + setState(State.KILLED); + disconnect(); + if (childProc != null) { + if (childProc.isAlive()) { + childProc.destroyForcibly(); + } + childProc = null; } - childProc = null; } - setState(State.KILLED); } void setChildProc(Process childProc, String loggerName, InputStream logStream) { @@ -94,8 +96,6 @@ void monitorChild() { return; } - disconnect(); - int ec; try { ec = proc.exitValue(); @@ -118,6 +118,8 @@ void monitorChild() { if (newState != null) { setState(newState, true); } + + disconnect(); } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index acd64c962604f..f04263cb74a58 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -39,15 +39,16 @@ class InProcessAppHandle extends AbstractAppHandle { @Override public synchronized void kill() { - LOG.warning("kill() may leave the underlying app running in in-process mode."); - disconnect(); - - // Interrupt the thread. This is not guaranteed to kill the app, though. - if (app != null) { - app.interrupt(); + if (!isDisposed()) { + LOG.warning("kill() may leave the underlying app running in in-process mode."); + setState(State.KILLED); + disconnect(); + + // Interrupt the thread. This is not guaranteed to kill the app, though. + if (app != null) { + app.interrupt(); + } } - - setState(State.KILLED); } synchronized void start(String appName, Method main, String[] args) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java index b4a8719e26053..e8ab3f5e369ab 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java @@ -53,7 +53,7 @@ abstract class LauncherConnection implements Closeable, Runnable { public void run() { try { FilteredObjectInputStream in = new FilteredObjectInputStream(socket.getInputStream()); - while (!closed) { + while (isOpen()) { Message msg = (Message) in.readObject(); handle(msg); } @@ -95,15 +95,15 @@ protected synchronized void send(Message msg) throws IOException { } @Override - public void close() throws IOException { - if (!closed) { - synchronized (this) { - if (!closed) { - closed = true; - socket.close(); - } - } + public synchronized void close() throws IOException { + if (isOpen()) { + closed = true; + socket.close(); } } + boolean isOpen() { + return !closed; + } + } diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index b8999a1d7a4f4..8091885c4f562 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -217,6 +217,33 @@ void unregister(AbstractAppHandle handle) { break; } } + + // If there is a live connection for this handle, we need to wait for it to finish before + // returning, otherwise there might be a race between the connection thread processing + // buffered data and the handle cleaning up after itself, leading to potentially the wrong + // state being reported for the handle. + ServerConnection conn = null; + synchronized (clients) { + for (ServerConnection c : clients) { + if (c.handle == handle) { + conn = c; + break; + } + } + } + + if (conn != null) { + synchronized (conn) { + if (conn.isOpen()) { + try { + conn.wait(); + } catch (InterruptedException ie) { + // Ignore. + } + } + } + } + unref(); } @@ -288,7 +315,7 @@ private String createSecret() { private class ServerConnection extends LauncherConnection { private TimerTask timeout; - private AbstractAppHandle handle; + volatile AbstractAppHandle handle; ServerConnection(Socket socket, TimerTask timeout) throws IOException { super(socket); @@ -313,7 +340,7 @@ protected void handle(Message msg) throws IOException { } else { if (handle == null) { throw new IllegalArgumentException("Expected hello, got: " + - msg != null ? msg.getClass().getName() : null); + msg != null ? msg.getClass().getName() : null); } if (msg instanceof SetAppId) { SetAppId set = (SetAppId) msg; @@ -331,6 +358,9 @@ protected void handle(Message msg) throws IOException { timeout.cancel(); } close(); + if (handle != null) { + handle.dispose(); + } } finally { timeoutTimer.purge(); } @@ -338,16 +368,17 @@ protected void handle(Message msg) throws IOException { @Override public void close() throws IOException { + if (!isOpen()) { + return; + } + synchronized (clients) { clients.remove(this); } - super.close(); - if (handle != null) { - if (!handle.getState().isFinal()) { - LOG.log(Level.WARNING, "Lost connection to spark application."); - handle.setState(SparkAppHandle.State.LOST); - } - handle.disconnect(); + + synchronized (this) { + super.close(); + notifyAll(); } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java index 3e1a90eae98d4..3722a59d9438e 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.launcher; +import java.time.Duration; import java.util.concurrent.TimeUnit; import org.junit.After; @@ -47,19 +48,46 @@ public void postChecks() { assertNull(server); } - protected void waitFor(SparkAppHandle handle) throws Exception { - long deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + protected void waitFor(final SparkAppHandle handle) throws Exception { try { - while (!handle.getState().isFinal()) { - assertTrue("Timed out waiting for handle to transition to final state.", - System.nanoTime() < deadline); - TimeUnit.MILLISECONDS.sleep(10); - } + eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { + assertTrue("Handle is not in final state.", handle.getState().isFinal()); + }); } finally { if (!handle.getState().isFinal()) { handle.kill(); } } + + // Wait until the handle has been marked as disposed, to make sure all cleanup tasks + // have been performed. + AbstractAppHandle ahandle = (AbstractAppHandle) handle; + eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { + assertTrue("Handle is still not marked as disposed.", ahandle.isDisposed()); + }); + } + + /** + * Call a closure that performs a check every "period" until it succeeds, or the timeout + * elapses. + */ + protected void eventually(Duration timeout, Duration period, Runnable check) throws Exception { + assertTrue("Timeout needs to be larger than period.", timeout.compareTo(period) > 0); + long deadline = System.nanoTime() + timeout.toNanos(); + int count = 0; + while (true) { + try { + count++; + check.run(); + return; + } catch (Throwable t) { + if (System.nanoTime() >= deadline) { + String msg = String.format("Failed check after %d tries: %s.", count, t.getMessage()); + throw new IllegalStateException(msg, t); + } + Thread.sleep(period.toMillis()); + } + } } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 7e2b09ce25c9b..024efac33c391 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -23,12 +23,14 @@ import java.net.InetAddress; import java.net.Socket; import java.net.SocketException; +import java.time.Duration; import java.util.Arrays; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import static org.junit.Assert.*; @@ -143,7 +145,8 @@ public void infoChanged(SparkAppHandle handle) { assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS)); // Make sure the server matched the client to the handle. assertNotNull(handle.getConnection()); - close(client); + client.close(); + handle.dispose(); assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS)); assertEquals(SparkAppHandle.State.LOST, handle.getState()); } finally { @@ -197,28 +200,20 @@ private void close(Closeable c) { * server-side close immediately. */ private void waitForError(TestClient client, String secret) throws Exception { - boolean helloSent = false; - int maxTries = 10; - for (int i = 0; i < maxTries; i++) { + final AtomicBoolean helloSent = new AtomicBoolean(); + eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> { try { - if (!helloSent) { + if (!helloSent.get()) { client.send(new Hello(secret, "1.4.0")); - helloSent = true; + helloSent.set(true); } else { client.send(new SetAppId("appId")); } fail("Expected error but message went through."); } catch (IllegalStateException | IOException e) { // Expected. - break; - } catch (AssertionError e) { - if (i < maxTries - 1) { - Thread.sleep(100); - } else { - throw new AssertionError("Test failed after " + maxTries + " attempts.", e); - } } - } + }); } private static class TestClient extends LauncherConnection { diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 061f653b97b7a..e9dcfaf6ba4f0 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -381,7 +381,9 @@ private object YarnClusterDriver extends Logging with Matchers { // Verify that the config archive is correctly placed in the classpath of all containers. val confFile = "/" + Client.SPARK_CONF_FILE - assert(getClass().getResource(confFile) != null) + if (conf.getOption(SparkLauncher.DEPLOY_MODE) == Some("cluster")) { + assert(getClass().getResource(confFile) != null) + } val configFromExecutors = sc.parallelize(1 to 4, 4) .map { _ => Option(getClass().getResource(confFile)).map(_.toString).orNull } .collect() From 60175e959f275d2961798fbc5a9150dac9de51ff Mon Sep 17 00:00:00 2001 From: Arseniy Tashoyan Date: Mon, 22 Jan 2018 20:17:05 +0800 Subject: [PATCH 167/774] [MINOR][DOC] Fix the path to the examples jar ## What changes were proposed in this pull request? The example jar file is now in ./examples/jars directory of Spark distribution. Author: Arseniy Tashoyan Closes #20349 from tashoyan/patch-1. --- docs/running-on-yarn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index e4f5a0c659e66..c010af35f8d2e 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -35,7 +35,7 @@ For example: --executor-memory 2g \ --executor-cores 1 \ --queue thequeue \ - lib/spark-examples*.jar \ + examples/jars/spark-examples*.jar \ 10 The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. From 73281161fc7fddd645c712986ec376ac2b1bd213 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 22 Jan 2018 04:27:59 -0800 Subject: [PATCH 168/774] [SPARK-23122][PYSPARK][FOLLOW-UP] Update the docs for UDF Registration ## What changes were proposed in this pull request? This PR is to update the docs for UDF registration ## How was this patch tested? N/A Author: gatorsmile Closes #20348 from gatorsmile/testUpdateDoc. --- python/pyspark/sql/udf.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index c77f19f89a442..134badb8485f5 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -199,8 +199,8 @@ def __init__(self, sparkSession): @ignore_unicode_prefix @since("1.3.1") def register(self, name, f, returnType=None): - """Registers a Python function (including lambda function) or a user-defined function - in SQL statements. + """Register a Python function (including lambda function) or a user-defined function + as a SQL function. :param name: name of the user-defined function in SQL statements. :param f: a Python function, or a user-defined function. The user-defined function can @@ -210,6 +210,10 @@ def register(self, name, f, returnType=None): be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. :return: a user-defined function. + To register a nondeterministic Python function, users need to first build + a nondeterministic user-defined function for the Python function and then register it + as a SQL function. + `returnType` can be optionally specified when `f` is a Python function but not when `f` is a user-defined function. Please see below. @@ -297,7 +301,7 @@ def register(self, name, f, returnType=None): @ignore_unicode_prefix @since(2.3) def registerJavaFunction(self, name, javaClassName, returnType=None): - """Register a Java user-defined function so it can be used in SQL statements. + """Register a Java user-defined function as a SQL function. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not specified we would infer it via reflection. @@ -334,7 +338,7 @@ def registerJavaFunction(self, name, javaClassName, returnType=None): @ignore_unicode_prefix @since(2.3) def registerJavaUDAF(self, name, javaClassName): - """Register a Java user-defined aggregate function so it can be used in SQL statements. + """Register a Java user-defined aggregate function as a SQL function. :param name: name of the user-defined aggregate function :param javaClassName: fully qualified name of java class From 78801881c405de47f7e53eea3e0420dd69593dbd Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 22 Jan 2018 04:31:24 -0800 Subject: [PATCH 169/774] [SPARK-23170][SQL] Dump the statistics of effective runs of analyzer and optimizer rules ## What changes were proposed in this pull request? Dump the statistics of effective runs of analyzer and optimizer rules. ## How was this patch tested? Do a manual run of TPCDSQuerySuite ``` === Metrics of Analyzer/Optimizer Rules === Total number of runs: 175899 Total time: 25.486559948 seconds Rule Effective Time / Total Time Effective Runs / Total Runs org.apache.spark.sql.catalyst.optimizer.ColumnPruning 1603280450 / 2868461549 761 / 1877 org.apache.spark.sql.catalyst.analysis.Analyzer$CTESubstitution 2045860009 / 2056602674 37 / 788 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggregateFunctions 440719059 / 1693110949 38 / 1982 org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries 1429834919 / 1446016225 39 / 285 org.apache.spark.sql.catalyst.optimizer.PruneFilters 33273083 / 1389586938 3 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveReferences 821183615 / 1266668754 616 / 1982 org.apache.spark.sql.catalyst.optimizer.ReorderJoin 775837028 / 866238225 132 / 1592 org.apache.spark.sql.catalyst.analysis.DecimalPrecision 550683593 / 748854507 211 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubquery 513075345 / 634370596 49 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$FixNullability 33475731 / 606406532 12 / 742 org.apache.spark.sql.catalyst.analysis.TypeCoercion$ImplicitTypeCasts 193144298 / 545403925 86 / 1982 org.apache.spark.sql.catalyst.optimizer.BooleanSimplification 18651497 / 495725004 7 / 1592 org.apache.spark.sql.catalyst.optimizer.PushPredicateThroughJoin 369257217 / 489934378 709 / 1592 org.apache.spark.sql.catalyst.optimizer.RemoveRedundantAliases 3707000 / 468291609 9 / 1592 org.apache.spark.sql.catalyst.optimizer.InferFiltersFromConstraints 410155900 / 435254175 192 / 285 org.apache.spark.sql.execution.datasources.FindDataSourceTable 348885539 / 371855866 233 / 1982 org.apache.spark.sql.catalyst.optimizer.NullPropagation 11307645 / 307531225 26 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions 120324545 / 304948785 294 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$FunctionArgumentConversion 92323199 / 286695007 38 / 1982 org.apache.spark.sql.catalyst.optimizer.PushDownPredicate 230084193 / 265845972 785 / 1592 org.apache.spark.sql.catalyst.analysis.TypeCoercion$PromoteStrings 45938401 / 265144009 40 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$InConversion 14888776 / 261499450 1 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$CaseWhenCoercion 113796384 / 244913861 29 / 1982 org.apache.spark.sql.catalyst.optimizer.ConstantFolding 65008069 / 236548480 126 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$ExtractGenerator 0 / 226338929 0 / 1982 org.apache.spark.sql.catalyst.analysis.ResolveTimeZone 98134906 / 221323770 417 / 1982 org.apache.spark.sql.catalyst.optimizer.ReorderAssociativeOperator 0 / 208421703 0 / 1592 org.apache.spark.sql.catalyst.optimizer.OptimizeIn 8762534 / 199351958 16 / 1592 org.apache.spark.sql.catalyst.analysis.TypeCoercion$DateTimeOperations 11980016 / 190779046 27 / 1982 org.apache.spark.sql.catalyst.optimizer.SimplifyBinaryComparison 0 / 188887385 0 / 1592 org.apache.spark.sql.catalyst.optimizer.SimplifyConditionals 0 / 186812106 0 / 1592 org.apache.spark.sql.catalyst.optimizer.SimplifyCaseConversionExpressions 0 / 183885230 0 / 1592 org.apache.spark.sql.catalyst.optimizer.SimplifyCasts 17128295 / 182901910 69 / 1592 org.apache.spark.sql.catalyst.analysis.TypeCoercion$Division 14579110 / 180309340 8 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$BooleanEquality 0 / 176740516 0 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$IfCoercion 0 / 170781986 0 / 1982 org.apache.spark.sql.catalyst.optimizer.LikeSimplification 771605 / 164136736 1 / 1592 org.apache.spark.sql.catalyst.optimizer.RemoveDispensableExpressions 0 / 155958962 0 / 1592 org.apache.spark.sql.catalyst.analysis.ResolveCreateNamedStruct 0 / 151222943 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowOrder 7534632 / 146596355 14 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$EltCoercion 0 / 144488654 0 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$ConcatCoercion 0 / 142403338 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowFrame 12067635 / 141500665 21 / 1982 org.apache.spark.sql.catalyst.analysis.TimeWindowing 0 / 140431958 0 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$WindowFrameCoercion 0 / 125471960 0 / 1982 org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin 14226972 / 124922019 11 / 1592 org.apache.spark.sql.catalyst.analysis.TypeCoercion$StackCoercion 0 / 123613887 0 / 1982 org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery 8491071 / 121179056 7 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGroupingAnalytics 55526073 / 120290529 11 / 1982 org.apache.spark.sql.catalyst.optimizer.ConstantPropagation 0 / 113886790 0 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveDeserializer 52383759 / 107160222 148 / 1982 org.apache.spark.sql.catalyst.analysis.CleanupAliases 52543524 / 102091518 344 / 1086 org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject 40682895 / 94403652 342 / 1877 org.apache.spark.sql.catalyst.analysis.Analyzer$ExtractWindowExpressions 38473816 / 89740578 23 / 1982 org.apache.spark.sql.catalyst.optimizer.CollapseProject 46806090 / 83315506 281 / 1877 org.apache.spark.sql.catalyst.optimizer.FoldablePropagation 0 / 78750087 0 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAliases 13742765 / 77227258 47 / 1982 org.apache.spark.sql.catalyst.optimizer.CombineFilters 53386729 / 76960344 448 / 1592 org.apache.spark.sql.execution.datasources.DataSourceAnalysis 68034341 / 75724186 24 / 742 org.apache.spark.sql.catalyst.analysis.Analyzer$LookupFunctions 0 / 71151084 0 / 750 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveMissingReferences 12139848 / 67599140 8 / 1982 org.apache.spark.sql.catalyst.optimizer.PullupCorrelatedPredicates 45017938 / 65968777 23 / 285 org.apache.spark.sql.execution.datasources.v2.PushDownOperatorsToDataSource 0 / 60937767 0 / 285 org.apache.spark.sql.catalyst.optimizer.CollapseRepartition 0 / 59897237 0 / 1592 org.apache.spark.sql.catalyst.optimizer.PushProjectionThroughUnion 8547262 / 53941370 10 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$HandleNullInputsForUDF 0 / 52735976 0 / 742 org.apache.spark.sql.catalyst.analysis.TypeCoercion$WidenSetOperationTypes 9797713 / 52401665 9 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$PullOutNondeterministic 0 / 51741500 0 / 742 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRelations 28614911 / 51061186 233 / 1990 org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions 0 / 50621510 0 / 285 org.apache.spark.sql.catalyst.optimizer.CombineUnions 2777800 / 50262112 17 / 1877 org.apache.spark.sql.catalyst.analysis.Analyzer$GlobalAggregates 1640641 / 49633909 46 / 1982 org.apache.spark.sql.catalyst.optimizer.DecimalAggregates 20198374 / 48488419 100 / 385 org.apache.spark.sql.catalyst.optimizer.LimitPushDown 0 / 45052523 0 / 1592 org.apache.spark.sql.catalyst.optimizer.CombineLimits 0 / 44719443 0 / 1592 org.apache.spark.sql.catalyst.optimizer.EliminateSorts 0 / 44216930 0 / 1592 org.apache.spark.sql.catalyst.optimizer.RewritePredicateSubquery 36235699 / 44165786 148 / 285 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveNewInstance 0 / 42750307 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUpCast 0 / 41811748 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveOrdinalInOrderByAndGroupBy 3819476 / 41776562 4 / 1982 org.apache.spark.sql.catalyst.optimizer.ComputeCurrentTime 0 / 40527808 0 / 285 org.apache.spark.sql.catalyst.optimizer.CollapseWindow 0 / 36832538 0 / 1592 org.apache.spark.sql.catalyst.optimizer.EliminateSerialization 0 / 36120667 0 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggAliasInGroupBy 0 / 32435826 0 / 1982 org.apache.spark.sql.execution.datasources.PreprocessTableCreation 0 / 32145218 0 / 742 org.apache.spark.sql.execution.datasources.ResolveSQLOnFile 0 / 30295614 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolvePivot 0 / 30111655 0 / 1982 org.apache.spark.sql.catalyst.expressions.codegen.package$ExpressionCanonicalizer$CleanExpressions 59930 / 28038201 26 / 8280 org.apache.spark.sql.catalyst.analysis.ResolveInlineTables 0 / 27808108 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubqueryColumnAliases 0 / 27066690 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGenerate 0 / 26660210 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveNaturalAndUsingJoin 0 / 25255184 0 / 1982 org.apache.spark.sql.catalyst.analysis.ResolveTableValuedFunctions 0 / 24663088 0 / 1990 org.apache.spark.sql.catalyst.analysis.SubstituteUnresolvedOrdinals 9709079 / 24450670 4 / 788 org.apache.spark.sql.catalyst.analysis.ResolveHints$ResolveBroadcastHints 0 / 23776535 0 / 750 org.apache.spark.sql.catalyst.optimizer.ReplaceExpressions 0 / 22697895 0 / 285 org.apache.spark.sql.catalyst.optimizer.CheckCartesianProducts 0 / 22523798 0 / 285 org.apache.spark.sql.catalyst.optimizer.ReplaceDistinctWithAggregate 988593 / 21535410 15 / 300 org.apache.spark.sql.catalyst.optimizer.EliminateMapObjects 0 / 20269996 0 / 285 org.apache.spark.sql.catalyst.optimizer.RewriteDistinctAggregates 0 / 19388592 0 / 285 org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases 17675532 / 18971185 215 / 285 org.apache.spark.sql.catalyst.optimizer.GetCurrentDatabase 0 / 18271152 0 / 285 org.apache.spark.sql.catalyst.optimizer.PropagateEmptyRelation 2077097 / 17190855 3 / 288 org.apache.spark.sql.catalyst.analysis.EliminateBarriers 0 / 16736359 0 / 1086 org.apache.spark.sql.execution.OptimizeMetadataOnlyQuery 0 / 16669341 0 / 285 org.apache.spark.sql.catalyst.analysis.UpdateOuterReferences 0 / 14470235 0 / 742 org.apache.spark.sql.catalyst.optimizer.ReplaceExceptWithAntiJoin 6715625 / 12190561 1 / 300 org.apache.spark.sql.catalyst.optimizer.ReplaceIntersectWithSemiJoin 3451793 / 11431432 7 / 300 org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate 0 / 10810568 0 / 285 org.apache.spark.sql.catalyst.optimizer.RemoveRepetitionFromGroupExpressions 344198 / 10475276 1 / 286 org.apache.spark.sql.catalyst.analysis.Analyzer$WindowsSubstitution 0 / 10386630 0 / 788 org.apache.spark.sql.catalyst.analysis.EliminateUnions 0 / 10096526 0 / 788 org.apache.spark.sql.catalyst.analysis.AliasViewChild 0 / 9991706 0 / 742 org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation 0 / 9649334 0 / 288 org.apache.spark.sql.catalyst.analysis.ResolveHints$RemoveAllHints 0 / 8739109 0 / 750 org.apache.spark.sql.execution.datasources.PreprocessTableInsertion 0 / 8420889 0 / 742 org.apache.spark.sql.catalyst.analysis.EliminateView 0 / 8319134 0 / 285 org.apache.spark.sql.catalyst.optimizer.RemoveLiteralFromGroupExpressions 0 / 7392627 0 / 286 org.apache.spark.sql.catalyst.optimizer.ReplaceExceptWithFilter 0 / 7170516 0 / 300 org.apache.spark.sql.catalyst.optimizer.SimplifyCreateArrayOps 0 / 7109643 0 / 1592 org.apache.spark.sql.catalyst.optimizer.SimplifyCreateStructOps 0 / 6837590 0 / 1592 org.apache.spark.sql.catalyst.optimizer.SimplifyCreateMapOps 0 / 6617848 0 / 1592 org.apache.spark.sql.catalyst.optimizer.CombineConcats 0 / 5768406 0 / 1592 org.apache.spark.sql.catalyst.optimizer.ReplaceDeduplicateWithAggregate 0 / 5349831 0 / 285 org.apache.spark.sql.catalyst.optimizer.CombineTypedFilters 0 / 5186642 0 / 285 org.apache.spark.sql.catalyst.optimizer.EliminateDistinct 0 / 2427686 0 / 285 org.apache.spark.sql.catalyst.optimizer.CostBasedJoinReorder 0 / 2420436 0 / 285 ``` Author: gatorsmile Closes #20342 from gatorsmile/reportExecution. --- .../rules/QueryExecutionMetering.scala | 91 +++++++++++++++++++ .../sql/catalyst/rules/RuleExecutor.scala | 32 +++---- .../apache/spark/sql/BenchmarkQueryTest.scala | 2 +- .../apache/spark/sql/SQLQueryTestSuite.scala | 2 +- .../execution/HiveCompatibilitySuite.scala | 2 +- 5 files changed, 109 insertions(+), 20 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala new file mode 100644 index 0000000000000..62f7541150a6e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.rules + +import scala.collection.JavaConverters._ + +import com.google.common.util.concurrent.AtomicLongMap + +case class QueryExecutionMetering() { + private val timeMap = AtomicLongMap.create[String]() + private val numRunsMap = AtomicLongMap.create[String]() + private val numEffectiveRunsMap = AtomicLongMap.create[String]() + private val timeEffectiveRunsMap = AtomicLongMap.create[String]() + + /** Resets statistics about time spent running specific rules */ + def resetMetrics(): Unit = { + timeMap.clear() + numRunsMap.clear() + numEffectiveRunsMap.clear() + timeEffectiveRunsMap.clear() + } + + def totalTime: Long = { + timeMap.sum() + } + + def totalNumRuns: Long = { + numRunsMap.sum() + } + + def incExecutionTimeBy(ruleName: String, delta: Long): Unit = { + timeMap.addAndGet(ruleName, delta) + } + + def incTimeEffectiveExecutionBy(ruleName: String, delta: Long): Unit = { + timeEffectiveRunsMap.addAndGet(ruleName, delta) + } + + def incNumEffectiveExecution(ruleName: String): Unit = { + numEffectiveRunsMap.incrementAndGet(ruleName) + } + + def incNumExecution(ruleName: String): Unit = { + numRunsMap.incrementAndGet(ruleName) + } + + /** Dump statistics about time spent running specific rules. */ + def dumpTimeSpent(): String = { + val map = timeMap.asMap().asScala + val maxLengthRuleNames = map.keys.map(_.toString.length).max + + val colRuleName = "Rule".padTo(maxLengthRuleNames, " ").mkString + val colRunTime = "Effective Time / Total Time".padTo(len = 47, " ").mkString + val colNumRuns = "Effective Runs / Total Runs".padTo(len = 47, " ").mkString + + val ruleMetrics = map.toSeq.sortBy(_._2).reverseMap { case (name, time) => + val timeEffectiveRun = timeEffectiveRunsMap.get(name) + val numRuns = numRunsMap.get(name) + val numEffectiveRun = numEffectiveRunsMap.get(name) + + val ruleName = name.padTo(maxLengthRuleNames, " ").mkString + val runtimeValue = s"$timeEffectiveRun / $time".padTo(len = 47, " ").mkString + val numRunValue = s"$numEffectiveRun / $numRuns".padTo(len = 47, " ").mkString + s"$ruleName $runtimeValue $numRunValue" + }.mkString("\n", "\n", "") + + s""" + |=== Metrics of Analyzer/Optimizer Rules === + |Total number of runs: $totalNumRuns + |Total time: ${totalTime / 1000000000D} seconds + | + |$colRuleName $colRunTime $colNumRuns + |$ruleMetrics + """.stripMargin + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 7e4b784033bfc..dccb44ddebfa4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.catalyst.rules -import scala.collection.JavaConverters._ - -import com.google.common.util.concurrent.AtomicLongMap - import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees.TreeNode @@ -28,18 +24,16 @@ import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.util.Utils object RuleExecutor { - protected val timeMap = AtomicLongMap.create[String]() - - /** Resets statistics about time spent running specific rules */ - def resetTime(): Unit = timeMap.clear() + protected val queryExecutionMeter = QueryExecutionMetering() /** Dump statistics about time spent running specific rules. */ def dumpTimeSpent(): String = { - val map = timeMap.asMap().asScala - val maxSize = map.keys.map(_.toString.length).max - map.toSeq.sortBy(_._2).reverseMap { case (k, v) => - s"${k.padTo(maxSize, " ").mkString} $v" - }.mkString("\n", "\n", "") + queryExecutionMeter.dumpTimeSpent() + } + + /** Resets statistics about time spent running specific rules */ + def resetMetrics(): Unit = { + queryExecutionMeter.resetMetrics() } } @@ -77,6 +71,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { */ def execute(plan: TreeType): TreeType = { var curPlan = plan + val queryExecutionMetrics = RuleExecutor.queryExecutionMeter batches.foreach { batch => val batchStartPlan = curPlan @@ -91,15 +86,18 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { val startTime = System.nanoTime() val result = rule(plan) val runTime = System.nanoTime() - startTime - RuleExecutor.timeMap.addAndGet(rule.ruleName, runTime) if (!result.fastEquals(plan)) { + queryExecutionMetrics.incNumEffectiveExecution(rule.ruleName) + queryExecutionMetrics.incTimeEffectiveExecutionBy(rule.ruleName, runTime) logTrace( s""" |=== Applying Rule ${rule.ruleName} === |${sideBySide(plan.treeString, result.treeString).mkString("\n")} """.stripMargin) } + queryExecutionMetrics.incExecutionTimeBy(rule.ruleName, runTime) + queryExecutionMetrics.incNumExecution(rule.ruleName) // Run the structural integrity checker against the plan after each rule. if (!isPlanIntegral(result)) { @@ -135,9 +133,9 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { if (!batchStartPlan.fastEquals(curPlan)) { logDebug( s""" - |=== Result of Batch ${batch.name} === - |${sideBySide(batchStartPlan.treeString, curPlan.treeString).mkString("\n")} - """.stripMargin) + |=== Result of Batch ${batch.name} === + |${sideBySide(batchStartPlan.treeString, curPlan.treeString).mkString("\n")} + """.stripMargin) } else { logTrace(s"Batch ${batch.name} has no effect.") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala index 7037749f14478..e51aad021fcbf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala @@ -46,7 +46,7 @@ abstract class BenchmarkQueryTest extends QueryTest with SharedSQLContext with B override def beforeAll() { super.beforeAll() - RuleExecutor.resetTime() + RuleExecutor.resetMetrics() } protected def checkGeneratedCode(plan: SparkPlan): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index e3901af4b9988..054ada56d99ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -291,7 +291,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting Locale.setDefault(Locale.US) - RuleExecutor.resetTime() + RuleExecutor.resetMetrics() } override def afterAll(): Unit = { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 45791c69b4cb7..cebaad5b4ad9b 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -62,7 +62,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Fix session local timezone to America/Los_Angeles for those timezone sensitive tests // (timestamp_*) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, "America/Los_Angeles") - RuleExecutor.resetTime() + RuleExecutor.resetMetrics() } override def afterAll() { From 896e45af5fea264683b1d7d20a1711f33908a06f Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 22 Jan 2018 04:32:59 -0800 Subject: [PATCH 170/774] [MINOR][SQL][TEST] Test case cleanups for recent PRs ## What changes were proposed in this pull request? Revert the unneeded test case changes we made in SPARK-23000 Also fixes the test suites that do not call `super.afterAll()` in the local `afterAll`. The `afterAll()` of `TestHiveSingleton` actually reset the environments. ## How was this patch tested? N/A Author: gatorsmile Closes #20341 from gatorsmile/testRelated. --- .../apache/spark/sql/DataFrameJoinSuite.scala | 21 ++++++----- .../apache/spark/sql/hive/test/TestHive.scala | 3 +- .../sql/hive/HiveMetastoreCatalogSuite.scala | 26 +++++++------- .../sql/hive/execution/HiveUDAFSuite.scala | 8 +++-- .../hive/execution/Hive_2_1_DDLSuite.scala | 6 +++- .../execution/ObjectHashAggregateSuite.scala | 6 +++- .../apache/spark/sql/hive/parquetSuites.scala | 35 +++++++++++-------- 7 files changed, 60 insertions(+), 45 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 1656f290ee19c..0d9eeabb397a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext class DataFrameJoinSuite extends QueryTest with SharedSQLContext { @@ -276,16 +277,14 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { test("SPARK-23087: don't throw Analysis Exception in CheckCartesianProduct when join condition " + "is false or null") { - val df = spark.range(10) - val dfNull = spark.range(10).select(lit(null).as("b")) - val planNull = df.join(dfNull, $"id" === $"b", "left").queryExecution.analyzed - - spark.sessionState.executePlan(planNull).optimizedPlan - - val dfOne = df.select(lit(1).as("a")) - val dfTwo = spark.range(10).select(lit(2).as("b")) - val planFalse = dfOne.join(dfTwo, $"a" === $"b", "left").queryExecution.analyzed - - spark.sessionState.executePlan(planFalse).optimizedPlan + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") { + val df = spark.range(10) + val dfNull = spark.range(10).select(lit(null).as("b")) + df.join(dfNull, $"id" === $"b", "left").queryExecution.optimizedPlan + + val dfOne = df.select(lit(1).as("a")) + val dfTwo = spark.range(10).select(lit(2).as("b")) + dfOne.join(dfTwo, $"a" === $"b", "left").queryExecution.optimizedPlan + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index c84131fc3212a..7287e20d55bbe 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -492,8 +492,7 @@ private[hive] class TestHiveSparkSession( protected val originalUDFs: JavaSet[String] = FunctionRegistry.getFunctionNames /** - * Resets the test instance by deleting any tables that have been created. - * TODO: also clear out UDFs, views, etc. + * Resets the test instance by deleting any table, view, temp view, and UDF that have been created */ def reset() { try { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 83b4c862e2546..ba9b944e4a055 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -166,13 +166,13 @@ class DataSourceWithHiveMetastoreCatalogSuite )) ).foreach { case (provider, (inputFormat, outputFormat, serde)) => test(s"Persist non-partitioned $provider relation into metastore as managed table") { - withTable("default.t") { + withTable("t") { withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { testDF .write .mode(SaveMode.Overwrite) .format(provider) - .saveAsTable("default.t") + .saveAsTable("t") } val hiveTable = sessionState.catalog.getTableMetadata(TableIdentifier("t", Some("default"))) @@ -187,15 +187,14 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.dataType) === Seq(DecimalType(10, 3), StringType)) - checkAnswer(table("default.t"), testDF) - assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM default.t") === - Seq("1.1\t1", "2.1\t2")) + checkAnswer(table("t"), testDF) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } test(s"Persist non-partitioned $provider relation into metastore as external table") { withTempPath { dir => - withTable("default.t") { + withTable("t") { val path = dir.getCanonicalFile withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { @@ -204,7 +203,7 @@ class DataSourceWithHiveMetastoreCatalogSuite .mode(SaveMode.Overwrite) .format(provider) .option("path", path.toString) - .saveAsTable("default.t") + .saveAsTable("t") } val hiveTable = @@ -220,8 +219,8 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.dataType) === Seq(DecimalType(10, 3), StringType)) - checkAnswer(table("default.t"), testDF) - assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM default.t") === + checkAnswer(table("t"), testDF) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } @@ -229,9 +228,9 @@ class DataSourceWithHiveMetastoreCatalogSuite test(s"Persist non-partitioned $provider relation into metastore as managed table using CTAS") { withTempPath { dir => - withTable("default.t") { + withTable("t") { sql( - s"""CREATE TABLE default.t USING $provider + s"""CREATE TABLE t USING $provider |OPTIONS (path '${dir.toURI}') |AS SELECT 1 AS d1, "val_1" AS d2 """.stripMargin) @@ -249,9 +248,8 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.dataType) === Seq(IntegerType, StringType)) - checkAnswer(table("default.t"), Row(1, "val_1")) - assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM default.t") === - Seq("1\tval_1")) + checkAnswer(table("t"), Row(1, "val_1")) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index 8986fb58c6460..7402c9626873c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -49,8 +49,12 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } protected override def afterAll(): Unit = { - sql(s"DROP TEMPORARY FUNCTION IF EXISTS mock") - sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") + try { + sql(s"DROP TEMPORARY FUNCTION IF EXISTS mock") + sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") + } finally { + super.afterAll() + } } test("built-in Hive UDAF") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala index bc828877e35ec..eaedac1fa95d8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala @@ -74,7 +74,11 @@ class Hive_2_1_DDLSuite extends SparkFunSuite with TestHiveSingleton with Before } override def afterAll(): Unit = { - catalog = null + try { + catalog = null + } finally { + super.afterAll() + } } test("SPARK-21617: ALTER TABLE for non-compatible DataSource tables") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index 9eaf44c043c71..8dbcd24cd78de 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -47,7 +47,11 @@ class ObjectHashAggregateSuite } protected override def afterAll(): Unit = { - sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") + try { + sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") + } finally { + super.afterAll() + } } test("typed_count without grouping keys") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 740e0837350cc..2327d83a1b4f6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -180,15 +180,18 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } override def afterAll(): Unit = { - dropTables("partitioned_parquet", - "partitioned_parquet_with_key", - "partitioned_parquet_with_complextypes", - "partitioned_parquet_with_key_and_complextypes", - "normal_parquet", - "jt", - "jt_array", - "test_parquet") - super.afterAll() + try { + dropTables("partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes", + "normal_parquet", + "jt", + "jt_array", + "test_parquet") + } finally { + super.afterAll() + } } test(s"conversion is working") { @@ -931,11 +934,15 @@ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with } override protected def afterAll(): Unit = { - partitionedTableDir.delete() - normalTableDir.delete() - partitionedTableDirWithKey.delete() - partitionedTableDirWithComplexTypes.delete() - partitionedTableDirWithKeyAndComplexTypes.delete() + try { + partitionedTableDir.delete() + normalTableDir.delete() + partitionedTableDirWithKey.delete() + partitionedTableDirWithComplexTypes.delete() + partitionedTableDirWithKeyAndComplexTypes.delete() + } finally { + super.afterAll() + } } /** From 5d680cae486c77cdb12dbe9e043710e49e8d51e4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 22 Jan 2018 20:56:38 +0800 Subject: [PATCH 171/774] [SPARK-23090][SQL] polish ColumnVector ## What changes were proposed in this pull request? Several improvements: * provide a default implementation for the batch get methods * rename `getChildColumn` to `getChild`, which is more concise * remove `getStruct(int, int)`, it's only used to simplify the codegen, which is an internal thing, we should not add a public API for this purpose. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #20277 from cloud-fan/column-vector. --- .../expressions/codegen/CodeGenerator.scala | 18 ++-- .../datasources/orc/OrcColumnVector.java | 65 +----------- .../orc/OrcColumnarBatchReader.java | 23 ++--- .../vectorized/ColumnVectorUtils.java | 10 +- .../vectorized/MutableColumnarRow.java | 4 +- .../vectorized/WritableColumnVector.java | 10 +- .../sql/vectorized/ArrowColumnVector.java | 99 +------------------ .../spark/sql/vectorized/ColumnVector.java | 79 +++++++++++---- .../spark/sql/vectorized/ColumnarArray.java | 4 +- .../spark/sql/vectorized/ColumnarRow.java | 46 ++++----- .../sql/execution/ColumnarBatchScan.scala | 2 +- .../VectorizedHashMapGenerator.scala | 4 +- .../execution/arrow/ArrowWriterSuite.scala | 14 +-- .../vectorized/ArrowColumnVectorSuite.scala | 12 +-- .../vectorized/ColumnVectorSuite.scala | 12 +-- .../vectorized/ColumnarBatchBenchmark.scala | 38 +++---- .../vectorized/ColumnarBatchSuite.scala | 20 ++-- 17 files changed, 164 insertions(+), 296 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2c714c228e6c9..f96ed7628fda1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -688,17 +688,13 @@ class CodegenContext { /** * Returns the specialized code to access a value from a column vector for a given `DataType`. */ - def getValue(vector: String, rowId: String, dataType: DataType): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => - s"$vector.get${primitiveTypeName(jt)}($rowId)" - case t: DecimalType => - s"$vector.getDecimal($rowId, ${t.precision}, ${t.scale})" - case StringType => - s"$vector.getUTF8String($rowId)" - case _ => - throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") + def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = { + if (dataType.isInstanceOf[StructType]) { + // `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an + // `ordinal` parameter. + s"$vector.getStruct($rowId)" + } else { + getValue(vector, dataType, rowId) } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index b6e792274da11..aaf2a380034a9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -110,57 +110,21 @@ public boolean getBoolean(int rowId) { return longData.vector[getRowIndex(rowId)] == 1; } - @Override - public boolean[] getBooleans(int rowId, int count) { - boolean[] res = new boolean[count]; - for (int i = 0; i < count; i++) { - res[i] = getBoolean(rowId + i); - } - return res; - } - @Override public byte getByte(int rowId) { return (byte) longData.vector[getRowIndex(rowId)]; } - @Override - public byte[] getBytes(int rowId, int count) { - byte[] res = new byte[count]; - for (int i = 0; i < count; i++) { - res[i] = getByte(rowId + i); - } - return res; - } - @Override public short getShort(int rowId) { return (short) longData.vector[getRowIndex(rowId)]; } - @Override - public short[] getShorts(int rowId, int count) { - short[] res = new short[count]; - for (int i = 0; i < count; i++) { - res[i] = getShort(rowId + i); - } - return res; - } - @Override public int getInt(int rowId) { return (int) longData.vector[getRowIndex(rowId)]; } - @Override - public int[] getInts(int rowId, int count) { - int[] res = new int[count]; - for (int i = 0; i < count; i++) { - res[i] = getInt(rowId + i); - } - return res; - } - @Override public long getLong(int rowId) { int index = getRowIndex(rowId); @@ -171,43 +135,16 @@ public long getLong(int rowId) { } } - @Override - public long[] getLongs(int rowId, int count) { - long[] res = new long[count]; - for (int i = 0; i < count; i++) { - res[i] = getLong(rowId + i); - } - return res; - } - @Override public float getFloat(int rowId) { return (float) doubleData.vector[getRowIndex(rowId)]; } - @Override - public float[] getFloats(int rowId, int count) { - float[] res = new float[count]; - for (int i = 0; i < count; i++) { - res[i] = getFloat(rowId + i); - } - return res; - } - @Override public double getDouble(int rowId) { return doubleData.vector[getRowIndex(rowId)]; } - @Override - public double[] getDoubles(int rowId, int count) { - double[] res = new double[count]; - for (int i = 0; i < count; i++) { - res[i] = getDouble(rowId + i); - } - return res; - } - @Override public int getArrayLength(int rowId) { throw new UnsupportedOperationException(); @@ -245,7 +182,7 @@ public org.apache.spark.sql.vectorized.ColumnVector arrayData() { } @Override - public org.apache.spark.sql.vectorized.ColumnVector getChildColumn(int ordinal) { + public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) { throw new UnsupportedOperationException(); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index 89bae4326e93b..5e7cad470e1d1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -289,10 +289,9 @@ private void putRepeatingValues( toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector[0]); } else if (type instanceof StringType || type instanceof BinaryType) { BytesColumnVector data = (BytesColumnVector)fromColumn; - WritableColumnVector arrayData = toColumn.getChildColumn(0); int size = data.vector[0].length; - arrayData.reserve(size); - arrayData.putBytes(0, size, data.vector[0], 0); + toColumn.arrayData().reserve(size); + toColumn.arrayData().putBytes(0, size, data.vector[0], 0); for (int index = 0; index < batchSize; index++) { toColumn.putArray(index, 0, size); } @@ -352,7 +351,7 @@ private void putNonNullValues( toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector, 0); } else if (type instanceof StringType || type instanceof BinaryType) { BytesColumnVector data = ((BytesColumnVector)fromColumn); - WritableColumnVector arrayData = toColumn.getChildColumn(0); + WritableColumnVector arrayData = toColumn.arrayData(); int totalNumBytes = IntStream.of(data.length).sum(); arrayData.reserve(totalNumBytes); for (int index = 0, pos = 0; index < batchSize; pos += data.length[index], index++) { @@ -363,8 +362,7 @@ private void putNonNullValues( DecimalType decimalType = (DecimalType)type; DecimalColumnVector data = ((DecimalColumnVector)fromColumn); if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) { - WritableColumnVector arrayData = toColumn.getChildColumn(0); - arrayData.reserve(batchSize * 16); + toColumn.arrayData().reserve(batchSize * 16); } for (int index = 0; index < batchSize; index++) { putDecimalWritable( @@ -459,7 +457,7 @@ private void putValues( } } else if (type instanceof StringType || type instanceof BinaryType) { BytesColumnVector vector = (BytesColumnVector)fromColumn; - WritableColumnVector arrayData = toColumn.getChildColumn(0); + WritableColumnVector arrayData = toColumn.arrayData(); int totalNumBytes = IntStream.of(vector.length).sum(); arrayData.reserve(totalNumBytes); for (int index = 0, pos = 0; index < batchSize; pos += vector.length[index], index++) { @@ -474,8 +472,7 @@ private void putValues( DecimalType decimalType = (DecimalType)type; HiveDecimalWritable[] vector = ((DecimalColumnVector)fromColumn).vector; if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) { - WritableColumnVector arrayData = toColumn.getChildColumn(0); - arrayData.reserve(batchSize * 16); + toColumn.arrayData().reserve(batchSize * 16); } for (int index = 0; index < batchSize; index++) { if (fromColumn.isNull[index]) { @@ -521,8 +518,7 @@ private static void putDecimalWritable( toColumn.putLong(index, value.toUnscaledLong()); } else { byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray(); - WritableColumnVector arrayData = toColumn.getChildColumn(0); - arrayData.putBytes(index * 16, bytes.length, bytes, 0); + toColumn.arrayData().putBytes(index * 16, bytes.length, bytes, 0); toColumn.putArray(index, index * 16, bytes.length); } } @@ -547,9 +543,8 @@ private static void putDecimalWritables( toColumn.putLongs(0, size, value.toUnscaledLong()); } else { byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray(); - WritableColumnVector arrayData = toColumn.getChildColumn(0); - arrayData.reserve(bytes.length); - arrayData.putBytes(0, bytes.length, bytes, 0); + toColumn.arrayData().reserve(bytes.length); + toColumn.arrayData().putBytes(0, bytes.length, bytes, 0); for (int index = 0; index < size; index++) { toColumn.putArray(index, 0, bytes.length); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 5ee8cc8da2309..a2853bbadc92b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -85,8 +85,8 @@ public static void populate(WritableColumnVector col, InternalRow row, int field } } else if (t instanceof CalendarIntervalType) { CalendarInterval c = (CalendarInterval)row.get(fieldIdx, t); - col.getChildColumn(0).putInts(0, capacity, c.months); - col.getChildColumn(1).putLongs(0, capacity, c.microseconds); + col.getChild(0).putInts(0, capacity, c.months); + col.getChild(1).putLongs(0, capacity, c.microseconds); } else if (t instanceof DateType) { col.putInts(0, capacity, row.getInt(fieldIdx)); } else if (t instanceof TimestampType) { @@ -149,8 +149,8 @@ private static void appendValue(WritableColumnVector dst, DataType t, Object o) } else if (t instanceof CalendarIntervalType) { CalendarInterval c = (CalendarInterval)o; dst.appendStruct(false); - dst.getChildColumn(0).appendInt(c.months); - dst.getChildColumn(1).appendLong(c.microseconds); + dst.getChild(0).appendInt(c.months); + dst.getChild(1).appendLong(c.microseconds); } else if (t instanceof DateType) { dst.appendInt(DateTimeUtils.fromJavaDate((Date)o)); } else { @@ -179,7 +179,7 @@ private static void appendValue(WritableColumnVector dst, DataType t, Row src, i dst.appendStruct(false); Row c = src.getStruct(fieldIdx); for (int i = 0; i < st.fields().length; i++) { - appendValue(dst.getChildColumn(i), st.fields()[i].dataType(), c, i); + appendValue(dst.getChild(i), st.fields()[i].dataType(), c, i); } } } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 70057a9def6c0..2bab095d4d951 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -146,8 +146,8 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { if (columns[ordinal].isNullAt(rowId)) return null; - final int months = columns[ordinal].getChildColumn(0).getInt(rowId); - final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); + final int months = columns[ordinal].getChild(0).getInt(rowId); + final long microseconds = columns[ordinal].getChild(1).getLong(rowId); return new CalendarInterval(months, microseconds); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index d2ae32b06f83b..ca4f00985c2a3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -599,17 +599,13 @@ public final int appendStruct(boolean isNull) { return elementsAppended; } - /** - * Returns the data for the underlying array. - */ + // `WritableColumnVector` puts the data of array in the first child column vector, and puts the + // array offsets and lengths in the current column vector. @Override public WritableColumnVector arrayData() { return childColumns[0]; } - /** - * Returns the ordinal's child data column. - */ @Override - public WritableColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } + public WritableColumnVector getChild(int ordinal) { return childColumns[ordinal]; } /** * Returns the elements appended. diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index bfd1b4cb0ef12..ca7a4751450d4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -33,18 +33,6 @@ public final class ArrowColumnVector extends ColumnVector { private final ArrowVectorAccessor accessor; private ArrowColumnVector[] childColumns; - private void ensureAccessible(int index) { - ensureAccessible(index, 1); - } - - private void ensureAccessible(int index, int count) { - int valueCount = accessor.getValueCount(); - if (index < 0 || index + count > valueCount) { - throw new IndexOutOfBoundsException( - String.format("index range: [%d, %d), valueCount: %d", index, index + count, valueCount)); - } - } - @Override public int numNulls() { return accessor.getNullCount(); @@ -55,156 +43,75 @@ public void close() { if (childColumns != null) { for (int i = 0; i < childColumns.length; i++) { childColumns[i].close(); + childColumns[i] = null; } + childColumns = null; } accessor.close(); } @Override public boolean isNullAt(int rowId) { - ensureAccessible(rowId); return accessor.isNullAt(rowId); } @Override public boolean getBoolean(int rowId) { - ensureAccessible(rowId); return accessor.getBoolean(rowId); } - @Override - public boolean[] getBooleans(int rowId, int count) { - ensureAccessible(rowId, count); - boolean[] array = new boolean[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getBoolean(rowId + i); - } - return array; - } - @Override public byte getByte(int rowId) { - ensureAccessible(rowId); return accessor.getByte(rowId); } - @Override - public byte[] getBytes(int rowId, int count) { - ensureAccessible(rowId, count); - byte[] array = new byte[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getByte(rowId + i); - } - return array; - } - @Override public short getShort(int rowId) { - ensureAccessible(rowId); return accessor.getShort(rowId); } - @Override - public short[] getShorts(int rowId, int count) { - ensureAccessible(rowId, count); - short[] array = new short[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getShort(rowId + i); - } - return array; - } - @Override public int getInt(int rowId) { - ensureAccessible(rowId); return accessor.getInt(rowId); } - @Override - public int[] getInts(int rowId, int count) { - ensureAccessible(rowId, count); - int[] array = new int[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getInt(rowId + i); - } - return array; - } - @Override public long getLong(int rowId) { - ensureAccessible(rowId); return accessor.getLong(rowId); } - @Override - public long[] getLongs(int rowId, int count) { - ensureAccessible(rowId, count); - long[] array = new long[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getLong(rowId + i); - } - return array; - } - @Override public float getFloat(int rowId) { - ensureAccessible(rowId); return accessor.getFloat(rowId); } - @Override - public float[] getFloats(int rowId, int count) { - ensureAccessible(rowId, count); - float[] array = new float[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getFloat(rowId + i); - } - return array; - } - @Override public double getDouble(int rowId) { - ensureAccessible(rowId); return accessor.getDouble(rowId); } - @Override - public double[] getDoubles(int rowId, int count) { - ensureAccessible(rowId, count); - double[] array = new double[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getDouble(rowId + i); - } - return array; - } - @Override public int getArrayLength(int rowId) { - ensureAccessible(rowId); return accessor.getArrayLength(rowId); } @Override public int getArrayOffset(int rowId) { - ensureAccessible(rowId); return accessor.getArrayOffset(rowId); } @Override public Decimal getDecimal(int rowId, int precision, int scale) { - ensureAccessible(rowId); return accessor.getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int rowId) { - ensureAccessible(rowId); return accessor.getUTF8String(rowId); } @Override public byte[] getBinary(int rowId) { - ensureAccessible(rowId); return accessor.getBinary(rowId); } @@ -212,7 +119,7 @@ public byte[] getBinary(int rowId) { public ArrowColumnVector arrayData() { return childColumns[0]; } @Override - public ArrowColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } + public ArrowColumnVector getChild(int ordinal) { return childColumns[ordinal]; } public ArrowColumnVector(ValueVector vector) { super(ArrowUtils.fromArrowField(vector.getField())); diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index d1196e1299fee..f9936214035b6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -51,12 +51,16 @@ public abstract class ColumnVector implements AutoCloseable { public final DataType dataType() { return type; } /** - * Cleans up memory for this column. The column is not usable after this. + * Cleans up memory for this column vector. The column vector is not usable after this. + * + * This overwrites `AutoCloseable.close` to remove the `throws` clause, as column vector is + * in-memory and we don't expect any exception to happen during closing. */ + @Override public abstract void close(); /** - * Returns the number of nulls in this column. + * Returns the number of nulls in this column vector. */ public abstract int numNulls(); @@ -73,7 +77,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract boolean[] getBooleans(int rowId, int count); + public boolean[] getBooleans(int rowId, int count) { + boolean[] res = new boolean[count]; + for (int i = 0; i < count; i++) { + res[i] = getBoolean(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -83,7 +93,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract byte[] getBytes(int rowId, int count); + public byte[] getBytes(int rowId, int count) { + byte[] res = new byte[count]; + for (int i = 0; i < count; i++) { + res[i] = getByte(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -93,7 +109,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract short[] getShorts(int rowId, int count); + public short[] getShorts(int rowId, int count) { + short[] res = new short[count]; + for (int i = 0; i < count; i++) { + res[i] = getShort(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -103,7 +125,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract int[] getInts(int rowId, int count); + public int[] getInts(int rowId, int count) { + int[] res = new int[count]; + for (int i = 0; i < count; i++) { + res[i] = getInt(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -113,7 +141,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract long[] getLongs(int rowId, int count); + public long[] getLongs(int rowId, int count) { + long[] res = new long[count]; + for (int i = 0; i < count; i++) { + res[i] = getLong(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -123,7 +157,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract float[] getFloats(int rowId, int count); + public float[] getFloats(int rowId, int count) { + float[] res = new float[count]; + for (int i = 0; i < count; i++) { + res[i] = getFloat(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -133,7 +173,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract double[] getDoubles(int rowId, int count); + public double[] getDoubles(int rowId, int count) { + double[] res = new double[count]; + for (int i = 0; i < count; i++) { + res[i] = getDouble(rowId + i); + } + return res; + } /** * Returns the length of the array for rowId. @@ -152,14 +198,6 @@ public final ColumnarRow getStruct(int rowId) { return new ColumnarRow(this, rowId); } - /** - * A special version of {@link #getStruct(int)}, which is only used as an adapter for Spark - * codegen framework, the second parameter is totally ignored. - */ - public final ColumnarRow getStruct(int rowId, int size) { - return getStruct(rowId); - } - /** * Returns the array for rowId. */ @@ -196,9 +234,9 @@ public MapData getMap(int ordinal) { public abstract ColumnVector arrayData(); /** - * Returns the ordinal's child data column. + * Returns the ordinal's child column vector. */ - public abstract ColumnVector getChildColumn(int ordinal); + public abstract ColumnVector getChild(int ordinal); /** * Data type for this column. @@ -206,8 +244,7 @@ public MapData getMap(int ordinal) { protected DataType type; /** - * Sets up the common state and also handles creating the child columns if this is a nested - * type. + * Sets up the data type of this column vector. */ protected ColumnVector(DataType type) { this.type = type; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 0d89a52e7a4fe..522c39580389f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -133,8 +133,8 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { - int month = data.getChildColumn(0).getInt(offset + ordinal); - long microseconds = data.getChildColumn(1).getLong(offset + ordinal); + int month = data.getChild(0).getInt(offset + ordinal); + long microseconds = data.getChild(1).getLong(offset + ordinal); return new CalendarInterval(month, microseconds); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 3c6656dec77cd..2e59085a82768 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -28,7 +28,7 @@ */ public final class ColumnarRow extends InternalRow { // The data for this row. - // E.g. the value of 3rd int field is `data.getChildColumn(3).getInt(rowId)`. + // E.g. the value of 3rd int field is `data.getChild(3).getInt(rowId)`. private final ColumnVector data; private final int rowId; private final int numFields; @@ -53,7 +53,7 @@ public InternalRow copy() { if (isNullAt(i)) { row.setNullAt(i); } else { - DataType dt = data.getChildColumn(i).dataType(); + DataType dt = data.getChild(i).dataType(); if (dt instanceof BooleanType) { row.setBoolean(i, getBoolean(i)); } else if (dt instanceof ByteType) { @@ -93,65 +93,65 @@ public boolean anyNull() { } @Override - public boolean isNullAt(int ordinal) { return data.getChildColumn(ordinal).isNullAt(rowId); } + public boolean isNullAt(int ordinal) { return data.getChild(ordinal).isNullAt(rowId); } @Override - public boolean getBoolean(int ordinal) { return data.getChildColumn(ordinal).getBoolean(rowId); } + public boolean getBoolean(int ordinal) { return data.getChild(ordinal).getBoolean(rowId); } @Override - public byte getByte(int ordinal) { return data.getChildColumn(ordinal).getByte(rowId); } + public byte getByte(int ordinal) { return data.getChild(ordinal).getByte(rowId); } @Override - public short getShort(int ordinal) { return data.getChildColumn(ordinal).getShort(rowId); } + public short getShort(int ordinal) { return data.getChild(ordinal).getShort(rowId); } @Override - public int getInt(int ordinal) { return data.getChildColumn(ordinal).getInt(rowId); } + public int getInt(int ordinal) { return data.getChild(ordinal).getInt(rowId); } @Override - public long getLong(int ordinal) { return data.getChildColumn(ordinal).getLong(rowId); } + public long getLong(int ordinal) { return data.getChild(ordinal).getLong(rowId); } @Override - public float getFloat(int ordinal) { return data.getChildColumn(ordinal).getFloat(rowId); } + public float getFloat(int ordinal) { return data.getChild(ordinal).getFloat(rowId); } @Override - public double getDouble(int ordinal) { return data.getChildColumn(ordinal).getDouble(rowId); } + public double getDouble(int ordinal) { return data.getChild(ordinal).getDouble(rowId); } @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getDecimal(rowId, precision, scale); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getUTF8String(rowId); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getUTF8String(rowId); } @Override public byte[] getBinary(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getBinary(rowId); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getBinary(rowId); } @Override public CalendarInterval getInterval(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - final int months = data.getChildColumn(ordinal).getChildColumn(0).getInt(rowId); - final long microseconds = data.getChildColumn(ordinal).getChildColumn(1).getLong(rowId); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + final int months = data.getChild(ordinal).getChild(0).getInt(rowId); + final long microseconds = data.getChild(ordinal).getChild(1).getLong(rowId); return new CalendarInterval(months, microseconds); } @Override public ColumnarRow getStruct(int ordinal, int numFields) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getStruct(rowId); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getStruct(rowId); } @Override public ColumnarArray getArray(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getArray(rowId); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getArray(rowId); } @Override diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index dd68df9686691..04f2619ed7541 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -50,7 +50,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { dataType: DataType, nullable: Boolean): ExprCode = { val javaType = ctx.javaType(dataType) - val value = ctx.getValue(columnVar, dataType, ordinal) + val value = ctx.getValueFromVector(columnVar, dataType, ordinal) val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index eb48584d0c1ee..633eeac180974 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -127,8 +127,8 @@ class VectorizedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"vectors[$ordinal]", "buckets[idx]", - key.dataType), key.name)})""" + val value = ctx.getValueFromVector(s"vectors[$ordinal]", key.dataType, "buckets[idx]") + s"(${ctx.genEqual(key.dataType, value, key.name)})" }.mkString(" && ") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index c42bc60a59d67..92506032ab2e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -217,21 +217,21 @@ class ArrowWriterSuite extends SparkFunSuite { val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) - val struct0 = reader.getStruct(0, 2) + val struct0 = reader.getStruct(0) assert(struct0.getInt(0) === 1) assert(struct0.getUTF8String(1) === UTF8String.fromString("str1")) - val struct1 = reader.getStruct(1, 2) + val struct1 = reader.getStruct(1) assert(struct1.isNullAt(0)) assert(struct1.isNullAt(1)) assert(reader.isNullAt(2)) - val struct3 = reader.getStruct(3, 2) + val struct3 = reader.getStruct(3) assert(struct3.getInt(0) === 4) assert(struct3.isNullAt(1)) - val struct4 = reader.getStruct(4, 2) + val struct4 = reader.getStruct(4) assert(struct4.isNullAt(0)) assert(struct4.getUTF8String(1) === UTF8String.fromString("str5")) @@ -252,15 +252,15 @@ class ArrowWriterSuite extends SparkFunSuite { val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) - val struct00 = reader.getStruct(0, 1).getStruct(0, 2) + val struct00 = reader.getStruct(0).getStruct(0, 2) assert(struct00.getInt(0) === 1) assert(struct00.getUTF8String(1) === UTF8String.fromString("str1")) - val struct10 = reader.getStruct(1, 1).getStruct(0, 2) + val struct10 = reader.getStruct(1).getStruct(0, 2) assert(struct10.isNullAt(0)) assert(struct10.isNullAt(1)) - val struct2 = reader.getStruct(2, 1) + val struct2 = reader.getStruct(2) assert(struct2.isNullAt(0)) assert(reader.isNullAt(3)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index 53432669e215d..e794f50781ff2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -346,11 +346,11 @@ class ArrowColumnVectorSuite extends SparkFunSuite { assert(columnVector.dataType === schema) assert(columnVector.numNulls === 0) - val row0 = columnVector.getStruct(0, 2) + val row0 = columnVector.getStruct(0) assert(row0.getInt(0) === 1) assert(row0.getLong(1) === 1L) - val row1 = columnVector.getStruct(1, 2) + val row1 = columnVector.getStruct(1) assert(row1.getInt(0) === 2) assert(row1.isNullAt(1)) @@ -398,21 +398,21 @@ class ArrowColumnVectorSuite extends SparkFunSuite { assert(columnVector.dataType === schema) assert(columnVector.numNulls === 1) - val row0 = columnVector.getStruct(0, 2) + val row0 = columnVector.getStruct(0) assert(row0.getInt(0) === 1) assert(row0.getLong(1) === 1L) - val row1 = columnVector.getStruct(1, 2) + val row1 = columnVector.getStruct(1) assert(row1.getInt(0) === 2) assert(row1.isNullAt(1)) - val row2 = columnVector.getStruct(2, 2) + val row2 = columnVector.getStruct(2) assert(row2.isNullAt(0)) assert(row2.getLong(1) === 3L) assert(columnVector.isNullAt(3)) - val row4 = columnVector.getStruct(4, 2) + val row4 = columnVector.getStruct(4) assert(row4.getInt(0) === 5) assert(row4.getLong(1) === 5L) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 944240f3bade5..2d1ad4b456783 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -199,17 +199,17 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { val structType: StructType = new StructType().add("int", IntegerType).add("double", DoubleType) testVectors("struct", 10, structType) { testVector => - val c1 = testVector.getChildColumn(0) - val c2 = testVector.getChildColumn(1) + val c1 = testVector.getChild(0) + val c2 = testVector.getChild(1) c1.putInt(0, 123) c2.putDouble(0, 3.45) c1.putInt(1, 456) c2.putDouble(1, 5.67) - assert(testVector.getStruct(0, structType.length).get(0, IntegerType) === 123) - assert(testVector.getStruct(0, structType.length).get(1, DoubleType) === 3.45) - assert(testVector.getStruct(1, structType.length).get(0, IntegerType) === 456) - assert(testVector.getStruct(1, structType.length).get(1, DoubleType) === 5.67) + assert(testVector.getStruct(0).get(0, IntegerType) === 123) + assert(testVector.getStruct(0).get(1, DoubleType) === 3.45) + assert(testVector.getStruct(1).get(0, IntegerType) === 456) + assert(testVector.getStruct(1).get(1, DoubleType) === 5.67) } test("[SPARK-22092] off-heap column vector reallocation corrupts array data") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 38ea2e47fdef8..ad74fb99b0c73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -268,17 +268,17 @@ object ColumnarBatchBenchmark { Int Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Java Array 177 / 181 1856.4 0.5 1.0X - ByteBuffer Unsafe 318 / 322 1032.0 1.0 0.6X - ByteBuffer API 1411 / 1418 232.2 4.3 0.1X - DirectByteBuffer 467 / 474 701.8 1.4 0.4X - Unsafe Buffer 178 / 185 1843.6 0.5 1.0X - Column(on heap) 178 / 184 1840.8 0.5 1.0X - Column(off heap) 341 / 344 961.8 1.0 0.5X - Column(off heap direct) 178 / 184 1845.4 0.5 1.0X - UnsafeRow (on heap) 378 / 389 866.3 1.2 0.5X - UnsafeRow (off heap) 393 / 402 834.0 1.2 0.4X - Column On Heap Append 309 / 318 1059.1 0.9 0.6X + Java Array 177 / 183 1851.1 0.5 1.0X + ByteBuffer Unsafe 314 / 330 1043.7 1.0 0.6X + ByteBuffer API 1298 / 1307 252.4 4.0 0.1X + DirectByteBuffer 465 / 483 704.2 1.4 0.4X + Unsafe Buffer 179 / 183 1835.5 0.5 1.0X + Column(on heap) 181 / 186 1815.2 0.6 1.0X + Column(off heap) 344 / 349 951.7 1.1 0.5X + Column(off heap direct) 178 / 186 1838.6 0.5 1.0X + UnsafeRow (on heap) 388 / 394 844.8 1.2 0.5X + UnsafeRow (off heap) 400 / 403 819.4 1.2 0.4X + Column On Heap Append 315 / 325 1041.8 1.0 0.6X */ val benchmark = new Benchmark("Int Read/Write", count * iters) benchmark.addCase("Java Array")(javaArray) @@ -337,8 +337,8 @@ object ColumnarBatchBenchmark { Boolean Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Bitset 726 / 727 462.4 2.2 1.0X - Byte Array 530 / 542 632.7 1.6 1.4X + Bitset 741 / 747 452.6 2.2 1.0X + Byte Array 531 / 542 631.6 1.6 1.4X */ benchmark.run() } @@ -394,8 +394,8 @@ object ColumnarBatchBenchmark { String Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - On Heap 332 / 338 49.3 20.3 1.0X - Off Heap 466 / 467 35.2 28.4 0.7X + On Heap 351 / 362 46.6 21.4 1.0X + Off Heap 456 / 466 35.9 27.8 0.8X */ val benchmark = new Benchmark("String Read/Write", count * iters) benchmark.addCase("On Heap")(column(MemoryMode.ON_HEAP)) @@ -479,10 +479,10 @@ object ColumnarBatchBenchmark { Array Vector Read: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - On Heap Read Size Only 415 / 422 394.7 2.5 1.0X - Off Heap Read Size Only 394 / 402 415.9 2.4 1.1X - On Heap Read Elements 2558 / 2593 64.0 15.6 0.2X - Off Heap Read Elements 3316 / 3317 49.4 20.2 0.1X + On Heap Read Size Only 416 / 423 393.5 2.5 1.0X + Off Heap Read Size Only 396 / 404 413.6 2.4 1.1X + On Heap Read Elements 2569 / 2590 63.8 15.7 0.2X + Off Heap Read Elements 3302 / 3333 49.6 20.2 0.1X */ benchmark.run } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index cd90681ecabc6..1873c24ab063c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -732,8 +732,8 @@ class ColumnarBatchSuite extends SparkFunSuite { "Struct Column", 10, new StructType().add("int", IntegerType).add("double", DoubleType)) { column => - val c1 = column.getChildColumn(0) - val c2 = column.getChildColumn(1) + val c1 = column.getChild(0) + val c2 = column.getChild(1) assert(c1.dataType() == IntegerType) assert(c2.dataType() == DoubleType) @@ -787,8 +787,8 @@ class ColumnarBatchSuite extends SparkFunSuite { 10, new ArrayType(structType, true)) { column => val data = column.arrayData() - val c0 = data.getChildColumn(0) - val c1 = data.getChildColumn(1) + val c0 = data.getChild(0) + val c1 = data.getChild(1) // Structs in child column: (0, 0), (1, 10), (2, 20), (3, 30), (4, 40), (5, 50) (0 until 6).foreach { i => c0.putInt(i, i) @@ -815,8 +815,8 @@ class ColumnarBatchSuite extends SparkFunSuite { new StructType() .add("int", IntegerType) .add("array", new ArrayType(IntegerType, true))) { column => - val c0 = column.getChildColumn(0) - val c1 = column.getChildColumn(1) + val c0 = column.getChild(0) + val c1 = column.getChild(1) c0.putInt(0, 0) c0.putInt(1, 1) c0.putInt(2, 2) @@ -844,13 +844,13 @@ class ColumnarBatchSuite extends SparkFunSuite { "Nest Struct in Struct", 10, new StructType().add("int", IntegerType).add("struct", subSchema)) { column => - val c0 = column.getChildColumn(0) - val c1 = column.getChildColumn(1) + val c0 = column.getChild(0) + val c1 = column.getChild(1) c0.putInt(0, 0) c0.putInt(1, 1) c0.putInt(2, 2) - val c1c0 = c1.getChildColumn(0) - val c1c1 = c1.getChildColumn(1) + val c1c0 = c1.getChild(0) + val c1c1 = c1.getChild(1) // Structs in c1: (7, 70), (8, 80), (9, 90) c1c0.putInt(0, 7) c1c0.putInt(1, 8) From 87ffe7adddf517541aac0d1e8536b02ad8881606 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 22 Jan 2018 22:12:50 +0900 Subject: [PATCH 172/774] [SPARK-7721][PYTHON][TESTS] Adds PySpark coverage generation script MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Note that this PR was made based on the top of https://github.com/apache/spark/pull/20151. So, it almost leaves the main codes intact. This PR proposes to add a script for the preparation of automatic PySpark coverage generation. Now, it's difficult to check the actual coverage in case of PySpark. With this script, it allows to run tests by the way we did via `run-tests` script before. The usage is exactly the same with `run-tests` script as this basically wraps it. This script and PR alone should also be useful. I was asked about how to run this before, and seems some reviewers (including me) need this. It would be also useful to run it manually. It usually requires a small diff in normal Python projects but PySpark cases are a bit different because apparently we are unable to track the coverage after it's forked. So, here, I made a custom worker that forces the coverage, based on the top of https://github.com/apache/spark/pull/20151. I made a simple demo. Please take a look - https://spark-test.github.io/pyspark-coverage-site. To show up the structure, this PR adds the files as below: ``` python ├── .coveragerc # Runtime configuration when we run the script. ├── run-tests-with-coverage # The script that has coverage support and wraps run-tests script. └── test_coverage # Directories that have files required when running coverage. ├── conf │   └── spark-defaults.conf # Having the configuration 'spark.python.daemon.module'. ├── coverage_daemon.py # A daemon having custom fix and wrapping our daemon.py └── sitecustomize.py # Initiate coverage with COVERAGE_PROCESS_START ``` Note that this PR has a minor nit: [This scope](https://github.com/apache/spark/blob/04e44b37cc04f62fbf9e08c7076349e0a4d12ea8/python/pyspark/daemon.py#L148-L169) in `daemon.py` is not in the coverage results as basically I am producing the coverage results in `worker.py` separately and then merging it. I believe it's not a big deal. In a followup, I might have a site that has a single up-to-date PySpark coverage from the master branch as the fallback / default, or have a site that has multiple PySpark coverages and the site link will be left to each pull request. ## How was this patch tested? Manually tested. Usage is the same with the existing Python test script - `./python/run-tests`. For example, ``` sh run-tests-with-coverage --python-executables=python3 --modules=pyspark-sql ``` Running this will generate HTMLs under `./python/test_coverage/htmlcov`. Console output example: ``` sh run-tests-with-coverage --python-executables=python3,python --modules=pyspark-core Running PySpark tests. Output is in /.../spark/python/unit-tests.log Will test against the following Python executables: ['python3', 'python'] Will test the following Python modules: ['pyspark-core'] Starting test(python): pyspark.tests Starting test(python3): pyspark.tests ... Tests passed in 231 seconds Combining collected coverage data under /.../spark/python/test_coverage/coverage_data Reporting the coverage data at /...spark/python/test_coverage/coverage_data/coverage Name Stmts Miss Branch BrPart Cover -------------------------------------------------------------- pyspark/__init__.py 41 0 8 2 96% ... pyspark/profiler.py 74 11 22 5 83% pyspark/rdd.py 871 40 303 32 93% pyspark/rddsampler.py 68 10 32 2 82% ... -------------------------------------------------------------- TOTAL 8521 3077 2748 191 59% Generating HTML files for PySpark coverage under /.../spark/python/test_coverage/htmlcov ``` Author: hyukjinkwon Closes #20204 from HyukjinKwon/python-coverage. --- .gitignore | 2 + python/.coveragerc | 21 ++++++ python/run-tests-with-coverage | 69 +++++++++++++++++++ python/run-tests.py | 5 +- python/test_coverage/conf/spark-defaults.conf | 21 ++++++ python/test_coverage/coverage_daemon.py | 45 ++++++++++++ python/test_coverage/sitecustomize.py | 23 +++++++ 7 files changed, 185 insertions(+), 1 deletion(-) create mode 100644 python/.coveragerc create mode 100755 python/run-tests-with-coverage create mode 100644 python/test_coverage/conf/spark-defaults.conf create mode 100644 python/test_coverage/coverage_daemon.py create mode 100644 python/test_coverage/sitecustomize.py diff --git a/.gitignore b/.gitignore index 903297db96901..39085904e324c 100644 --- a/.gitignore +++ b/.gitignore @@ -62,6 +62,8 @@ project/plugins/src_managed/ project/plugins/target/ python/lib/pyspark.zip python/deps +python/test_coverage/coverage_data +python/test_coverage/htmlcov python/pyspark/python reports/ scalastyle-on-compile.generated.xml diff --git a/python/.coveragerc b/python/.coveragerc new file mode 100644 index 0000000000000..b3339cd356a6e --- /dev/null +++ b/python/.coveragerc @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +[run] +branch = true +parallel = true +data_file = ${COVERAGE_DIR}/coverage_data/coverage diff --git a/python/run-tests-with-coverage b/python/run-tests-with-coverage new file mode 100755 index 0000000000000..6d74b563e9140 --- /dev/null +++ b/python/run-tests-with-coverage @@ -0,0 +1,69 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set -o pipefail +set -e + +# This variable indicates which coverage executable to run to combine coverages +# and generate HTMLs, for example, 'coverage3' in Python 3. +COV_EXEC="${COV_EXEC:-coverage}" +FWDIR="$(cd "`dirname $0`"; pwd)" +pushd "$FWDIR" > /dev/null + +# Ensure that coverage executable is installed. +if ! hash $COV_EXEC 2>/dev/null; then + echo "Missing coverage executable in your path, skipping PySpark coverage" + exit 1 +fi + +# Set up the directories for coverage results. +export COVERAGE_DIR="$FWDIR/test_coverage" +rm -fr "$COVERAGE_DIR/coverage_data" +rm -fr "$COVERAGE_DIR/htmlcov" +mkdir -p "$COVERAGE_DIR/coverage_data" + +# Current directory are added in the python path so that it doesn't refer our built +# pyspark zip library first. +export PYTHONPATH="$FWDIR:$PYTHONPATH" +# Also, our sitecustomize.py and coverage_daemon.py are included in the path. +export PYTHONPATH="$COVERAGE_DIR:$PYTHONPATH" + +# We use 'spark.python.daemon.module' configuration to insert the coverage supported workers. +export SPARK_CONF_DIR="$COVERAGE_DIR/conf" + +# This environment variable enables the coverage. +export COVERAGE_PROCESS_START="$FWDIR/.coveragerc" + +# If you'd like to run a specific unittest class, you could do such as +# SPARK_TESTING=1 ../bin/pyspark pyspark.sql.tests VectorizedUDFTests +./run-tests "$@" + +# Don't run coverage for the coverage command itself +unset COVERAGE_PROCESS_START + +# Coverage could generate empty coverage data files. Remove it to get rid of warnings when combining. +find $COVERAGE_DIR/coverage_data -size 0 -print0 | xargs -0 rm +echo "Combining collected coverage data under $COVERAGE_DIR/coverage_data" +$COV_EXEC combine +echo "Reporting the coverage data at $COVERAGE_DIR/coverage_data/coverage" +$COV_EXEC report --include "pyspark/*" +echo "Generating HTML files for PySpark coverage under $COVERAGE_DIR/htmlcov" +$COV_EXEC html --ignore-errors --include "pyspark/*" --directory "$COVERAGE_DIR/htmlcov" + +popd diff --git a/python/run-tests.py b/python/run-tests.py index 1341086f02db0..f03284c334285 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -38,7 +38,7 @@ from sparktestsupport import SPARK_HOME # noqa (suppress pep8 warnings) -from sparktestsupport.shellutils import which, subprocess_check_output # noqa +from sparktestsupport.shellutils import which, subprocess_check_output, run_cmd # noqa from sparktestsupport.modules import all_modules # noqa @@ -175,6 +175,9 @@ def main(): task_queue = Queue.PriorityQueue() for python_exec in python_execs: + if "COVERAGE_PROCESS_START" in os.environ: + # Make sure if coverage is installed. + run_cmd([python_exec, "-c", "import coverage"]) python_implementation = subprocess_check_output( [python_exec, "-c", "import platform; print(platform.python_implementation())"], universal_newlines=True).strip() diff --git a/python/test_coverage/conf/spark-defaults.conf b/python/test_coverage/conf/spark-defaults.conf new file mode 100644 index 0000000000000..bf44ea6e7cfec --- /dev/null +++ b/python/test_coverage/conf/spark-defaults.conf @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This is used to generate PySpark coverage results. Seems there's no way to +# add a configuration when SPARK_TESTING environment variable is set because +# we will directly execute modules by python -m. +spark.python.daemon.module coverage_daemon diff --git a/python/test_coverage/coverage_daemon.py b/python/test_coverage/coverage_daemon.py new file mode 100644 index 0000000000000..c87366a1ac23b --- /dev/null +++ b/python/test_coverage/coverage_daemon.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import imp + + +# This is a hack to always refer the main code rather than built zip. +main_code_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +daemon = imp.load_source("daemon", "%s/pyspark/daemon.py" % main_code_dir) + +if "COVERAGE_PROCESS_START" in os.environ: + worker = imp.load_source("worker", "%s/pyspark/worker.py" % main_code_dir) + + def _cov_wrapped(*args, **kwargs): + import coverage + cov = coverage.coverage( + config_file=os.environ["COVERAGE_PROCESS_START"]) + cov.start() + try: + worker.main(*args, **kwargs) + finally: + cov.stop() + cov.save() + daemon.worker_main = _cov_wrapped +else: + raise RuntimeError("COVERAGE_PROCESS_START environment variable is not set, exiting.") + + +if __name__ == '__main__': + daemon.manager() diff --git a/python/test_coverage/sitecustomize.py b/python/test_coverage/sitecustomize.py new file mode 100644 index 0000000000000..630237a518126 --- /dev/null +++ b/python/test_coverage/sitecustomize.py @@ -0,0 +1,23 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Note that this 'sitecustomize' module is a built-in feature in Python. +# If this module is defined, it's executed when the Python session begins. +# `coverage.process_startup()` seeks if COVERAGE_PROCESS_START environment +# variable is set or not. If set, it starts to run the coverage. +import coverage +coverage.process_startup() From 4327ccf289b5a0dc51f6294113d01af6eb52eea0 Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Mon, 22 Jan 2018 08:36:17 -0600 Subject: [PATCH 173/774] [SPARK-11630][CORE] ClosureCleaner moved from warning to debug ## What changes were proposed in this pull request? ClosureCleaner moved from warning to debug ## How was this patch tested? Existing tests Author: Rekha Joshi Author: rjoshi2 Closes #20337 from rekhajoshm/SPARK-11630-1. --- core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 40616421b5bca..ad0c0639521f6 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -207,7 +207,7 @@ private[spark] object ClosureCleaner extends Logging { accessedFields: Map[Class[_], Set[String]]): Unit = { if (!isClosure(func.getClass)) { - logWarning("Expected a closure; got " + func.getClass.getName) + logDebug(s"Expected a closure; got ${func.getClass.getName}") return } From 446948af1d8dbc080a26a6eec6f743d338f1d12b Mon Sep 17 00:00:00 2001 From: Sandor Murakozi Date: Mon, 22 Jan 2018 10:36:28 -0800 Subject: [PATCH 174/774] [SPARK-23121][CORE] Fix for ui becoming unaccessible for long running streaming apps ## What changes were proposed in this pull request? The allJobs and the job pages attempt to use stage attempt and DAG visualization from the store, but for long running jobs they are not guaranteed to be retained, leading to exceptions when these pages are rendered. To fix it `store.lastStageAttempt(stageId)` and `store.operationGraphForJob(jobId)` are wrapped in `store.asOption` and default values are used if the info is missing. ## How was this patch tested? Manual testing of the UI, also using the test command reported in SPARK-23121: ./bin/spark-submit --class org.apache.spark.examples.streaming.HdfsWordCount ./examples/jars/spark-examples_2.11-2.4.0-SNAPSHOT.jar /spark Closes #20287 Author: Sandor Murakozi Closes #20330 from smurakozi/SPARK-23121. --- .../apache/spark/ui/jobs/AllJobsPage.scala | 24 ++++++++++--------- .../org/apache/spark/ui/jobs/JobPage.scala | 10 ++++++-- .../org/apache/spark/ui/jobs/StagePage.scala | 9 ++++--- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index e3b72f1f34859..2b0f4acbac72a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -36,6 +36,9 @@ import org.apache.spark.util.Utils /** Page showing list of all ongoing and recently finished jobs */ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends WebUIPage("") { + + import ApiHelper._ + private val JOBS_LEGEND =
val jobId = job.jobId val status = job.status - val jobDescription = store.lastStageAttempt(job.stageIds.max).description - val displayJobDescription = jobDescription - .map(UIUtils.makeDescription(_, "", plainText = true).text) - .getOrElse("") + val (_, lastStageDescription) = lastStageNameAndDescription(store, job) + val jobDescription = UIUtils.makeDescription(lastStageDescription, "", plainText = true).text + val submissionTime = job.submissionTime.get.getTime() val completionTime = job.completionTime.map(_.getTime()).getOrElse(System.currentTimeMillis()) val classNameByStatus = status match { @@ -80,7 +82,7 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We // The timeline library treats contents as HTML, so we have to escape them. We need to add // extra layers of escaping in order to embed this in a Javascript string literal. - val escapedDesc = Utility.escape(displayJobDescription) + val escapedDesc = Utility.escape(jobDescription) val jsEscapedDesc = StringEscapeUtils.escapeEcmaScript(escapedDesc) val jobEventJsonAsStr = s""" @@ -430,6 +432,8 @@ private[ui] class JobDataSource( sortColumn: String, desc: Boolean) extends PagedDataSource[JobTableRowData](pageSize) { + import ApiHelper._ + // Convert JobUIData to JobTableRowData which contains the final contents to show in the table // so that we can avoid creating duplicate contents during sorting the data private val data = jobs.map(jobRow).sorted(ordering(sortColumn, desc)) @@ -454,23 +458,21 @@ private[ui] class JobDataSource( val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") val submissionTime = jobData.submissionTime val formattedSubmissionTime = submissionTime.map(UIUtils.formatDate).getOrElse("Unknown") - val lastStageAttempt = store.lastStageAttempt(jobData.stageIds.max) - val lastStageDescription = lastStageAttempt.description.getOrElse("") + val (lastStageName, lastStageDescription) = lastStageNameAndDescription(store, jobData) - val formattedJobDescription = - UIUtils.makeDescription(lastStageDescription, basePath, plainText = false) + val jobDescription = UIUtils.makeDescription(lastStageDescription, basePath, plainText = false) val detailUrl = "%s/jobs/job?id=%s".format(basePath, jobData.jobId) new JobTableRowData( jobData, - lastStageAttempt.name, + lastStageName, lastStageDescription, duration.getOrElse(-1), formattedDuration, submissionTime.map(_.getTime()).getOrElse(-1L), formattedSubmissionTime, - formattedJobDescription, + jobDescription, detailUrl ) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index c27f30c21a843..46f2a76cc651b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -336,8 +336,14 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP content ++= makeTimeline(activeStages ++ completedStages ++ failedStages, store.executorList(false), appStartTime) - content ++= UIUtils.showDagVizForJob( - jobId, store.operationGraphForJob(jobId)) + val operationGraphContent = store.asOption(store.operationGraphForJob(jobId)) match { + case Some(operationGraph) => UIUtils.showDagVizForJob(jobId, operationGraph) + case None => +
+

No DAG visualization information to display for job {jobId}

+
+ } + content ++= operationGraphContent if (shouldShowActiveStages) { content ++= diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 0eb3190205c3e..5c2b0c3a19996 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -23,12 +23,10 @@ import java.util.concurrent.TimeUnit import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{HashMap, HashSet} -import scala.xml.{Elem, Node, Unparsed} +import scala.xml.{Node, Unparsed} import org.apache.commons.lang3.StringEscapeUtils -import org.apache.spark.SparkConf -import org.apache.spark.internal.config._ import org.apache.spark.scheduler.TaskLocality import org.apache.spark.status._ import org.apache.spark.status.api.v1._ @@ -1020,4 +1018,9 @@ private object ApiHelper { } } + def lastStageNameAndDescription(store: AppStatusStore, job: JobData): (String, String) = { + val stage = store.asOption(store.lastStageAttempt(job.stageIds.max)) + (stage.map(_.name).getOrElse(""), stage.flatMap(_.description).getOrElse(job.name)) + } + } From 76b8b840ddc951ee6203f9cccd2c2b9671c1b5e8 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Mon, 22 Jan 2018 13:55:14 -0600 Subject: [PATCH 175/774] [MINOR] Typo fixes ## What changes were proposed in this pull request? Typo fixes ## How was this patch tested? Local build / Doc-only changes Author: Jacek Laskowski Closes #20344 from jaceklaskowski/typo-fixes. --- .../main/scala/org/apache/spark/SparkContext.scala | 2 +- .../spark/sql/kafka010/KafkaSourceProvider.scala | 4 ++-- .../apache/spark/sql/kafka010/KafkaWriteTask.scala | 2 +- .../org/apache/spark/sql/streaming/OutputMode.java | 2 +- .../spark/sql/catalyst/analysis/Analyzer.scala | 8 ++++---- .../spark/sql/catalyst/analysis/unresolved.scala | 2 +- .../catalyst/expressions/aggregate/interfaces.scala | 12 +++++------- .../catalyst/plans/logical/LogicalPlanVisitor.scala | 2 +- .../statsEstimation/BasicStatsPlanVisitor.scala | 2 +- .../SizeInBytesOnlyStatsPlanVisitor.scala | 4 ++-- .../org/apache/spark/sql/internal/SQLConf.scala | 2 +- .../apache/spark/sql/catalyst/plans/PlanTest.scala | 2 +- .../scala/org/apache/spark/sql/DataFrameWriter.scala | 2 +- .../apache/spark/sql/execution/SparkSqlParser.scala | 2 +- .../spark/sql/execution/WholeStageCodegenExec.scala | 2 +- .../spark/sql/execution/command/SetCommand.scala | 4 ++-- .../spark/sql/execution/datasources/rules.scala | 2 +- .../sql/execution/streaming/HDFSMetadataLog.scala | 2 +- .../spark/sql/execution/streaming/OffsetSeq.scala | 2 +- .../spark/sql/execution/streaming/OffsetSeqLog.scala | 2 +- .../execution/streaming/StreamingQueryWrapper.scala | 2 +- .../sql/execution/streaming/state/StateStore.scala | 2 +- .../spark/sql/execution/ui/ExecutionPage.scala | 2 +- .../spark/sql/expressions/UserDefinedFunction.scala | 4 ++-- .../spark/sql/internal/BaseSessionStateBuilder.scala | 4 ++-- .../spark/sql/streaming/DataStreamReader.scala | 6 +++--- .../results/columnresolution-negative.sql.out | 2 +- .../sql-tests/results/columnresolution-views.sql.out | 2 +- .../sql-tests/results/columnresolution.sql.out | 6 +++--- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 4 ++-- .../apache/spark/sql/execution/SQLViewSuite.scala | 2 +- .../apache/spark/sql/hive/HiveExternalCatalog.scala | 4 ++-- 32 files changed, 50 insertions(+), 52 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 31f3cb9dfa0ae..3828d4f703247 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2276,7 +2276,7 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Clean a closure to make it ready to be serialized and send to tasks + * Clean a closure to make it ready to be serialized and sent to tasks * (removes unreferenced variables in $outer's, updates REPL variables) * If checkSerializable is set, clean will also proactively * check to see if f is serializable and throw a SparkException diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 3914370a96595..62a998fbfb30b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -307,7 +307,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) { throw new IllegalArgumentException( s"Kafka option '${ConsumerConfig.GROUP_ID_CONFIG}' is not supported as " + - s"user-specified consumer groups is not used to track offsets.") + s"user-specified consumer groups are not used to track offsets.") } if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG}")) { @@ -335,7 +335,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister { throw new IllegalArgumentException( s"Kafka option '${ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG}' is not supported as " - + "value are deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame " + + "values are deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame " + "operations to explicitly deserialize the values.") } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index baa60febf661d..d90630a8adc93 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, Unsa import org.apache.spark.sql.types.{BinaryType, StringType} /** - * A simple trait for writing out data in a single Spark task, without any concerns about how + * Writes out data in a single Spark task, without any concerns about how * to commit or abort tasks. Exceptions thrown by the implementation of this class will * automatically trigger task aborts. */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java index 2800b3068f87b..470c128ee6c3d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes; /** - * OutputMode is used to what data will be written to a streaming sink when there is + * OutputMode describes what data will be written to a streaming sink when there is * new data available in a streaming DataFrame/Dataset. * * @since 2.0.0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 35b35110e491f..2b14c8220d43b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -611,8 +611,8 @@ class Analyzer( if (AnalysisContext.get.nestedViewDepth > conf.maxNestedViewDepth) { view.failAnalysis(s"The depth of view ${view.desc.identifier} exceeds the maximum " + s"view resolution depth (${conf.maxNestedViewDepth}). Analysis is aborted to " + - "avoid errors. Increase the value of spark.sql.view.maxNestedViewDepth to work " + - "aroud this.") + s"avoid errors. Increase the value of ${SQLConf.MAX_NESTED_VIEW_DEPTH.key} to work " + + "around this.") } executeSameContext(child) } @@ -653,7 +653,7 @@ class Analyzer( // Note that if the database is not defined, it is possible we are looking up a temp view. case e: NoSuchDatabaseException => u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}, the " + - s"database ${e.db} doesn't exsits.") + s"database ${e.db} doesn't exist.") } } @@ -1524,7 +1524,7 @@ class Analyzer( } /** - * Extracts [[Generator]] from the projectList of a [[Project]] operator and create [[Generate]] + * Extracts [[Generator]] from the projectList of a [[Project]] operator and creates [[Generate]] * operator under [[Project]]. * * This rule will throw [[AnalysisException]] for following cases: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index d336f801d0770..a65f58fa61ff4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -294,7 +294,7 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu } else { val from = input.inputSet.map(_.name).mkString(", ") val targetString = target.get.mkString(".") - throw new AnalysisException(s"cannot resolve '$targetString.*' give input columns '$from'") + throw new AnalysisException(s"cannot resolve '$targetString.*' given input columns '$from'") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 19abce01a26cf..e1d16a2cd38b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -190,17 +190,15 @@ abstract class AggregateFunction extends Expression { def defaultResult: Option[Literal] = None /** - * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] because - * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, - * and the flag indicating if this aggregation is distinct aggregation or not. - * An [[AggregateFunction]] should not be used without being wrapped in - * an [[AggregateExpression]]. + * Creates [[AggregateExpression]] with `isDistinct` flag disabled. + * + * @see `toAggregateExpression(isDistinct: Boolean)` for detailed description */ def toAggregateExpression(): AggregateExpression = toAggregateExpression(isDistinct = false) /** - * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] and set isDistinct - * field of the [[AggregateExpression]] to the given value because + * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] and sets `isDistinct` + * flag of the [[AggregateExpression]] to the given value because * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, * and the flag indicating if this aggregation is distinct aggregation or not. * An [[AggregateFunction]] should not be used without being wrapped in diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala index e0748043c46e2..2c248d74869ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical /** - * A visitor pattern for traversing a [[LogicalPlan]] tree and compute some properties. + * A visitor pattern for traversing a [[LogicalPlan]] tree and computing some properties. */ trait LogicalPlanVisitor[T] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala index ca0775a2e8408..b6c16079d1984 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import org.apache.spark.sql.catalyst.plans.logical._ /** - * An [[LogicalPlanVisitor]] that computes a the statistics used in a cost-based optimizer. + * A [[LogicalPlanVisitor]] that computes the statistics for the cost-based optimizer. */ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala index 5e1c4e0bd6069..85f67c7d66075 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -48,8 +48,8 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { } /** - * For leaf nodes, use its computeStats. For other nodes, we assume the size in bytes is the - * sum of all of the children's. + * For leaf nodes, use its `computeStats`. For other nodes, we assume the size in bytes is the + * product of all of the children's `computeStats`. */ override def default(p: LogicalPlan): Statistics = p match { case p: LeafNode => p.computeStats() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index cc4f4bf332459..1cef09a5bf053 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -894,7 +894,7 @@ object SQLConf { .internal() .doc("The number of bins when generating histograms.") .intConf - .checkValue(num => num > 1, "The number of bins must be large than 1.") + .checkValue(num => num > 1, "The number of bins must be larger than 1.") .createWithDefault(254) val PERCENTILE_ACCURACY = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 82c5307d54360..6241d5cbb1d25 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -154,7 +154,7 @@ trait PlanTestBase extends PredicateHelper { self: Suite => } /** - * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL * configurations. */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 97f12ff625c42..5f3d4448e4e54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -311,7 +311,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { if (partitioningColumns.isDefined) { throw new AnalysisException( "insertInto() can't be used together with partitionBy(). " + - "Partition columns have already be defined for the table. " + + "Partition columns have already been defined for the table. " + "It is not necessary to use partitionBy()." ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index d3cfd2a1ffbf2..4828fa60a7b58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -327,7 +327,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } /** - * Create a [[DescribeTableCommand]] logical plan. + * Create a [[DescribeColumnCommand]] or [[DescribeTableCommand]] logical commands. */ override def visitDescribeTable(ctx: DescribeTableContext): LogicalPlan = withOrigin(ctx) { val isExtended = ctx.EXTENDED != null || ctx.FORMATTED != null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 065954559e487..6102937852347 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -58,7 +58,7 @@ trait CodegenSupport extends SparkPlan { } /** - * Whether this SparkPlan support whole stage codegen or not. + * Whether this SparkPlan supports whole stage codegen or not. */ def supportCodegen: Boolean = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index 7477d025dfe89..3c900be839aa9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -91,8 +91,8 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm if (sparkSession.conf.get(CATALOG_IMPLEMENTATION.key).equals("hive") && key.startsWith("hive.")) { logWarning(s"'SET $key=$value' might not work, since Spark doesn't support changing " + - "the Hive config dynamically. Please passing the Hive-specific config by adding the " + - s"prefix spark.hadoop (e.g., spark.hadoop.$key) when starting a Spark application. " + + "the Hive config dynamically. Please pass the Hive-specific config by adding the " + + s"prefix spark.hadoop (e.g. spark.hadoop.$key) when starting a Spark application. " + "For details, see the link: https://spark.apache.org/docs/latest/configuration.html#" + "dynamically-loading-spark-properties.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index f64e079539c4f..5dbcf4a915cbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{AtomicType, StructType} import org.apache.spark.sql.util.SchemaUtils /** - * Try to replaces [[UnresolvedRelation]]s if the plan is for direct query on files. + * Replaces [[UnresolvedRelation]]s if the plan is for direct query on files. */ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { private def maybeSQLFile(u: UnresolvedRelation): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 6e8154d58d4c6..00bc215a5dc8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -330,7 +330,7 @@ object HDFSMetadataLog { /** A simple trait to abstract out the file management operations needed by HDFSMetadataLog. */ trait FileManager { - /** List the files in a path that matches a filter. */ + /** List the files in a path that match a filter. */ def list(path: Path, filter: PathFilter): Array[FileStatus] /** Make directory at the give path and all its parent directories as needed. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index a1b63a6de3823..73945b39b8967 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.internal.SQLConf.{SHUFFLE_PARTITIONS, STATE_STORE_PR case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[OffsetSeqMetadata] = None) { /** - * Unpacks an offset into [[StreamProgress]] by associating each offset with the order list of + * Unpacks an offset into [[StreamProgress]] by associating each offset with the ordered list of * sources. * * This method is typically used to associate a serialized offset with actual sources (which diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala index e3f4abcf9f1dc..2c8d7c7b0f3c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.SparkSession /** * This class is used to log offsets to persistent files in HDFS. * Each file corresponds to a specific batch of offsets. The file - * format contain a version string in the first line, followed + * format contains a version string in the first line, followed * by a the JSON string representation of the offsets separated * by a newline character. If a source offset is missing, then * that line will contain a string value defined in the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala index 020c9cb4a7304..3f2cdadfbaeee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamingQueryProgress, StreamingQueryStatus} /** - * Wrap non-serializable StreamExecution to make the query serializable as it's easy to for it to + * Wrap non-serializable StreamExecution to make the query serializable as it's easy for it to * get captured with normal usage. It's safe to capture the query but not use it in executors. * However, if the user tries to call its methods, it will throw `IllegalStateException`. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 6fe632f958ffc..d1d9f95cb0977 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -94,7 +94,7 @@ trait StateStore { def abort(): Unit /** - * Return an iterator containing all the key-value pairs in the SateStore. Implementations must + * Return an iterator containing all the key-value pairs in the StateStore. Implementations must * ensure that updates (puts, removes) can be made while iterating over this iterator. */ def iterator(): Iterator[UnsafeRowPair] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index f29e135ac357f..e0554f0c4d337 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -80,7 +80,7 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging planVisualization(metrics, graph) ++ physicalPlanDescription(executionUIData.physicalPlanDescription) }.getOrElse { -
No information to display for Plan {executionId}
+
No information to display for query {executionId}
} UIUtils.headerSparkPage(s"Details for Query $executionId", content, parent, Some(5000)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 40a058d2cadd2..bdc4bb4422ae7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.types.DataType * * As an example: * {{{ - * // Defined a UDF that returns true or false based on some numeric score. - * val predict = udf((score: Double) => if (score > 0.5) true else false) + * // Define a UDF that returns true or false based on some numeric score. + * val predict = udf((score: Double) => score > 0.5) * * // Projects a column that adds a prediction column based on the score column. * df.select( predict(df("score")) ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 2867b4cd7da5e..007f8760edf82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -206,7 +206,7 @@ abstract class BaseSessionStateBuilder( /** * Logical query plan optimizer. * - * Note: this depends on the `conf`, `catalog` and `experimentalMethods` fields. + * Note: this depends on `catalog` and `experimentalMethods` fields. */ protected def optimizer: Optimizer = { new SparkOptimizer(catalog, experimentalMethods) { @@ -263,7 +263,7 @@ abstract class BaseSessionStateBuilder( * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s * that listen for execution metrics. * - * This gets cloned from parent if available, otherwise is a new instance is created. + * This gets cloned from parent if available, otherwise a new instance is created. */ protected def listenerManager: ExecutionListenerManager = { parentState.map(_.listenerManager.clone()).getOrElse( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 52f2e2639cd86..9f5ca9f914284 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -118,7 +118,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * You can set the following option(s): *
    *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps in the JSON/CSV datasources or partition values.
  • + * to be used to parse timestamps in the JSON/CSV data sources or partition values. *
* * @since 2.0.0 @@ -129,12 +129,12 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } /** - * Adds input options for the underlying data source. + * (Java-specific) Adds input options for the underlying data source. * * You can set the following option(s): *
    *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps in the JSON/CSV datasources or partition values.
  • + * to be used to parse timestamps in the JSON/CSV data sources or partition values. *
* * @since 2.0.0 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out index b5a4f5c2bf654..539f673c9d679 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out @@ -195,7 +195,7 @@ SELECT t1.x.y.* FROM t1 struct<> -- !query 22 output org.apache.spark.sql.AnalysisException -cannot resolve 't1.x.y.*' give input columns 'i1'; +cannot resolve 't1.x.y.*' given input columns 'i1'; -- !query 23 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out index 7c451c2aa5b5c..2092119600954 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out @@ -88,7 +88,7 @@ SELECT global_temp.view1.* FROM global_temp.view1 struct<> -- !query 10 output org.apache.spark.sql.AnalysisException -cannot resolve 'global_temp.view1.*' give input columns 'i1'; +cannot resolve 'global_temp.view1.*' given input columns 'i1'; -- !query 11 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out index d3ca4443cce55..e10f516ad6e5b 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out @@ -179,7 +179,7 @@ SELECT mydb1.t1.* FROM mydb1.t1 struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -cannot resolve 'mydb1.t1.*' give input columns 'i1'; +cannot resolve 'mydb1.t1.*' given input columns 'i1'; -- !query 22 @@ -212,7 +212,7 @@ SELECT mydb1.t1.* FROM mydb1.t1 struct<> -- !query 25 output org.apache.spark.sql.AnalysisException -cannot resolve 'mydb1.t1.*' give input columns 'i1'; +cannot resolve 'mydb1.t1.*' given input columns 'i1'; -- !query 26 @@ -420,7 +420,7 @@ SELECT mydb1.t5.* FROM mydb1.t5 struct<> -- !query 50 output org.apache.spark.sql.AnalysisException -cannot resolve 'mydb1.t5.*' give input columns 'i1, t5'; +cannot resolve 'mydb1.t5.*' given input columns 'i1, t5'; -- !query 51 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 083a0c0b1b9a0..a79ab47f0197e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1896,12 +1896,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { var e = intercept[AnalysisException] { sql("SELECT a.* FROM temp_table_no_cols a") }.getMessage - assert(e.contains("cannot resolve 'a.*' give input columns ''")) + assert(e.contains("cannot resolve 'a.*' given input columns ''")) e = intercept[AnalysisException] { dfNoCols.select($"b.*") }.getMessage - assert(e.contains("cannot resolve 'b.*' give input columns ''")) + assert(e.contains("cannot resolve 'b.*' given input columns ''")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 14082197ba0bd..ce8fde28a941c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -663,7 +663,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { }.getMessage assert(e.contains("The depth of view `default`.`view0` exceeds the maximum view " + "resolution depth (10). Analysis is aborted to avoid errors. Increase the value " + - "of spark.sql.view.maxNestedViewDepth to work aroud this.")) + "of spark.sql.view.maxNestedViewDepth to work around this.")) } val e = intercept[IllegalArgumentException] { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 632e3e0c4c3f9..3b8a8ca301c27 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -109,8 +109,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } /** - * Get the raw table metadata from hive metastore directly. The raw table metadata may contains - * special data source properties and should not be exposed outside of `HiveExternalCatalog`. We + * Get the raw table metadata from hive metastore directly. The raw table metadata may contain + * special data source properties that should not be exposed outside of `HiveExternalCatalog`. We * should interpret these special data source properties and restore the original table metadata * before returning it. */ From 51eb750263dd710434ddb60311571fa3dcec66eb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 22 Jan 2018 15:21:09 -0800 Subject: [PATCH 176/774] [SPARK-22389][SQL] data source v2 partitioning reporting interface ## What changes were proposed in this pull request? a new interface which allows data source to report partitioning and avoid shuffle at Spark side. The design is pretty like the internal distribution/partitioing framework. Spark defines a `Distribution` interfaces and several concrete implementations, and ask the data source to report a `Partitioning`, the `Partitioning` should tell Spark if it can satisfy a `Distribution` or not. ## How was this patch tested? new test Author: Wenchen Fan Closes #20201 from cloud-fan/partition-reporting. --- .../plans/physical/partitioning.scala | 2 +- .../v2/reader/ClusteredDistribution.java | 38 ++++++ .../sql/sources/v2/reader/Distribution.java | 39 +++++++ .../sql/sources/v2/reader/Partitioning.java | 46 ++++++++ .../v2/reader/SupportsReportPartitioning.java | 33 ++++++ .../v2/DataSourcePartitioning.scala | 56 +++++++++ .../datasources/v2/DataSourceV2ScanExec.scala | 9 ++ .../v2/JavaPartitionAwareDataSource.java | 110 ++++++++++++++++++ .../sql/sources/v2/DataSourceV2Suite.scala | 79 +++++++++++++ 9 files changed, 411 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 0189bd73c56bf..4d9a9925fe3ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -153,7 +153,7 @@ case class BroadcastDistribution(mode: BroadcastMode) extends Distribution { * 1. number of partitions. * 2. if it can satisfy a given distribution. */ -sealed trait Partitioning { +trait Partitioning { /** Returns the number of partitions that the data is split across */ val numPartitions: Int diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java new file mode 100644 index 0000000000000..7346500de45b6 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A concrete implementation of {@link Distribution}. Represents a distribution where records that + * share the same values for the {@link #clusteredColumns} will be produced by the same + * {@link ReadTask}. + */ +@InterfaceStability.Evolving +public class ClusteredDistribution implements Distribution { + + /** + * The names of the clustered columns. Note that they are order insensitive. + */ + public final String[] clusteredColumns; + + public ClusteredDistribution(String[] clusteredColumns) { + this.clusteredColumns = clusteredColumns; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java new file mode 100644 index 0000000000000..a6201a222f541 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * An interface to represent data distribution requirement, which specifies how the records should + * be distributed among the {@link ReadTask}s that are returned by + * {@link DataSourceV2Reader#createReadTasks()}. Note that this interface has nothing to do with + * the data ordering inside one partition(the output records of a single {@link ReadTask}). + * + * The instance of this interface is created and provided by Spark, then consumed by + * {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to + * implement this interface, but need to catch as more concrete implementations of this interface + * as possible in {@link Partitioning#satisfy(Distribution)}. + * + * Concrete implementations until now: + *
    + *
  • {@link ClusteredDistribution}
  • + *
+ */ +@InterfaceStability.Evolving +public interface Distribution {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java new file mode 100644 index 0000000000000..199e45d4a02ab --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * An interface to represent the output data partitioning for a data source, which is returned by + * {@link SupportsReportPartitioning#outputPartitioning()}. Note that this should work like a + * snapshot. Once created, it should be deterministic and always report the same number of + * partitions and the same "satisfy" result for a certain distribution. + */ +@InterfaceStability.Evolving +public interface Partitioning { + + /** + * Returns the number of partitions(i.e., {@link ReadTask}s) the data source outputs. + */ + int numPartitions(); + + /** + * Returns true if this partitioning can satisfy the given distribution, which means Spark does + * not need to shuffle the output data of this data source for some certain operations. + * + * Note that, Spark may add new concrete implementations of {@link Distribution} in new releases. + * This method should be aware of it and always return false for unrecognized distributions. It's + * recommended to check every Spark new release and support new distributions if possible, to + * avoid shuffle at Spark side for more cases. + */ + boolean satisfy(Distribution distribution); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java new file mode 100644 index 0000000000000..f786472ccf345 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A mix in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * interface to report data partitioning and try to avoid shuffle at Spark side. + */ +@InterfaceStability.Evolving +public interface SupportsReportPartitioning { + + /** + * Returns the output data partitioning that this reader guarantees. + */ + Partitioning outputPartitioning(); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala new file mode 100644 index 0000000000000..943d0100aca56 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression} +import org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.sources.v2.reader.{ClusteredDistribution, Partitioning} + +/** + * An adapter from public data source partitioning to catalyst internal `Partitioning`. + */ +class DataSourcePartitioning( + partitioning: Partitioning, + colNames: AttributeMap[String]) extends physical.Partitioning { + + override val numPartitions: Int = partitioning.numPartitions() + + override def satisfies(required: physical.Distribution): Boolean = { + super.satisfies(required) || { + required match { + case d: physical.ClusteredDistribution if isCandidate(d.clustering) => + val attrs = d.clustering.map(_.asInstanceOf[Attribute]) + partitioning.satisfy( + new ClusteredDistribution(attrs.map { a => + val name = colNames.get(a) + assert(name.isDefined, s"Attribute ${a.name} is not found in the data source output") + name.get + }.toArray)) + + case _ => false + } + } + } + + private def isCandidate(clustering: Seq[Expression]): Boolean = { + clustering.forall { + case a: Attribute => colNames.contains(a) + case _ => false + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index beb66738732be..69d871df3e1dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader._ @@ -42,6 +43,14 @@ case class DataSourceV2ScanExec( override def producedAttributes: AttributeSet = AttributeSet(fullOutput) + override def outputPartitioning: physical.Partitioning = reader match { + case s: SupportsReportPartitioning => + new DataSourcePartitioning( + s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name))) + + case _ => super.outputPartitioning + } + private lazy val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match { case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks() case _ => diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java new file mode 100644 index 0000000000000..806d0bcd93f18 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.StructType; + +public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceV2Reader, SupportsReportPartitioning { + private final StructType schema = new StructType().add("a", "int").add("b", "int"); + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List> createReadTasks() { + return java.util.Arrays.asList( + new SpecificReadTask(new int[]{1, 1, 3}, new int[]{4, 4, 6}), + new SpecificReadTask(new int[]{2, 4, 4}, new int[]{6, 2, 2})); + } + + @Override + public Partitioning outputPartitioning() { + return new MyPartitioning(); + } + } + + static class MyPartitioning implements Partitioning { + + @Override + public int numPartitions() { + return 2; + } + + @Override + public boolean satisfy(Distribution distribution) { + if (distribution instanceof ClusteredDistribution) { + String[] clusteredCols = ((ClusteredDistribution) distribution).clusteredColumns; + return Arrays.asList(clusteredCols).contains("a"); + } + + return false; + } + } + + static class SpecificReadTask implements ReadTask, DataReader { + private int[] i; + private int[] j; + private int current = -1; + + SpecificReadTask(int[] i, int[] j) { + assert i.length == j.length; + this.i = i; + this.j = j; + } + + @Override + public boolean next() throws IOException { + current += 1; + return current < i.length; + } + + @Override + public Row get() { + return new GenericRow(new Object[] {i[current], j[current]}); + } + + @Override + public void close() throws IOException { + + } + + @Override + public DataReader createDataReader() { + return this; + } + } + + @Override + public DataSourceV2Reader createReader(DataSourceV2Options options) { + return new Reader(); + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 0ca29524c6d05..0620693b35d16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -24,6 +24,7 @@ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.sources.v2.reader._ @@ -95,6 +96,40 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } + test("partitioning reporting") { + import org.apache.spark.sql.functions.{count, sum} + Seq(classOf[PartitionAwareDataSource], classOf[JavaPartitionAwareDataSource]).foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2))) + + val groupByColA = df.groupBy('a).agg(sum('b)) + checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4))) + assert(groupByColA.queryExecution.executedPlan.collectFirst { + case e: ShuffleExchangeExec => e + }.isEmpty) + + val groupByColAB = df.groupBy('a, 'b).agg(count("*")) + checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2))) + assert(groupByColAB.queryExecution.executedPlan.collectFirst { + case e: ShuffleExchangeExec => e + }.isEmpty) + + val groupByColB = df.groupBy('b).agg(sum('a)) + checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5))) + assert(groupByColB.queryExecution.executedPlan.collectFirst { + case e: ShuffleExchangeExec => e + }.isDefined) + + val groupByAPlusB = df.groupBy('a + 'b).agg(count("*")) + checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1))) + assert(groupByAPlusB.queryExecution.executedPlan.collectFirst { + case e: ShuffleExchangeExec => e + }.isDefined) + } + } + } + test("simple writable data source") { // TODO: java implementation. Seq(classOf[SimpleWritableDataSource]).foreach { cls => @@ -365,3 +400,47 @@ class BatchReadTask(start: Int, end: Int) override def close(): Unit = batch.close() } + +class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceV2Reader with SupportsReportPartitioning { + override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") + + override def createReadTasks(): JList[ReadTask[Row]] = { + // Note that we don't have same value of column `a` across partitions. + java.util.Arrays.asList( + new SpecificReadTask(Array(1, 1, 3), Array(4, 4, 6)), + new SpecificReadTask(Array(2, 4, 4), Array(6, 2, 2))) + } + + override def outputPartitioning(): Partitioning = new MyPartitioning + } + + class MyPartitioning extends Partitioning { + override def numPartitions(): Int = 2 + + override def satisfy(distribution: Distribution): Boolean = distribution match { + case c: ClusteredDistribution => c.clusteredColumns.contains("a") + case _ => false + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader +} + +class SpecificReadTask(i: Array[Int], j: Array[Int]) extends ReadTask[Row] with DataReader[Row] { + assert(i.length == j.length) + + private var current = -1 + + override def createDataReader(): DataReader[Row] = this + + override def next(): Boolean = { + current += 1 + current < i.length + } + + override def get(): Row = Row(i(current), j(current)) + + override def close(): Unit = {} +} From b2ce17b4c9fea58140a57ca1846b2689b15c0d61 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 23 Jan 2018 14:11:30 +0900 Subject: [PATCH 177/774] [SPARK-22274][PYTHON][SQL] User-defined aggregation functions with pandas udf (full shuffle) ## What changes were proposed in this pull request? Add support for using pandas UDFs with groupby().agg(). This PR introduces a new type of pandas UDF - group aggregate pandas UDF. This type of UDF defines a transformation of multiple pandas Series -> a scalar value. Group aggregate pandas UDFs can be used with groupby().agg(). Note group aggregate pandas UDF doesn't support partial aggregation, i.e., a full shuffle is required. This PR doesn't support group aggregate pandas UDFs that return ArrayType, StructType or MapType. Support for these types is left for future PR. ## How was this patch tested? GroupbyAggPandasUDFTests Author: Li Jin Closes #19872 from icexelloss/SPARK-22274-groupby-agg. --- .../spark/api/python/PythonRunner.scala | 2 + python/pyspark/rdd.py | 1 + python/pyspark/sql/functions.py | 36 +- python/pyspark/sql/group.py | 33 +- python/pyspark/sql/tests.py | 486 +++++++++++++++++- python/pyspark/sql/udf.py | 13 +- python/pyspark/worker.py | 22 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 14 +- .../sql/catalyst/expressions}/PythonUDF.scala | 31 +- .../sql/catalyst/planning/patterns.scala | 12 +- .../spark/sql/RelationalGroupedDataset.scala | 1 - .../spark/sql/execution/SparkStrategies.scala | 29 +- .../python/AggregateInPandasExec.scala | 155 ++++++ .../execution/python/ExtractPythonUDFs.scala | 16 +- .../python/UserDefinedPythonFunction.scala | 2 +- 15 files changed, 792 insertions(+), 61 deletions(-) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/python => catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions}/PythonUDF.scala (60%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 1ec0e717fac29..29148a7ee558b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -39,12 +39,14 @@ private[spark] object PythonEvalType { val SQL_PANDAS_SCALAR_UDF = 200 val SQL_PANDAS_GROUP_MAP_UDF = 201 + val SQL_PANDAS_GROUP_AGG_UDF = 202 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" case SQL_BATCHED_UDF => "SQL_BATCHED_UDF" case SQL_PANDAS_SCALAR_UDF => "SQL_PANDAS_SCALAR_UDF" case SQL_PANDAS_GROUP_MAP_UDF => "SQL_PANDAS_GROUP_MAP_UDF" + case SQL_PANDAS_GROUP_AGG_UDF => "SQL_PANDAS_GROUP_AGG_UDF" } } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 1b3915548fb14..6b018c3a38444 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -70,6 +70,7 @@ class PythonEvalType(object): SQL_PANDAS_SCALAR_UDF = 200 SQL_PANDAS_GROUP_MAP_UDF = 201 + SQL_PANDAS_GROUP_AGG_UDF = 202 def portable_hash(x): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 961b3267b44cf..a291c9b71913f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2089,6 +2089,8 @@ class PandasUDFType(object): GROUP_MAP = PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF + GROUP_AGG = PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF + @since(1.3) def udf(f=None, returnType=StringType()): @@ -2159,7 +2161,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): 1. SCALAR A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. - The returnType should be a primitive data type, e.g., `DoubleType()`. + The returnType should be a primitive data type, e.g., :class:`DoubleType`. The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and @@ -2221,6 +2223,35 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. seealso:: :meth:`pyspark.sql.GroupedData.apply` + 3. GROUP_AGG + + A group aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar + The `returnType` should be a primitive data type, e.g., :class:`DoubleType`. + The returned scalar can be either a python primitive type, e.g., `int` or `float` + or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. + + :class:`ArrayType`, :class:`MapType` and :class:`StructType` are currently not supported as + output types. + + Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf("double", PandasUDFType.GROUP_AGG) # doctest: +SKIP + ... def mean_udf(v): + ... return v.mean() + >>> df.groupby("id").agg(mean_udf(df['v'])).show() # doctest: +SKIP + +---+-----------+ + | id|mean_udf(v)| + +---+-----------+ + | 1| 1.5| + | 2| 6.0| + +---+-----------+ + + .. seealso:: :meth:`pyspark.sql.GroupedData.agg` + .. note:: The user-defined functions are considered deterministic by default. Due to optimization, duplicate invocations may be eliminated or the function may even be invoked more times than it is present in the query. If your function is not deterministic, call @@ -2267,7 +2298,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): raise ValueError("Invalid returnType: returnType can not be None") if eval_type not in [PythonEvalType.SQL_PANDAS_SCALAR_UDF, - PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF]: + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, + PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF]: raise ValueError("Invalid functionType: " "functionType must be one the values from PandasUDFType") diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 22061b83eb78c..f90a909d7c2b1 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -65,13 +65,27 @@ def __init__(self, jgd, df): def agg(self, *exprs): """Compute aggregates and returns the result as a :class:`DataFrame`. - The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`. + The available aggregate functions can be: + + 1. built-in aggregation functions, such as `avg`, `max`, `min`, `sum`, `count` + + 2. group aggregate pandas UDFs, created with :func:`pyspark.sql.functions.pandas_udf` + + .. note:: There is no partial aggregation with group aggregate UDFs, i.e., + a full shuffle is required. Also, all the data of a group will be loaded into + memory, so the user should be aware of the potential OOM risk if data is skewed + and certain groups are too large to fit in memory. + + .. seealso:: :func:`pyspark.sql.functions.pandas_udf` If ``exprs`` is a single :class:`dict` mapping from string to string, then the key is the column to perform aggregation on, and the value is the aggregate function. Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. + .. note:: Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed + in a single call to this function. + :param exprs: a dict mapping from column name (string) to aggregate functions (string), or a list of :class:`Column`. @@ -82,6 +96,13 @@ def agg(self, *exprs): >>> from pyspark.sql import functions as F >>> sorted(gdf.agg(F.min(df.age)).collect()) [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> @pandas_udf('int', PandasUDFType.GROUP_AGG) # doctest: +SKIP + ... def min_udf(v): + ... return v.min() + >>> sorted(gdf.agg(min_udf(df.age)).collect()) # doctest: +SKIP + [Row(name=u'Alice', min_udf(age)=2), Row(name=u'Bob', min_udf(age)=5)] """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): @@ -204,16 +225,18 @@ def apply(self, udf): The user-defined function should take a `pandas.DataFrame` and return another `pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame` - to the user-function and the returned `pandas.DataFrame`s are combined as a + to the user-function and the returned `pandas.DataFrame` are combined as a :class:`DataFrame`. + The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the returnType of the pandas udf. - This function does not support partial aggregation, and requires shuffling all the data in - the :class:`DataFrame`. + .. note:: This function requires a full shuffle. all the data of a group will be loaded + into memory, so the user should be aware of the potential OOM risk if data is skewed + and certain groups are too large to fit in memory. :param udf: a group map user-defined function returned by - :meth:`pyspark.sql.functions.pandas_udf`. + :func:`pyspark.sql.functions.pandas_udf`. >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4fee2ecde391b..84e8eec71dd8a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -197,6 +197,12 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() cls.spark.stop() + def assertPandasEqual(self, expected, result): + msg = ("DataFrames are not equal: " + + "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + + "\n\nResult:\n%s\n%s" % (result, result.dtypes)) + self.assertTrue(expected.equals(result), msg=msg) + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 @@ -3371,12 +3377,6 @@ def tearDownClass(cls): time.tzset() ReusedSQLTestCase.tearDownClass() - def assertFramesEqual(self, df_with_arrow, df_without): - msg = ("DataFrame from Arrow is not equal" + - ("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) + - ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes))) - self.assertTrue(df_without.equals(df_with_arrow), msg=msg) - def create_pandas_data_frame(self): import pandas as pd import numpy as np @@ -3414,7 +3414,7 @@ def _toPandas_arrow_toggle(self, df): def test_toPandas_arrow_toggle(self): df = self.spark.createDataFrame(self.data, schema=self.schema) pdf, pdf_arrow = self._toPandas_arrow_toggle(df) - self.assertFramesEqual(pdf_arrow, pdf) + self.assertPandasEqual(pdf_arrow, pdf) def test_toPandas_respect_session_timezone(self): df = self.spark.createDataFrame(self.data, schema=self.schema) @@ -3425,11 +3425,11 @@ def test_toPandas_respect_session_timezone(self): self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") try: pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) - self.assertFramesEqual(pdf_arrow_la, pdf_la) + self.assertPandasEqual(pdf_arrow_la, pdf_la) finally: self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df) - self.assertFramesEqual(pdf_arrow_ny, pdf_ny) + self.assertPandasEqual(pdf_arrow_ny, pdf_ny) self.assertFalse(pdf_ny.equals(pdf_la)) @@ -3439,7 +3439,7 @@ def test_toPandas_respect_session_timezone(self): if isinstance(field.dataType, TimestampType): pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz( pdf_la_corrected[field.name], timezone) - self.assertFramesEqual(pdf_ny, pdf_la_corrected) + self.assertPandasEqual(pdf_ny, pdf_la_corrected) finally: self.spark.conf.set("spark.sql.session.timeZone", orig_tz) @@ -3447,7 +3447,7 @@ def test_pandas_round_trip(self): pdf = self.create_pandas_data_frame() df = self.spark.createDataFrame(self.data, schema=self.schema) pdf_arrow = df.toPandas() - self.assertFramesEqual(pdf_arrow, pdf) + self.assertPandasEqual(pdf_arrow, pdf) def test_filtered_frame(self): df = self.spark.range(3).toDF("i") @@ -3505,7 +3505,7 @@ def test_createDataFrame_with_schema(self): df = self.spark.createDataFrame(pdf, schema=self.schema) self.assertEquals(self.schema, df.schema) pdf_arrow = df.toPandas() - self.assertFramesEqual(pdf_arrow, pdf) + self.assertPandasEqual(pdf_arrow, pdf) def test_createDataFrame_with_incorrect_schema(self): pdf = self.create_pandas_data_frame() @@ -3717,7 +3717,7 @@ def foo(k, v): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class VectorizedUDFTests(ReusedSQLTestCase): +class ScalarPandasUDF(ReusedSQLTestCase): @classmethod def setUpClass(cls): @@ -4196,13 +4196,7 @@ def test_register_vectorized_udf_basic(self): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class GroupbyApplyTests(ReusedSQLTestCase): - - def assertFramesEqual(self, expected, result): - msg = ("DataFrames are not equal: " + - ("\n\nExpected:\n%s\n%s" % (expected, expected.dtypes)) + - ("\n\nResult:\n%s\n%s" % (result, result.dtypes))) - self.assertTrue(expected.equals(result), msg=msg) +class GroupbyApplyPandasUDFTests(ReusedSQLTestCase): @property def data(self): @@ -4227,7 +4221,7 @@ def test_simple(self): result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_register_group_map_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4251,7 +4245,7 @@ def foo(pdf): result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_coerce(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4266,7 +4260,7 @@ def test_coerce(self): result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) expected = expected.assign(v=expected.v.astype('float64')) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_complex_groupby(self): from pyspark.sql.functions import pandas_udf, col, PandasUDFType @@ -4285,7 +4279,7 @@ def normalize(pdf): expected = pdf.groupby(pdf['id'] % 2 == 0).apply(normalize.func) expected = expected.sort_values(['id', 'v']).reset_index(drop=True) expected = expected.assign(norm=expected.norm.astype('float64')) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_empty_groupby(self): from pyspark.sql.functions import pandas_udf, col, PandasUDFType @@ -4304,7 +4298,7 @@ def normalize(pdf): expected = normalize.func(pdf) expected = expected.sort_values(['id', 'v']).reset_index(drop=True) expected = expected.assign(norm=expected.norm.astype('float64')) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_datatype_string(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4318,7 +4312,7 @@ def test_datatype_string(self): result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4370,6 +4364,446 @@ def test_unsupported_types(self): df.groupby('id').apply(f).collect() +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +class GroupbyAggPandasUDFTests(ReusedSQLTestCase): + + @property + def data(self): + from pyspark.sql.functions import array, explode, col, lit + return self.spark.range(10).toDF('id') \ + .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ + .withColumn("v", explode(col('vs'))) \ + .drop('vs') \ + .withColumn('w', lit(1.0)) + + @property + def python_plus_one(self): + from pyspark.sql.functions import udf + + @udf('double') + def plus_one(v): + assert isinstance(v, (int, float)) + return v + 1 + return plus_one + + @property + def pandas_scalar_plus_two(self): + import pandas as pd + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.SCALAR) + def plus_two(v): + assert isinstance(v, pd.Series) + return v + 2 + return plus_two + + @property + def pandas_agg_mean_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def avg(v): + return v.mean() + return avg + + @property + def pandas_agg_sum_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def sum(v): + return v.sum() + return sum + + @property + def pandas_agg_weighted_mean_udf(self): + import numpy as np + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def weighted_mean(v, w): + return np.average(v, weights=w) + return weighted_mean + + def test_manual(self): + df = self.data + sum_udf = self.pandas_agg_sum_udf + mean_udf = self.pandas_agg_mean_udf + + result1 = df.groupby('id').agg(sum_udf(df.v), mean_udf(df.v)).sort('id') + expected1 = self.spark.createDataFrame( + [[0, 245.0, 24.5], + [1, 255.0, 25.5], + [2, 265.0, 26.5], + [3, 275.0, 27.5], + [4, 285.0, 28.5], + [5, 295.0, 29.5], + [6, 305.0, 30.5], + [7, 315.0, 31.5], + [8, 325.0, 32.5], + [9, 335.0, 33.5]], + ['id', 'sum(v)', 'avg(v)']) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_basic(self): + from pyspark.sql.functions import col, lit, sum, mean + + df = self.data + weighted_mean_udf = self.pandas_agg_weighted_mean_udf + + # Groupby one column and aggregate one UDF with literal + result1 = df.groupby('id').agg(weighted_mean_udf(df.v, lit(1.0))).sort('id') + expected1 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort('id') + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + # Groupby one expression and aggregate one UDF with literal + result2 = df.groupby((col('id') + 1)).agg(weighted_mean_udf(df.v, lit(1.0)))\ + .sort(df.id + 1) + expected2 = df.groupby((col('id') + 1))\ + .agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort(df.id + 1) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + # Groupby one column and aggregate one UDF without literal + result3 = df.groupby('id').agg(weighted_mean_udf(df.v, df.w)).sort('id') + expected3 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, w)')).sort('id') + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + + # Groupby one expression and aggregate one UDF without literal + result4 = df.groupby((col('id') + 1).alias('id'))\ + .agg(weighted_mean_udf(df.v, df.w))\ + .sort('id') + expected4 = df.groupby((col('id') + 1).alias('id'))\ + .agg(mean(df.v).alias('weighted_mean(v, w)'))\ + .sort('id') + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + + def test_unsupported_types(self): + from pyspark.sql.types import ArrayType, DoubleType, MapType + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with QuietTest(self.sc): + with self.assertRaisesRegex(NotImplementedError, 'not supported'): + @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUP_AGG) + def mean_and_std_udf(v): + return [v.mean(), v.std()] + + with QuietTest(self.sc): + with self.assertRaisesRegex(NotImplementedError, 'not supported'): + @pandas_udf('mean double, std double', PandasUDFType.GROUP_AGG) + def mean_and_std_udf(v): + return v.mean(), v.std() + + with QuietTest(self.sc): + with self.assertRaisesRegex(NotImplementedError, 'not supported'): + @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUP_AGG) + def mean_and_std_udf(v): + return {v.mean(): v.std()} + + def test_alias(self): + from pyspark.sql.functions import mean + + df = self.data + mean_udf = self.pandas_agg_mean_udf + + result1 = df.groupby('id').agg(mean_udf(df.v).alias('mean_alias')) + expected1 = df.groupby('id').agg(mean(df.v).alias('mean_alias')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_mixed_sql(self): + """ + Test mixing group aggregate pandas UDF with sql expression. + """ + from pyspark.sql.functions import sum, mean + + df = self.data + sum_udf = self.pandas_agg_sum_udf + + # Mix group aggregate pandas UDF with sql expression + result1 = (df.groupby('id') + .agg(sum_udf(df.v) + 1) + .sort('id')) + expected1 = (df.groupby('id') + .agg(sum(df.v) + 1) + .sort('id')) + + # Mix group aggregate pandas UDF with sql expression (order swapped) + result2 = (df.groupby('id') + .agg(sum_udf(df.v + 1)) + .sort('id')) + + expected2 = (df.groupby('id') + .agg(sum(df.v + 1)) + .sort('id')) + + # Wrap group aggregate pandas UDF with two sql expressions + result3 = (df.groupby('id') + .agg(sum_udf(df.v + 1) + 2) + .sort('id')) + expected3 = (df.groupby('id') + .agg(sum(df.v + 1) + 2) + .sort('id')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + + def test_mixed_udfs(self): + """ + Test mixing group aggregate pandas UDF with python UDF and scalar pandas UDF. + """ + from pyspark.sql.functions import sum, mean + + df = self.data + plus_one = self.python_plus_one + plus_two = self.pandas_scalar_plus_two + sum_udf = self.pandas_agg_sum_udf + + # Mix group aggregate pandas UDF and python UDF + result1 = (df.groupby('id') + .agg(plus_one(sum_udf(df.v))) + .sort('id')) + expected1 = (df.groupby('id') + .agg(plus_one(sum(df.v))) + .sort('id')) + + # Mix group aggregate pandas UDF and python UDF (order swapped) + result2 = (df.groupby('id') + .agg(sum_udf(plus_one(df.v))) + .sort('id')) + expected2 = (df.groupby('id') + .agg(sum(plus_one(df.v))) + .sort('id')) + + # Mix group aggregate pandas UDF and scalar pandas UDF + result3 = (df.groupby('id') + .agg(sum_udf(plus_two(df.v))) + .sort('id')) + expected3 = (df.groupby('id') + .agg(sum(plus_two(df.v))) + .sort('id')) + + # Mix group aggregate pandas UDF and scalar pandas UDF (order swapped) + result4 = (df.groupby('id') + .agg(plus_two(sum_udf(df.v))) + .sort('id')) + expected4 = (df.groupby('id') + .agg(plus_two(sum(df.v))) + .sort('id')) + + # Wrap group aggregate pandas UDF with two python UDFs and use python UDF in groupby + result5 = (df.groupby(plus_one(df.id)) + .agg(plus_one(sum_udf(plus_one(df.v)))) + .sort('plus_one(id)')) + expected5 = (df.groupby(plus_one(df.id)) + .agg(plus_one(sum(plus_one(df.v)))) + .sort('plus_one(id)')) + + # Wrap group aggregate pandas UDF with two scala pandas UDF and user scala pandas UDF in + # groupby + result6 = (df.groupby(plus_two(df.id)) + .agg(plus_two(sum_udf(plus_two(df.v)))) + .sort('plus_two(id)')) + expected6 = (df.groupby(plus_two(df.id)) + .agg(plus_two(sum(plus_two(df.v)))) + .sort('plus_two(id)')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + self.assertPandasEqual(expected5.toPandas(), result5.toPandas()) + self.assertPandasEqual(expected6.toPandas(), result6.toPandas()) + + def test_multiple_udfs(self): + """ + Test multiple group aggregate pandas UDFs in one agg function. + """ + from pyspark.sql.functions import col, lit, sum, mean + + df = self.data + mean_udf = self.pandas_agg_mean_udf + sum_udf = self.pandas_agg_sum_udf + weighted_mean_udf = self.pandas_agg_weighted_mean_udf + + result1 = (df.groupBy('id') + .agg(mean_udf(df.v), + sum_udf(df.v), + weighted_mean_udf(df.v, df.w)) + .sort('id') + .toPandas()) + expected1 = (df.groupBy('id') + .agg(mean(df.v), + sum(df.v), + mean(df.v).alias('weighted_mean(v, w)')) + .sort('id') + .toPandas()) + + self.assertPandasEqual(expected1, result1) + + def test_complex_groupby(self): + from pyspark.sql.functions import lit, sum + + df = self.data + sum_udf = self.pandas_agg_sum_udf + plus_one = self.python_plus_one + plus_two = self.pandas_scalar_plus_two + + # groupby one expression + result1 = df.groupby(df.v % 2).agg(sum_udf(df.v)) + expected1 = df.groupby(df.v % 2).agg(sum(df.v)) + + # empty groupby + result2 = df.groupby().agg(sum_udf(df.v)) + expected2 = df.groupby().agg(sum(df.v)) + + # groupby one column and one sql expression + result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)) + expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)) + + # groupby one python UDF + result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v)) + expected4 = df.groupby(plus_one(df.id)).agg(sum(df.v)) + + # groupby one scalar pandas UDF + result5 = df.groupby(plus_two(df.id)).agg(sum_udf(df.v)) + expected5 = df.groupby(plus_two(df.id)).agg(sum(df.v)) + + # groupby one expression and one python UDF + result6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum_udf(df.v)) + expected6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum(df.v)) + + # groupby one expression and one scalar pandas UDF + result7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum_udf(df.v)).sort('sum(v)') + expected7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum(df.v)).sort('sum(v)') + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + self.assertPandasEqual(expected5.toPandas(), result5.toPandas()) + self.assertPandasEqual(expected6.toPandas(), result6.toPandas()) + self.assertPandasEqual(expected7.toPandas(), result7.toPandas()) + + def test_complex_expressions(self): + from pyspark.sql.functions import col, sum + + df = self.data + plus_one = self.python_plus_one + plus_two = self.pandas_scalar_plus_two + sum_udf = self.pandas_agg_sum_udf + + # Test complex expressions with sql expression, python UDF and + # group aggregate pandas UDF + result1 = (df.withColumn('v1', plus_one(df.v)) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum_udf(col('v')), + sum_udf(col('v1') + 3), + sum_udf(col('v2')) + 5, + plus_one(sum_udf(col('v1'))), + sum_udf(plus_one(col('v2')))) + .sort('id') + .toPandas()) + + expected1 = (df.withColumn('v1', df.v + 1) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum(col('v')), + sum(col('v1') + 3), + sum(col('v2')) + 5, + plus_one(sum(col('v1'))), + sum(plus_one(col('v2')))) + .sort('id') + .toPandas()) + + # Test complex expressions with sql expression, scala pandas UDF and + # group aggregate pandas UDF + result2 = (df.withColumn('v1', plus_one(df.v)) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum_udf(col('v')), + sum_udf(col('v1') + 3), + sum_udf(col('v2')) + 5, + plus_two(sum_udf(col('v1'))), + sum_udf(plus_two(col('v2')))) + .sort('id') + .toPandas()) + + expected2 = (df.withColumn('v1', df.v + 1) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum(col('v')), + sum(col('v1') + 3), + sum(col('v2')) + 5, + plus_two(sum(col('v1'))), + sum(plus_two(col('v2')))) + .sort('id') + .toPandas()) + + # Test sequential groupby aggregate + result3 = (df.groupby('id') + .agg(sum_udf(df.v).alias('v')) + .groupby('id') + .agg(sum_udf(col('v'))) + .sort('id') + .toPandas()) + + expected3 = (df.groupby('id') + .agg(sum(df.v).alias('v')) + .groupby('id') + .agg(sum(col('v'))) + .sort('id') + .toPandas()) + + self.assertPandasEqual(expected1, result1) + self.assertPandasEqual(expected2, result2) + self.assertPandasEqual(expected3, result3) + + def test_retain_group_columns(self): + from pyspark.sql.functions import sum, lit, col + orig_value = self.spark.conf.get("spark.sql.retainGroupColumns", None) + self.spark.conf.set("spark.sql.retainGroupColumns", False) + try: + df = self.data + sum_udf = self.pandas_agg_sum_udf + + result1 = df.groupby(df.id).agg(sum_udf(df.v)) + expected1 = df.groupby(df.id).agg(sum(df.v)) + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + finally: + if orig_value is None: + self.spark.conf.unset("spark.sql.retainGroupColumns") + else: + self.spark.conf.set("spark.sql.retainGroupColumns", orig_value) + + def test_invalid_args(self): + from pyspark.sql.functions import mean + + df = self.data + plus_one = self.python_plus_one + mean_udf = self.pandas_agg_mean_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + 'nor.*aggregate function'): + df.groupby(df.id).agg(plus_one(df.v)).collect() + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + 'aggregate function.*argument.*aggregate function'): + df.groupby(df.id).agg(mean_udf(mean_udf(df.v))).collect() + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + 'mixture.*aggregate function.*group aggregate pandas UDF'): + df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 134badb8485f5..de96846c5c774 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -22,7 +22,8 @@ from pyspark import SparkContext, since from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column, _to_seq -from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string +from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \ + _parse_datatype_string __all__ = ["UDFRegistration"] @@ -36,8 +37,10 @@ def _wrap_function(sc, func, returnType): def _create_udf(f, returnType, evalType): - if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF or \ - evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + if evalType in (PythonEvalType.SQL_PANDAS_SCALAR_UDF, + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, + PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF): + import inspect from pyspark.sql.utils import require_minimum_pyarrow_version @@ -113,6 +116,10 @@ def returnType(self): and not isinstance(self._returnType_placeholder, StructType): raise ValueError("Invalid returnType: returnType must be a StructType for " "pandas_udf with function type GROUP_MAP") + elif self.evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF \ + and isinstance(self._returnType_placeholder, (StructType, ArrayType, MapType)): + raise NotImplementedError( + "ArrayType, StructType and MapType are not supported with PandasUDFType.GROUP_AGG") return self._returnType_placeholder diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index e6737ae1c1285..173d8fb2856fa 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -110,6 +110,17 @@ def wrapped(*series): return wrapped +def wrap_pandas_group_agg_udf(f, return_type): + arrow_return_type = to_arrow_type(return_type) + + def wrapped(*series): + import pandas as pd + result = f(*series) + return pd.Series(result) + + return lambda *a: (wrapped(*a), arrow_return_type) + + def read_single_udf(pickleSer, infile, eval_type): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] @@ -126,8 +137,12 @@ def read_single_udf(pickleSer, infile, eval_type): return arg_offsets, wrap_pandas_scalar_udf(row_func, return_type) elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: return arg_offsets, wrap_pandas_group_map_udf(row_func, return_type) - else: + elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF: + return arg_offsets, wrap_pandas_group_agg_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_BATCHED_UDF: return arg_offsets, wrap_udf(row_func, return_type) + else: + raise ValueError("Unknown eval type: {}".format(eval_type)) def read_udfs(pickleSer, infile, eval_type): @@ -148,8 +163,9 @@ def read_udfs(pickleSer, infile, eval_type): func = lambda _, it: map(mapper, it) - if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF \ - or eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + if eval_type in (PythonEvalType.SQL_PANDAS_SCALAR_UDF, + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, + PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF): timezone = utf8_deserializer.loads(infile) ser = ArrowStreamPandasSerializer(timezone) else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index bbcec5627bd49..ef91d79f3302c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -153,11 +153,19 @@ trait CheckAnalysis extends PredicateHelper { s"of type ${condition.dataType.simpleString} is not a boolean.") case Aggregate(groupingExprs, aggregateExprs, child) => + def isAggregateExpression(expr: Expression) = { + expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupAggPandasUDF(expr) + } + def checkValidAggregateExpression(expr: Expression): Unit = expr match { - case aggExpr: AggregateExpression => - aggExpr.aggregateFunction.children.foreach { child => + case expr: Expression if isAggregateExpression(expr) => + val aggFunction = expr match { + case agg: AggregateExpression => agg.aggregateFunction + case udf: PythonUDF => udf + } + aggFunction.children.foreach { child => child.foreach { - case agg: AggregateExpression => + case expr: Expression if isAggregateExpression(expr) => failAnalysis( s"It is not allowed to use an aggregate function in the argument of " + s"another aggregate function. Please use the inner aggregate function " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala similarity index 60% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index d3f743d9eb61e..4ba8ff6e3802f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -15,12 +15,31 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.python +package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.api.python.PythonFunction -import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable, UserDefinedExpression} +import org.apache.spark.api.python.{PythonEvalType, PythonFunction} +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.types.DataType +/** + * Helper functions for [[PythonUDF]] + */ +object PythonUDF { + private[this] val SCALAR_TYPES = Set( + PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_PANDAS_SCALAR_UDF + ) + + def isScalarPythonUDF(e: Expression): Boolean = { + e.isInstanceOf[PythonUDF] && SCALAR_TYPES.contains(e.asInstanceOf[PythonUDF].evalType) + } + + def isGroupAggPandasUDF(e: Expression): Boolean = { + e.isInstanceOf[PythonUDF] && + e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF + } +} + /** * A serialized version of a Python lambda function. */ @@ -30,12 +49,16 @@ case class PythonUDF( dataType: DataType, children: Seq[Expression], evalType: Int, - udfDeterministic: Boolean) + udfDeterministic: Boolean, + resultId: ExprId = NamedExpression.newExprId) extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) override def toString: String = s"$name(${children.mkString(", ")})" + lazy val resultAttribute: Attribute = AttributeReference(toPrettySQL(this), dataType, nullable)( + exprId = resultId) + override def nullable: Boolean = true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index cc391aae55787..132241061d510 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.planning +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -199,7 +200,7 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { object PhysicalAggregation { // groupingExpressions, aggregateExpressions, resultExpressions, child type ReturnType = - (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan) + (Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) def unapply(a: Any): Option[ReturnType] = a match { case logical.Aggregate(groupingExpressions, resultExpressions, child) => @@ -213,7 +214,10 @@ object PhysicalAggregation { expr.collect { // addExpr() always returns false for non-deterministic expressions and do not add them. case agg: AggregateExpression - if (!equivalentAggregateExpressions.addExpr(agg)) => agg + if !equivalentAggregateExpressions.addExpr(agg) => agg + case udf: PythonUDF + if PythonUDF.isGroupAggPandasUDF(udf) && + !equivalentAggregateExpressions.addExpr(udf) => udf } } @@ -241,6 +245,10 @@ object PhysicalAggregation { // so replace each aggregate expression by its corresponding attribute in the set: equivalentAggregateExpressions.getEquivalentExprs(ae).headOption .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute + // Similar to AggregateExpression + case ue: PythonUDF if PythonUDF.isGroupAggPandasUDF(ue) => + equivalentAggregateExpressions.getEquivalentExprs(ue).headOption + .getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute case expression => // Since we're using `namedGroupingAttributes` to extract the grouping key // columns, we need to replace grouping key expressions with their corresponding diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index a009c00b0abc5..d320c1c359411 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.execution.python.PythonUDF import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{NumericType, StructType} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 910294853c318..ce512bc46563a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.{execution, AnalysisException, Strategy} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -288,9 +289,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case PhysicalAggregation( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => + if (aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF)) { + throw new AnalysisException( + "Streaming aggregation doesn't support group aggregate pandas UDF") + } + aggregate.AggUtils.planStreamingAggregation( namedGroupingExpressions, - aggregateExpressions, + aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), rewrittenResultExpressions, planLater(child)) @@ -333,8 +339,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object Aggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalAggregation( - groupingExpressions, aggregateExpressions, resultExpressions, child) => + case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) + if aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression]) => + val aggregateExpressions = aggExpressions.map(expr => + expr.asInstanceOf[AggregateExpression]) val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct) @@ -363,6 +371,21 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { aggregateOperator + case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) + if aggExpressions.forall(expr => expr.isInstanceOf[PythonUDF]) => + val udfExpressions = aggExpressions.map(expr => expr.asInstanceOf[PythonUDF]) + + Seq(execution.python.AggregateInPandasExec( + groupingExpressions, + udfExpressions, + resultExpressions, + planLater(child))) + + case PhysicalAggregation(_, _, _, _) => + // If cannot match the two cases above, then it's an error + throw new AnalysisException( + "Cannot use a mixture of aggregate function and group aggregate pandas UDF") + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala new file mode 100644 index 0000000000000..18e5f8605c60d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.util.Utils + +/** + * Physical node for aggregation with group aggregate Pandas UDF. + * + * This plan works by sending the necessary (projected) input grouped data as Arrow record batches + * to the python worker, the python worker invokes the UDF and sends the results to the executor, + * finally the executor evaluates any post-aggregation expressions and join the result with the + * grouped key. + */ +case class AggregateInPandasExec( + groupingExpressions: Seq[NamedExpression], + udfExpressions: Seq[PythonUDF], + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryExecNode { + + override val output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def requiredChildDistribution: Seq[Distribution] = { + if (groupingExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { + udf.children match { + case Seq(u: PythonUDF) => + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + case children => + // There should not be any other UDFs, or the children can't be evaluated directly. + assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingExpressions.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute() + + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + val sessionLocalTimeZone = conf.sessionLocalTimeZone + val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + + val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip + + // Filter child output attributes down to only those that are UDF inputs. + // Also eliminate duplicate UDF inputs. + val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + }.toArray + }.toArray + + // Schema of input rows to the python runner + val aggInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => + StructField(s"_$i", dt) + }) + + inputRDD.mapPartitionsInternal { iter => + val prunedProj = UnsafeProjection.create(allInputs, child.output) + + val grouped = if (groupingExpressions.isEmpty) { + // Use an empty unsafe row as a place holder for the grouping key + Iterator((new UnsafeRow(), iter)) + } else { + GroupedIterator(iter, groupingExpressions, child.output) + }.map { case (key, rows) => + (key, rows.map(prunedProj)) + } + + val context = TaskContext.get() + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = HybridRowQueue(context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingExpressions.length) + context.addTaskCompletionListener { _ => + queue.close() + } + + // Add rows to queue to join later with the result. + val projectedRowIter = grouped.map { case (groupingKey, rows) => + queue.add(groupingKey.asInstanceOf[UnsafeRow]) + rows + } + + val columnarBatchIter = new ArrowPythonRunner( + pyFuncs, bufferSize, reuseWorker, + PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF, argOffsets, aggInputSchema, + sessionLocalTimeZone, pandasRespectSessionTimeZone) + .compute(projectedRowIter, context.partitionId(), context) + + val joinedAttributes = + groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) + val joined = new JoinedRow + val resultProj = UnsafeProjection.create(resultExpressions, joinedAttributes) + + columnarBatchIter.map(_.rowIterator.next()).map { aggOutputRow => + val leftRow = queue.remove() + val joinedRow = joined(leftRow, aggOutputRow) + resultProj(joinedRow) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 2f53fe788c7d0..1862e3f6e12ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -39,12 +39,13 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { */ private def belongAggregate(e: Expression, agg: Aggregate): Boolean = { e.isInstanceOf[AggregateExpression] || + PythonUDF.isGroupAggPandasUDF(e) || agg.groupingExpressions.exists(_.semanticEquals(e)) } private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = { expr.find { - e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, agg)).isDefined + e => PythonUDF.isScalarPythonUDF(e) && e.find(belongAggregate(_, agg)).isDefined }.isDefined } @@ -93,7 +94,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { private def hasPythonUDF(e: Expression): Boolean = { - e.find(_.isInstanceOf[PythonUDF]).isDefined + e.find(PythonUDF.isScalarPythonUDF).isDefined } private def canEvaluateInPython(e: PythonUDF): Boolean = { @@ -106,12 +107,12 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { - case udf: PythonUDF if canEvaluateInPython(udf) => Seq(udf) + case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf) case e => e.children.flatMap(collectEvaluatableUDF) } def apply(plan: SparkPlan): SparkPlan = plan transformUp { - // FlatMapGroupsInPandas can be evaluated directly in python worker + // AggregateInPandasExec and FlatMapGroupsInPandas can be evaluated directly in python worker // Therefore we don't need to extract the UDFs case plan: FlatMapGroupsInPandasExec => plan case plan: SparkPlan => extract(plan) @@ -149,10 +150,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { udf.references.subsetOf(child.outputSet) } if (validUdfs.nonEmpty) { - require(validUdfs.forall(udf => - udf.evalType == PythonEvalType.SQL_BATCHED_UDF || - udf.evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF - ), "Can only extract scalar vectorized udf or sql batch udf") + require( + validUdfs.forall(PythonUDF.isScalarPythonUDF), + "Can only extract scalar vectorized udf or sql batch udf") val resultAttrs = udfs.zipWithIndex.map { case (u, i) => AttributeReference(s"pythonUDF$i", u.dataType)() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 50dca32cb7861..f4c2d02ee9420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.api.python.PythonFunction import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDF} import org.apache.spark.sql.types.DataType /** From 96cb60bc33936c1aaf728a1738781073891480ff Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 23 Jan 2018 04:08:32 -0800 Subject: [PATCH 178/774] [SPARK-22465][FOLLOWUP] Update the number of partitions of default partitioner when defaultParallelism is set ## What changes were proposed in this pull request? #20002 purposed a way to safe check the default partitioner, however, if `spark.default.parallelism` is set, the defaultParallelism still could be smaller than the proper number of partitions for upstreams RDDs. This PR tries to extend the approach to address the condition when `spark.default.parallelism` is set. The requirements where the PR helps with are : - Max partitioner is not eligible since it is atleast an order smaller, and - User has explicitly set 'spark.default.parallelism', and - Value of 'spark.default.parallelism' is lower than max partitioner - Since max partitioner was discarded due to being at least an order smaller, default parallelism is worse - even though user specified. Under the rest cases, the changes should be no-op. ## How was this patch tested? Add corresponding test cases in `PairRDDFunctionsSuite` and `PartitioningSuite`. Author: Xingbo Jiang Closes #20091 from jiangxb1987/partitioner. --- .../scala/org/apache/spark/Partitioner.scala | 51 ++++++++++--------- .../org/apache/spark/PartitioningSuite.scala | 44 +++++++++++++--- .../spark/rdd/PairRDDFunctionsSuite.scala | 45 +++++++++++++++- 3 files changed, 108 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 437bbaae1968b..c940cb25d478b 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -43,17 +43,19 @@ object Partitioner { /** * Choose a partitioner to use for a cogroup-like operation between a number of RDDs. * - * If any of the RDDs already has a partitioner, and the number of partitions of the - * partitioner is either greater than or is less than and within a single order of - * magnitude of the max number of upstream partitions, choose that one. + * If spark.default.parallelism is set, we'll use the value of SparkContext defaultParallelism + * as the default partitions number, otherwise we'll use the max number of upstream partitions. * - * Otherwise, we use a default HashPartitioner. For the number of partitions, if - * spark.default.parallelism is set, then we'll use the value from SparkContext - * defaultParallelism, otherwise we'll use the max number of upstream partitions. + * When available, we choose the partitioner from rdds with maximum number of partitions. If this + * partitioner is eligible (number of partitions within an order of maximum number of partitions + * in rdds), or has partition number higher than default partitions number - we use this + * partitioner. * - * Unless spark.default.parallelism is set, the number of partitions will be the - * same as the number of partitions in the largest upstream RDD, as this should - * be least likely to cause out-of-memory errors. + * Otherwise, we'll use a new HashPartitioner with the default partitions number. + * + * Unless spark.default.parallelism is set, the number of partitions will be the same as the + * number of partitions in the largest upstream RDD, as this should be least likely to cause + * out-of-memory errors. * * We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD. */ @@ -67,31 +69,32 @@ object Partitioner { None } - if (isEligiblePartitioner(hasMaxPartitioner, rdds)) { + val defaultNumPartitions = if (rdd.context.conf.contains("spark.default.parallelism")) { + rdd.context.defaultParallelism + } else { + rdds.map(_.partitions.length).max + } + + // If the existing max partitioner is an eligible one, or its partitions number is larger + // than the default number of partitions, use the existing partitioner. + if (hasMaxPartitioner.nonEmpty && (isEligiblePartitioner(hasMaxPartitioner.get, rdds) || + defaultNumPartitions < hasMaxPartitioner.get.getNumPartitions)) { hasMaxPartitioner.get.partitioner.get } else { - if (rdd.context.conf.contains("spark.default.parallelism")) { - new HashPartitioner(rdd.context.defaultParallelism) - } else { - new HashPartitioner(rdds.map(_.partitions.length).max) - } + new HashPartitioner(defaultNumPartitions) } } /** - * Returns true if the number of partitions of the RDD is either greater - * than or is less than and within a single order of magnitude of the - * max number of upstream partitions; - * otherwise, returns false + * Returns true if the number of partitions of the RDD is either greater than or is less than and + * within a single order of magnitude of the max number of upstream partitions, otherwise returns + * false. */ private def isEligiblePartitioner( - hasMaxPartitioner: Option[RDD[_]], + hasMaxPartitioner: RDD[_], rdds: Seq[RDD[_]]): Boolean = { - if (hasMaxPartitioner.isEmpty) { - return false - } val maxPartitions = rdds.map(_.partitions.length).max - log10(maxPartitions) - log10(hasMaxPartitioner.get.getNumPartitions) < 1 + log10(maxPartitions) - log10(hasMaxPartitioner.getNumPartitions) < 1 } } diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 155ca17db726b..9206b5debf4f3 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -262,14 +262,11 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva test("defaultPartitioner") { val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 150) - val rdd2 = sc - .parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + val rdd2 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) .partitionBy(new HashPartitioner(10)) - val rdd3 = sc - .parallelize(Array((1, 6), (7, 8), (3, 10), (5, 12), (13, 14))) + val rdd3 = sc.parallelize(Array((1, 6), (7, 8), (3, 10), (5, 12), (13, 14))) .partitionBy(new HashPartitioner(100)) - val rdd4 = sc - .parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + val rdd4 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) .partitionBy(new HashPartitioner(9)) val rdd5 = sc.parallelize((1 to 10).map(x => (x, x)), 11) @@ -284,7 +281,42 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva assert(partitioner3.numPartitions == rdd3.getNumPartitions) assert(partitioner4.numPartitions == rdd3.getNumPartitions) assert(partitioner5.numPartitions == rdd4.getNumPartitions) + } + test("defaultPartitioner when defaultParallelism is set") { + assert(!sc.conf.contains("spark.default.parallelism")) + try { + sc.conf.set("spark.default.parallelism", "4") + + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 150) + val rdd2 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + .partitionBy(new HashPartitioner(10)) + val rdd3 = sc.parallelize(Array((1, 6), (7, 8), (3, 10), (5, 12), (13, 14))) + .partitionBy(new HashPartitioner(100)) + val rdd4 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + .partitionBy(new HashPartitioner(9)) + val rdd5 = sc.parallelize((1 to 10).map(x => (x, x)), 11) + val rdd6 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + .partitionBy(new HashPartitioner(3)) + + val partitioner1 = Partitioner.defaultPartitioner(rdd1, rdd2) + val partitioner2 = Partitioner.defaultPartitioner(rdd2, rdd3) + val partitioner3 = Partitioner.defaultPartitioner(rdd3, rdd1) + val partitioner4 = Partitioner.defaultPartitioner(rdd1, rdd2, rdd3) + val partitioner5 = Partitioner.defaultPartitioner(rdd4, rdd5) + val partitioner6 = Partitioner.defaultPartitioner(rdd5, rdd5) + val partitioner7 = Partitioner.defaultPartitioner(rdd1, rdd6) + + assert(partitioner1.numPartitions == rdd2.getNumPartitions) + assert(partitioner2.numPartitions == rdd3.getNumPartitions) + assert(partitioner3.numPartitions == rdd3.getNumPartitions) + assert(partitioner4.numPartitions == rdd3.getNumPartitions) + assert(partitioner5.numPartitions == rdd4.getNumPartitions) + assert(partitioner6.numPartitions == sc.defaultParallelism) + assert(partitioner7.numPartitions == sc.defaultParallelism) + } finally { + sc.conf.remove("spark.default.parallelism") + } } } diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index a39e0469272fe..47af5c3320dd9 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -322,8 +322,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } // See SPARK-22465 - test("cogroup between multiple RDD" + - " with number of partitions similar in order of magnitude") { + test("cogroup between multiple RDD with number of partitions similar in order of magnitude") { val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 20) val rdd2 = sc .parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) @@ -332,6 +331,48 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { assert(joined.getNumPartitions == rdd2.getNumPartitions) } + test("cogroup between multiple RDD when defaultParallelism is set without proper partitioner") { + assert(!sc.conf.contains("spark.default.parallelism")) + try { + sc.conf.set("spark.default.parallelism", "4") + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 20) + val rdd2 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)), 10) + val joined = rdd1.cogroup(rdd2) + assert(joined.getNumPartitions == sc.defaultParallelism) + } finally { + sc.conf.remove("spark.default.parallelism") + } + } + + test("cogroup between multiple RDD when defaultParallelism is set with proper partitioner") { + assert(!sc.conf.contains("spark.default.parallelism")) + try { + sc.conf.set("spark.default.parallelism", "4") + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 20) + val rdd2 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + .partitionBy(new HashPartitioner(10)) + val joined = rdd1.cogroup(rdd2) + assert(joined.getNumPartitions == rdd2.getNumPartitions) + } finally { + sc.conf.remove("spark.default.parallelism") + } + } + + test("cogroup between multiple RDD when defaultParallelism is set; with huge number of " + + "partitions in upstream RDDs") { + assert(!sc.conf.contains("spark.default.parallelism")) + try { + sc.conf.set("spark.default.parallelism", "4") + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 1000) + val rdd2 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + .partitionBy(new HashPartitioner(10)) + val joined = rdd1.cogroup(rdd2) + assert(joined.getNumPartitions == rdd2.getNumPartitions) + } finally { + sc.conf.remove("spark.default.parallelism") + } + } + test("rightOuterJoin") { val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) From ee572ba8c1339d21c592001ec4f7f270005ff1cf Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 23 Jan 2018 21:36:20 +0900 Subject: [PATCH 179/774] [SPARK-20749][SQL][FOLLOW-UP] Override prettyName for bit_length and octet_length ## What changes were proposed in this pull request? We need to override the prettyName for bit_length and octet_length for getting the expected auto-generated alias name. ## How was this patch tested? The existing tests Author: gatorsmile Closes #20358 from gatorsmile/test2.3More. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../expressions/stringExpressions.scala | 4 ++ .../sql-tests/results/operators.sql.out | 4 +- .../scalar-subquery-predicate.sql.out | 45 ++++++++++--------- 4 files changed, 30 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 39d5e4ed56628..5fa75fe348e68 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -141,7 +141,7 @@ statement (LIKE? pattern=STRING)? #showTables | SHOW TABLE EXTENDED ((FROM | IN) db=identifier)? LIKE pattern=STRING partitionSpec? #showTable - | SHOW DATABASES (LIKE? pattern=STRING)? #showDatabases + | SHOW DATABASES (LIKE? pattern=STRING)? #showDatabases | SHOW TBLPROPERTIES table=tableIdentifier ('(' key=tablePropertyKey ')')? #showTblProperties | SHOW COLUMNS (FROM | IN) tableIdentifier diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index e004bfc6af473..5cf783f1a5979 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1708,6 +1708,8 @@ case class BitLength(child: Expression) extends UnaryExpression with ImplicitCas case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length * 8") } } + + override def prettyName: String = "bit_length" } /** @@ -1735,6 +1737,8 @@ case class OctetLength(child: Expression) extends UnaryExpression with ImplicitC case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") } } + + override def prettyName: String = "octet_length" } /** diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 237b618a8b904..840655b7a6447 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -425,7 +425,7 @@ struct<(7 % 2):int,(7 % 0):int,(0 % 2):int,(7 % CAST(NULL AS INT)):int,(CAST(NUL -- !query 51 select BIT_LENGTH('abc') -- !query 51 schema -struct +struct -- !query 51 output 24 @@ -449,7 +449,7 @@ struct -- !query 54 select OCTET_LENGTH('abc') -- !query 54 schema -struct +struct -- !query 54 output 3 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out index a2b86db3e4f4c..dd82efba0dde1 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 29 +-- Number of queries: 27 -- !query 0 @@ -307,7 +307,8 @@ struct val1c val1d --- !query 22 + +-- !query 20 SELECT count(t1a) FROM t1 RIGHT JOIN t2 ON t1d = t2d @@ -315,13 +316,13 @@ WHERE t1a < (SELECT max(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 22 schema +-- !query 20 schema struct --- !query 22 output +-- !query 20 output 7 --- !query 23 +-- !query 21 SELECT t1a FROM t1 WHERE t1b <= (SELECT max(t2b) @@ -332,14 +333,14 @@ AND t1b >= (SELECT min(t2b) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 23 schema +-- !query 21 schema struct --- !query 23 output +-- !query 21 output val1b val1c --- !query 24 +-- !query 22 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -353,14 +354,14 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 24 schema +-- !query 22 schema struct --- !query 24 output +-- !query 22 output val1b val1c --- !query 25 +-- !query 23 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -374,9 +375,9 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 25 schema +-- !query 23 schema struct --- !query 25 output +-- !query 23 output val1a val1a val1b @@ -387,7 +388,7 @@ val1d val1d --- !query 26 +-- !query 24 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -401,16 +402,16 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 26 schema +-- !query 24 schema struct --- !query 26 output +-- !query 24 output val1a val1b val1c val1d --- !query 27 +-- !query 25 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -424,13 +425,13 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 27 schema +-- !query 25 schema struct --- !query 27 output +-- !query 25 output val1a --- !query 28 +-- !query 26 SELECT t1a FROM t1 GROUP BY t1a, t1c @@ -438,8 +439,8 @@ HAVING max(t1b) <= (SELECT max(t2b) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 28 schema +-- !query 26 schema struct --- !query 28 output +-- !query 26 output val1b val1c From bdebb8e48eafcca0382d1a3173b2f3ce969abab3 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 23 Jan 2018 10:12:13 -0800 Subject: [PATCH 180/774] [SPARK-20664][SPARK-23103][CORE] Follow-up: remove workaround for . Author: Marcelo Vanzin Closes #20353 from vanzin/SPARK-20664. --- .../apache/spark/deploy/history/FsHistoryProviderSuite.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 787de59edf465..fde5f25bce456 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -716,9 +716,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } test("SPARK-21571: clean up removes invalid history files") { - // TODO: "maxTime" becoming negative in cleanLogs() causes this test to fail, so avoid that - // until we figure out what's causing the problem. - val clock = new ManualClock(TimeUnit.DAYS.toMillis(120)) + val clock = new ManualClock() val conf = createTestConf().set(MAX_LOG_AGE_S.key, s"2d") val provider = new FsHistoryProvider(conf, clock) { override def getNewLastScanTime(): Long = clock.getTimeMillis() From dc4761fd8f0eec1d001e53837e65f7c5fe4e248d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 23 Jan 2018 12:51:40 -0800 Subject: [PATCH 181/774] [SPARK-17088][HIVE] Fix 'sharesHadoopClasses' option when creating client. Because the call to the constructor of HiveClientImpl crosses class loader boundaries, different versions of the same class (Configuration in this case) were loaded, and that caused a runtime error when instantiating the client. By using a safer type in the signature of the constructor, it's possible to avoid the problem. I considered removing 'sharesHadoopClasses', but it may still be desired (even though there are 0 users of it since it was not working). When Spark starts to support Hadoop 3, it may be necessary to use that option to load clients for older Hive metastore versions that don't know about Hadoop 3. Tested with added unit test. Author: Marcelo Vanzin Closes #20169 from vanzin/SPARK-17088. --- .../spark/sql/hive/client/HiveClientImpl.scala | 8 +++++--- .../sql/hive/client/IsolatedClientLoader.scala | 16 ++++++++++------ .../sql/hive/client/HiveClientBuilder.scala | 6 ++++-- .../spark/sql/hive/client/HiveClientSuite.scala | 4 ++++ .../spark/sql/hive/client/HiveVersionSuite.scala | 11 ++++++++--- 5 files changed, 31 insertions(+), 14 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 4b923f5235a90..39d839059be75 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.hive.client import java.io.{File, PrintStream} -import java.util.Locale +import java.lang.{Iterable => JIterable} +import java.util.{Locale, Map => JMap} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -82,8 +83,9 @@ import org.apache.spark.util.{CircularBuffer, Utils} */ private[hive] class HiveClientImpl( override val version: HiveVersion, + warehouseDir: Option[String], sparkConf: SparkConf, - hadoopConf: Configuration, + hadoopConf: JIterable[JMap.Entry[String, String]], extraConfig: Map[String, String], initClassLoader: ClassLoader, val clientLoader: IsolatedClientLoader) @@ -130,7 +132,7 @@ private[hive] class HiveClientImpl( if (ret != null) { // hive.metastore.warehouse.dir is determined in SharedState after the CliSessionState // instance constructed, we need to follow that change here. - Option(hadoopConf.get(ConfVars.METASTOREWAREHOUSE.varname)).foreach { dir => + warehouseDir.foreach { dir => ret.getConf.setVar(ConfVars.METASTOREWAREHOUSE, dir) } ret diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 7a76fd3fd2eb3..dac0e333b63bc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -26,6 +26,7 @@ import scala.util.Try import org.apache.commons.io.{FileUtils, IOUtils} import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkSubmitUtils @@ -48,11 +49,12 @@ private[hive] object IsolatedClientLoader extends Logging { config: Map[String, String] = Map.empty, ivyPath: Option[String] = None, sharedPrefixes: Seq[String] = Seq.empty, - barrierPrefixes: Seq[String] = Seq.empty): IsolatedClientLoader = synchronized { + barrierPrefixes: Seq[String] = Seq.empty, + sharesHadoopClasses: Boolean = true): IsolatedClientLoader = synchronized { val resolvedVersion = hiveVersion(hiveMetastoreVersion) // We will first try to share Hadoop classes. If we cannot resolve the Hadoop artifact // with the given version, we will use Hadoop 2.6 and then will not share Hadoop classes. - var sharesHadoopClasses = true + var _sharesHadoopClasses = sharesHadoopClasses val files = if (resolvedVersions.contains((resolvedVersion, hadoopVersion))) { resolvedVersions((resolvedVersion, hadoopVersion)) } else { @@ -68,7 +70,7 @@ private[hive] object IsolatedClientLoader extends Logging { "Hadoop classes will not be shared between Spark and Hive metastore client. " + "It is recommended to set jars used by Hive metastore client through " + "spark.sql.hive.metastore.jars in the production environment.") - sharesHadoopClasses = false + _sharesHadoopClasses = false (downloadVersion(resolvedVersion, "2.6.5", ivyPath), "2.6.5") } resolvedVersions.put((resolvedVersion, actualHadoopVersion), downloadedFiles) @@ -81,7 +83,7 @@ private[hive] object IsolatedClientLoader extends Logging { execJars = files, hadoopConf = hadoopConf, config = config, - sharesHadoopClasses = sharesHadoopClasses, + sharesHadoopClasses = _sharesHadoopClasses, sharedPrefixes = sharedPrefixes, barrierPrefixes = barrierPrefixes) } @@ -249,8 +251,10 @@ private[hive] class IsolatedClientLoader( /** The isolated client interface to Hive. */ private[hive] def createClient(): HiveClient = synchronized { + val warehouseDir = Option(hadoopConf.get(ConfVars.METASTOREWAREHOUSE.varname)) if (!isolationOn) { - return new HiveClientImpl(version, sparkConf, hadoopConf, config, baseClassLoader, this) + return new HiveClientImpl(version, warehouseDir, sparkConf, hadoopConf, config, + baseClassLoader, this) } // Pre-reflective instantiation setup. logDebug("Initializing the logger to avoid disaster...") @@ -261,7 +265,7 @@ private[hive] class IsolatedClientLoader( classLoader .loadClass(classOf[HiveClientImpl].getName) .getConstructors.head - .newInstance(version, sparkConf, hadoopConf, config, classLoader, this) + .newInstance(version, warehouseDir, sparkConf, hadoopConf, config, classLoader, this) .asInstanceOf[HiveClient] } catch { case e: InvocationTargetException => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala index ae804ce7c7b07..ab73f668c6ca6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala @@ -46,13 +46,15 @@ private[client] object HiveClientBuilder { def buildClient( version: String, hadoopConf: Configuration, - extraConf: Map[String, String] = Map.empty): HiveClient = { + extraConf: Map[String, String] = Map.empty, + sharesHadoopClasses: Boolean = true): HiveClient = { IsolatedClientLoader.forVersion( hiveMetastoreVersion = version, hadoopVersion = VersionInfo.getVersion, sparkConf = new SparkConf(), hadoopConf = hadoopConf, config = buildConf(extraConf), - ivyPath = ivyPath).createClient() + ivyPath = ivyPath, + sharesHadoopClasses = sharesHadoopClasses).createClient() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index a5dfd89b3a574..f991352b207d4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -202,6 +202,10 @@ class HiveClientSuite(version: String) day1 :: day2 :: Nil) } + test("create client with sharesHadoopClasses = false") { + buildClient(new Configuration(), sharesHadoopClasses = false) + } + private def testMetastorePartitionFiltering( filterString: String, expectedDs: Seq[Int], diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala index bb8a4697b0a13..a70fb6464cc1d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala @@ -28,7 +28,9 @@ private[client] abstract class HiveVersionSuite(version: String) extends SparkFu override protected val enableAutoThreadAudit = false protected var client: HiveClient = null - protected def buildClient(hadoopConf: Configuration): HiveClient = { + protected def buildClient( + hadoopConf: Configuration, + sharesHadoopClasses: Boolean = true): HiveClient = { // Hive changed the default of datanucleus.schema.autoCreateAll from true to false and // hive.metastore.schema.verification from false to true since 2.0 // For details, see the JIRA HIVE-6113 and HIVE-12463 @@ -36,8 +38,11 @@ private[client] abstract class HiveVersionSuite(version: String) extends SparkFu hadoopConf.set("datanucleus.schema.autoCreateAll", "true") hadoopConf.set("hive.metastore.schema.verification", "false") } - HiveClientBuilder - .buildClient(version, hadoopConf, HiveUtils.formatTimeVarsForHiveClient(hadoopConf)) + HiveClientBuilder.buildClient( + version, + hadoopConf, + HiveUtils.formatTimeVarsForHiveClient(hadoopConf), + sharesHadoopClasses = sharesHadoopClasses) } override def suiteName: String = s"${super.suiteName}($version)" From 05839d164836e544af79c13de25802552eadd636 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Tue, 23 Jan 2018 14:11:23 -0800 Subject: [PATCH 182/774] [SPARK-22735][ML][DOC] Added VectorSizeHint docs and examples. ## What changes were proposed in this pull request? Added documentation for new transformer. Author: Bago Amirbekian Closes #20285 from MrBago/sizeHintDocs. --- docs/ml-features.md | 51 ++++++++++++ .../ml/JavaVectorSizeHintExample.java | 79 +++++++++++++++++++ .../python/ml/vector_size_hint_example.py | 57 +++++++++++++ .../examples/ml/VectorSizeHintExample.scala | 63 +++++++++++++++ 4 files changed, 250 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSizeHintExample.java create mode 100644 examples/src/main/python/ml/vector_size_hint_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorSizeHintExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index 466a8fbe99cf6..3370eb3893272 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1291,6 +1291,57 @@ for more details on the API.
+## VectorSizeHint + +It can sometimes be useful to explicitly specify the size of the vectors for a column of +`VectorType`. For example, `VectorAssembler` uses size information from its input columns to +produce size information and metadata for its output column. While in some cases this information +can be obtained by inspecting the contents of the column, in a streaming dataframe the contents are +not available until the stream is started. `VectorSizeHint` allows a user to explicitly specify the +vector size for a column so that `VectorAssembler`, or other transformers that might +need to know vector size, can use that column as an input. + +To use `VectorSizeHint` a user must set the `inputCol` and `size` parameters. Applying this +transformer to a dataframe produces a new dataframe with updated metadata for `inputCol` specifying +the vector size. Downstream operations on the resulting dataframe can get this size using the +meatadata. + +`VectorSizeHint` can also take an optional `handleInvalid` parameter which controls its +behaviour when the vector column contains nulls or vectors of the wrong size. By default +`handleInvalid` is set to "error", indicating an exception should be thrown. This parameter can +also be set to "skip", indicating that rows containing invalid values should be filtered out from +the resulting dataframe, or "optimistic", indicating that the column should not be checked for +invalid values and all rows should be kept. Note that the use of "optimistic" can cause the +resulting dataframe to be in an inconsistent state, me:aning the metadata for the column +`VectorSizeHint` was applied to does not match the contents of that column. Users should take care +to avoid this kind of inconsistent state. + +
+
+ +Refer to the [VectorSizeHint Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorSizeHint) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/VectorSizeHintExample.scala %} +
+ +
+ +Refer to the [VectorSizeHint Java docs](api/java/org/apache/spark/ml/feature/VectorSizeHint.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaVectorSizeHintExample.java %} +
+ +
+ +Refer to the [VectorSizeHint Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorSizeHint) +for more details on the API. + +{% include_example python/ml/vector_size_hint_example.py %} +
+
+ ## QuantileDiscretizer `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSizeHintExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSizeHintExample.java new file mode 100644 index 0000000000000..d649a2ccbaa72 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSizeHintExample.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.sql.SparkSession; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.ml.feature.VectorAssembler; +import org.apache.spark.ml.feature.VectorSizeHint; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaVectorSizeHintExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaVectorSizeHintExample") + .getOrCreate(); + + // $example on$ + StructType schema = createStructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("hour", IntegerType, false), + createStructField("mobile", DoubleType, false), + createStructField("userFeatures", new VectorUDT(), false), + createStructField("clicked", DoubleType, false) + }); + Row row0 = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); + Row row1 = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0), 0.0); + Dataset dataset = spark.createDataFrame(Arrays.asList(row0, row1), schema); + + VectorSizeHint sizeHint = new VectorSizeHint() + .setInputCol("userFeatures") + .setHandleInvalid("skip") + .setSize(3); + + Dataset datasetWithSize = sizeHint.transform(dataset); + System.out.println("Rows where 'userFeatures' is not the right size are filtered out"); + datasetWithSize.show(false); + + VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[]{"hour", "mobile", "userFeatures"}) + .setOutputCol("features"); + + // This dataframe can be used by downstream transformers as before + Dataset output = assembler.transform(datasetWithSize); + System.out.println("Assembled columns 'hour', 'mobile', 'userFeatures' to vector column " + + "'features'"); + output.select("features", "clicked").show(false); + // $example off$ + + spark.stop(); + } +} + diff --git a/examples/src/main/python/ml/vector_size_hint_example.py b/examples/src/main/python/ml/vector_size_hint_example.py new file mode 100644 index 0000000000000..fb77dacec629d --- /dev/null +++ b/examples/src/main/python/ml/vector_size_hint_example.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +# $example on$ +from pyspark.ml.linalg import Vectors +from pyspark.ml.feature import (VectorSizeHint, VectorAssembler) +# $example off$ +from pyspark.sql import SparkSession + +if __name__ == "__main__": + spark = SparkSession\ + .builder\ + .appName("VectorSizeHintExample")\ + .getOrCreate() + + # $example on$ + dataset = spark.createDataFrame( + [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0), + (0, 18, 1.0, Vectors.dense([0.0, 10.0]), 0.0)], + ["id", "hour", "mobile", "userFeatures", "clicked"]) + + sizeHint = VectorSizeHint( + inputCol="userFeatures", + handleInvalid="skip", + size=3) + + datasetWithSize = sizeHint.transform(dataset) + print("Rows where 'userFeatures' is not the right size are filtered out") + datasetWithSize.show(truncate=False) + + assembler = VectorAssembler( + inputCols=["hour", "mobile", "userFeatures"], + outputCol="features") + + # This dataframe can be used by downstream transformers as before + output = assembler.transform(datasetWithSize) + print("Assembled columns 'hour', 'mobile', 'userFeatures' to vector column 'features'") + output.select("features", "clicked").show(truncate=False) + # $example off$ + + spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSizeHintExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSizeHintExample.scala new file mode 100644 index 0000000000000..688731a791f35 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSizeHintExample.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.{VectorAssembler, VectorSizeHint} +import org.apache.spark.ml.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SparkSession + +object VectorSizeHintExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName("VectorSizeHintExample") + .getOrCreate() + + // $example on$ + val dataset = spark.createDataFrame( + Seq( + (0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0), + (0, 18, 1.0, Vectors.dense(0.0, 10.0), 0.0)) + ).toDF("id", "hour", "mobile", "userFeatures", "clicked") + + val sizeHint = new VectorSizeHint() + .setInputCol("userFeatures") + .setHandleInvalid("skip") + .setSize(3) + + val datasetWithSize = sizeHint.transform(dataset) + println("Rows where 'userFeatures' is not the right size are filtered out") + datasetWithSize.show(false) + + val assembler = new VectorAssembler() + .setInputCols(Array("hour", "mobile", "userFeatures")) + .setOutputCol("features") + + // This dataframe can be used by downstream transformers as before + val output = assembler.transform(datasetWithSize) + println("Assembled columns 'hour', 'mobile', 'userFeatures' to vector column 'features'") + output.select("features", "clicked").show(false) + // $example off$ + + spark.stop() + } +} +// scalastyle:on println From 613c290336e3826111164c24319f66774b1f65a3 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 23 Jan 2018 14:56:28 -0800 Subject: [PATCH 183/774] [SPARK-23192][SQL] Keep the Hint after Using Cached Data ## What changes were proposed in this pull request? The hint of the plan segment is lost, if the plan segment is replaced by the cached data. ```Scala val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") df2.cache() val df3 = df1.join(broadcast(df2), Seq("key"), "inner") ``` This PR is to fix it. ## How was this patch tested? Added a test Author: gatorsmile Closes #20365 from gatorsmile/fixBroadcastHintloss. --- .../apache/spark/sql/execution/CacheManager.scala | 12 ++++++++---- .../sql/execution/joins/BroadcastJoinSuite.scala | 13 +++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index b05fe49a6ac3b..432eb59d6fe57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ResolvedHint} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.storage.StorageLevel @@ -170,9 +170,13 @@ class CacheManager extends Logging { def useCachedData(plan: LogicalPlan): LogicalPlan = { val newPlan = plan transformDown { case currentFragment => - lookupCachedData(currentFragment) - .map(_.cachedRepresentation.withOutput(currentFragment.output)) - .getOrElse(currentFragment) + lookupCachedData(currentFragment).map { cached => + val cachedPlan = cached.cachedRepresentation.withOutput(currentFragment.output) + currentFragment match { + case hint: ResolvedHint => ResolvedHint(cachedPlan, hint.hints) + case _ => cachedPlan + } + }.getOrElse(currentFragment) } newPlan transformAllExpressions { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 0bcd54e1fceab..1704bc8376f0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -109,6 +109,19 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } + test("broadcast hint is retained after using the cached data") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + df2.cache() + val df3 = df1.join(broadcast(df2), Seq("key"), "inner") + val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { + case b: BroadcastHashJoinExec => b + }.size + assert(numBroadCastHashJoin === 1) + } + } + test("broadcast hint isn't propagated after a join") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") From 44cc4daf3a03f1a220eef8ce3c86867745db9ab7 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 23 Jan 2018 16:17:09 -0800 Subject: [PATCH 184/774] [SPARK-23195][SQL] Keep the Hint of Cached Data ## What changes were proposed in this pull request? The broadcast hint of the cached plan is lost if we cache the plan. This PR is to correct it. ```Scala val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") broadcast(df2).cache() df2.collect() val df3 = df1.join(df2, Seq("key"), "inner") ``` ## How was this patch tested? Added a test. Author: gatorsmile Closes #20368 from gatorsmile/cachedBroadcastHint. --- .../execution/columnar/InMemoryRelation.scala | 4 ++-- .../sql/execution/joins/BroadcastJoinSuite.scala | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 51928d914841e..5945808c4abfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -63,7 +63,7 @@ case class InMemoryRelation( tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics = null) + statsOfPlanToCache: Statistics) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[SparkPlan] = Seq(child) @@ -77,7 +77,7 @@ case class InMemoryRelation( // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache statsOfPlanToCache } else { - Statistics(sizeInBytes = batchStats.value.longValue) + Statistics(sizeInBytes = batchStats.value.longValue, hints = statsOfPlanToCache.hints) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 1704bc8376f0d..889cab0489534 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -139,6 +139,22 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } + test("broadcast hint is retained in a cached plan") { + Seq(true, false).foreach { materialized => + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + broadcast(df2).cache() + if (materialized) df2.collect() + val df3 = df1.join(df2, Seq("key"), "inner") + val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { + case b: BroadcastHashJoinExec => b + }.size + assert(numBroadCastHashJoin === 1) + } + } + } + private def assertBroadcastJoin(df : Dataset[Row]) : Unit = { val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") val joined = df1.join(df, Seq("key"), "inner") From 15adcc8273e73352e5e1c3fc9915c0b004ec4836 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 23 Jan 2018 16:24:20 -0800 Subject: [PATCH 185/774] [SPARK-23197][DSTREAMS] Increased timeouts to resolve flakiness ## What changes were proposed in this pull request? Increased timeout from 50 ms to 300 ms (50 ms was really too low). ## How was this patch tested? Multiple rounds of tests. Author: Tathagata Das Closes #20371 from tdas/SPARK-23197. --- .../scala/org/apache/spark/streaming/ReceiverSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 145c48e5a9a72..fc6218a33f741 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -105,13 +105,13 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable { assert(executor.errors.head.eq(exception)) // Verify restarting actually stops and starts the receiver - receiver.restart("restarting", null, 100) - eventually(timeout(50 millis), interval(10 millis)) { + receiver.restart("restarting", null, 600) + eventually(timeout(300 millis), interval(10 millis)) { // receiver will be stopped async assert(receiver.isStopped) assert(receiver.onStopCalled) } - eventually(timeout(1000 millis), interval(100 millis)) { + eventually(timeout(1000 millis), interval(10 millis)) { // receiver will be started async assert(receiver.onStartCalled) assert(executor.isReceiverStarted) From a3911cf896de6e9386042ae4d93632cba69eef0f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 24 Jan 2018 11:43:48 +0900 Subject: [PATCH 186/774] [SPARK-23177][SQL][PYSPARK] Extract zero-parameter UDFs from aggregate ## What changes were proposed in this pull request? We extract Python UDFs in logical aggregate which depends on aggregate expression or grouping key in ExtractPythonUDFFromAggregate rule. But Python UDFs which don't depend on above expressions should also be extracted to avoid the issue reported in the JIRA. A small code snippet to reproduce that issue looks like: ```python import pyspark.sql.functions as f df = spark.createDataFrame([(1,2), (3,4)]) f_udf = f.udf(lambda: str("const_str")) df2 = df.distinct().withColumn("a", f_udf()) df2.show() ``` Error exception is raised as: ``` : org.apache.spark.sql.catalyst.errors.package$TreeNodeException: Binding attribute, tree: pythonUDF0#50 at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:56) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:91) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:90) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$2.apply(TreeNode.scala:267) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$2.apply(TreeNode.scala:267) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:266) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:306) at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187) at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:304) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:272) at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:256) at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:90) at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$38.apply(HashAggregateExec.scala:514) at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$38.apply(HashAggregateExec.scala:513) ``` This exception raises because `HashAggregateExec` tries to bind the aliased Python UDF expression (e.g., `pythonUDF0#50 AS a#44`) to grouping key. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #20360 from viirya/SPARK-23177. --- python/pyspark/sql/tests.py | 8 ++++++++ .../spark/sql/execution/python/ExtractPythonUDFs.scala | 5 +++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 84e8eec71dd8a..a466ab87d882d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1106,6 +1106,14 @@ def myudf(x): rows = [r[0] for r in df.selectExpr("udf(id)").take(2)] self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)]) + def test_nonparam_udf_with_aggregate(self): + import pyspark.sql.functions as f + + df = self.spark.createDataFrame([(1, 2), (1, 2)]) + f_udf = f.udf(lambda: "const_str") + rows = df.distinct().withColumn("a", f_udf()).collect() + self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')]) + def test_infer_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 1862e3f6e12ca..4ae4e164830be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} /** * Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or - * grouping key, evaluate them after aggregate. + * grouping key, or doesn't depend on any above expressions, evaluate them after aggregate. */ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { @@ -45,7 +45,8 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = { expr.find { - e => PythonUDF.isScalarPythonUDF(e) && e.find(belongAggregate(_, agg)).isDefined + e => PythonUDF.isScalarPythonUDF(e) && + (e.references.isEmpty || e.find(belongAggregate(_, agg)).isDefined) }.isDefined } From f54b65c15a732540f7a41a9083eeb7a08feca125 Mon Sep 17 00:00:00 2001 From: neilalex Date: Tue, 23 Jan 2018 22:31:14 -0800 Subject: [PATCH 187/774] [SPARK-21727][R] Allow multi-element atomic vector as column type in SparkR DataFrame MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? A fix to https://issues.apache.org/jira/browse/SPARK-21727, "Operating on an ArrayType in a SparkR DataFrame throws error" ## How was this patch tested? - Ran tests at R\pkg\tests\run-all.R (see below attached results) - Tested the following lines in SparkR, which now seem to execute without error: ``` indices <- 1:4 myDf <- data.frame(indices) myDf$data <- list(rep(0, 20)) mySparkDf <- as.DataFrame(myDf) collect(mySparkDf) ``` [2018-01-22 SPARK-21727 Test Results.txt](https://github.com/apache/spark/files/1653535/2018-01-22.SPARK-21727.Test.Results.txt) felixcheung yanboliang sun-rui shivaram _The contribution is my original work and I license the work to the project under the project’s open source license_ Author: neilalex Closes #20352 from neilalex/neilalex-sparkr-arraytype. --- R/pkg/R/serialize.R | 11 +++---- R/pkg/tests/fulltests/test_Serde.R | 47 ++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 3bbf60d9b668c..263b9b576c0c5 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -30,14 +30,17 @@ # POSIXct,POSIXlt -> Time # # list[T] -> Array[T], where T is one of above mentioned types +# Multi-element vector of any of the above (except raw) -> Array[T] # environment -> Map[String, T], where T is a native type # jobj -> Object, where jobj is an object created in the backend # nolint end getSerdeType <- function(object) { type <- class(object)[[1]] - if (type != "list") { - type + if (is.atomic(object) & !is.raw(object) & length(object) > 1) { + "array" + } else if (type != "list") { + type } else { # Check if all elements are of same type elemType <- unique(sapply(object, function(elem) { getSerdeType(elem) })) @@ -50,9 +53,7 @@ getSerdeType <- function(object) { } writeObject <- function(con, object, writeType = TRUE) { - # NOTE: In R vectors have same type as objects. So we don't support - # passing in vectors as arrays and instead require arrays to be passed - # as lists. + # NOTE: In R vectors have same type as objects type <- class(object)[[1]] # class of POSIXlt is c("POSIXlt", "POSIXt") # Checking types is needed here, since 'is.na' only handles atomic vectors, # lists and pairlists diff --git a/R/pkg/tests/fulltests/test_Serde.R b/R/pkg/tests/fulltests/test_Serde.R index 6bbd201bf1d82..3577929323b8b 100644 --- a/R/pkg/tests/fulltests/test_Serde.R +++ b/R/pkg/tests/fulltests/test_Serde.R @@ -37,6 +37,53 @@ test_that("SerDe of primitive types", { expect_equal(class(x), "character") }) +test_that("SerDe of multi-element primitive vectors inside R data.frame", { + # vector of integers embedded in R data.frame + indices <- 1L:3L + myDf <- data.frame(indices) + myDf$data <- list(rep(0L, 3L)) + mySparkDf <- as.DataFrame(myDf) + myResultingDf <- collect(mySparkDf) + myDfListedData <- data.frame(indices) + myDfListedData$data <- list(as.list(rep(0L, 3L))) + expect_equal(myResultingDf, myDfListedData) + expect_equal(class(myResultingDf[["data"]][[1]]), "list") + expect_equal(class(myResultingDf[["data"]][[1]][[1]]), "integer") + + # vector of numeric embedded in R data.frame + myDf <- data.frame(indices) + myDf$data <- list(rep(0, 3L)) + mySparkDf <- as.DataFrame(myDf) + myResultingDf <- collect(mySparkDf) + myDfListedData <- data.frame(indices) + myDfListedData$data <- list(as.list(rep(0, 3L))) + expect_equal(myResultingDf, myDfListedData) + expect_equal(class(myResultingDf[["data"]][[1]]), "list") + expect_equal(class(myResultingDf[["data"]][[1]][[1]]), "numeric") + + # vector of logical embedded in R data.frame + myDf <- data.frame(indices) + myDf$data <- list(rep(TRUE, 3L)) + mySparkDf <- as.DataFrame(myDf) + myResultingDf <- collect(mySparkDf) + myDfListedData <- data.frame(indices) + myDfListedData$data <- list(as.list(rep(TRUE, 3L))) + expect_equal(myResultingDf, myDfListedData) + expect_equal(class(myResultingDf[["data"]][[1]]), "list") + expect_equal(class(myResultingDf[["data"]][[1]][[1]]), "logical") + + # vector of character embedded in R data.frame + myDf <- data.frame(indices) + myDf$data <- list(rep("abc", 3L)) + mySparkDf <- as.DataFrame(myDf) + myResultingDf <- collect(mySparkDf) + myDfListedData <- data.frame(indices) + myDfListedData$data <- list(as.list(rep("abc", 3L))) + expect_equal(myResultingDf, myDfListedData) + expect_equal(class(myResultingDf[["data"]][[1]]), "list") + expect_equal(class(myResultingDf[["data"]][[1]][[1]]), "character") +}) + test_that("SerDe of list of primitive types", { x <- list(1L, 2L, 3L) y <- callJStatic("SparkRHandler", "echo", x) From 4e7b49041aceca0beafec20f697b63a473a2b42f Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 23 Jan 2018 22:38:20 -0800 Subject: [PATCH 188/774] Revert "[SPARK-23195][SQL] Keep the Hint of Cached Data" This reverts commit 44cc4daf3a03f1a220eef8ce3c86867745db9ab7. --- .../execution/columnar/InMemoryRelation.scala | 4 ++-- .../sql/execution/joins/BroadcastJoinSuite.scala | 16 ---------------- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 5945808c4abfb..51928d914841e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -63,7 +63,7 @@ case class InMemoryRelation( tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics) + statsOfPlanToCache: Statistics = null) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[SparkPlan] = Seq(child) @@ -77,7 +77,7 @@ case class InMemoryRelation( // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache statsOfPlanToCache } else { - Statistics(sizeInBytes = batchStats.value.longValue, hints = statsOfPlanToCache.hints) + Statistics(sizeInBytes = batchStats.value.longValue) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 889cab0489534..1704bc8376f0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -139,22 +139,6 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } - test("broadcast hint is retained in a cached plan") { - Seq(true, false).foreach { materialized => - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") - broadcast(df2).cache() - if (materialized) df2.collect() - val df3 = df1.join(df2, Seq("key"), "inner") - val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { - case b: BroadcastHashJoinExec => b - }.size - assert(numBroadCastHashJoin === 1) - } - } - } - private def assertBroadcastJoin(df : Dataset[Row]) : Unit = { val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") val joined = df1.join(df, Seq("key"), "inner") From 7af1a325da57daa2e25c713472a320f4ccb43d71 Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Wed, 24 Jan 2018 21:13:47 +0900 Subject: [PATCH 189/774] [SPARK-23174][BUILD][PYTHON] python code style checker update ## What changes were proposed in this pull request? Referencing latest python code style checking from PyPi/pycodestyle Removed pending TODO For now, in tox.ini excluded the additional style error discovered on existing python due to latest style checker (will fallback on review comment to finalize exclusion or fix py) Any further code styling requirement needs to be part of pycodestyle, not in SPARK. ## How was this patch tested? ./dev/run-tests Author: Rekha Joshi Author: rjoshi2 Closes #20338 from rekhajoshm/SPARK-11222. --- dev/lint-python | 37 ++++++++++++++++++------------------- dev/run-tests.py | 5 ++++- dev/tox.ini | 4 ++-- 3 files changed, 24 insertions(+), 22 deletions(-) diff --git a/dev/lint-python b/dev/lint-python index df8df037a5f69..e069cafa1b8c6 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -21,7 +21,7 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" # Exclude auto-generated configuration file. PATHS_TO_CHECK="$( cd "$SPARK_ROOT_DIR" && find . -name "*.py" )" -PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" +PYCODESTYLE_REPORT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-report.txt" PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt" SPHINXBUILD=${SPHINXBUILD:=sphinx-build} @@ -30,23 +30,22 @@ SPHINX_REPORT_PATH="$SPARK_ROOT_DIR/dev/sphinx-report.txt" cd "$SPARK_ROOT_DIR" # compileall: https://docs.python.org/2/library/compileall.html -python -B -m compileall -q -l $PATHS_TO_CHECK > "$PEP8_REPORT_PATH" +python -B -m compileall -q -l $PATHS_TO_CHECK > "$PYCODESTYLE_REPORT_PATH" compile_status="${PIPESTATUS[0]}" -# Get pep8 at runtime so that we don't rely on it being installed on the build server. +# Get pycodestyle at runtime so that we don't rely on it being installed on the build server. #+ See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 -#+ TODOs: -#+ - Download pep8 from PyPI. It's more "official". -PEP8_VERSION="1.7.0" -PEP8_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pep8-$PEP8_VERSION.py" -PEP8_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/jcrocholl/pep8/$PEP8_VERSION/pep8.py" +# Updated to latest official version for pep8. pep8 is formally renamed to pycodestyle. +PYCODESTYLE_VERSION="2.3.1" +PYCODESTYLE_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-$PYCODESTYLE_VERSION.py" +PYCODESTYLE_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/PyCQA/pycodestyle/$PYCODESTYLE_VERSION/pycodestyle.py" -if [ ! -e "$PEP8_SCRIPT_PATH" ]; then - curl --silent -o "$PEP8_SCRIPT_PATH" "$PEP8_SCRIPT_REMOTE_PATH" +if [ ! -e "$PYCODESTYLE_SCRIPT_PATH" ]; then + curl --silent -o "$PYCODESTYLE_SCRIPT_PATH" "$PYCODESTYLE_SCRIPT_REMOTE_PATH" curl_status="$?" if [ "$curl_status" -ne 0 ]; then - echo "Failed to download pep8.py from \"$PEP8_SCRIPT_REMOTE_PATH\"." + echo "Failed to download pycodestyle.py from \"$PYCODESTYLE_SCRIPT_REMOTE_PATH\"." exit "$curl_status" fi fi @@ -64,23 +63,23 @@ export "PATH=$PYTHONPATH:$PATH" #+ first, but we do so so that the check status can #+ be output before the report, like with the #+ scalastyle and RAT checks. -python "$PEP8_SCRIPT_PATH" --config=dev/tox.ini $PATHS_TO_CHECK >> "$PEP8_REPORT_PATH" -pep8_status="${PIPESTATUS[0]}" +python "$PYCODESTYLE_SCRIPT_PATH" --config=dev/tox.ini $PATHS_TO_CHECK >> "$PYCODESTYLE_REPORT_PATH" +pycodestyle_status="${PIPESTATUS[0]}" -if [ "$compile_status" -eq 0 -a "$pep8_status" -eq 0 ]; then +if [ "$compile_status" -eq 0 -a "$pycodestyle_status" -eq 0 ]; then lint_status=0 else lint_status=1 fi if [ "$lint_status" -ne 0 ]; then - echo "PEP8 checks failed." - cat "$PEP8_REPORT_PATH" - rm "$PEP8_REPORT_PATH" + echo "PYCODESTYLE checks failed." + cat "$PYCODESTYLE_REPORT_PATH" + rm "$PYCODESTYLE_REPORT_PATH" exit "$lint_status" else - echo "PEP8 checks passed." - rm "$PEP8_REPORT_PATH" + echo "pycodestyle checks passed." + rm "$PYCODESTYLE_REPORT_PATH" fi # Check that the documentation builds acceptably, skip check if sphinx is not installed. diff --git a/dev/run-tests.py b/dev/run-tests.py index fb270c4ee0508..fe75ef4411c8c 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -576,7 +576,10 @@ def main(): for f in changed_files): # run_java_style_checks() pass - if not changed_files or any(f.endswith(".py") for f in changed_files): + if not changed_files or any(f.endswith("lint-python") + or f.endswith("tox.ini") + or f.endswith(".py") + for f in changed_files): run_python_style_checks() if not changed_files or any(f.endswith(".R") or f.endswith("lint-r") diff --git a/dev/tox.ini b/dev/tox.ini index eb8b1eb2c2886..583c1eaaa966b 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -[pep8] -ignore=E402,E731,E241,W503,E226 +[pycodestyle] +ignore=E402,E731,E241,W503,E226,E722,E741,E305 max-line-length=100 exclude=cloudpickle.py,heapq3.py,shared.py,python/docs/conf.py,work/*/*.py,python/.eggs/* From de36f65d3a819c00d6bf6979deef46c824203669 Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Wed, 24 Jan 2018 21:19:09 +0900 Subject: [PATCH 190/774] [SPARK-23148][SQL] Allow pathnames with special characters for CSV / JSON / text MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …JSON / text ## What changes were proposed in this pull request? Fix for JSON and CSV data sources when file names include characters that would be changed by URL encoding. ## How was this patch tested? New unit tests for JSON, CSV and text suites Author: Henry Robinson Closes #20355 from henryr/spark-23148. --- .../execution/datasources/CodecStreams.scala | 6 +++--- .../datasources/csv/CSVDataSource.scala | 11 ++++++----- .../datasources/json/JsonDataSource.scala | 10 ++++++---- .../spark/sql/FileBasedDataSourceSuite.scala | 18 ++++++++++++++++-- 4 files changed, 31 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala index 54549f698aca5..c0df6c779d7bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala @@ -45,11 +45,11 @@ object CodecStreams { } /** - * Creates an input stream from the string path and add a closure for the input stream to be + * Creates an input stream from the given path and add a closure for the input stream to be * closed on task completion. */ - def createInputStreamWithCloseResource(config: Configuration, path: String): InputStream = { - val inputStream = createInputStream(config, new Path(path)) + def createInputStreamWithCloseResource(config: Configuration, path: Path): InputStream = { + val inputStream = createInputStream(config, path) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close())) inputStream } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 2031381dd2e10..4870d75fc5f08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution.datasources.csv +import java.net.URI import java.nio.charset.{Charset, StandardCharsets} import com.univocity.parsers.csv.CsvParser import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.Job @@ -32,7 +33,6 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -206,7 +206,7 @@ object MultiLineCSVDataSource extends CSVDataSource { parser: UnivocityParser, schema: StructType): Iterator[InternalRow] = { UnivocityParser.parseStream( - CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), + CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))), parser.options.headerFlag, parser, schema) @@ -218,8 +218,9 @@ object MultiLineCSVDataSource extends CSVDataSource { parsedOptions: CSVOptions): StructType = { val csv = createBaseRdd(sparkSession, inputPaths, parsedOptions) csv.flatMap { lines => + val path = new Path(lines.getPath()) UnivocityParser.tokenizeStream( - CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()), + CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, path), shouldDropHeader = false, new CsvParser(parsedOptions.asParserSettings)) }.take(1).headOption match { @@ -230,7 +231,7 @@ object MultiLineCSVDataSource extends CSVDataSource { UnivocityParser.tokenizeStream( CodecStreams.createInputStreamWithCloseResource( lines.getConfiguration, - lines.getPath()), + new Path(lines.getPath())), parsedOptions.headerFlag, new CsvParser(parsedOptions.asParserSettings)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 8b7c2709afde1..77e7edc8e7a20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.execution.datasources.json import java.io.InputStream +import java.net.URI import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.FileInputFormat @@ -168,9 +169,10 @@ object MultiLineJsonDataSource extends JsonDataSource { } private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { + val path = new Path(record.getPath()) CreateJacksonParser.inputStream( jsonFactory, - CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, record.getPath())) + CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, path)) } override def readFile( @@ -180,7 +182,7 @@ object MultiLineJsonDataSource extends JsonDataSource { schema: StructType): Iterator[InternalRow] = { def partitionedFileString(ignored: Any): UTF8String = { Utils.tryWithResource { - CodecStreams.createInputStreamWithCloseResource(conf, file.filePath) + CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))) } { inputStream => UTF8String.fromBytes(ByteStreams.toByteArray(inputStream)) } @@ -193,6 +195,6 @@ object MultiLineJsonDataSource extends JsonDataSource { parser.options.columnNameOfCorruptRecord) safeParser.parse( - CodecStreams.createInputStreamWithCloseResource(conf, file.filePath)) + CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath)))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 22fb496bc838e..c272c99ae45a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -23,6 +23,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { import testImplicits._ private val allFileBasedDataSources = Seq("orc", "parquet", "csv", "json", "text") + private val nameWithSpecialChars = "sp&cial%c hars" allFileBasedDataSources.foreach { format => test(s"Writing empty datasets should not fail - $format") { @@ -54,7 +55,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { // Only ORC/Parquet support this. `CSV` and `JSON` returns an empty schema. // `TEXT` data source always has a single column whose name is `value`. Seq("orc", "parquet").foreach { format => - test(s"SPARK-15474 Write and read back non-emtpy schema with empty dataframe - $format") { + test(s"SPARK-15474 Write and read back non-empty schema with empty dataframe - $format") { withTempPath { file => val path = file.getCanonicalPath val emptyDf = Seq((true, 1, "str")).toDF().limit(0) @@ -69,7 +70,6 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { allFileBasedDataSources.foreach { format => test(s"SPARK-22146 read files containing special characters using $format") { - val nameWithSpecialChars = s"sp&cial%chars" withTempDir { dir => val tmpFile = s"$dir/$nameWithSpecialChars" spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) @@ -78,4 +78,18 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { } } } + + // Separate test case for formats that support multiLine as an option. + Seq("json", "csv").foreach { format => + test("SPARK-23148 read files containing special characters " + + s"using $format with multiline enabled") { + withTempDir { dir => + val tmpFile = s"$dir/$nameWithSpecialChars" + spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) + val reader = spark.read.format(format).option("multiLine", true) + val fileContent = reader.load(tmpFile) + checkAnswer(fileContent, Seq(Row("a"), Row("b"))) + } + } + } } From 0ec95bb7df775be33fc8983f6c0983a67032d2c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Wed, 24 Jan 2018 11:34:59 -0600 Subject: [PATCH 191/774] [SPARK-22577][CORE] executor page blacklist status should update with TaskSet level blacklisting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In this PR stage blacklisting is propagated to UI by introducing a new Spark listener event (SparkListenerExecutorBlacklistedForStage) which indicates the executor is blacklisted for a stage. Either because of the number of failures are exceeded a limit given for an executor (spark.blacklist.stage.maxFailedTasksPerExecutor) or because of the whole node is blacklisted for a stage (spark.blacklist.stage.maxFailedExecutorsPerNode). In case of the node is blacklisting all executors will listed as blacklisted for the stage. Blacklisting state for a selected stage can be seen "Aggregated Metrics by Executor" table's blacklisting column, where after this change three possible labels could be found: - "for application": when the executor is blacklisted for the application (see the configuration spark.blacklist.application.maxFailedTasksPerExecutor for details) - "for stage": when the executor is **only** blacklisted for the stage - "false" : when the executor is not blacklisted at all ## How was this patch tested? It is tested both manually and with unit tests. #### Unit tests - HistoryServerSuite - TaskSetBlacklistSuite - AppStatusListenerSuite #### Manual test for executor blacklisting Running Spark as a local cluster: ``` $ bin/spark-shell --master "local-cluster[2,1,1024]" --conf "spark.blacklist.enabled=true" --conf "spark.blacklist.stage.maxFailedTasksPerExecutor=1" --conf "spark.blacklist.application.maxFailedTasksPerExecutor=10" --conf "spark.eventLog.enabled=true" ``` Executing: ``` scala import org.apache.spark.SparkEnv sc.parallelize(1 to 10, 10).map { x => if (SparkEnv.get.executorId == "0") throw new RuntimeException("Bad executor") else (x % 3, x) }.reduceByKey((a, b) => a + b).collect() ``` To see result check the "Aggregated Metrics by Executor" section at the bottom of picture: ![UI screenshot for stage level blacklisting executor](https://issues.apache.org/jira/secure/attachment/12905283/stage_blacklisting.png) #### Manual test for node blacklisting Running Spark as on a cluster: ``` bash ./bin/spark-shell --master yarn --deploy-mode client --executor-memory=2G --num-executors=8 --conf "spark.blacklist.enabled=true" --conf "spark.blacklist.stage.maxFailedTasksPerExecutor=1" --conf "spark.blacklist.stage.maxFailedExecutorsPerNode=1" --conf "spark.blacklist.application.maxFailedTasksPerExecutor=10" --conf "spark.eventLog.enabled=true" ``` And the job was: ``` scala import org.apache.spark.SparkEnv sc.parallelize(1 to 10000, 10).map { x => if (SparkEnv.get.executorId.toInt >= 4) throw new RuntimeException("Bad executor") else (x % 3, x) }.reduceByKey((a, b) => a + b).collect() ``` The result is: ![UI screenshot for stage level node blacklisting](https://issues.apache.org/jira/secure/attachment/12906833/node_blacklisting_for_stage.png) Here you can see apiros3.gce.test.com was node blacklisted for the stage because of failures on executor 4 and 5. As expected executor 3 is also blacklisted even it has no failures itself but sharing the node with 4 and 5. Author: “attilapiros” Author: Attila Zsolt Piros <2017933+attilapiros@users.noreply.github.com> Closes #20203 from attilapiros/SPARK-22577. --- .../apache/spark/SparkFirehoseListener.java | 12 + .../scheduler/EventLoggingListener.scala | 9 + .../spark/scheduler/SparkListener.scala | 35 + .../spark/scheduler/SparkListenerBus.scala | 4 + .../spark/scheduler/TaskSetBlacklist.scala | 19 +- .../spark/scheduler/TaskSetManager.scala | 2 +- .../spark/status/AppStatusListener.scala | 25 + .../org/apache/spark/status/LiveEntity.scala | 4 +- .../org/apache/spark/status/api/v1/api.scala | 3 +- .../apache/spark/ui/jobs/ExecutorTable.scala | 10 +- .../application_list_json_expectation.json | 70 +- .../blacklisting_for_stage_expectation.json | 639 ++++++++++++++ ...acklisting_node_for_stage_expectation.json | 783 ++++++++++++++++++ .../completed_app_list_json_expectation.json | 71 +- .../limit_app_list_json_expectation.json | 54 +- .../minDate_app_list_json_expectation.json | 62 +- .../minEndDate_app_list_json_expectation.json | 34 +- .../one_stage_attempt_json_expectation.json | 3 +- .../one_stage_json_expectation.json | 3 +- ...age_with_accumulable_json_expectation.json | 3 +- .../spark-events/app-20180109111548-0000 | 59 ++ .../application_1516285256255_0012 | 71 ++ .../deploy/history/HistoryServerSuite.scala | 2 + .../scheduler/BlacklistTrackerSuite.scala | 2 +- .../scheduler/TaskSetBlacklistSuite.scala | 119 ++- .../spark/status/AppStatusListenerSuite.scala | 43 + dev/.rat-excludes | 2 + 27 files changed, 2040 insertions(+), 103 deletions(-) create mode 100644 core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json create mode 100644 core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json create mode 100755 core/src/test/resources/spark-events/app-20180109111548-0000 create mode 100755 core/src/test/resources/spark-events/application_1516285256255_0012 diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index 3583856d88998..94c5c11b61a50 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -118,6 +118,18 @@ public final void onExecutorBlacklisted(SparkListenerExecutorBlacklisted executo onEvent(executorBlacklisted); } + @Override + public void onExecutorBlacklistedForStage( + SparkListenerExecutorBlacklistedForStage executorBlacklistedForStage) { + onEvent(executorBlacklistedForStage); + } + + @Override + public void onNodeBlacklistedForStage( + SparkListenerNodeBlacklistedForStage nodeBlacklistedForStage) { + onEvent(nodeBlacklistedForStage); + } + @Override public final void onExecutorUnblacklisted( SparkListenerExecutorUnblacklisted executorUnblacklisted) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index b3a5b1f1e05b3..69bc51c1ecf90 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -207,6 +207,15 @@ private[spark] class EventLoggingListener( logEvent(event, flushLogger = true) } + override def onExecutorBlacklistedForStage( + event: SparkListenerExecutorBlacklistedForStage): Unit = { + logEvent(event, flushLogger = true) + } + + override def onNodeBlacklistedForStage(event: SparkListenerNodeBlacklistedForStage): Unit = { + logEvent(event, flushLogger = true) + } + override def onExecutorUnblacklisted(event: SparkListenerExecutorUnblacklisted): Unit = { logEvent(event, flushLogger = true) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 3b677ca9657db..8a112f6a37b96 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -120,6 +120,24 @@ case class SparkListenerExecutorBlacklisted( taskFailures: Int) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerExecutorBlacklistedForStage( + time: Long, + executorId: String, + taskFailures: Int, + stageId: Int, + stageAttemptId: Int) + extends SparkListenerEvent + +@DeveloperApi +case class SparkListenerNodeBlacklistedForStage( + time: Long, + hostId: String, + executorFailures: Int, + stageId: Int, + stageAttemptId: Int) + extends SparkListenerEvent + @DeveloperApi case class SparkListenerExecutorUnblacklisted(time: Long, executorId: String) extends SparkListenerEvent @@ -261,6 +279,17 @@ private[spark] trait SparkListenerInterface { */ def onExecutorBlacklisted(executorBlacklisted: SparkListenerExecutorBlacklisted): Unit + /** + * Called when the driver blacklists an executor for a stage. + */ + def onExecutorBlacklistedForStage( + executorBlacklistedForStage: SparkListenerExecutorBlacklistedForStage): Unit + + /** + * Called when the driver blacklists a node for a stage. + */ + def onNodeBlacklistedForStage(nodeBlacklistedForStage: SparkListenerNodeBlacklistedForStage): Unit + /** * Called when the driver re-enables a previously blacklisted executor. */ @@ -339,6 +368,12 @@ abstract class SparkListener extends SparkListenerInterface { override def onExecutorBlacklisted( executorBlacklisted: SparkListenerExecutorBlacklisted): Unit = { } + def onExecutorBlacklistedForStage( + executorBlacklistedForStage: SparkListenerExecutorBlacklistedForStage): Unit = { } + + def onNodeBlacklistedForStage( + nodeBlacklistedForStage: SparkListenerNodeBlacklistedForStage): Unit = { } + override def onExecutorUnblacklisted( executorUnblacklisted: SparkListenerExecutorUnblacklisted): Unit = { } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 056c0cbded435..ff19cc65552e0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -61,6 +61,10 @@ private[spark] trait SparkListenerBus listener.onExecutorAdded(executorAdded) case executorRemoved: SparkListenerExecutorRemoved => listener.onExecutorRemoved(executorRemoved) + case executorBlacklistedForStage: SparkListenerExecutorBlacklistedForStage => + listener.onExecutorBlacklistedForStage(executorBlacklistedForStage) + case nodeBlacklistedForStage: SparkListenerNodeBlacklistedForStage => + listener.onNodeBlacklistedForStage(nodeBlacklistedForStage) case executorBlacklisted: SparkListenerExecutorBlacklisted => listener.onExecutorBlacklisted(executorBlacklisted) case executorUnblacklisted: SparkListenerExecutorUnblacklisted => diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala index 233781f3d9719..b680979a466a5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala @@ -36,8 +36,12 @@ import org.apache.spark.util.Clock * [[TaskSetManager]] this class is designed only to be called from code with a lock on the * TaskScheduler (e.g. its event handlers). It should not be called from other threads. */ -private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, val clock: Clock) - extends Logging { +private[scheduler] class TaskSetBlacklist( + private val listenerBus: LiveListenerBus, + val conf: SparkConf, + val stageId: Int, + val stageAttemptId: Int, + val clock: Clock) extends Logging { private val MAX_TASK_ATTEMPTS_PER_EXECUTOR = conf.get(config.MAX_TASK_ATTEMPTS_PER_EXECUTOR) private val MAX_TASK_ATTEMPTS_PER_NODE = conf.get(config.MAX_TASK_ATTEMPTS_PER_NODE) @@ -128,16 +132,23 @@ private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, } // Check if enough tasks have failed on the executor to blacklist it for the entire stage. - if (execFailures.numUniqueTasksWithFailures >= MAX_FAILURES_PER_EXEC_STAGE) { + val numFailures = execFailures.numUniqueTasksWithFailures + if (numFailures >= MAX_FAILURES_PER_EXEC_STAGE) { if (blacklistedExecs.add(exec)) { logInfo(s"Blacklisting executor ${exec} for stage $stageId") // This executor has been pushed into the blacklist for this stage. Let's check if it // pushes the whole node into the blacklist. val blacklistedExecutorsOnNode = execsWithFailuresOnNode.filter(blacklistedExecs.contains(_)) - if (blacklistedExecutorsOnNode.size >= MAX_FAILED_EXEC_PER_NODE_STAGE) { + val now = clock.getTimeMillis() + listenerBus.post( + SparkListenerExecutorBlacklistedForStage(now, exec, numFailures, stageId, stageAttemptId)) + val numFailExec = blacklistedExecutorsOnNode.size + if (numFailExec >= MAX_FAILED_EXEC_PER_NODE_STAGE) { if (blacklistedNodes.add(host)) { logInfo(s"Blacklisting ${host} for stage $stageId") + listenerBus.post( + SparkListenerNodeBlacklistedForStage(now, host, numFailExec, stageId, stageAttemptId)) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index c3ed11bfe352a..886c2c99f1ff3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -102,7 +102,7 @@ private[spark] class TaskSetManager( private[scheduler] val taskSetBlacklistHelperOpt: Option[TaskSetBlacklist] = { blacklistTracker.map { _ => - new TaskSetBlacklist(conf, stageId, clock) + new TaskSetBlacklist(sched.sc.listenerBus, conf, stageId, taskSet.stageAttemptId, clock) } } diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index b4edcf23abc09..3e34bdc0c7b63 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -211,6 +211,31 @@ private[spark] class AppStatusListener( updateBlackListStatus(event.executorId, true) } + override def onExecutorBlacklistedForStage( + event: SparkListenerExecutorBlacklistedForStage): Unit = { + Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => + val now = System.nanoTime() + val esummary = stage.executorSummary(event.executorId) + esummary.isBlacklisted = true + maybeUpdate(esummary, now) + } + } + + override def onNodeBlacklistedForStage(event: SparkListenerNodeBlacklistedForStage): Unit = { + val now = System.nanoTime() + + // Implicitly blacklist every available executor for the stage associated with this node + Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => + liveExecutors.values.foreach { exec => + if (exec.hostname == event.hostId) { + val esummary = stage.executorSummary(exec.executorId) + esummary.isBlacklisted = true + maybeUpdate(esummary, now) + } + } + } + } + override def onExecutorUnblacklisted(event: SparkListenerExecutorUnblacklisted): Unit = { updateBlackListStatus(event.executorId, false) } diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 4295e664e131c..d5f9e19ffdcd0 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -316,6 +316,7 @@ private class LiveExecutorStageSummary( var succeededTasks = 0 var failedTasks = 0 var killedTasks = 0 + var isBlacklisted = false var metrics = createMetrics(default = 0L) @@ -334,7 +335,8 @@ private class LiveExecutorStageSummary( metrics.shuffleWriteMetrics.bytesWritten, metrics.shuffleWriteMetrics.recordsWritten, metrics.memoryBytesSpilled, - metrics.diskBytesSpilled) + metrics.diskBytesSpilled, + isBlacklisted) new ExecutorStageSummaryWrapper(stageId, attemptId, executorId, info) } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 7d8e4de3c8efb..550eac3952bbb 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -68,7 +68,8 @@ class ExecutorStageSummary private[spark]( val shuffleWrite : Long, val shuffleWriteRecords : Long, val memoryBytesSpilled : Long, - val diskBytesSpilled : Long) + val diskBytesSpilled : Long, + val isBlacklistedForStage: Boolean) class ExecutorSummary private[spark]( val id: String, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 95c12b1e73653..0ff64f053f371 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -136,7 +136,15 @@ private[ui] class ExecutorTable(stage: StageData, store: AppStatusStore) { {Utils.bytesToString(v.diskBytesSpilled)} }} -
+ { + if (executor.map(_.isBlacklisted).getOrElse(false)) { + + } else if (v.isBlacklistedForStage) { + + } else { + + } + } } } diff --git a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json index f2c3ec5da8891..4fecf84db65a2 100644 --- a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json @@ -1,4 +1,34 @@ [ { + "id" : "application_1516285256255_0012", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-18T18:30:35.119GMT", + "endTime" : "2018-01-18T18:38:27.938GMT", + "lastUpdated" : "", + "duration" : 472819, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1516300235119, + "endTimeEpoch" : 1516300707938 + } ] +}, { + "id" : "app-20180109111548-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-09T10:15:42.372GMT", + "endTime" : "2018-01-09T10:24:37.606GMT", + "lastUpdated" : "", + "duration" : 535234, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1515492942372, + "endTimeEpoch" : 1515493477606 + } ] +}, { "id" : "app-20161116163331-0000", "name" : "Spark shell", "attempts" : [ { @@ -9,9 +39,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479335620587, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479335609916, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479335620587 } ] }, { "id" : "app-20161115172038-0000", @@ -24,9 +54,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479252138874, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479252037079, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479252138874 } ] }, { "id" : "local-1430917381534", @@ -39,9 +69,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917391398 } ] }, { "id" : "local-1430917381535", @@ -55,9 +85,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380950 }, { "attemptId" : "1", "startTime" : "2015-05-06T13:03:00.880GMT", @@ -67,9 +97,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380880, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380890 } ] }, { "id" : "local-1426533911241", @@ -83,9 +113,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426633910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426633945177 }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", @@ -95,9 +125,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426533910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426533945177 } ] }, { "id" : "local-1425081759269", @@ -110,9 +140,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1425081766912, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1425081758277, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1425081766912 } ] }, { "id" : "local-1422981780767", @@ -125,9 +155,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1422981788731, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1422981779720, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1422981788731 } ] }, { "id" : "local-1422981759269", @@ -140,8 +170,8 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1422981766912, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1422981758277, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1422981766912 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json b/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json new file mode 100644 index 0000000000000..5e9e8230e2745 --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json @@ -0,0 +1,639 @@ +{ + "status": "COMPLETE", + "stageId": 0, + "attemptId": 0, + "numTasks": 10, + "numActiveTasks": 0, + "numCompleteTasks": 10, + "numFailedTasks": 2, + "numKilledTasks": 0, + "numCompletedIndices": 10, + "executorRunTime": 761, + "executorCpuTime": 269916000, + "submissionTime": "2018-01-09T10:21:18.152GMT", + "firstTaskLaunchedTime": "2018-01-09T10:21:18.347GMT", + "completionTime": "2018-01-09T10:21:19.062GMT", + "inputBytes": 0, + "inputRecords": 0, + "outputBytes": 0, + "outputRecords": 0, + "shuffleReadBytes": 0, + "shuffleReadRecords": 0, + "shuffleWriteBytes": 460, + "shuffleWriteRecords": 10, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "name": "map at :26", + "details": "org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)", + "schedulingPool": "default", + "rddIds": [ + 1, + 0 + ], + "accumulatorUpdates": [], + "tasks": { + "0": { + "taskId": 0, + "index": 0, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.347GMT", + "duration": 562, + "executorId": "0", + "host": "172.30.65.138", + "status": "FAILED", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "errorMessage": "java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics": { + "executorDeserializeTime": 0, + "executorDeserializeCpuTime": 0, + "executorRunTime": 460, + "executorCpuTime": 0, + "resultSize": 0, + "jvmGcTime": 14, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 0, + "writeTime": 3873006, + "recordsWritten": 0 + } + } + }, + "5": { + "taskId": 5, + "index": 3, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.958GMT", + "duration": 22, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 3, + "executorDeserializeCpuTime": 2586000, + "executorRunTime": 9, + "executorCpuTime": 9635000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 262919, + "recordsWritten": 1 + } + } + }, + "10": { + "taskId": 10, + "index": 8, + "attempt": 0, + "launchTime": "2018-01-09T10:21:19.034GMT", + "duration": 12, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 2, + "executorDeserializeCpuTime": 1803000, + "executorRunTime": 6, + "executorCpuTime": 6157000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 243647, + "recordsWritten": 1 + } + } + }, + "1": { + "taskId": 1, + "index": 1, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.364GMT", + "duration": 565, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 301, + "executorDeserializeCpuTime": 200029000, + "executorRunTime": 212, + "executorCpuTime": 198479000, + "resultSize": 1115, + "jvmGcTime": 13, + "resultSerializationTime": 1, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 2409488, + "recordsWritten": 1 + } + } + }, + "6": { + "taskId": 6, + "index": 4, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.980GMT", + "duration": 16, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 3, + "executorDeserializeCpuTime": 2610000, + "executorRunTime": 10, + "executorCpuTime": 9622000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 385110, + "recordsWritten": 1 + } + } + }, + "9": { + "taskId": 9, + "index": 7, + "attempt": 0, + "launchTime": "2018-01-09T10:21:19.022GMT", + "duration": 12, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 2, + "executorDeserializeCpuTime": 1981000, + "executorRunTime": 7, + "executorCpuTime": 6335000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 259354, + "recordsWritten": 1 + } + } + }, + "2": { + "taskId": 2, + "index": 2, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.899GMT", + "duration": 27, + "executorId": "0", + "host": "172.30.65.138", + "status": "FAILED", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "errorMessage": "java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics": { + "executorDeserializeTime": 0, + "executorDeserializeCpuTime": 0, + "executorRunTime": 16, + "executorCpuTime": 0, + "resultSize": 0, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 0, + "writeTime": 126128, + "recordsWritten": 0 + } + } + }, + "7": { + "taskId": 7, + "index": 5, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.996GMT", + "duration": 15, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 2, + "executorDeserializeCpuTime": 2231000, + "executorRunTime": 9, + "executorCpuTime": 8407000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 205520, + "recordsWritten": 1 + } + } + }, + "3": { + "taskId": 3, + "index": 0, + "attempt": 1, + "launchTime": "2018-01-09T10:21:18.919GMT", + "duration": 24, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 8, + "executorDeserializeCpuTime": 8878000, + "executorRunTime": 10, + "executorCpuTime": 9364000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 207014, + "recordsWritten": 1 + } + } + }, + "11": { + "taskId": 11, + "index": 9, + "attempt": 0, + "launchTime": "2018-01-09T10:21:19.045GMT", + "duration": 15, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 3, + "executorDeserializeCpuTime": 2017000, + "executorRunTime": 6, + "executorCpuTime": 6676000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 233652, + "recordsWritten": 1 + } + } + }, + "8": { + "taskId": 8, + "index": 6, + "attempt": 0, + "launchTime": "2018-01-09T10:21:19.011GMT", + "duration": 11, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 1, + "executorDeserializeCpuTime": 1554000, + "executorRunTime": 7, + "executorCpuTime": 6034000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 213296, + "recordsWritten": 1 + } + } + }, + "4": { + "taskId": 4, + "index": 2, + "attempt": 1, + "launchTime": "2018-01-09T10:21:18.943GMT", + "duration": 16, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 2, + "executorDeserializeCpuTime": 2211000, + "executorRunTime": 9, + "executorCpuTime": 9207000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 292381, + "recordsWritten": 1 + } + } + } + }, + "executorSummary": { + "0": { + "taskTime": 589, + "failedTasks": 2, + "succeededTasks": 0, + "killedTasks": 0, + "inputBytes": 0, + "inputRecords": 0, + "outputBytes": 0, + "outputRecords": 0, + "shuffleRead": 0, + "shuffleReadRecords": 0, + "shuffleWrite": 0, + "shuffleWriteRecords": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "isBlacklistedForStage": true + }, + "1": { + "taskTime": 708, + "failedTasks": 0, + "succeededTasks": 10, + "killedTasks": 0, + "inputBytes": 0, + "inputRecords": 0, + "outputBytes": 0, + "outputRecords": 0, + "shuffleRead": 0, + "shuffleReadRecords": 0, + "shuffleWrite": 460, + "shuffleWriteRecords": 10, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "isBlacklistedForStage": false + } + }, + "killedTasksSummary": {} +} diff --git a/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json b/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json new file mode 100644 index 0000000000000..acd4cc53de6cd --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json @@ -0,0 +1,783 @@ +{ + "status" : "COMPLETE", + "stageId" : 0, + "attemptId" : 0, + "numTasks" : 10, + "numActiveTasks" : 0, + "numCompleteTasks" : 10, + "numFailedTasks" : 4, + "numKilledTasks" : 0, + "numCompletedIndices" : 10, + "executorRunTime" : 5080, + "executorCpuTime" : 1163210819, + "submissionTime" : "2018-01-18T18:33:12.658GMT", + "firstTaskLaunchedTime" : "2018-01-18T18:33:12.816GMT", + "completionTime" : "2018-01-18T18:33:15.279GMT", + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleReadBytes" : 0, + "shuffleReadRecords" : 0, + "shuffleWriteBytes" : 1461, + "shuffleWriteRecords" : 30, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "name" : "map at :27", + "details" : "org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:27)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)", + "schedulingPool" : "default", + "rddIds" : [ 1, 0 ], + "accumulatorUpdates" : [ ], + "tasks" : { + "0" : { + "taskId" : 0, + "index" : 0, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:12.816GMT", + "duration" : 2064, + "executorId" : "1", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 1081, + "executorDeserializeCpuTime" : 353981050, + "executorRunTime" : 914, + "executorCpuTime" : 368865439, + "resultSize" : 1134, + "jvmGcTime" : 75, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 144, + "writeTime" : 3662221, + "recordsWritten" : 3 + } + } + }, + "5" : { + "taskId" : 5, + "index" : 5, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:14.320GMT", + "duration" : 73, + "executorId" : "5", + "host" : "apiros-2.gce.test.com", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 27, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 191901, + "recordsWritten" : 0 + } + } + }, + "10" : { + "taskId" : 10, + "index" : 1, + "attempt" : 1, + "launchTime" : "2018-01-18T18:33:15.069GMT", + "duration" : 132, + "executorId" : "2", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 4598966, + "executorRunTime" : 76, + "executorCpuTime" : 20826337, + "resultSize" : 1091, + "jvmGcTime" : 0, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 144, + "writeTime" : 301705, + "recordsWritten" : 3 + } + } + }, + "1" : { + "taskId" : 1, + "index" : 1, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:12.832GMT", + "duration" : 1506, + "executorId" : "5", + "host" : "apiros-2.gce.test.com", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 1332, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 33, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 3075188, + "recordsWritten" : 0 + } + } + }, + "6" : { + "taskId" : 6, + "index" : 6, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:14.323GMT", + "duration" : 67, + "executorId" : "4", + "host" : "apiros-2.gce.test.com", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 51, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 183718, + "recordsWritten" : 0 + } + } + }, + "9" : { + "taskId" : 9, + "index" : 4, + "attempt" : 1, + "launchTime" : "2018-01-18T18:33:14.973GMT", + "duration" : 96, + "executorId" : "2", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 4793905, + "executorRunTime" : 48, + "executorCpuTime" : 25678331, + "resultSize" : 1091, + "jvmGcTime" : 0, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 366050, + "recordsWritten" : 3 + } + } + }, + "13" : { + "taskId" : 13, + "index" : 9, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:15.200GMT", + "duration" : 76, + "executorId" : "2", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 25, + "executorDeserializeCpuTime" : 5860574, + "executorRunTime" : 25, + "executorCpuTime" : 20585619, + "resultSize" : 1048, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 369513, + "recordsWritten" : 3 + } + } + }, + "2" : { + "taskId" : 2, + "index" : 2, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:12.832GMT", + "duration" : 1774, + "executorId" : "3", + "host" : "apiros-2.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 1206, + "executorDeserializeCpuTime" : 263386625, + "executorRunTime" : 493, + "executorCpuTime" : 278399617, + "resultSize" : 1134, + "jvmGcTime" : 78, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 144, + "writeTime" : 3322956, + "recordsWritten" : 3 + } + } + }, + "12" : { + "taskId" : 12, + "index" : 8, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:15.165GMT", + "duration" : 60, + "executorId" : "1", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 4010338, + "executorRunTime" : 34, + "executorCpuTime" : 21657558, + "resultSize" : 1048, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 319101, + "recordsWritten" : 3 + } + } + }, + "7" : { + "taskId" : 7, + "index" : 5, + "attempt" : 1, + "launchTime" : "2018-01-18T18:33:14.859GMT", + "duration" : 115, + "executorId" : "2", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 11, + "executorDeserializeCpuTime" : 10894331, + "executorRunTime" : 84, + "executorCpuTime" : 28283110, + "resultSize" : 1048, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 377601, + "recordsWritten" : 3 + } + } + }, + "3" : { + "taskId" : 3, + "index" : 3, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:12.833GMT", + "duration" : 2027, + "executorId" : "2", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 1282, + "executorDeserializeCpuTime" : 365807898, + "executorRunTime" : 681, + "executorCpuTime" : 349920830, + "resultSize" : 1134, + "jvmGcTime" : 102, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 3587839, + "recordsWritten" : 3 + } + } + }, + "11" : { + "taskId" : 11, + "index" : 7, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:15.072GMT", + "duration" : 93, + "executorId" : "1", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 4239884, + "executorRunTime" : 77, + "executorCpuTime" : 21689428, + "resultSize" : 1048, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 323898, + "recordsWritten" : 3 + } + } + }, + "8" : { + "taskId" : 8, + "index" : 6, + "attempt" : 1, + "launchTime" : "2018-01-18T18:33:14.879GMT", + "duration" : 194, + "executorId" : "1", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 56, + "executorDeserializeCpuTime" : 12246145, + "executorRunTime" : 54, + "executorCpuTime" : 27304550, + "resultSize" : 1048, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 311940, + "recordsWritten" : 3 + } + } + }, + "4" : { + "taskId" : 4, + "index" : 4, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:12.833GMT", + "duration" : 1522, + "executorId" : "4", + "host" : "apiros-2.gce.test.com", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 1184, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 82, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 16858066, + "recordsWritten" : 0 + } + } + } + }, + "executorSummary" : { + "4" : { + "taskTime" : 1589, + "failedTasks" : 2, + "succeededTasks" : 0, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 0, + "shuffleWriteRecords" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : true + }, + "5" : { + "taskTime" : 1579, + "failedTasks" : 2, + "succeededTasks" : 0, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 0, + "shuffleWriteRecords" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : true + }, + "1" : { + "taskTime" : 2411, + "failedTasks" : 0, + "succeededTasks" : 4, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 585, + "shuffleWriteRecords" : 12, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false + }, + "2" : { + "taskTime" : 2446, + "failedTasks" : 0, + "succeededTasks" : 5, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 732, + "shuffleWriteRecords" : 15, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false + }, + "3" : { + "taskTime" : 1774, + "failedTasks" : 0, + "succeededTasks" : 1, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 144, + "shuffleWriteRecords" : 3, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : true + } + }, + "killedTasksSummary" : { } +} \ No newline at end of file diff --git a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json index c925c1dd8a4d3..4fecf84db65a2 100644 --- a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json @@ -1,4 +1,34 @@ [ { + "id" : "application_1516285256255_0012", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-18T18:30:35.119GMT", + "endTime" : "2018-01-18T18:38:27.938GMT", + "lastUpdated" : "", + "duration" : 472819, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1516300235119, + "endTimeEpoch" : 1516300707938 + } ] +}, { + "id" : "app-20180109111548-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-09T10:15:42.372GMT", + "endTime" : "2018-01-09T10:24:37.606GMT", + "lastUpdated" : "", + "duration" : 535234, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1515492942372, + "endTimeEpoch" : 1515493477606 + } ] +}, { "id" : "app-20161116163331-0000", "name" : "Spark shell", "attempts" : [ { @@ -9,9 +39,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479335620587, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479335609916, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479335620587 } ] }, { "id" : "app-20161115172038-0000", @@ -24,9 +54,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479252138874, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479252037079, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479252138874 } ] }, { "id" : "local-1430917381534", @@ -39,9 +69,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917391398 } ] }, { "id" : "local-1430917381535", @@ -55,9 +85,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380950 }, { "attemptId" : "1", "startTime" : "2015-05-06T13:03:00.880GMT", @@ -67,9 +97,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380880, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380890 } ] }, { "id" : "local-1426533911241", @@ -83,9 +113,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426633910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426633945177 }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", @@ -95,9 +125,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426533910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426533945177 } ] }, { "id" : "local-1425081759269", @@ -110,10 +140,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "appSparkVersion" : "", - "endTimeEpoch" : 1425081766912, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1425081758277, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1425081766912 } ] }, { "id" : "local-1422981780767", @@ -126,9 +155,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1422981788731, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1422981779720, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1422981788731 } ] }, { "id" : "local-1422981759269", @@ -141,8 +170,8 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1422981766912, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1422981758277, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1422981766912 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json index cc0b2b0022bd3..79950b0dc6486 100644 --- a/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json @@ -1,46 +1,46 @@ [ { - "id" : "app-20161116163331-0000", + "id" : "application_1516285256255_0012", "name" : "Spark shell", "attempts" : [ { - "startTime" : "2016-11-16T22:33:29.916GMT", - "endTime" : "2016-11-16T22:33:40.587GMT", + "startTime" : "2018-01-18T18:30:35.119GMT", + "endTime" : "2018-01-18T18:38:27.938GMT", "lastUpdated" : "", - "duration" : 10671, - "sparkUser" : "jose", + "duration" : 472819, + "sparkUser" : "attilapiros", "completed" : true, - "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479335620587, - "startTimeEpoch" : 1479335609916, - "lastUpdatedEpoch" : 0 + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1516300235119, + "endTimeEpoch" : 1516300707938 } ] }, { - "id" : "app-20161115172038-0000", + "id" : "app-20180109111548-0000", "name" : "Spark shell", "attempts" : [ { - "startTime" : "2016-11-15T23:20:37.079GMT", - "endTime" : "2016-11-15T23:22:18.874GMT", + "startTime" : "2018-01-09T10:15:42.372GMT", + "endTime" : "2018-01-09T10:24:37.606GMT", "lastUpdated" : "", - "duration" : 101795, - "sparkUser" : "jose", + "duration" : 535234, + "sparkUser" : "attilapiros", "completed" : true, - "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479252138874, - "startTimeEpoch" : 1479252037079, - "lastUpdatedEpoch" : 0 + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1515492942372, + "endTimeEpoch" : 1515493477606 } ] }, { - "id" : "local-1430917381534", + "id" : "app-20161116163331-0000", "name" : "Spark shell", "attempts" : [ { - "startTime" : "2015-05-06T13:03:00.893GMT", - "endTime" : "2015-05-06T13:03:11.398GMT", + "startTime" : "2016-11-16T22:33:29.916GMT", + "endTime" : "2016-11-16T22:33:40.587GMT", "lastUpdated" : "", - "duration" : 10505, - "sparkUser" : "irashid", + "duration" : 10671, + "sparkUser" : "jose", "completed" : true, - "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917391398, - "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "appSparkVersion" : "2.1.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1479335609916, + "endTimeEpoch" : 1479335620587 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json index 5af50abd85330..7d60977dcd4fe 100644 --- a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json @@ -1,4 +1,34 @@ [ { + "id" : "application_1516285256255_0012", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-18T18:30:35.119GMT", + "endTime" : "2018-01-18T18:38:27.938GMT", + "lastUpdated" : "", + "duration" : 472819, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1516300235119, + "endTimeEpoch" : 1516300707938 + } ] +}, { + "id" : "app-20180109111548-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-09T10:15:42.372GMT", + "endTime" : "2018-01-09T10:24:37.606GMT", + "lastUpdated" : "", + "duration" : 535234, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1515492942372, + "endTimeEpoch" : 1515493477606 + } ] +}, { "id" : "app-20161116163331-0000", "name" : "Spark shell", "attempts" : [ { @@ -9,9 +39,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479335620587, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479335609916, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479335620587 } ] }, { "id" : "app-20161115172038-0000", @@ -24,9 +54,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479252138874, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479252037079, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479252138874 } ] }, { "id" : "local-1430917381534", @@ -39,9 +69,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917391398 } ] }, { "id" : "local-1430917381535", @@ -55,9 +85,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380950 }, { "attemptId" : "1", "startTime" : "2015-05-06T13:03:00.880GMT", @@ -67,9 +97,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380880, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380890 } ] }, { "id" : "local-1426533911241", @@ -83,9 +113,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426633910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426633945177 }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", @@ -95,9 +125,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426533910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426533945177 } ] }, { "id" : "local-1425081759269", @@ -110,8 +140,8 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1425081766912, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1425081758277, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1425081766912 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json index 7f896c74b5be1..dfbfd8aedcc23 100644 --- a/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json @@ -1,4 +1,34 @@ [ { + "id" : "application_1516285256255_0012", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-18T18:30:35.119GMT", + "endTime" : "2018-01-18T18:38:27.938GMT", + "lastUpdated" : "", + "duration" : 472819, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1516300235119, + "endTimeEpoch" : 1516300707938 + } ] +}, { + "id" : "app-20180109111548-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-09T10:15:42.372GMT", + "endTime" : "2018-01-09T10:24:37.606GMT", + "lastUpdated" : "", + "duration" : 535234, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1515492942372, + "endTimeEpoch" : 1515493477606 + } ] +}, { "id" : "app-20161116163331-0000", "name" : "Spark shell", "attempts" : [ { @@ -9,8 +39,8 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "startTimeEpoch" : 1479335609916, "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1479335609916, "endTimeEpoch" : 1479335620587 } ] }, { @@ -24,8 +54,8 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "startTimeEpoch" : 1479252037079, "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1479252037079, "endTimeEpoch" : 1479252138874 } ] }, { diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index 31093a661663b..03f886afa5413 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -421,7 +421,8 @@ "shuffleWrite" : 13180, "shuffleWriteRecords" : 0, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false } }, "killedTasksSummary" : { } diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index 601d70695b17c..947c89906955d 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -421,7 +421,8 @@ "shuffleWrite" : 13180, "shuffleWriteRecords" : 0, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false } }, "killedTasksSummary" : { } diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index 9cdcef0746185..963f010968b62 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -465,7 +465,8 @@ "shuffleWrite" : 0, "shuffleWriteRecords" : 0, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false } }, "killedTasksSummary" : { } diff --git a/core/src/test/resources/spark-events/app-20180109111548-0000 b/core/src/test/resources/spark-events/app-20180109111548-0000 new file mode 100755 index 0000000000000..50893d3001b95 --- /dev/null +++ b/core/src/test/resources/spark-events/app-20180109111548-0000 @@ -0,0 +1,59 @@ +{"Event":"SparkListenerLogStart","Spark Version":"2.3.0-SNAPSHOT"} +{"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre","Java Version":"1.8.0_152 (Oracle Corporation)","Scala Version":"version 2.11.8"},"Spark Properties":{"spark.blacklist.enabled":"true","spark.driver.host":"172.30.65.138","spark.eventLog.enabled":"true","spark.driver.port":"64273","spark.repl.class.uri":"spark://172.30.65.138:64273/classes","spark.jars":"","spark.repl.class.outputDir":"/private/var/folders/9g/gf583nd1765cvfgb_lsvwgp00000gp/T/spark-811c1b49-eb66-4bfb-91ae-33b45efa269d/repl-c4438f51-ee23-41ed-8e04-71496e2f40f5","spark.app.name":"Spark shell","spark.scheduler.mode":"FIFO","spark.ui.showConsoleProgress":"true","spark.blacklist.stage.maxFailedTasksPerExecutor":"1","spark.executor.id":"driver","spark.submit.deployMode":"client","spark.master":"local-cluster[2,1,1024]","spark.home":"*********(redacted)","spark.sql.catalogImplementation":"in-memory","spark.blacklist.application.maxFailedTasksPerExecutor":"10","spark.app.id":"app-20180109111548-0000"},"System Properties":{"java.io.tmpdir":"/var/folders/9g/gf583nd1765cvfgb_lsvwgp00000gp/T/","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"*********(redacted)","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","ftp.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","sun.arch.data.model":"64","sun.boot.library.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib","user.dir":"*********(redacted)","java.library.path":"*********(redacted)","sun.cpu.isalist":"","os.arch":"x86_64","java.vm.version":"25.152-b16","java.endorsed.dirs":"/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/endorsed","java.runtime.version":"1.8.0_152-b16","java.vm.info":"mixed mode","java.ext.dirs":"*********(redacted)","java.runtime.name":"Java(TM) SE Runtime Environment","file.separator":"/","java.class.version":"52.0","scala.usejavacp":"true","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/resources.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/rt.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/sunrsasign.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/jsse.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/jce.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/charsets.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/jfr.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/classes","file.encoding":"UTF-8","user.timezone":"*********(redacted)","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"10.12.6","sun.os.patch.level":"unknown","gopherProxySet":"false","java.vm.specification.vendor":"Oracle Corporation","user.country":"*********(redacted)","sun.jnu.encoding":"UTF-8","http.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","user.language":"*********(redacted)","socksNonProxyHosts":"local|*.local|169.254/16|*.169.254/16","java.vendor.url":"*********(redacted)","java.awt.printerjob":"sun.lwawt.macosx.CPrinterJob","java.awt.graphicsenv":"sun.awt.CGraphicsEnvironment","awt.toolkit":"sun.lwawt.macosx.LWCToolkit","os.name":"Mac OS X","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"*********(redacted)","user.name":"*********(redacted)","java.vm.name":"Java HotSpot(TM) 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --master local-cluster[2,1,1024] --conf spark.blacklist.stage.maxFailedTasksPerExecutor=1 --conf spark.blacklist.enabled=true --conf spark.blacklist.application.maxFailedTasksPerExecutor=10 --conf spark.eventLog.enabled=true --class org.apache.spark.repl.Main --name Spark shell spark-shell","java.home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre","java.version":"1.8.0_152","sun.io.unicode.encoding":"UnicodeBig"},"Classpath Entries":{"/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/api-asn1-api-1.0.0-M20.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/json4s-jackson_2.11-3.2.11.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/oro-2.0.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/machinist_2.11-0.6.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/metrics-json-3.1.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/lz4-java-1.4.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-sketch_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-catalyst_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scala-reflect-2.11.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-app-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/activation-1.1.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jsr305-1.3.9.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/univocity-parsers-2.5.9.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hk2-locator-2.4.0-b34.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/curator-framework-2.6.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/avro-mapred-1.7.7-hadoop2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-jaxrs-1.9.13.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jtransforms-2.4.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/json4s-core_2.11-3.2.11.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/metrics-jvm-3.1.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-mapper-asl-1.9.13.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-encoding-1.8.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hk2-api-2.4.0-b34.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/py4j-0.10.6.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/zookeeper-3.4.6.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-core-asl-1.9.13.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/core-1.1.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-core-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-api-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-beanutils-1.7.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/RoaringBitmap-0.5.11.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-module-paranamer-2.7.9.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-common-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-common-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/javax.ws.rs-api-2.0.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-configuration-1.6.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/javax.inject-2.4.0-b34.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/xercesImpl-2.9.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/gson-2.2.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-hdfs-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/arrow-format-0.8.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-databind-2.6.7.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-guava-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-lang3-3.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/aopalliance-repackaged-2.4.0-b34.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-media-jaxb-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/janino-3.0.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-client-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-auth-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/javassist-3.18.1-GA.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-format-2.3.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/breeze-macros_2.11-0.13.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-compress-1.4.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-core-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/slf4j-log4j12-1.7.16.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-server-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-collections-3.2.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/stax-api-1.0-2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/guava-14.0.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/httpcore-4.4.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-mllib_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/osgi-resource-locator-1.0.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-network-common_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/arrow-memory-0.8.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/log4j-1.2.17.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/breeze_2.11-0.13.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/arrow-vector-0.8.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/opencsv-2.3.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-jackson-1.8.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/minlog-1.3.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-jobclient-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-network-shuffle_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/htrace-core-3.0.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/bcprov-jdk15on-1.58.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scalap-2.11.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/netty-all-4.1.17.Final.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hppc-0.7.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/antlr4-runtime-4.7.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-io-2.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/httpclient-4.5.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jcl-over-slf4j-1.7.16.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hk2-utils-2.4.0-b34.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/shapeless_2.11-2.3.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scala-parser-combinators_2.11-1.0.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-codec-1.10.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/protobuf-java-2.5.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/java-xmlbuilder-1.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-net-2.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/compress-lzf-1.0.3.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-beanutils-core-1.8.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/chill_2.11-0.8.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/flatbuffers-1.2.0-3f79e055.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/leveldbjni-all-1.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-client-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/orc-mapreduce-1.4.1-nohive.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/paranamer-2.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-launcher_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-httpclient-3.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/javax.servlet-api-3.1.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/aircompressor-0.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-sql_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-module-scala_2.11-2.6.7.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/slf4j-api-1.7.16.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/metrics-core-3.1.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-common-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-streaming_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-unsafe_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/xbean-asm5-shaded-4.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/orc-core-1.4.1-nohive.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scala-xml_2.11-1.0.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-core_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/javax.annotation-api-1.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-math3-3.4.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jets3t-0.9.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-crypto-1.0.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/base64-2.3.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-lang-2.6.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/curator-recipes-2.6.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spire-macros_2.11-0.13.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-compiler-3.0.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-repl_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/apacheds-i18n-2.0.0-M15.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scala-library-2.11.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/conf/":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-annotations-2.6.7.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-common-1.8.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jetty-util-6.1.26.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/metrics-graphite-3.1.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/stream-2.7.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/chill-java-0.8.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-common-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jul-to-slf4j-1.7.16.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/ivy-2.4.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/xz-1.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spire_2.11-0.13.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-hadoop-1.8.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scala-compiler-2.11.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-cli-1.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/avro-1.7.7.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-server-common-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/kryo-shaded-3.0.3.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-digester-1.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-client-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-graphx_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-shuffle-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-mllib-local_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/snappy-java-1.1.2.6.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/xmlenc-0.52.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-kvstore_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/validation-api-1.1.0.Final.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-core-2.6.7.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/macro-compat_2.11-1.1.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jaxb-api-2.2.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-column-1.8.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/zstd-jni-1.3.2-2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/arpack_combined_all-0.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/json4s-ast_2.11-3.2.11.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/netty-3.9.9.Final.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/apacheds-kerberos-codec-2.0.0-M15.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-tags_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-annotations-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/api-util-1.0.0-M20.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/curator-client-2.6.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/pyrolite-4.13.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/joda-time-2.9.3.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-xc-1.9.13.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/objenesis-2.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7.jar":"*********(redacted)"}} +{"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"app-20180109111548-0000","Timestamp":1515492942372,"User":"attilapiros"} +{"Event":"SparkListenerExecutorAdded","Timestamp":1515492965588,"Executor ID":"0","Executor Info":{"Host":"172.30.65.138","Total Cores":1,"Log Urls":{"stdout":"http://172.30.65.138:64279/logPage/?appId=app-20180109111548-0000&executorId=0&logType=stdout","stderr":"http://172.30.65.138:64279/logPage/?appId=app-20180109111548-0000&executorId=0&logType=stderr"}}} +{"Event":"SparkListenerExecutorAdded","Timestamp":1515492965598,"Executor ID":"1","Executor Info":{"Host":"172.30.65.138","Total Cores":1,"Log Urls":{"stdout":"http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout","stderr":"http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"0","Host":"172.30.65.138","Port":64290},"Maximum Memory":384093388,"Timestamp":1515492965643,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":0} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"1","Host":"172.30.65.138","Port":64291},"Maximum Memory":384093388,"Timestamp":1515492965652,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":0} +{"Event":"SparkListenerJobStart","Job ID":0,"Submission Time":1515493278122,"Stage Infos":[{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :26","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]},{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :29","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :29","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:29)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]}],"Stage IDs":[0,1],"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :26","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1515493278152,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1515493278347,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1515493278364,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1515493278899,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklistedForStage","time":1515493278918,"executorId":"0","taskFailures":1,"stageId":0,"stageAttemptId":0} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"460","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":5,"Update":"14","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"3873006","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1515493278347,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278909,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3873006,"Value":3873006,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":14,"Value":14,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":460,"Value":460,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":460,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":14,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":3873006,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":3,"Index":0,"Attempt":1,"Launch Time":1515493278919,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278943,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":207014,"Value":6615636,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":92,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":896,"Value":1792,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":2144,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9364000,"Value":207843000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":10,"Value":698,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":8878000,"Value":208907000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":8,"Value":309,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"16","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"126128","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1515493278899,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278926,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":126128,"Value":3999134,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":16,"Value":476,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":16,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":126128,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1515493278364,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278929,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":2409488,"Value":6408622,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":46,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":896,"Value":896,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":13,"Value":27,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1115,"Value":1115,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":198479000,"Value":198479000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":212,"Value":688,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":200029000,"Value":200029000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":301,"Value":301,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":301,"Executor Deserialize CPU Time":200029000,"Executor Run Time":212,"Executor CPU Time":198479000,"Result Size":1115,"JVM GC Time":13,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":2409488,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":4,"Index":2,"Attempt":1,"Launch Time":1515493278943,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278959,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":292381,"Value":6908017,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":2704,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":3173,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9207000,"Value":217050000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":9,"Value":707,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2211000,"Value":211118000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":311,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":3,"Index":0,"Attempt":1,"Launch Time":1515493278919,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278943,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":207014,"Value":6615636,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":92,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":896,"Value":1792,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":2144,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9364000,"Value":207843000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":10,"Value":698,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":8878000,"Value":208907000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":8,"Value":309,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":8,"Executor Deserialize CPU Time":8878000,"Executor Run Time":10,"Executor CPU Time":9364000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":207014,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":5,"Index":3,"Attempt":0,"Launch Time":1515493278958,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278980,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":262919,"Value":7170936,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":184,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":3616,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":4202,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9635000,"Value":226685000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":9,"Value":716,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2586000,"Value":213704000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":314,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":4,"Index":2,"Attempt":1,"Launch Time":1515493278943,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278959,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":292381,"Value":6908017,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":2704,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":3173,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9207000,"Value":217050000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":9,"Value":707,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2211000,"Value":211118000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":311,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":2211000,"Executor Run Time":9,"Executor CPU Time":9207000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":292381,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":6,"Index":4,"Attempt":0,"Launch Time":1515493278980,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":5,"Index":3,"Attempt":0,"Launch Time":1515493278958,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278980,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":262919,"Value":7170936,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":184,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":3616,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":4202,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9635000,"Value":226685000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":9,"Value":716,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2586000,"Value":213704000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":314,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":2586000,"Executor Run Time":9,"Executor CPU Time":9635000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":262919,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":7,"Index":5,"Attempt":0,"Launch Time":1515493278996,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":6,"Index":4,"Attempt":0,"Launch Time":1515493278980,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278996,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":385110,"Value":7556046,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":5,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":230,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":4528,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":5231,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9622000,"Value":236307000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":10,"Value":726,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2610000,"Value":216314000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":317,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":2610000,"Executor Run Time":10,"Executor CPU Time":9622000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":385110,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":8,"Index":6,"Attempt":0,"Launch Time":1515493279011,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":7,"Index":5,"Attempt":0,"Launch Time":1515493278996,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279011,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":205520,"Value":7761566,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":6,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":276,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":5440,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":6260,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":8407000,"Value":244714000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":9,"Value":735,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2231000,"Value":218545000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":319,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":2231000,"Executor Run Time":9,"Executor CPU Time":8407000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":205520,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":9,"Index":7,"Attempt":0,"Launch Time":1515493279022,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":8,"Index":6,"Attempt":0,"Launch Time":1515493279011,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279022,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":213296,"Value":7974862,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":6352,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":7289,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":6034000,"Value":250748000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":7,"Value":742,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1554000,"Value":220099000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":1,"Value":320,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1,"Executor Deserialize CPU Time":1554000,"Executor Run Time":7,"Executor CPU Time":6034000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":213296,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":10,"Index":8,"Attempt":0,"Launch Time":1515493279034,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":9,"Index":7,"Attempt":0,"Launch Time":1515493279022,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279034,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":259354,"Value":8234216,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":8,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":368,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":7264,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":8318,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":6335000,"Value":257083000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":7,"Value":749,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1981000,"Value":222080000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":322,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1981000,"Executor Run Time":7,"Executor CPU Time":6335000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":259354,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":11,"Index":9,"Attempt":0,"Launch Time":1515493279045,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":10,"Index":8,"Attempt":0,"Launch Time":1515493279034,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279046,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":243647,"Value":8477863,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":9,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":414,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":8176,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":9347,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":6157000,"Value":263240000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":6,"Value":755,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1803000,"Value":223883000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":324,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1803000,"Executor Run Time":6,"Executor CPU Time":6157000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":243647,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":11,"Index":9,"Attempt":0,"Launch Time":1515493279045,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279060,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":233652,"Value":8711515,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":460,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":9088,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":10376,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":6676000,"Value":269916000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":6,"Value":761,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2017000,"Value":225900000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":327,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":2017000,"Executor Run Time":6,"Executor CPU Time":6676000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":233652,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :26","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1515493278152,"Completion Time":1515493279062,"Accumulables":[{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Value":761,"Internal":true,"Count Failed Values":true},{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Value":8711515,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Value":27,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Value":10376,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Value":225900000,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Value":10,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Value":9088,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Value":460,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Value":269916000,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Value":1,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Value":327,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :29","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :29","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:29)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1515493279071,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":12,"Index":0,"Attempt":0,"Launch Time":1515493279077,"Executor ID":"0","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":13,"Index":1,"Attempt":0,"Launch Time":1515493279078,"Executor ID":"1","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":14,"Index":2,"Attempt":0,"Launch Time":1515493279152,"Executor ID":"1","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":13,"Index":1,"Attempt":0,"Launch Time":1515493279078,"Executor ID":"1","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279152,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":4,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":184,"Value":184,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":4,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":944,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":1286,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":41280000,"Value":41280000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":53,"Value":53,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":11820000,"Value":11820000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":17,"Value":17,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":17,"Executor Deserialize CPU Time":11820000,"Executor Run Time":53,"Executor CPU Time":41280000,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":4,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":184,"Total Records Read":4},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":15,"Index":3,"Attempt":0,"Launch Time":1515493279166,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":14,"Index":2,"Attempt":0,"Launch Time":1515493279152,"Executor ID":"1","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279167,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":3,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":138,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":3,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":1888,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":2572,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":7673000,"Value":48953000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":8,"Value":61,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1706000,"Value":13526000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":19,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1706000,"Executor Run Time":8,"Executor CPU Time":7673000,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":3,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":138,"Total Records Read":3},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":16,"Index":4,"Attempt":0,"Launch Time":1515493279179,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":15,"Index":3,"Attempt":0,"Launch Time":1515493279166,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279180,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":1888,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":3706,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":6972000,"Value":55925000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":7,"Value":68,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1569000,"Value":15095000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":21,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1569000,"Executor Run Time":7,"Executor CPU Time":6972000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":17,"Index":5,"Attempt":0,"Launch Time":1515493279190,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":16,"Index":4,"Attempt":0,"Launch Time":1515493279179,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279190,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":1888,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":4840,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":4905000,"Value":60830000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":5,"Value":73,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1882000,"Value":16977000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":23,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1882000,"Executor Run Time":5,"Executor CPU Time":4905000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":18,"Index":6,"Attempt":0,"Launch Time":1515493279193,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":12,"Index":0,"Attempt":0,"Launch Time":1515493279077,"Executor ID":"0","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279194,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":3,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":23,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":138,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":3,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":6126,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":56742000,"Value":117572000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":89,"Value":162,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":12625000,"Value":29602000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":18,"Value":41,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":18,"Executor Deserialize CPU Time":12625000,"Executor Run Time":89,"Executor CPU Time":56742000,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":3,"Local Blocks Fetched":0,"Fetch Wait Time":23,"Remote Bytes Read":138,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":3},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":19,"Index":7,"Attempt":0,"Launch Time":1515493279202,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":17,"Index":5,"Attempt":0,"Launch Time":1515493279190,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279203,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":7260,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":6476000,"Value":124048000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":7,"Value":169,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1890000,"Value":31492000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":43,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1890000,"Executor Run Time":7,"Executor CPU Time":6476000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":20,"Index":8,"Attempt":0,"Launch Time":1515493279215,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":19,"Index":7,"Attempt":0,"Launch Time":1515493279202,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279216,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":8394,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":6927000,"Value":130975000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":7,"Value":176,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2038000,"Value":33530000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":45,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":2038000,"Executor Run Time":7,"Executor CPU Time":6927000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":21,"Index":9,"Attempt":0,"Launch Time":1515493279218,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":18,"Index":6,"Attempt":0,"Launch Time":1515493279193,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279218,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":9528,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":11214000,"Value":142189000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":16,"Value":192,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2697000,"Value":36227000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":49,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":2697000,"Executor Run Time":16,"Executor CPU Time":11214000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":20,"Index":8,"Attempt":0,"Launch Time":1515493279215,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279226,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":10662,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":4905000,"Value":147094000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":5,"Value":197,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1700000,"Value":37927000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":51,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1700000,"Executor Run Time":5,"Executor CPU Time":4905000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":21,"Index":9,"Attempt":0,"Launch Time":1515493279218,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279232,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":11796,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":7850000,"Value":154944000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":8,"Value":205,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2186000,"Value":40113000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":54,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":2186000,"Executor Run Time":8,"Executor CPU Time":7850000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :29","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :29","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:29)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1515493279071,"Completion Time":1515493279232,"Accumulables":[{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Value":23,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Value":40113000,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Value":11796,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Value":138,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Value":322,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Value":54,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Value":2832,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Value":7,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Value":154944000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Value":205,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Value":3,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Value":0,"Internal":true,"Count Failed Values":true},{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Value":10,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerJobEnd","Job ID":0,"Completion Time":1515493279237,"Job Result":{"Result":"JobSucceeded"}} +{"Event":"SparkListenerApplicationEnd","Timestamp":1515493477606} diff --git a/core/src/test/resources/spark-events/application_1516285256255_0012 b/core/src/test/resources/spark-events/application_1516285256255_0012 new file mode 100755 index 0000000000000..3e1736c3fe224 --- /dev/null +++ b/core/src/test/resources/spark-events/application_1516285256255_0012 @@ -0,0 +1,71 @@ +{"Event":"SparkListenerLogStart","Spark Version":"2.3.0-SNAPSHOT"} +{"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre","Java Version":"1.8.0_161 (Oracle Corporation)","Scala Version":"version 2.11.8"},"Spark Properties":{"spark.blacklist.enabled":"true","spark.driver.host":"apiros-1.gce.test.com","spark.eventLog.enabled":"true","spark.driver.port":"33058","spark.repl.class.uri":"spark://apiros-1.gce.test.com:33058/classes","spark.jars":"","spark.repl.class.outputDir":"/tmp/spark-6781fb17-e07a-4b32-848b-9936c2e88b33/repl-c0fd7008-04be-471e-a173-6ad3e62d53d7","spark.app.name":"Spark shell","spark.blacklist.stage.maxFailedExecutorsPerNode":"1","spark.scheduler.mode":"FIFO","spark.executor.instances":"8","spark.ui.showConsoleProgress":"true","spark.blacklist.stage.maxFailedTasksPerExecutor":"1","spark.executor.id":"driver","spark.submit.deployMode":"client","spark.master":"yarn","spark.ui.filters":"org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter","spark.executor.memory":"2G","spark.home":"/github/spark","spark.sql.catalogImplementation":"hive","spark.driver.appUIAddress":"http://apiros-1.gce.test.com:4040","spark.blacklist.application.maxFailedTasksPerExecutor":"10","spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.param.PROXY_HOSTS":"apiros-1.gce.test.com","spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.param.PROXY_URI_BASES":"http://apiros-1.gce.test.com:8088/proxy/application_1516285256255_0012","spark.app.id":"application_1516285256255_0012"},"System Properties":{"java.io.tmpdir":"/tmp","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"*********(redacted)","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","sun.arch.data.model":"64","sun.boot.library.path":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/amd64","user.dir":"*********(redacted)","java.library.path":"/usr/java/packages/lib/amd64:/usr/lib64:/lib64:/lib:/usr/lib","sun.cpu.isalist":"","os.arch":"amd64","java.vm.version":"25.161-b14","java.endorsed.dirs":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/endorsed","java.runtime.version":"1.8.0_161-b14","java.vm.info":"mixed mode","java.ext.dirs":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/ext:/usr/java/packages/lib/ext","java.runtime.name":"OpenJDK Runtime Environment","file.separator":"/","java.class.version":"52.0","scala.usejavacp":"true","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/resources.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/rt.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/sunrsasign.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/jsse.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/jce.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/charsets.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/jfr.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/classes","file.encoding":"UTF-8","user.timezone":"*********(redacted)","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"3.10.0-693.5.2.el7.x86_64","sun.os.patch.level":"unknown","java.vm.specification.vendor":"Oracle Corporation","user.country":"*********(redacted)","sun.jnu.encoding":"UTF-8","user.language":"*********(redacted)","java.vendor.url":"*********(redacted)","java.awt.printerjob":"sun.print.PSPrinterJob","java.awt.graphicsenv":"sun.awt.X11GraphicsEnvironment","awt.toolkit":"sun.awt.X11.XToolkit","os.name":"Linux","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"*********(redacted)","user.name":"*********(redacted)","java.vm.name":"OpenJDK 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --master yarn --deploy-mode client --conf spark.blacklist.stage.maxFailedTasksPerExecutor=1 --conf spark.blacklist.enabled=true --conf spark.blacklist.application.maxFailedTasksPerExecutor=10 --conf spark.blacklist.stage.maxFailedExecutorsPerNode=1 --conf spark.eventLog.enabled=true --class org.apache.spark.repl.Main --name Spark shell --executor-memory 2G --num-executors 8 spark-shell","java.home":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre","java.version":"1.8.0_161","sun.io.unicode.encoding":"UnicodeLittle"},"Classpath Entries":{"/github/spark/assembly/target/scala-2.11/jars/validation-api-1.1.0.Final.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/arrow-vector-0.8.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-io-2.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.servlet-api-3.1.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-hive_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-parser-combinators_2.11-1.0.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/stax-api-1.0-2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/json4s-ast_2.11-3.2.11.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/apache-log4j-extras-1.2.17.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hive-metastore-1.2.1.spark2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/avro-1.7.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/core-1.1.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-common-2.22.2.jar":"System Classpath","/github/spark/conf/":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/metrics-json-3.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/protobuf-java-2.5.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/aircompressor-0.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/stax-api-1.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/leveldbjni-all-1.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/snappy-java-1.1.2.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/curator-recipes-2.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-core-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/arrow-format-0.8.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/ivy-2.4.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/libthrift-0.9.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-lang-2.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-sketch_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-tags_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-common-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/slf4j-api-1.7.16.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-server-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/stringtemplate-3.2.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/pyrolite-4.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-crypto-1.0.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/datanucleus-api-jdo-3.2.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-net-2.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-annotations-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/orc-core-1.4.1-nohive.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spire_2.11-0.13.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/arrow-memory-0.8.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/log4j-1.2.17.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-core-asl-1.9.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scalap-2.11.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-xml_2.11-1.0.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/JavaEWAH-0.3.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/bcprov-jdk15on-1.58.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-reflect-2.11.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-sql_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javolution-5.5.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/libfb303-0.9.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-media-jaxb-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jodd-core-3.5.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/janino-3.0.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-unsafe_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/antlr4-runtime-4.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/snappy-0.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/guice-3.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/java-xmlbuilder-1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/chill_2.11-0.8.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/apacheds-kerberos-codec-2.0.0-M15.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/stream-2.7.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/ST4-4.0.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/datanucleus-core-3.2.10.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-api-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/guice-servlet-3.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/avro-mapred-1.7.7-hadoop2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hive-exec-1.2.1.spark2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-beanutils-1.7.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jetty-6.1.26.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-server-common-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-configuration-1.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/minlog-1.3.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/base64-2.3.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/slf4j-log4j12-1.7.16.jar":"System Classpath","/etc/hadoop/conf/":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-httpclient-3.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-mapper-asl-1.9.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-yarn_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-repl_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spire-macros_2.11-0.13.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-client-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-jaxrs-1.9.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/apacheds-i18n-2.0.0-M15.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-cli-1.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.annotation-api-1.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/lz4-java-1.4.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-mllib-local_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-compress-1.4.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/breeze-macros_2.11-0.13.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-module-scala_2.11-2.6.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/curator-framework-2.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/curator-client-2.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/netty-3.9.9.Final.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/calcite-avatica-1.2.0-incubating.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-annotations-2.6.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/machinist_2.11-0.6.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jaxb-api-2.2.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/api-asn1-api-1.0.0-M20.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/calcite-linq4j-1.2.0-incubating.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-network-common_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-auth-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/orc-mapreduce-1.4.1-nohive.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-common-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-common-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/xercesImpl-2.9.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hppc-0.7.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-beanutils-core-1.8.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-math3-3.4.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-core_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-library-2.11.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-app-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-hadoop-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-catalyst_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/metrics-jvm-3.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-compiler-2.11.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/objenesis-2.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/shapeless_2.11-2.3.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/activation-1.1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/py4j-0.10.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-core-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/zookeeper-3.4.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-hadoop-bundle-1.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/antlr-runtime-3.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-mllib_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/oro-2.0.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/eigenbase-properties-1.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-graphx_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hk2-locator-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.ws.rs-api-2.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/aopalliance-repackaged-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-network-shuffle_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-format-2.3.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-launcher_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-shuffle-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/paranamer-2.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jta-1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/derby-10.12.1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/xz-1.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-client-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-logging-1.1.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-pool-1.5.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-streaming_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javassist-3.18.1-GA.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/guava-14.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/xmlenc-0.52.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/htrace-core-3.0.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.inject-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/httpclient-4.5.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-databind-2.6.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-column-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/zstd-jni-1.3.2-2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-server-web-proxy-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-kvstore_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-encoding-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/univocity-parsers-2.5.9.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/compress-lzf-1.0.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-collections-3.2.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-jobclient-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/osgi-resource-locator-1.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-client-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/chill-java-0.8.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/antlr-2.7.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hk2-utils-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/RoaringBitmap-0.5.11.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jul-to-slf4j-1.7.16.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/xbean-asm5-shaded-4.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/datanucleus-rdbms-3.2.9.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/arpack_combined_all-0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hk2-api-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/metrics-graphite-3.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-common-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-hdfs-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.inject-1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/opencsv-2.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/api-util-1.0.0-M20.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jdo-api-3.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-module-paranamer-2.7.9.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/kryo-shaded-3.0.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-dbcp-1.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/netty-all-4.1.17.Final.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-jackson-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/gson-2.2.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/calcite-core-1.2.0-incubating.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/macro-compat_2.11-1.1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/flatbuffers-1.2.0-3f79e055.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/json4s-core_2.11-3.2.11.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/breeze_2.11-0.13.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-digester-1.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jsr305-1.3.9.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jtransforms-2.4.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jets3t-0.9.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-core-2.6.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-xc-1.9.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/aopalliance-1.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/bonecp-0.8.0.RELEASE.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jetty-util-6.1.26.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/joda-time-2.9.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/json4s-jackson_2.11-3.2.11.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/metrics-core-3.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jcl-over-slf4j-1.7.16.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/httpcore-4.4.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-lang3-3.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-guava-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-codec-1.10.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-compiler-3.0.8.jar":"System Classpath"}} +{"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"application_1516285256255_0012","Timestamp":1516300235119,"User":"attilapiros"} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300252095,"Executor ID":"2","Executor Info":{"Host":"apiros-3.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stdout?start=-4096","stderr":"http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"2","Host":"apiros-3.gce.test.com","Port":38670},"Maximum Memory":956615884,"Timestamp":1516300252260,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300252715,"Executor ID":"3","Executor Info":{"Host":"apiros-2.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000004/attilapiros/stdout?start=-4096","stderr":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000004/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300252918,"Executor ID":"1","Executor Info":{"Host":"apiros-3.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stdout?start=-4096","stderr":"http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"3","Host":"apiros-2.gce.test.com","Port":38641},"Maximum Memory":956615884,"Timestamp":1516300252959,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"1","Host":"apiros-3.gce.test.com","Port":34970},"Maximum Memory":956615884,"Timestamp":1516300252988,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300253542,"Executor ID":"4","Executor Info":{"Host":"apiros-2.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000005/attilapiros/stdout?start=-4096","stderr":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000005/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"4","Host":"apiros-2.gce.test.com","Port":33229},"Maximum Memory":956615884,"Timestamp":1516300253653,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300254323,"Executor ID":"5","Executor Info":{"Host":"apiros-2.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000007/attilapiros/stdout?start=-4096","stderr":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000007/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"5","Host":"apiros-2.gce.test.com","Port":45147},"Maximum Memory":956615884,"Timestamp":1516300254385,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerJobStart","Job ID":0,"Submission Time":1516300392631,"Stage Infos":[{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :27","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :27","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :27","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:27)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]},{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :30","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :30","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:30)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]}],"Stage IDs":[0,1],"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :27","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :27","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :27","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:27)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1516300392658,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1516300392816,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1516300392832,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1516300392832,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":3,"Index":3,"Attempt":0,"Launch Time":1516300392833,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":4,"Index":4,"Attempt":0,"Launch Time":1516300392833,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":5,"Index":5,"Attempt":0,"Launch Time":1516300394320,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":6,"Index":6,"Attempt":0,"Launch Time":1516300394323,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklistedForStage","time":1516300394348,"executorId":"5","taskFailures":1,"stageId":0,"stageAttemptId":0} +{"Event":"org.apache.spark.scheduler.SparkListenerNodeBlacklistedForStage","time":1516300394348,"hostId":"apiros-2.gce.test.com","executorFailures":1,"stageId":0,"stageAttemptId":0} +{"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklistedForStage","time":1516300394356,"executorId":"4","taskFailures":1,"stageId":0,"stageAttemptId":0} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":28},{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"1332","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":5,"Update":"33","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"3075188","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1516300392832,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394338,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3075188,"Value":3075188,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":33,"Value":33,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1332,"Value":1332,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":1332,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":33,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":3075188,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":28},{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"1184","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":5,"Update":"82","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"16858066","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":4,"Index":4,"Attempt":0,"Launch Time":1516300392833,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394355,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":16858066,"Value":19933254,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":82,"Value":115,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1184,"Value":2516,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":1184,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":82,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":16858066,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":28},{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"51","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"183718","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":6,"Index":6,"Attempt":0,"Launch Time":1516300394323,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394390,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":183718,"Value":20116972,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":51,"Value":2567,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":51,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":183718,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":28},{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"27","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"191901","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":5,"Index":5,"Attempt":0,"Launch Time":1516300394320,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394393,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":191901,"Value":20308873,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":27,"Value":2594,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":27,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":191901,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1516300392832,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394606,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3322956,"Value":23631829,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":144,"Value":144,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":1080,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":78,"Value":193,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1134,"Value":1134,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":278399617,"Value":278399617,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":493,"Value":3087,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":263386625,"Value":263386625,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":1206,"Value":1206,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1206,"Executor Deserialize CPU Time":263386625,"Executor Run Time":493,"Executor CPU Time":278399617,"Result Size":1134,"JVM GC Time":78,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":144,"Shuffle Write Time":3322956,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":7,"Index":5,"Attempt":1,"Launch Time":1516300394859,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":3,"Index":3,"Attempt":0,"Launch Time":1516300392833,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394860,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3587839,"Value":27219668,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":6,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":291,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":2160,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":102,"Value":295,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1134,"Value":2268,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":349920830,"Value":628320447,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":681,"Value":3768,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":365807898,"Value":629194523,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":1282,"Value":2488,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1282,"Executor Deserialize CPU Time":365807898,"Executor Run Time":681,"Executor CPU Time":349920830,"Result Size":1134,"JVM GC Time":102,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":3587839,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":8,"Index":6,"Attempt":1,"Launch Time":1516300394879,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1516300392816,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394880,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3662221,"Value":30881889,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":9,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":144,"Value":435,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":3240,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":75,"Value":370,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1134,"Value":3402,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":368865439,"Value":997185886,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":914,"Value":4682,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":353981050,"Value":983175573,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":1081,"Value":3569,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1081,"Executor Deserialize CPU Time":353981050,"Executor Run Time":914,"Executor CPU Time":368865439,"Result Size":1134,"JVM GC Time":75,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":144,"Shuffle Write Time":3662221,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":9,"Index":4,"Attempt":1,"Launch Time":1516300394973,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":7,"Index":5,"Attempt":1,"Launch Time":1516300394859,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394974,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":377601,"Value":31259490,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":12,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":582,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":4320,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":4450,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":28283110,"Value":1025468996,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":84,"Value":4766,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":10894331,"Value":994069904,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":11,"Value":3580,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":11,"Executor Deserialize CPU Time":10894331,"Executor Run Time":84,"Executor CPU Time":28283110,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":377601,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":10,"Index":1,"Attempt":1,"Launch Time":1516300395069,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":9,"Index":4,"Attempt":1,"Launch Time":1516300394973,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395069,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":366050,"Value":31625540,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":15,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":729,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":5400,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1091,"Value":5541,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":25678331,"Value":1051147327,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":48,"Value":4814,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4793905,"Value":998863809,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":5,"Value":3585,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":5,"Executor Deserialize CPU Time":4793905,"Executor Run Time":48,"Executor CPU Time":25678331,"Result Size":1091,"JVM GC Time":0,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":366050,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":11,"Index":7,"Attempt":0,"Launch Time":1516300395072,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":8,"Index":6,"Attempt":1,"Launch Time":1516300394879,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395073,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":311940,"Value":31937480,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":18,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":876,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":6480,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":6589,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":27304550,"Value":1078451877,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":54,"Value":4868,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":12246145,"Value":1011109954,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":56,"Value":3641,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":56,"Executor Deserialize CPU Time":12246145,"Executor Run Time":54,"Executor CPU Time":27304550,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":311940,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":12,"Index":8,"Attempt":0,"Launch Time":1516300395165,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":11,"Index":7,"Attempt":0,"Launch Time":1516300395072,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395165,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":323898,"Value":32261378,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":21,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":1023,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":7560,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":7637,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":21689428,"Value":1100141305,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":77,"Value":4945,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4239884,"Value":1015349838,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":3645,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":4239884,"Executor Run Time":77,"Executor CPU Time":21689428,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":323898,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":13,"Index":9,"Attempt":0,"Launch Time":1516300395200,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":10,"Index":1,"Attempt":1,"Launch Time":1516300395069,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395201,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":301705,"Value":32563083,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":24,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":144,"Value":1167,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":8640,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":5,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1091,"Value":8728,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":20826337,"Value":1120967642,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":76,"Value":5021,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4598966,"Value":1019948804,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":5,"Value":3650,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":5,"Executor Deserialize CPU Time":4598966,"Executor Run Time":76,"Executor CPU Time":20826337,"Result Size":1091,"JVM GC Time":0,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":144,"Shuffle Write Time":301705,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":12,"Index":8,"Attempt":0,"Launch Time":1516300395165,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395225,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":319101,"Value":32882184,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":27,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":1314,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":9720,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":9776,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":21657558,"Value":1142625200,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":34,"Value":5055,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4010338,"Value":1023959142,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":3654,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":4010338,"Executor Run Time":34,"Executor CPU Time":21657558,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":319101,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":13,"Index":9,"Attempt":0,"Launch Time":1516300395200,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395276,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":369513,"Value":33251697,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":30,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":1461,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":10800,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":10824,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":20585619,"Value":1163210819,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":25,"Value":5080,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":5860574,"Value":1029819716,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":25,"Value":3679,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":25,"Executor Deserialize CPU Time":5860574,"Executor Run Time":25,"Executor CPU Time":20585619,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":369513,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :27","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :27","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :27","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:27)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1516300392658,"Completion Time":1516300395279,"Accumulables":[{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Value":5080,"Internal":true,"Count Failed Values":true},{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Value":33251697,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Value":370,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Value":10824,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Value":1029819716,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Value":30,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Value":10800,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Value":1461,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Value":1163210819,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Value":5,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Value":3679,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :30","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :30","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:30)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1516300395292,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":14,"Index":0,"Attempt":0,"Launch Time":1516300395302,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":15,"Index":1,"Attempt":0,"Launch Time":1516300395303,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":16,"Index":3,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":17,"Index":4,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":18,"Index":5,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":19,"Index":6,"Attempt":0,"Launch Time":1516300395525,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":17,"Index":4,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395525,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":1134,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":52455999,"Value":52455999,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":95,"Value":95,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":23136577,"Value":23136577,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":82,"Value":82,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":82,"Executor Deserialize CPU Time":23136577,"Executor Run Time":95,"Executor CPU Time":52455999,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":20,"Index":7,"Attempt":0,"Launch Time":1516300395575,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":19,"Index":6,"Attempt":0,"Launch Time":1516300395525,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395576,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":2268,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":13617615,"Value":66073614,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":29,"Value":124,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3469612,"Value":26606189,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":86,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":3469612,"Executor Run Time":29,"Executor CPU Time":13617615,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":21,"Index":8,"Attempt":0,"Launch Time":1516300395581,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":18,"Index":5,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395581,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":3402,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":55540208,"Value":121613822,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":179,"Value":303,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":22400065,"Value":49006254,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":78,"Value":164,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":78,"Executor Deserialize CPU Time":22400065,"Executor Run Time":179,"Executor CPU Time":55540208,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":22,"Index":9,"Attempt":0,"Launch Time":1516300395593,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":16,"Index":3,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395593,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":4536,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":52311573,"Value":173925395,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":153,"Value":456,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":20519033,"Value":69525287,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":67,"Value":231,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":67,"Executor Deserialize CPU Time":20519033,"Executor Run Time":153,"Executor CPU Time":52311573,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":20,"Index":7,"Attempt":0,"Launch Time":1516300395575,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395660,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":5670,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":11294260,"Value":185219655,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":33,"Value":489,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3570887,"Value":73096174,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":235,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":3570887,"Executor Run Time":33,"Executor CPU Time":11294260,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":22,"Index":9,"Attempt":0,"Launch Time":1516300395593,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395669,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":6804,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":12983732,"Value":198203387,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":44,"Value":533,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3518757,"Value":76614931,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":239,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":3518757,"Executor Run Time":44,"Executor CPU Time":12983732,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":21,"Index":8,"Attempt":0,"Launch Time":1516300395581,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395674,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":7938,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":14706240,"Value":212909627,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":64,"Value":597,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":7698059,"Value":84312990,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":21,"Value":260,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":21,"Executor Deserialize CPU Time":7698059,"Executor Run Time":64,"Executor CPU Time":14706240,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":23,"Index":2,"Attempt":0,"Launch Time":1516300395686,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":14,"Index":0,"Attempt":0,"Launch Time":1516300395302,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395687,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":10,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":52,"Value":52,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":195,"Value":195,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":292,"Value":292,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":4,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":6,"Value":6,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":944,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":9224,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":91696783,"Value":304606410,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":221,"Value":818,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":24063461,"Value":108376451,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":150,"Value":410,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":150,"Executor Deserialize CPU Time":24063461,"Executor Run Time":221,"Executor CPU Time":91696783,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":6,"Local Blocks Fetched":4,"Fetch Wait Time":52,"Remote Bytes Read":292,"Remote Bytes Read To Disk":0,"Local Bytes Read":195,"Total Records Read":10},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":15,"Index":1,"Attempt":0,"Launch Time":1516300395303,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395687,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":10,"Value":20,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":107,"Value":159,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":244,"Value":439,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":243,"Value":535,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":5,"Value":9,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":5,"Value":11,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":1888,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":10510,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":91683507,"Value":396289917,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":289,"Value":1107,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":22106726,"Value":130483177,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":79,"Value":489,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":79,"Executor Deserialize CPU Time":22106726,"Executor Run Time":289,"Executor CPU Time":91683507,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":5,"Local Blocks Fetched":5,"Fetch Wait Time":107,"Remote Bytes Read":243,"Remote Bytes Read To Disk":0,"Local Bytes Read":244,"Total Records Read":10},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":23,"Index":2,"Attempt":0,"Launch Time":1516300395686,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395728,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":10,"Value":30,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":159,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":195,"Value":634,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":292,"Value":827,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":4,"Value":13,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":6,"Value":17,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":11796,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":17607810,"Value":413897727,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":33,"Value":1140,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2897647,"Value":133380824,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":491,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":2897647,"Executor Run Time":33,"Executor CPU Time":17607810,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":6,"Local Blocks Fetched":4,"Fetch Wait Time":0,"Remote Bytes Read":292,"Remote Bytes Read To Disk":0,"Local Bytes Read":195,"Total Records Read":10},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :30","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :30","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:30)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1516300395292,"Completion Time":1516300395728,"Accumulables":[{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Value":159,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Value":133380824,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Value":11796,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Value":827,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Value":634,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Value":491,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Value":2832,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Value":13,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Value":413897727,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Value":1140,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Value":17,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Value":0,"Internal":true,"Count Failed Values":true},{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Value":30,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerJobEnd","Job ID":0,"Completion Time":1516300395734,"Job Result":{"Result":"JobSucceeded"}} +{"Event":"SparkListenerApplicationEnd","Timestamp":1516300707938} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 7aa60f2b60796..87f12f303cd5e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -156,6 +156,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers "applications/local-1426533911241/1/stages/0/0/taskList", "stage task list from multi-attempt app json(2)" -> "applications/local-1426533911241/2/stages/0/0/taskList", + "blacklisting for stage" -> "applications/app-20180109111548-0000/stages/0/0", + "blacklisting node for stage" -> "applications/application_1516285256255_0012/stages/0/0", "rdd list storage json" -> "applications/local-1422981780767/storage/rdd", "executor node blacklisting" -> "applications/app-20161116163331-0000/executors", diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index cd1b7a9e5ab18..afebcdd7b9e31 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -92,7 +92,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M } def createTaskSetBlacklist(stageId: Int = 0): TaskSetBlacklist = { - new TaskSetBlacklist(conf, stageId, clock) + new TaskSetBlacklist(listenerBusMock, conf, stageId, stageAttemptId = 0, clock = clock) } test("executors can be blacklisted with only a few failures per stage") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala index 18981d5be2f94..6e2709dbe1e8b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala @@ -16,18 +16,32 @@ */ package org.apache.spark.scheduler +import org.mockito.Matchers.isA +import org.mockito.Mockito.{never, verify} +import org.scalatest.BeforeAndAfterEach +import org.scalatest.mockito.MockitoSugar + import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.config -import org.apache.spark.util.{ManualClock, SystemClock} +import org.apache.spark.util.ManualClock + +class TaskSetBlacklistSuite extends SparkFunSuite with BeforeAndAfterEach with MockitoSugar { -class TaskSetBlacklistSuite extends SparkFunSuite { + private var listenerBusMock: LiveListenerBus = _ + + override def beforeEach(): Unit = { + listenerBusMock = mock[LiveListenerBus] + super.beforeEach() + } test("Blacklisting tasks, executors, and nodes") { val conf = new SparkConf().setAppName("test").setMaster("local") .set(config.BLACKLIST_ENABLED.key, "true") val clock = new ManualClock + val attemptId = 0 + val taskSetBlacklist = new TaskSetBlacklist( + listenerBusMock, conf, stageId = 0, stageAttemptId = attemptId, clock = clock) - val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, clock = clock) clock.setTime(0) // We will mark task 0 & 1 failed on both executor 1 & 2. // We should blacklist all executors on that host, for all tasks for the stage. Note the API @@ -46,27 +60,53 @@ class TaskSetBlacklistSuite extends SparkFunSuite { val shouldBeBlacklisted = (executor == "exec1" && index == 0) assert(taskSetBlacklist.isExecutorBlacklistedForTask(executor, index) === shouldBeBlacklisted) } + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerExecutorBlacklistedForStage])) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) // Mark task 1 failed on exec1 -- this pushes the executor into the blacklist taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "exec1", index = 1, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + verify(listenerBusMock).post( + SparkListenerExecutorBlacklistedForStage(0, "exec1", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) + // Mark one task as failed on exec2 -- not enough for any further blacklisting yet. taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "exec2", index = 0, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec2")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) + // Mark another task as failed on exec2 -- now we blacklist exec2, which also leads to // blacklisting the entire node. taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "exec2", index = 1, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec2")) + verify(listenerBusMock).post( + SparkListenerExecutorBlacklistedForStage(0, "exec2", 2, 0, attemptId)) + assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock).post( + SparkListenerNodeBlacklistedForStage(0, "hostA", 2, 0, attemptId)) + // Make sure the blacklist has the correct per-task && per-executor responses, over a wider // range of inputs. for { @@ -81,6 +121,10 @@ class TaskSetBlacklistSuite extends SparkFunSuite { // intentional, it keeps it fast and is sufficient for usage in the scheduler. taskSetBlacklist.isExecutorBlacklistedForTask(executor, index) === (badExec && badIndex)) assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet(executor) === badExec) + if (badExec) { + verify(listenerBusMock).post( + SparkListenerExecutorBlacklistedForStage(0, executor, 2, 0, attemptId)) + } } } assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) @@ -110,7 +154,14 @@ class TaskSetBlacklistSuite extends SparkFunSuite { .set(config.MAX_TASK_ATTEMPTS_PER_NODE, 3) .set(config.MAX_FAILURES_PER_EXEC_STAGE, 2) .set(config.MAX_FAILED_EXEC_PER_NODE_STAGE, 3) - val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, new SystemClock()) + val clock = new ManualClock + + val attemptId = 0 + val taskSetBlacklist = new TaskSetBlacklist( + listenerBusMock, conf, stageId = 0, stageAttemptId = attemptId, clock = clock) + + var time = 0 + clock.setTime(time) // Fail a task twice on hostA, exec:1 taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "1", index = 0, failureReason = "testing") @@ -118,37 +169,75 @@ class TaskSetBlacklistSuite extends SparkFunSuite { "hostA", exec = "1", index = 0, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTask("1", 0)) assert(!taskSetBlacklist.isNodeBlacklistedForTask("hostA", 0)) + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + verify(listenerBusMock, never()).post( + SparkListenerExecutorBlacklistedForStage(time, "1", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()).post( + SparkListenerNodeBlacklistedForStage(time, "hostA", 2, 0, attemptId)) // Fail the same task once more on hostA, exec:2 + time += 1 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "2", index = 0, failureReason = "testing") assert(taskSetBlacklist.isNodeBlacklistedForTask("hostA", 0)) + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) + verify(listenerBusMock, never()).post( + SparkListenerExecutorBlacklistedForStage(time, "2", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()).post( + SparkListenerNodeBlacklistedForStage(time, "hostA", 2, 0, attemptId)) // Fail another task on hostA, exec:1. Now that executor has failures on two different tasks, // so its blacklisted + time += 1 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "1", index = 1, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + verify(listenerBusMock) + .post(SparkListenerExecutorBlacklistedForStage(time, "1", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) // Fail a third task on hostA, exec:2, so that exec is blacklisted for the whole task set + time += 1 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "2", index = 2, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) + verify(listenerBusMock) + .post(SparkListenerExecutorBlacklistedForStage(time, "2", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) // Fail a fourth & fifth task on hostA, exec:3. Now we've got three executors that are // blacklisted for the taskset, so blacklist the whole node. + time += 1 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "3", index = 3, failureReason = "testing") taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "3", index = 4, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("3")) + verify(listenerBusMock) + .post(SparkListenerExecutorBlacklistedForStage(time, "3", 2, 0, attemptId)) + assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock).post( + SparkListenerNodeBlacklistedForStage(time, "hostA", 3, 0, attemptId)) } test("only blacklist nodes for the task set when all the blacklisted executors are all on " + @@ -157,22 +246,42 @@ class TaskSetBlacklistSuite extends SparkFunSuite { // lead to any node blacklisting val conf = new SparkConf().setAppName("test").setMaster("local") .set(config.BLACKLIST_ENABLED.key, "true") - val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, new SystemClock()) + val clock = new ManualClock + + val attemptId = 0 + val taskSetBlacklist = new TaskSetBlacklist( + listenerBusMock, conf, stageId = 0, stageAttemptId = attemptId, clock = clock) + var time = 0 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "1", index = 0, failureReason = "testing") taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "1", index = 1, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + verify(listenerBusMock) + .post(SparkListenerExecutorBlacklistedForStage(time, "1", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()).post( + SparkListenerNodeBlacklistedForStage(time, "hostA", 2, 0, attemptId)) + time += 1 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostB", exec = "2", index = 0, failureReason = "testing") taskSetBlacklist.updateBlacklistForFailedTask( "hostB", exec = "2", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) + verify(listenerBusMock) + .post(SparkListenerExecutorBlacklistedForStage(time, "2", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostB")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) } } diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index e7981bec6d64b..042bba7f226fd 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -251,6 +251,49 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } + // Blacklisting executor for stage + time += 1 + listener.onExecutorBlacklistedForStage(SparkListenerExecutorBlacklistedForStage( + time = time, + executorId = execIds.head, + taskFailures = 2, + stageId = stages.head.stageId, + stageAttemptId = stages.head.attemptId)) + + val executorStageSummaryWrappers = + store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") + .first(key(stages.head)) + .last(key(stages.head)) + .asScala.toSeq + + assert(executorStageSummaryWrappers.nonEmpty) + executorStageSummaryWrappers.foreach { exec => + // only the first executor is expected to be blacklisted + val expectedBlacklistedFlag = exec.executorId == execIds.head + assert(exec.info.isBlacklistedForStage === expectedBlacklistedFlag) + } + + // Blacklisting node for stage + time += 1 + listener.onNodeBlacklistedForStage(SparkListenerNodeBlacklistedForStage( + time = time, + hostId = "2.example.com", // this is where the second executor is hosted + executorFailures = 1, + stageId = stages.head.stageId, + stageAttemptId = stages.head.attemptId)) + + val executorStageSummaryWrappersForNode = + store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") + .first(key(stages.head)) + .last(key(stages.head)) + .asScala.toSeq + + assert(executorStageSummaryWrappersForNode.nonEmpty) + executorStageSummaryWrappersForNode.foreach { exec => + // both executor is expected to be blacklisted + assert(exec.info.isBlacklistedForStage === true) + } + // Fail one of the tasks, re-start it. time += 1 s1Tasks.head.markFinished(TaskState.FAILED, time) diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 607234b4068d0..243fbe3e1bc24 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -73,8 +73,10 @@ logs .*dependency-reduced-pom.xml known_translations json_expectation +app-20180109111548-0000 app-20161115172038-0000 app-20161116163331-0000 +application_1516285256255_0012 local-1422981759269 local-1422981780767 local-1425081759269 From e18d6f5326e0d9ea03d31de5ce04cb84d3b8ab37 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Wed, 24 Jan 2018 09:37:54 -0800 Subject: [PATCH 192/774] [SPARK-20906][SPARKR] Add API doc example for Constrained Logistic Regression ## What changes were proposed in this pull request? doc only changes ## How was this patch tested? manual Author: Felix Cheung Closes #20380 from felixcheung/rclrdoc. --- R/pkg/R/mllib_classification.R | 15 ++++++++++++++- R/pkg/tests/fulltests/test_mllib_classification.R | 10 +++++----- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index 7cd072a1d6f89..f6e9b1357561b 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -279,11 +279,24 @@ function(object, path, overwrite = FALSE) { #' savedModel <- read.ml(path) #' summary(savedModel) #' -#' # multinomial logistic regression +#' # binary logistic regression against two classes with +#' # upperBoundsOnCoefficients and upperBoundsOnIntercepts +#' ubc <- matrix(c(1.0, 0.0, 1.0, 0.0), nrow = 1, ncol = 4) +#' model <- spark.logit(training, Species ~ ., +#' upperBoundsOnCoefficients = ubc, +#' upperBoundsOnIntercepts = 1.0) #' +#' # multinomial logistic regression #' model <- spark.logit(training, Class ~ ., regParam = 0.5) #' summary <- summary(model) #' +#' # multinomial logistic regression with +#' # lowerBoundsOnCoefficients and lowerBoundsOnIntercepts +#' lbc <- matrix(c(0.0, -1.0, 0.0, -1.0, 0.0, -1.0, 0.0, -1.0), nrow = 2, ncol = 4) +#' lbi <- as.array(c(0.0, 0.0)) +#' model <- spark.logit(training, Species ~ ., family = "multinomial", +#' lowerBoundsOnCoefficients = lbc, +#' lowerBoundsOnIntercepts = lbi) #' } #' @note spark.logit since 2.1.0 setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), diff --git a/R/pkg/tests/fulltests/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R index ad47717ddc12f..a46c47dccd02e 100644 --- a/R/pkg/tests/fulltests/test_mllib_classification.R +++ b/R/pkg/tests/fulltests/test_mllib_classification.R @@ -124,7 +124,7 @@ test_that("spark.logit", { # Petal.Width 0.42122607 # nolint end - # Test multinomial logistic regression againt three classes + # Test multinomial logistic regression against three classes df <- suppressWarnings(createDataFrame(iris)) model <- spark.logit(df, Species ~ ., regParam = 0.5) summary <- summary(model) @@ -196,7 +196,7 @@ test_that("spark.logit", { # # nolint end - # Test multinomial logistic regression againt two classes + # Test multinomial logistic regression against two classes df <- suppressWarnings(createDataFrame(iris)) training <- df[df$Species %in% c("versicolor", "virginica"), ] model <- spark.logit(training, Species ~ ., regParam = 0.5, family = "multinomial") @@ -208,7 +208,7 @@ test_that("spark.logit", { expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) - # Test binomial logistic regression againt two classes + # Test binomial logistic regression against two classes model <- spark.logit(training, Species ~ ., regParam = 0.5) summary <- summary(model) coefsR <- c(-6.08, 0.25, 0.16, 0.48, 1.04) @@ -239,7 +239,7 @@ test_that("spark.logit", { prediction2 <- collect(select(predict(model2, df2), "prediction")) expect_equal(sort(prediction2$prediction), c("0.0", "0.0", "0.0", "0.0", "0.0")) - # Test binomial logistic regression againt two classes with upperBoundsOnCoefficients + # Test binomial logistic regression against two classes with upperBoundsOnCoefficients # and upperBoundsOnIntercepts u <- matrix(c(1.0, 0.0, 1.0, 0.0), nrow = 1, ncol = 4) model <- spark.logit(training, Species ~ ., upperBoundsOnCoefficients = u, @@ -252,7 +252,7 @@ test_that("spark.logit", { expect_error(spark.logit(training, Species ~ ., upperBoundsOnCoefficients = as.array(c(1, 2)), upperBoundsOnIntercepts = 1.0)) - # Test binomial logistic regression againt two classes with lowerBoundsOnCoefficients + # Test binomial logistic regression against two classes with lowerBoundsOnCoefficients # and lowerBoundsOnIntercepts l <- matrix(c(0.0, -1.0, 0.0, -1.0), nrow = 1, ncol = 4) model <- spark.logit(training, Species ~ ., lowerBoundsOnCoefficients = l, From 8c273b4162b6138c4abba64f595c2750d1ef8bcb Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 24 Jan 2018 10:00:42 -0800 Subject: [PATCH 193/774] [SPARK-23020][CORE][FOLLOWUP] Fix Java style check issues. ## What changes were proposed in this pull request? This is a follow-up of #20297 which broke lint-java checks. This pr fixes the lint-java issues. ``` [ERROR] src/test/java/org/apache/spark/launcher/BaseSuite.java:[21,8] (imports) UnusedImports: Unused import - java.util.concurrent.TimeUnit. [ERROR] src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java:[27,8] (imports) UnusedImports: Unused import - java.util.concurrent.TimeUnit. ``` ## How was this patch tested? Checked manually in my local environment. Author: Takuya UESHIN Closes #20376 from ueshin/issues/SPARK-23020/fup1. --- .../test/java/org/apache/spark/launcher/SparkLauncherSuite.java | 1 - launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java | 1 - 2 files changed, 2 deletions(-) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index a042375c6ae91..1543f4fdb0162 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -24,7 +24,6 @@ import java.util.List; import java.util.Map; import java.util.Properties; -import java.util.concurrent.TimeUnit; import org.junit.Test; import static org.junit.Assert.*; diff --git a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java index 3722a59d9438e..438349e027a24 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java @@ -18,7 +18,6 @@ package org.apache.spark.launcher; import java.time.Duration; -import java.util.concurrent.TimeUnit; import org.junit.After; import org.slf4j.bridge.SLF4JBridgeHandler; From bbb87b350d9d0d393db3fb7ca61dcbae538553bb Mon Sep 17 00:00:00 2001 From: zuotingbing Date: Wed, 24 Jan 2018 10:07:24 -0800 Subject: [PATCH 194/774] [SPARK-22837][SQL] Session timeout checker does not work in SessionManager. ## What changes were proposed in this pull request? Currently we do not call the `super.init(hiveConf)` in `SparkSQLSessionManager.init`. So we do not load the config `HIVE_SERVER2_SESSION_CHECK_INTERVAL HIVE_SERVER2_IDLE_SESSION_TIMEOUT HIVE_SERVER2_IDLE_SESSION_CHECK_OPERATION` , which cause the session timeout checker does not work. ## How was this patch tested? manual tests Author: zuotingbing Closes #20025 from zuotingbing/SPARK-22837. --- .../thriftserver/SparkSQLSessionManager.scala | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 48c0ebef3e0ce..2958b771f3648 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -40,22 +40,8 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, sqlContext: private lazy val sparkSqlOperationManager = new SparkSQLOperationManager() override def init(hiveConf: HiveConf) { - setSuperField(this, "hiveConf", hiveConf) - - // Create operation log root directory, if operation logging is enabled - if (hiveConf.getBoolVar(ConfVars.HIVE_SERVER2_LOGGING_OPERATION_ENABLED)) { - invoke(classOf[SessionManager], this, "initOperationLogRootDir") - } - - val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) - setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) - getAncestorField[Log](this, 3, "LOG").info( - s"HiveServer2: Async execution pool size $backgroundPoolSize") - setSuperField(this, "operationManager", sparkSqlOperationManager) - addService(sparkSqlOperationManager) - - initCompositeService(hiveConf) + super.init(hiveConf) } override def openSession( From 840dea64abd8a3a5960de830f19a57f5f1aa3bf6 Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Wed, 24 Jan 2018 13:13:44 -0500 Subject: [PATCH 195/774] [SPARK-23152][ML] - Correctly guard against empty datasets ## What changes were proposed in this pull request? Correctly guard against empty datasets in `org.apache.spark.ml.classification.Classifier` ## How was this patch tested? existing tests Author: Matthew Tovbin Closes #20321 from tovbinm/SPARK-23152. --- .../org/apache/spark/ml/classification/Classifier.scala | 2 +- .../apache/spark/ml/classification/ClassifierSuite.scala | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index bc0b49d48d323..9d1d5aa1e0cff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -109,7 +109,7 @@ abstract class Classifier[ case None => // Get number of classes from dataset itself. val maxLabelRow: Array[Row] = dataset.select(max($(labelCol))).take(1) - if (maxLabelRow.isEmpty) { + if (maxLabelRow.isEmpty || maxLabelRow(0).get(0) == null) { throw new SparkException("ML algorithm was given empty dataset.") } val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index de712079329da..87bf2be06c2be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -90,6 +90,13 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { } assert(e.getMessage.contains("requires integers in range")) } + val df3 = getTestData(Seq.empty[Double]) + withClue("getNumClasses should fail if dataset is empty") { + val e: SparkException = intercept[SparkException] { + c.getNumClasses(df3) + } + assert(e.getMessage == "ML algorithm was given empty dataset.") + } } } From 0e178e1523175a0be9437920045e80deb0a2712b Mon Sep 17 00:00:00 2001 From: Mark Petruska Date: Wed, 24 Jan 2018 10:25:14 -0800 Subject: [PATCH 196/774] [SPARK-22297][CORE TESTS] Flaky test: BlockManagerSuite "Shuffle registration timeout and maxAttempts conf" ## What changes were proposed in this pull request? [Ticket](https://issues.apache.org/jira/browse/SPARK-22297) - one of the tests seems to produce unreliable results due to execution speed variability Since the original test was trying to connect to the test server with `40 ms` timeout, and the test server replied after `50 ms`, the error might be produced under the following conditions: - it might occur that the test server replies correctly after `50 ms` - but the client does only receive the timeout after `51 ms`s - this might happen if the executor has to schedule a big number of threads, and decides to delay the thread/actor that is responsible to watch the timeout, because of high CPU load - running an entire test suite usually produces high loads on the CPU executing the tests ## How was this patch tested? The test's check cases remain the same and the set-up emulates the previous version's. Author: Mark Petruska Closes #19671 from mpetruska/SPARK-22297. --- .../spark/storage/BlockManagerSuite.scala | 55 +++++++++++++------ 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 629eed49b04cc..b19d8ebf72c61 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.storage import java.nio.ByteBuffer import scala.collection.JavaConverters._ -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.concurrent.Future import scala.concurrent.duration._ @@ -44,8 +43,9 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, TempFileManager} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} +import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} @@ -1325,9 +1325,18 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("SPARK-20640: Shuffle registration timeout and maxAttempts conf are working") { val tryAgainMsg = "test_spark_20640_try_again" + val timingoutExecutor = "timingoutExecutor" + val tryAgainExecutor = "tryAgainExecutor" + val succeedingExecutor = "succeedingExecutor" + // a server which delays response 50ms and must try twice for success. def newShuffleServer(port: Int): (TransportServer, Int) = { - val attempts = new mutable.HashMap[String, Int]() + val failure = new Exception(tryAgainMsg) + val success = ByteBuffer.wrap(new Array[Byte](0)) + + var secondExecutorFailedOnce = false + var thirdExecutorFailedOnce = false + val handler = new NoOpRpcHandler { override def receive( client: TransportClient, @@ -1335,15 +1344,26 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE callback: RpcResponseCallback): Unit = { val msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message) msgObj match { - case exec: RegisterExecutor => - Thread.sleep(50) - val attempt = attempts.getOrElse(exec.execId, 0) + 1 - attempts(exec.execId) = attempt - if (attempt < 2) { - callback.onFailure(new Exception(tryAgainMsg)) - return - } - callback.onSuccess(ByteBuffer.wrap(new Array[Byte](0))) + + case exec: RegisterExecutor if exec.execId == timingoutExecutor => + () // No reply to generate client-side timeout + + case exec: RegisterExecutor + if exec.execId == tryAgainExecutor && !secondExecutorFailedOnce => + secondExecutorFailedOnce = true + callback.onFailure(failure) + + case exec: RegisterExecutor if exec.execId == tryAgainExecutor => + callback.onSuccess(success) + + case exec: RegisterExecutor + if exec.execId == succeedingExecutor && !thirdExecutorFailedOnce => + thirdExecutorFailedOnce = true + callback.onFailure(failure) + + case exec: RegisterExecutor if exec.execId == succeedingExecutor => + callback.onSuccess(success) + } } } @@ -1352,6 +1372,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val transCtx = new TransportContext(transConf, handler, true) (transCtx.createServer(port, Seq.empty[TransportServerBootstrap].asJava), port) } + val candidatePort = RandomUtils.nextInt(1024, 65536) val (server, shufflePort) = Utils.startServiceOnPort(candidatePort, newShuffleServer, conf, "ShuffleServer") @@ -1360,21 +1381,21 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE conf.set("spark.shuffle.service.port", shufflePort.toString) conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "40") conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1") - var e = intercept[SparkException]{ - makeBlockManager(8000, "executor1") + var e = intercept[SparkException] { + makeBlockManager(8000, timingoutExecutor) }.getMessage assert(e.contains("TimeoutException")) conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000") conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1") - e = intercept[SparkException]{ - makeBlockManager(8000, "executor2") + e = intercept[SparkException] { + makeBlockManager(8000, tryAgainExecutor) }.getMessage assert(e.contains(tryAgainMsg)) conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000") conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "2") - makeBlockManager(8000, "executor3") + makeBlockManager(8000, succeedingExecutor) server.close() } From bc9641d9026aeae3571915b003ac971f6245d53c Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 24 Jan 2018 12:58:44 -0800 Subject: [PATCH 197/774] [SPARK-23198][SS][TEST] Fix KafkaContinuousSourceStressForDontFailOnDataLossSuite to test ContinuousExecution ## What changes were proposed in this pull request? Currently, `KafkaContinuousSourceStressForDontFailOnDataLossSuite` runs on `MicroBatchExecution`. It should test `ContinuousExecution`. ## How was this patch tested? Pass the updated test suite. Author: Dongjoon Hyun Closes #20374 from dongjoon-hyun/SPARK-23198. --- .../apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index b3dade414f625..a7083fa4e3417 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -91,6 +91,7 @@ class KafkaContinuousSourceStressForDontFailOnDataLossSuite ds.writeStream .format("memory") .queryName("memory") + .trigger(Trigger.Continuous("1 second")) .start() } } From 6f0ba8472d1128551fa8090deebcecde0daebc53 Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Wed, 24 Jan 2018 13:06:09 -0800 Subject: [PATCH 198/774] [MINOR][SQL] add new unit test to LimitPushdown MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR is repaired as follows 1、update y -> x in "left outer join" test case ,maybe is mistake. 2、add a new test case:"left outer join and left sides are limited" 3、add a new test case:"left outer join and right sides are limited" 4、add a new test case: "right outer join and right sides are limited" 5、add a new test case: "right outer join and left sides are limited" 6、Remove annotations without code implementation ## How was this patch tested? add new unit test case. Author: caoxuewen Closes #20381 from heary-cao/LimitPushdownSuite. --- .../sql/catalyst/optimizer/Optimizer.scala | 1 - .../optimizer/LimitPushdownSuite.scala | 30 ++++++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0f9daa5f04c76..8d207708c12ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -352,7 +352,6 @@ object LimitPushDown extends Rule[LogicalPlan] { // on both sides if it is applied multiple times. Therefore: // - If one side is already limited, stack another limit on top if the new limit is smaller. // The redundant limit will be collapsed by the CombineLimits rule. - // - If neither side is limited, limit the side that is estimated to be bigger. case LocalLimit(exp, join @ Join(left, right, joinType, _)) => val newJoin = joinType match { case RightOuter => join.copy(right = maybePushLocalLimit(exp, right)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index cc98d2350c777..17fb9fc5d11e3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -93,7 +93,21 @@ class LimitPushdownSuite extends PlanTest { test("left outer join") { val originalQuery = x.join(y, LeftOuter).limit(1) val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = Limit(1, LocalLimit(1, y).join(y, LeftOuter)).analyze + val correctAnswer = Limit(1, LocalLimit(1, x).join(y, LeftOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("left outer join and left sides are limited") { + val originalQuery = x.limit(2).join(y, LeftOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, x).join(y, LeftOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("left outer join and right sides are limited") { + val originalQuery = x.join(y.limit(2), LeftOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, x).join(Limit(2, y), LeftOuter)).analyze comparePlans(optimized, correctAnswer) } @@ -104,6 +118,20 @@ class LimitPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("right outer join and right sides are limited") { + val originalQuery = x.join(y.limit(2), RightOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, x.join(LocalLimit(1, y), RightOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("right outer join and left sides are limited") { + val originalQuery = x.limit(2).join(y, RightOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, Limit(2, x).join(LocalLimit(1, y), RightOuter)).analyze + comparePlans(optimized, correctAnswer) + } + test("larger limits are not pushed on top of smaller ones in right outer join") { val originalQuery = x.join(y.limit(5), RightOuter).limit(10) val optimized = Optimize.execute(originalQuery.analyze) From 45b4bbfddc18a77011c3bc1bfd71b2cd3466443c Mon Sep 17 00:00:00 2001 From: zhoukang Date: Thu, 25 Jan 2018 15:24:52 +0800 Subject: [PATCH 199/774] [SPARK-23129][CORE] Make deserializeStream of DiskMapIterator init lazily ## What changes were proposed in this pull request? Currently,the deserializeStream in ExternalAppendOnlyMap#DiskMapIterator init when DiskMapIterator instance created.This will cause memory use overhead when ExternalAppendOnlyMap spill too much times. We can avoid this by making deserializeStream init when it is used the first time. This patch make deserializeStream init lazily. ## How was this patch tested? Exist tests Author: zhoukang Closes #20292 from caneGuy/zhoukang/lay-diskmapiterator. --- .../util/collection/ExternalAppendOnlyMap.scala | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 375f4a6921225..5c6dd45ec58e3 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -463,7 +463,7 @@ class ExternalAppendOnlyMap[K, V, C]( // An intermediate stream that reads from exactly one batch // This guards against pre-fetching and other arbitrary behavior of higher level streams - private var deserializeStream = nextBatchStream() + private var deserializeStream: DeserializationStream = null private var nextItem: (K, C) = null private var objectsRead = 0 @@ -528,7 +528,11 @@ class ExternalAppendOnlyMap[K, V, C]( override def hasNext: Boolean = { if (nextItem == null) { if (deserializeStream == null) { - return false + // In case of deserializeStream has not been initialized + deserializeStream = nextBatchStream() + if (deserializeStream == null) { + return false + } } nextItem = readNextItem() } @@ -536,19 +540,18 @@ class ExternalAppendOnlyMap[K, V, C]( } override def next(): (K, C) = { - val item = if (nextItem == null) readNextItem() else nextItem - if (item == null) { + if (!hasNext) { throw new NoSuchElementException } + val item = nextItem nextItem = null item } private def cleanup() { batchIndex = batchOffsets.length // Prevent reading any other batch - val ds = deserializeStream - if (ds != null) { - ds.close() + if (deserializeStream != null) { + deserializeStream.close() deserializeStream = null } if (fileStream != null) { From e29b08add92462a6505fef966629e74ba30e994e Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 25 Jan 2018 16:40:41 +0800 Subject: [PATCH 200/774] [SPARK-23208][SQL] Fix code generation for complex create array (related) expressions ## What changes were proposed in this pull request? The `GenArrayData.genCodeToCreateArrayData` produces illegal java code when code splitting is enabled. This is used in `CreateArray` and `CreateMap` expressions for complex object arrays. This issue is caused by a typo. ## How was this patch tested? Added a regression test in `complexTypesSuite`. Author: Herman van Hovell Closes #20391 from hvanhovell/SPARK-23208. --- .../sql/catalyst/expressions/complexTypeCreator.scala | 2 +- .../spark/sql/catalyst/optimizer/complexTypesSuite.scala | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 3dc2ee03a86e3..047b80ac5289c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -111,7 +111,7 @@ private [sql] object GenArrayData { val assignmentString = ctx.splitExpressionsWithCurrentInputs( expressions = assignments, funcName = "apply", - extraArguments = ("Object[]", arrayDataName) :: Nil) + extraArguments = ("Object[]", arrayName) :: Nil) (s"Object[] $arrayName = new Object[$numElements];", assignmentString, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 0d11958876ce9..de544ac314789 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ /** @@ -31,7 +32,7 @@ import org.apache.spark.sql.types._ * i.e. {{{create_named_struct(square, `x` * `x`).square}}} can be simplified to {{{`x` * `x`}}}. * sam applies to create_array and create_map */ -class ComplexTypesSuite extends PlanTest{ +class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { object Optimizer extends RuleExecutor[LogicalPlan] { val batches = @@ -171,6 +172,11 @@ class ComplexTypesSuite extends PlanTest{ assert(ctx.inlinedMutableStates.length == 0) } + test("SPARK-23208: Test code splitting for create array related methods") { + val inputs = (1 to 2500).map(x => Literal(s"l_$x")) + checkEvaluation(CreateArray(inputs), new GenericArrayData(inputs.map(_.eval()))) + } + test("simplify map ops") { val rel = relation .select( From 39ee2acf96f1e1496cff8e4d2614d27fca76d43b Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 25 Jan 2018 01:48:11 -0800 Subject: [PATCH 201/774] [SPARK-23163][DOC][PYTHON] Sync ML Python API with Scala ## What changes were proposed in this pull request? This syncs the ML Python API with Scala for differences found after the 2.3 QA audit. ## How was this patch tested? NA Author: Bryan Cutler Closes #20354 from BryanCutler/pyspark-ml-doc-sync-23163. --- python/pyspark/ml/evaluation.py | 8 +++++++- python/pyspark/ml/feature.py | 2 +- python/pyspark/ml/fpm.py | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index aa8dbe708a115..0cbce9b40048f 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -334,7 +334,13 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, .. note:: Experimental Evaluator for Clustering results, which expects two input - columns: prediction and features. + columns: prediction and features. The metric computes the Silhouette + measure using the squared Euclidean distance. + + The Silhouette is a measure for the validation of the consistency + within clusters. It ranges between 1 and -1, where a value close to + 1 means that the points in a cluster are close to the other points + in the same cluster and far from the points of the other clusters. >>> from pyspark.ml.linalg import Vectors >>> featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]), diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index eb79b193103e2..da85ba761a145 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -3440,7 +3440,7 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja selectorType = Param(Params._dummy(), "selectorType", "The selector type of the ChisqSelector. " + - "Supported options: numTopFeatures (default), percentile and fpr.", + "Supported options: numTopFeatures (default), percentile, fpr, fdr, fwe.", typeConverter=TypeConverters.toString) numTopFeatures = \ diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index dd7dda5f03124..b8dafd49d354d 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -144,7 +144,7 @@ def freqItemsets(self): @since("2.2.0") def associationRules(self): """ - Data with three columns: + DataFrame with three columns: * `antecedent` - Array of the same type as the input column. * `consequent` - Array of the same type as the input column. * `confidence` - Confidence for the rule (`DoubleType`). From d20bbc2d87ae6bd56d236a7c3d036b52c5f20ff5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 25 Jan 2018 19:49:58 +0800 Subject: [PATCH 202/774] [SPARK-21717][SQL] Decouple consume functions of physical operators in whole-stage codegen ## What changes were proposed in this pull request? It has been observed in SPARK-21603 that whole-stage codegen suffers performance degradation, if the generated functions are too long to be optimized by JIT. We basically produce a single function to incorporate generated codes from all physical operators in whole-stage. Thus, it is possibly to grow the size of generated function over a threshold that we can't have JIT optimization for it anymore. This patch is trying to decouple the logic of consuming rows in physical operators to avoid a giant function processing rows. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh Closes #18931 from viirya/SPARK-21717. --- .../expressions/codegen/CodeGenerator.scala | 38 ++++- .../apache/spark/sql/internal/SQLConf.scala | 12 ++ .../sql/execution/WholeStageCodegenExec.scala | 135 +++++++++++++++--- .../execution/WholeStageCodegenSuite.scala | 47 +++++- 4 files changed, 203 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f96ed7628fda1..4dcbb702893da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1245,6 +1245,31 @@ class CodegenContext { "" } } + + /** + * Returns the length of parameters for a Java method descriptor. `this` contributes one unit + * and a parameter of type long or double contributes two units. Besides, for nullable parameter, + * we also need to pass a boolean parameter for the null status. + */ + def calculateParamLength(params: Seq[Expression]): Int = { + def paramLengthForExpr(input: Expression): Int = { + // For a nullable expression, we need to pass in an extra boolean parameter. + (if (input.nullable) 1 else 0) + javaType(input.dataType) match { + case JAVA_LONG | JAVA_DOUBLE => 2 + case _ => 1 + } + } + // Initial value is 1 for `this`. + 1 + params.map(paramLengthForExpr(_)).sum + } + + /** + * In Java, a method descriptor is valid only if it represents method parameters with a total + * length less than a pre-defined constant. + */ + def isValidParamLength(paramLength: Int): Boolean = { + paramLength <= CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH + } } /** @@ -1311,26 +1336,29 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin object CodeGenerator extends Logging { // This is the value of HugeMethodLimit in the OpenJDK JVM settings - val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000 + final val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000 + + // The max valid length of method parameters in JVM. + final val MAX_JVM_METHOD_PARAMS_LENGTH = 255 // This is the threshold over which the methods in an inner class are grouped in a single // method which is going to be called by the outer class instead of the many small ones - val MERGE_SPLIT_METHODS_THRESHOLD = 3 + final val MERGE_SPLIT_METHODS_THRESHOLD = 3 // The number of named constants that can exist in the class is limited by the Constant Pool // limit, 65,536. We cannot know how many constants will be inserted for a class, so we use a // threshold of 1000k bytes to determine when a function should be inlined to a private, inner // class. - val GENERATED_CLASS_SIZE_THRESHOLD = 1000000 + final val GENERATED_CLASS_SIZE_THRESHOLD = 1000000 // This is the threshold for the number of global variables, whose types are primitive type or // complex type (e.g. more than one-dimensional array), that will be placed at the outer class - val OUTER_CLASS_VARIABLES_THRESHOLD = 10000 + final val OUTER_CLASS_VARIABLES_THRESHOLD = 10000 // This is the maximum number of array elements to keep global variables in one Java array // 32767 is the maximum integer value that does not require a constant pool entry in a Java // bytecode instruction - val MUTABLESTATEARRAY_SIZE_LIMIT = 32768 + final val MUTABLESTATEARRAY_SIZE_LIMIT = 32768 /** * Compile the Java source code into a Java class, using Janino. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1cef09a5bf053..470f88c213561 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -661,6 +661,15 @@ object SQLConf { .intConf .createWithDefault(CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT) + val WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR = + buildConf("spark.sql.codegen.splitConsumeFuncByOperator") + .internal() + .doc("When true, whole stage codegen would put the logic of consuming rows of each " + + "physical operator into individual methods, instead of a single big method. This can be " + + "used to avoid oversized function that can miss the opportunity of JIT optimization.") + .booleanConf + .createWithDefault(true) + val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes") .doc("The maximum number of bytes to pack into a single partition when reading files.") .longConf @@ -1263,6 +1272,9 @@ class SQLConf extends Serializable with Logging { def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT) + def wholeStageSplitConsumeFuncByOperator: Boolean = + getConf(WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR) + def tableRelationCacheSize: Int = getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 6102937852347..8ea9e81b2e53b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution import java.util.Locale +import scala.collection.mutable + import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -106,6 +108,31 @@ trait CodegenSupport extends SparkPlan { */ protected def doProduce(ctx: CodegenContext): String + private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = { + if (row != null) { + ExprCode("", "false", row) + } else { + if (colVars.nonEmpty) { + val colExprs = output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + val evaluateInputs = evaluateVariables(colVars) + // generate the code to create a UnsafeRow + ctx.INPUT_ROW = row + ctx.currentVars = colVars + val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) + val code = s""" + |$evaluateInputs + |${ev.code.trim} + """.stripMargin.trim + ExprCode(code, "false", ev.value) + } else { + // There is no columns + ExprCode("", "false", "unsafeRow") + } + } + } + /** * Consume the generated columns or row from current SparkPlan, call its parent's `doConsume()`. * @@ -126,28 +153,7 @@ trait CodegenSupport extends SparkPlan { } } - val rowVar = if (row != null) { - ExprCode("", "false", row) - } else { - if (outputVars.nonEmpty) { - val colExprs = output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable) - } - val evaluateInputs = evaluateVariables(outputVars) - // generate the code to create a UnsafeRow - ctx.INPUT_ROW = row - ctx.currentVars = outputVars - val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) - val code = s""" - |$evaluateInputs - |${ev.code.trim} - """.stripMargin.trim - ExprCode(code, "false", ev.value) - } else { - // There is no columns - ExprCode("", "false", "unsafeRow") - } - } + val rowVar = prepareRowVar(ctx, row, outputVars) // Set up the `currentVars` in the codegen context, as we generate the code of `inputVars` // before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because parent needs to @@ -156,13 +162,96 @@ trait CodegenSupport extends SparkPlan { ctx.INPUT_ROW = null ctx.freshNamePrefix = parent.variablePrefix val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs) + + // Under certain conditions, we can put the logic to consume the rows of this operator into + // another function. So we can prevent a generated function too long to be optimized by JIT. + // The conditions: + // 1. The config "spark.sql.codegen.splitConsumeFuncByOperator" is enabled. + // 2. `inputVars` are all materialized. That is guaranteed to be true if the parent plan uses + // all variables in output (see `requireAllOutput`). + // 3. The number of output variables must less than maximum number of parameters in Java method + // declaration. + val confEnabled = SQLConf.get.wholeStageSplitConsumeFuncByOperator + val requireAllOutput = output.forall(parent.usedInputs.contains(_)) + val paramLength = ctx.calculateParamLength(output) + (if (row != null) 1 else 0) + val consumeFunc = if (confEnabled && requireAllOutput && ctx.isValidParamLength(paramLength)) { + constructDoConsumeFunction(ctx, inputVars, row) + } else { + parent.doConsume(ctx, inputVars, rowVar) + } s""" |${ctx.registerComment(s"CONSUME: ${parent.simpleString}")} |$evaluated - |${parent.doConsume(ctx, inputVars, rowVar)} + |$consumeFunc + """.stripMargin + } + + /** + * To prevent concatenated function growing too long to be optimized by JIT. We can separate the + * parent's `doConsume` codes of a `CodegenSupport` operator into a function to call. + */ + private def constructDoConsumeFunction( + ctx: CodegenContext, + inputVars: Seq[ExprCode], + row: String): String = { + val (args, params, inputVarsInFunc) = constructConsumeParameters(ctx, output, inputVars, row) + val rowVar = prepareRowVar(ctx, row, inputVarsInFunc) + + val doConsume = ctx.freshName("doConsume") + ctx.currentVars = inputVarsInFunc + ctx.INPUT_ROW = null + + val doConsumeFuncName = ctx.addNewFunction(doConsume, + s""" + | private void $doConsume(${params.mkString(", ")}) throws java.io.IOException { + | ${parent.doConsume(ctx, inputVarsInFunc, rowVar)} + | } + """.stripMargin) + + s""" + | $doConsumeFuncName(${args.mkString(", ")}); """.stripMargin } + /** + * Returns arguments for calling method and method definition parameters of the consume function. + * And also returns the list of `ExprCode` for the parameters. + */ + private def constructConsumeParameters( + ctx: CodegenContext, + attributes: Seq[Attribute], + variables: Seq[ExprCode], + row: String): (Seq[String], Seq[String], Seq[ExprCode]) = { + val arguments = mutable.ArrayBuffer[String]() + val parameters = mutable.ArrayBuffer[String]() + val paramVars = mutable.ArrayBuffer[ExprCode]() + + if (row != null) { + arguments += row + parameters += s"InternalRow $row" + } + + variables.zipWithIndex.foreach { case (ev, i) => + val paramName = ctx.freshName(s"expr_$i") + val paramType = ctx.javaType(attributes(i).dataType) + + arguments += ev.value + parameters += s"$paramType $paramName" + val paramIsNull = if (!attributes(i).nullable) { + // Use constant `false` without passing `isNull` for non-nullable variable. + "false" + } else { + val isNull = ctx.freshName(s"exprIsNull_$i") + arguments += ev.isNull + parameters += s"boolean $isNull" + isNull + } + + paramVars += ExprCode("", paramIsNull, paramName) + } + (arguments, parameters, paramVars) + } + /** * Returns source code to evaluate all the variables, and clear the code of them, to prevent * them to be evaluated twice. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 22ca128c27768..242bb48c22942 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -205,7 +205,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { val codeWithShortFunctions = genGroupByCode(3) val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions) assert(maxCodeSize1 < SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) - val codeWithLongFunctions = genGroupByCode(20) + val codeWithLongFunctions = genGroupByCode(50) val (_, maxCodeSize2) = CodeGenerator.compile(codeWithLongFunctions) assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) } @@ -228,4 +228,49 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } } } + + test("Control splitting consume function by operators with config") { + import testImplicits._ + val df = spark.range(10).select(Seq.tabulate(2) {i => ('id + i).as(s"c$i")} : _*) + + Seq(true, false).foreach { config => + withSQLConf(SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> s"$config") { + val plan = df.queryExecution.executedPlan + val wholeStageCodeGenExec = plan.find(p => p match { + case wp: WholeStageCodegenExec => true + case _ => false + }) + assert(wholeStageCodeGenExec.isDefined) + val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 + assert(code.body.contains("project_doConsume") == config) + } + } + } + + test("Skip splitting consume function when parameter number exceeds JVM limit") { + import testImplicits._ + + Seq((255, false), (254, true)).foreach { case (columnNum, hasSplit) => + withTempPath { dir => + val path = dir.getCanonicalPath + spark.range(10).select(Seq.tabulate(columnNum) {i => ('id + i).as(s"c$i")} : _*) + .write.mode(SaveMode.Overwrite).parquet(path) + + withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255", + SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> "true") { + val projection = Seq.tabulate(columnNum)(i => s"c$i + c$i as newC$i") + val df = spark.read.parquet(path).selectExpr(projection: _*) + + val plan = df.queryExecution.executedPlan + val wholeStageCodeGenExec = plan.find(p => p match { + case wp: WholeStageCodegenExec => true + case _ => false + }) + assert(wholeStageCodeGenExec.isDefined) + val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 + assert(code.body.contains("project_doConsume") == hasSplit) + } + } + } + } } From 8532e26f335b67b74c976712ad82c20ea6dbbf80 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Thu, 25 Jan 2018 15:01:22 +0200 Subject: [PATCH 203/774] [SPARK-23112][DOC] Add highlights and migration guide for 2.3 Update ML user guide with highlights and migration guide for `2.3`. ## How was this patch tested? Doc only. Author: Nick Pentreath Closes #20363 from MLnick/SPARK-23112-ml-guide. --- docs/ml-guide.md | 78 ++++++++++++++----------------------- docs/ml-migration-guides.md | 23 +++++++++++ 2 files changed, 52 insertions(+), 49 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index f6288e7c32d97..b957445579ffd 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -72,32 +72,31 @@ To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4 [^1]: To learn more about the benefits and background of system optimised natives, you may wish to watch Sam Halliday's ScalaX talk on [High Performance Linear Algebra in Scala](http://fommil.github.io/scalax14/#/). -# Highlights in 2.2 +# Highlights in 2.3 -The list below highlights some of the new features and enhancements added to MLlib in the `2.2` +The list below highlights some of the new features and enhancements added to MLlib in the `2.3` release of Spark: -* [`ALS`](ml-collaborative-filtering.html) methods for _top-k_ recommendations for all - users or items, matching the functionality in `mllib` - ([SPARK-19535](https://issues.apache.org/jira/browse/SPARK-19535)). - Performance was also improved for both `ml` and `mllib` - ([SPARK-11968](https://issues.apache.org/jira/browse/SPARK-11968) and - [SPARK-20587](https://issues.apache.org/jira/browse/SPARK-20587)) -* [`Correlation`](ml-statistics.html#correlation) and - [`ChiSquareTest`](ml-statistics.html#hypothesis-testing) stats functions for `DataFrames` - ([SPARK-19636](https://issues.apache.org/jira/browse/SPARK-19636) and - [SPARK-19635](https://issues.apache.org/jira/browse/SPARK-19635)) -* [`FPGrowth`](ml-frequent-pattern-mining.html#fp-growth) algorithm for frequent pattern mining - ([SPARK-14503](https://issues.apache.org/jira/browse/SPARK-14503)) -* `GLM` now supports the full `Tweedie` family - ([SPARK-18929](https://issues.apache.org/jira/browse/SPARK-18929)) -* [`Imputer`](ml-features.html#imputer) feature transformer to impute missing values in a dataset - ([SPARK-13568](https://issues.apache.org/jira/browse/SPARK-13568)) -* [`LinearSVC`](ml-classification-regression.html#linear-support-vector-machine) - for linear Support Vector Machine classification - ([SPARK-14709](https://issues.apache.org/jira/browse/SPARK-14709)) -* Logistic regression now supports constraints on the coefficients during training - ([SPARK-20047](https://issues.apache.org/jira/browse/SPARK-20047)) +* Built-in support for reading images into a `DataFrame` was added +([SPARK-21866](https://issues.apache.org/jira/browse/SPARK-21866)). +* [`OneHotEncoderEstimator`](ml-features.html#onehotencoderestimator) was added, and should be +used instead of the existing `OneHotEncoder` transformer. The new estimator supports +transforming multiple columns. +* Multiple column support was also added to `QuantileDiscretizer` and `Bucketizer` +([SPARK-22397](https://issues.apache.org/jira/browse/SPARK-22397) and +[SPARK-20542](https://issues.apache.org/jira/browse/SPARK-20542)) +* A new [`FeatureHasher`](ml-features.html#featurehasher) transformer was added + ([SPARK-13969](https://issues.apache.org/jira/browse/SPARK-13969)). +* Added support for evaluating multiple models in parallel when performing cross-validation using +[`TrainValidationSplit` or `CrossValidator`](ml-tuning.html) +([SPARK-19357](https://issues.apache.org/jira/browse/SPARK-19357)). +* Improved support for custom pipeline components in Python (see +[SPARK-21633](https://issues.apache.org/jira/browse/SPARK-21633) and +[SPARK-21542](https://issues.apache.org/jira/browse/SPARK-21542)). +* `DataFrame` functions for descriptive summary statistics over vector columns +([SPARK-19634](https://issues.apache.org/jira/browse/SPARK-19634)). +* Robust linear regression with Huber loss +([SPARK-3181](https://issues.apache.org/jira/browse/SPARK-3181)). # Migration guide @@ -115,36 +114,17 @@ There are no breaking changes. **Deprecations** -There are no deprecations. +* `OneHotEncoder` has been deprecated and will be removed in `3.0`. It has been replaced by the +new [`OneHotEncoderEstimator`](ml-features.html#onehotencoderestimator) +(see [SPARK-13030](https://issues.apache.org/jira/browse/SPARK-13030)). **Note** that +`OneHotEncoderEstimator` will be renamed to `OneHotEncoder` in `3.0` (but +`OneHotEncoderEstimator` will be kept as an alias). **Changes of behavior** * [SPARK-21027](https://issues.apache.org/jira/browse/SPARK-21027): - We are now setting the default parallelism used in `OneVsRest` to be 1 (i.e. serial), in 2.2 and earlier version, - the `OneVsRest` parallelism would be parallelism of the default threadpool in scala. - -## From 2.1 to 2.2 - -### Breaking changes - -There are no breaking changes. - -### Deprecations and changes of behavior - -**Deprecations** - -There are no deprecations. - -**Changes of behavior** - -* [SPARK-19787](https://issues.apache.org/jira/browse/SPARK-19787): - Default value of `regParam` changed from `1.0` to `0.1` for `ALS.train` method (marked `DeveloperApi`). - **Note** this does _not affect_ the `ALS` Estimator or Model, nor MLlib's `ALS` class. -* [SPARK-14772](https://issues.apache.org/jira/browse/SPARK-14772): - Fixed inconsistency between Python and Scala APIs for `Param.copy` method. -* [SPARK-11569](https://issues.apache.org/jira/browse/SPARK-11569): - `StringIndexer` now handles `NULL` values in the same way as unseen values. Previously an exception - would always be thrown regardless of the setting of the `handleInvalid` parameter. + We are now setting the default parallelism used in `OneVsRest` to be 1 (i.e. serial). In 2.2 and + earlier versions, the level of parallelism was set to the default threadpool size in Scala. ## Previous Spark versions diff --git a/docs/ml-migration-guides.md b/docs/ml-migration-guides.md index 687d7c8930362..f4b0df58cf63b 100644 --- a/docs/ml-migration-guides.md +++ b/docs/ml-migration-guides.md @@ -7,6 +7,29 @@ description: MLlib migration guides from before Spark SPARK_VERSION_SHORT The migration guide for the current Spark version is kept on the [MLlib Guide main page](ml-guide.html#migration-guide). +## From 2.1 to 2.2 + +### Breaking changes + +There are no breaking changes. + +### Deprecations and changes of behavior + +**Deprecations** + +There are no deprecations. + +**Changes of behavior** + +* [SPARK-19787](https://issues.apache.org/jira/browse/SPARK-19787): + Default value of `regParam` changed from `1.0` to `0.1` for `ALS.train` method (marked `DeveloperApi`). + **Note** this does _not affect_ the `ALS` Estimator or Model, nor MLlib's `ALS` class. +* [SPARK-14772](https://issues.apache.org/jira/browse/SPARK-14772): + Fixed inconsistency between Python and Scala APIs for `Param.copy` method. +* [SPARK-11569](https://issues.apache.org/jira/browse/SPARK-11569): + `StringIndexer` now handles `NULL` values in the same way as unseen values. Previously an exception + would always be thrown regardless of the setting of the `handleInvalid` parameter. + ## From 2.0 to 2.1 ### Breaking changes From 8480c0c57698b7dcccec5483d67b17cf2c7527ed Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 26 Jan 2018 07:50:48 +0900 Subject: [PATCH 204/774] [SPARK-23081][PYTHON] Add colRegex API to PySpark ## What changes were proposed in this pull request? Add colRegex API to PySpark ## How was this patch tested? add a test in sql/tests.py Author: Huaxin Gao Closes #20390 from huaxingao/spark-23081. --- python/pyspark/sql/dataframe.py | 23 +++++++++++++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 8 +++---- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 2d5e9b91468cf..ac403080acfdf 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -819,6 +819,29 @@ def columns(self): """ return [f.name for f in self.schema.fields] + @since(2.3) + def colRegex(self, colName): + """ + Selects column based on the column name specified as a regex and returns it + as :class:`Column`. + + :param colName: string, column name specified as a regex. + + >>> df = spark.createDataFrame([("a", 1), ("b", 2), ("c", 3)], ["Col1", "Col2"]) + >>> df.select(df.colRegex("`(Col1)?+.+`")).show() + +----+ + |Col2| + +----+ + | 1| + | 2| + | 3| + +----+ + """ + if not isinstance(colName, basestring): + raise ValueError("colName should be provided as string") + jc = self._jdf.colRegex(colName) + return Column(jc) + @ignore_unicode_prefix @since(1.3) def alias(self, alias): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 912f411fa3845..edb6644ed5ac0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1194,7 +1194,7 @@ class Dataset[T] private[sql]( def orderBy(sortExprs: Column*): Dataset[T] = sort(sortExprs : _*) /** - * Selects column based on the column name and return it as a [[Column]]. + * Selects column based on the column name and returns it as a [[Column]]. * * @note The column name can also reference to a nested column like `a.b`. * @@ -1220,7 +1220,7 @@ class Dataset[T] private[sql]( } /** - * Selects column based on the column name and return it as a [[Column]]. + * Selects column based on the column name and returns it as a [[Column]]. * * @note The column name can also reference to a nested column like `a.b`. * @@ -1240,7 +1240,7 @@ class Dataset[T] private[sql]( } /** - * Selects column based on the column name specified as a regex and return it as [[Column]]. + * Selects column based on the column name specified as a regex and returns it as [[Column]]. * @group untypedrel * @since 2.3.0 */ @@ -2729,7 +2729,7 @@ class Dataset[T] private[sql]( } /** - * Return an iterator that contains all rows in this Dataset. + * Returns an iterator that contains all rows in this Dataset. * * The iterator will consume as much memory as the largest partition in this Dataset. * From e57f394818b0a62f99609e1032fede7e981f306f Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Thu, 25 Jan 2018 16:11:33 -0800 Subject: [PATCH 205/774] [SPARK-23032][SQL] Add a per-query codegenStageId to WholeStageCodegenExec ## What changes were proposed in this pull request? **Proposal** Add a per-query ID to the codegen stages as represented by `WholeStageCodegenExec` operators. This ID will be used in - the explain output of the physical plan, and in - the generated class name. Specifically, this ID will be stable within a query, counting up from 1 in depth-first post-order for all the `WholeStageCodegenExec` inserted into a plan. The ID value 0 is reserved for "free-floating" `WholeStageCodegenExec` objects, which may have been created for one-off purposes, e.g. for fallback handling of codegen stages that failed to codegen the whole stage and wishes to codegen a subset of the children operators (as seen in `org.apache.spark.sql.execution.FileSourceScanExec#doExecute`). Example: for the following query: ```scala scala> spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 1) scala> val df1 = spark.range(10).select('id as 'x, 'id + 1 as 'y).orderBy('x).select('x + 1 as 'z, 'y) df1: org.apache.spark.sql.DataFrame = [z: bigint, y: bigint] scala> val df2 = spark.range(5) df2: org.apache.spark.sql.Dataset[Long] = [id: bigint] scala> val query = df1.join(df2, 'z === 'id) query: org.apache.spark.sql.DataFrame = [z: bigint, y: bigint ... 1 more field] ``` The explain output before the change is: ```scala scala> query.explain == Physical Plan == *SortMergeJoin [z#9L], [id#13L], Inner :- *Sort [z#9L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(z#9L, 200) : +- *Project [(x#3L + 1) AS z#9L, y#4L] : +- *Sort [x#3L ASC NULLS FIRST], true, 0 : +- Exchange rangepartitioning(x#3L ASC NULLS FIRST, 200) : +- *Project [id#0L AS x#3L, (id#0L + 1) AS y#4L] : +- *Range (0, 10, step=1, splits=8) +- *Sort [id#13L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(id#13L, 200) +- *Range (0, 5, step=1, splits=8) ``` Note how codegen'd operators are annotated with a prefix `"*"`. See how the `SortMergeJoin` operator and its direct children `Sort` operators are adjacent and all annotated with the `"*"`, so it's hard to tell they're actually in separate codegen stages. and after this change it'll be: ```scala scala> query.explain == Physical Plan == *(6) SortMergeJoin [z#9L], [id#13L], Inner :- *(3) Sort [z#9L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(z#9L, 200) : +- *(2) Project [(x#3L + 1) AS z#9L, y#4L] : +- *(2) Sort [x#3L ASC NULLS FIRST], true, 0 : +- Exchange rangepartitioning(x#3L ASC NULLS FIRST, 200) : +- *(1) Project [id#0L AS x#3L, (id#0L + 1) AS y#4L] : +- *(1) Range (0, 10, step=1, splits=8) +- *(5) Sort [id#13L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(id#13L, 200) +- *(4) Range (0, 5, step=1, splits=8) ``` Note that the annotated prefix becomes `"*(id) "`. See how the `SortMergeJoin` operator and its direct children `Sort` operators have different codegen stage IDs. It'll also show up in the name of the generated class, as a suffix in the format of `GeneratedClass$GeneratedIterator$id`. For example, note how `GeneratedClass$GeneratedIteratorForCodegenStage3` and `GeneratedClass$GeneratedIteratorForCodegenStage6` in the following stack trace corresponds to the IDs shown in the explain output above: ``` "Executor task launch worker for task 42412957" daemon prio=5 tid=0x58 nid=NA runnable java.lang.Thread.State: RUNNABLE at org.apache.spark.sql.execution.UnsafeExternalRowSorter.insertRow(UnsafeExternalRowSorter.java:109) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.sort_addToSorter$(generated.java:32) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(generated.java:41) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$9$$anon$1.hasNext(WholeStageCodegenExec.scala:494) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage6.findNextInnerJoinRows$(generated.java:42) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage6.processNext(generated.java:101) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$2.hasNext(WholeStageCodegenExec.scala:513) at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253) at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:828) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:828) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324) at org.apache.spark.rdd.RDD.iterator(RDD.scala:288) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87) at org.apache.spark.scheduler.Task.run(Task.scala:109) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:748) ``` **Rationale** Right now, the codegen from Spark SQL lacks the means to differentiate between a couple of things: 1. It's hard to tell which physical operators are in the same WholeStageCodegen stage. Note that this "stage" is a separate notion from Spark's RDD execution stages; this one is only to delineate codegen units. There can be adjacent physical operators that are both codegen'd but are in separate codegen stages. Some of this is due to hacky implementation details, such as the case with `SortMergeJoin` and its `Sort` inputs -- they're hard coded to be split into separate stages although both are codegen'd. When printing out the explain output of the physical plan, you'd only see the codegen'd physical operators annotated with a preceding star (`'*'`) but would have no way to figure out if they're in the same stage. 2. Performance/error diagnosis The generated code has class/method names that are hard to differentiate between queries or even between codegen stages within the same query. If we use a Java-level profiler to collect profiles, or if we encounter a Java-level exception with a stack trace in it, it's really hard to tell which part of a query it's at. By introducing a per-query codegen stage ID, we'd at least be able to know which codegen stage (and in turn, which group of physical operators) was a profile tick or an exception happened. The reason why this proposal uses a per-query ID is because it's stable within a query, so that multiple runs of the same query will see the same resulting IDs. This both benefits understandability for users, and also it plays well with the codegen cache in Spark SQL which uses the generated source code as the key. The downside to using per-query IDs as opposed to a per-session or globally incrementing ID is of course we can't tell apart different query runs with this ID alone. But for now I believe this is a good enough tradeoff. ## How was this patch tested? Existing tests. This PR does not involve any runtime behavior changes other than some name changes. The SQL query test suites that compares explain outputs have been updates to ignore the newly added `codegenStageId`. Author: Kris Mok Closes #20224 from rednaxelafx/wsc-codegenstageid. --- .../apache/spark/sql/internal/SQLConf.scala | 10 +++ .../sql/execution/DataSourceScanExec.scala | 2 +- .../sql/execution/WholeStageCodegenExec.scala | 85 +++++++++++++++++-- .../columnar/InMemoryTableScanExec.scala | 2 +- .../datasources/v2/DataSourceV2ScanExec.scala | 2 +- .../apache/spark/sql/SQLQueryTestSuite.scala | 3 +- .../execution/WholeStageCodegenSuite.scala | 34 ++++++++ .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../sql/hive/execution/HiveExplainSuite.scala | 39 +++++++-- 9 files changed, 158 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 470f88c213561..b0d18b6dced76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -629,6 +629,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME = + buildConf("spark.sql.codegen.useIdInClassName") + .internal() + .doc("When true, embed the (whole-stage) codegen stage ID into " + + "the class name of the generated class as a suffix") + .booleanConf + .createWithDefault(true) + val WHOLESTAGE_MAX_NUM_FIELDS = buildConf("spark.sql.codegen.maxFields") .internal() .doc("The maximum number of fields (including nested fields) that will be supported before" + @@ -1264,6 +1272,8 @@ class SQLConf extends Serializable with Logging { def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED) + def wholeStageUseIdInClassName: Boolean = getConf(WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME) + def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS) def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 7c7d79c2bbd7c..aa66ee7e948ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -324,7 +324,7 @@ case class FileSourceScanExec( // in the case of fallback, this batched scan should never fail because of: // 1) only primitive types are supported // 2) the number of columns should be smaller than spark.sql.codegen.maxFields - WholeStageCodegenExec(this).execute() + WholeStageCodegenExec(this)(codegenStageId = 0).execute() } else { val unsafeRows = { val scan = inputRDD diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 8ea9e81b2e53b..0e525b1e22eb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import java.util.Locale +import java.util.function.Supplier import scala.collection.mutable @@ -414,6 +415,58 @@ object WholeStageCodegenExec { } } +object WholeStageCodegenId { + // codegenStageId: ID for codegen stages within a query plan. + // It does not affect equality, nor does it participate in destructuring pattern matching + // of WholeStageCodegenExec. + // + // This ID is used to help differentiate between codegen stages. It is included as a part + // of the explain output for physical plans, e.g. + // + // == Physical Plan == + // *(5) SortMergeJoin [x#3L], [y#9L], Inner + // :- *(2) Sort [x#3L ASC NULLS FIRST], false, 0 + // : +- Exchange hashpartitioning(x#3L, 200) + // : +- *(1) Project [(id#0L % 2) AS x#3L] + // : +- *(1) Filter isnotnull((id#0L % 2)) + // : +- *(1) Range (0, 5, step=1, splits=8) + // +- *(4) Sort [y#9L ASC NULLS FIRST], false, 0 + // +- Exchange hashpartitioning(y#9L, 200) + // +- *(3) Project [(id#6L % 2) AS y#9L] + // +- *(3) Filter isnotnull((id#6L % 2)) + // +- *(3) Range (0, 5, step=1, splits=8) + // + // where the ID makes it obvious that not all adjacent codegen'd plan operators are of the + // same codegen stage. + // + // The codegen stage ID is also optionally included in the name of the generated classes as + // a suffix, so that it's easier to associate a generated class back to the physical operator. + // This is controlled by SQLConf: spark.sql.codegen.useIdInClassName + // + // The ID is also included in various log messages. + // + // Within a query, a codegen stage in a plan starts counting from 1, in "insertion order". + // WholeStageCodegenExec operators are inserted into a plan in depth-first post-order. + // See CollapseCodegenStages.insertWholeStageCodegen for the definition of insertion order. + // + // 0 is reserved as a special ID value to indicate a temporary WholeStageCodegenExec object + // is created, e.g. for special fallback handling when an existing WholeStageCodegenExec + // failed to generate/compile code. + + private val codegenStageCounter = ThreadLocal.withInitial(new Supplier[Integer] { + override def get() = 1 // TODO: change to Scala lambda syntax when upgraded to Scala 2.12+ + }) + + def resetPerQuery(): Unit = codegenStageCounter.set(1) + + def getNextStageId(): Int = { + val counter = codegenStageCounter + val id = counter.get() + counter.set(id + 1) + id + } +} + /** * WholeStageCodegen compiles a subtree of plans that support codegen together into single Java * function. @@ -442,7 +495,8 @@ object WholeStageCodegenExec { * `doCodeGen()` will create a `CodeGenContext`, which will hold a list of variables for input, * used to generated code for [[BoundReference]]. */ -case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport { +case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) + extends UnaryExecNode with CodegenSupport { override def output: Seq[Attribute] = child.output @@ -454,6 +508,12 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co "pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext, WholeStageCodegenExec.PIPELINE_DURATION_METRIC)) + def generatedClassName(): String = if (conf.wholeStageUseIdInClassName) { + s"GeneratedIteratorForCodegenStage$codegenStageId" + } else { + "GeneratedIterator" + } + /** * Generates code for this subtree. * @@ -471,19 +531,23 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co } """, inlineToOuterClass = true) + val className = generatedClassName() + val source = s""" public Object generate(Object[] references) { - return new GeneratedIterator(references); + return new $className(references); } - ${ctx.registerComment(s"""Codegend pipeline for\n${child.treeString.trim}""")} - final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { + ${ctx.registerComment( + s"""Codegend pipeline for stage (id=$codegenStageId) + |${this.treeString.trim}""".stripMargin)} + final class $className extends ${classOf[BufferedRowIterator].getName} { private Object[] references; private scala.collection.Iterator[] inputs; ${ctx.declareMutableStates()} - public GeneratedIterator(Object[] references) { + public $className(Object[] references) { this.references = references; } @@ -516,7 +580,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co } catch { case _: Exception if !Utils.isTesting && sqlContext.conf.codegenFallback => // We should already saw the error message - logWarning(s"Whole-stage codegen disabled for this plan:\n $treeString") + logWarning(s"Whole-stage codegen disabled for plan (id=$codegenStageId):\n $treeString") return child.execute() } @@ -525,7 +589,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co logInfo(s"Found too long generated codes and JIT optimization might not work: " + s"the bytecode size ($maxCodeSize) is above the limit " + s"${sqlContext.conf.hugeMethodLimit}, and the whole-stage codegen was disabled " + - s"for this plan. To avoid this, you can raise the limit " + + s"for this plan (id=$codegenStageId). To avoid this, you can raise the limit " + s"`${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}`:\n$treeString") child match { // The fallback solution of batch file source scan still uses WholeStageCodegenExec @@ -603,10 +667,12 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co verbose: Boolean, prefix: String = "", addSuffix: Boolean = false): StringBuilder = { - child.generateTreeString(depth, lastChildren, builder, verbose, "*") + child.generateTreeString(depth, lastChildren, builder, verbose, s"*($codegenStageId) ") } override def needStopCheck: Boolean = true + + override protected def otherCopyArgs: Seq[AnyRef] = Seq(codegenStageId.asInstanceOf[Integer]) } @@ -657,13 +723,14 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { case plan if plan.output.length == 1 && plan.output.head.dataType.isInstanceOf[ObjectType] => plan.withNewChildren(plan.children.map(insertWholeStageCodegen)) case plan: CodegenSupport if supportCodegen(plan) => - WholeStageCodegenExec(insertInputAdapter(plan)) + WholeStageCodegenExec(insertInputAdapter(plan))(WholeStageCodegenId.getNextStageId()) case other => other.withNewChildren(other.children.map(insertWholeStageCodegen)) } def apply(plan: SparkPlan): SparkPlan = { if (conf.wholeStageEnabled) { + WholeStageCodegenId.resetPerQuery() insertWholeStageCodegen(plan) } else { plan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 28b3875505cd2..c167f1e7dc621 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -274,7 +274,7 @@ case class InMemoryTableScanExec( protected override def doExecute(): RDD[InternalRow] = { if (supportsBatch) { - WholeStageCodegenExec(this).execute() + WholeStageCodegenExec(this)(codegenStageId = 0).execute() } else { inputRDD } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 69d871df3e1dd..2c22239e81869 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -88,7 +88,7 @@ case class DataSourceV2ScanExec( override protected def doExecute(): RDD[InternalRow] = { if (supportsBatch) { - WholeStageCodegenExec(this).execute() + WholeStageCodegenExec(this)(codegenStageId = 0).execute() } else { val numOutputRows = longMetric("numOutputRows") inputRDD.map { r => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 054ada56d99ad..beac9699585d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -230,7 +230,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { .replaceAll("Location.*/sql/core/", s"Location ${notIncludedMsg}sql/core/") .replaceAll("Created By.*", s"Created By $notIncludedMsg") .replaceAll("Created Time.*", s"Created Time $notIncludedMsg") - .replaceAll("Last Access.*", s"Last Access $notIncludedMsg")) + .replaceAll("Last Access.*", s"Last Access $notIncludedMsg") + .replaceAll("\\*\\(\\d+\\) ", "*")) // remove the WholeStageCodegen codegenStageIds // If the output is not pre-sorted, sort it. if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 242bb48c22942..28ad712feaae6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.{QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -273,4 +274,37 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } } } + + test("codegen stage IDs should be preserved in transformations after CollapseCodegenStages") { + // test case adapted from DataFrameSuite to trigger ReuseExchange + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2") { + val df = spark.range(100) + val join = df.join(df, "id") + val plan = join.queryExecution.executedPlan + assert(!plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].codegenStageId == 0).isDefined, + "codegen stage IDs should be preserved through ReuseExchange") + checkAnswer(join, df.toDF) + } + } + + test("including codegen stage ID in generated class name should not regress codegen caching") { + import testImplicits._ + + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> "true") { + val bytecodeSizeHisto = CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE + + // the same query run twice should hit the codegen cache + spark.range(3).select('id + 2).collect + val after1 = bytecodeSizeHisto.getCount + spark.range(3).select('id + 2).collect + val after2 = bytecodeSizeHisto.getCount // same query shape as above, deliberately + // bytecodeSizeHisto's count is always monotonically increasing if new compilation to + // bytecode had occurred. If the count stayed the same that means we've got a cache hit. + assert(after1 == after2, "Should hit codegen cache. No new compilation to bytecode expected") + + // a different query can result in codegen cache miss, that's by design + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index ff7c5e58e9863..2280da927cf70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -477,7 +477,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { assert(planBeforeFilter.head.isInstanceOf[InMemoryTableScanExec]) val execPlan = if (enabled == "true") { - WholeStageCodegenExec(planBeforeFilter.head) + WholeStageCodegenExec(planBeforeFilter.head)(codegenStageId = 0) } else { planBeforeFilter.head } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index a4273de5fe260..f84d188075b72 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -154,14 +154,39 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto } } - test("EXPLAIN CODEGEN command") { - checkKeywordsExist(sql("EXPLAIN CODEGEN SELECT 1"), - "WholeStageCodegen", - "Generated code:", - "/* 001 */ public Object generate(Object[] references) {", - "/* 002 */ return new GeneratedIterator(references);", - "/* 003 */ }" + test("explain output of physical plan should contain proper codegen stage ID") { + checkKeywordsExist(sql( + """ + |EXPLAIN SELECT t1.id AS a, t2.id AS b FROM + |(SELECT * FROM range(3)) t1 JOIN + |(SELECT * FROM range(10)) t2 ON t1.id == t2.id % 3 + """.stripMargin), + "== Physical Plan ==", + "*(2) Project ", + "+- *(2) BroadcastHashJoin ", + " :- BroadcastExchange ", + " : +- *(1) Range ", + " +- *(2) Range " ) + } + + test("EXPLAIN CODEGEN command") { + // the generated class name in this test should stay in sync with + // org.apache.spark.sql.execution.WholeStageCodegenExec.generatedClassName() + for ((useIdInClassName, expectedClassName) <- Seq( + ("true", "GeneratedIteratorForCodegenStage1"), + ("false", "GeneratedIterator"))) { + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> useIdInClassName) { + checkKeywordsExist(sql("EXPLAIN CODEGEN SELECT 1"), + "WholeStageCodegen", + "Generated code:", + "/* 001 */ public Object generate(Object[] references) {", + s"/* 002 */ return new $expectedClassName(references);", + "/* 003 */ }" + ) + } + } checkKeywordsNotExist(sql("EXPLAIN CODEGEN SELECT 1"), "== Physical Plan ==" From 7bd46d9871567597216cc02e1dc72ff5806ecdf8 Mon Sep 17 00:00:00 2001 From: Sid Murching Date: Thu, 25 Jan 2018 18:15:29 -0600 Subject: [PATCH 206/774] [SPARK-23205][ML] Update ImageSchema.readImages to correctly set alpha values for four-channel images ## What changes were proposed in this pull request? When parsing raw image data in ImageSchema.decode(), we use a [java.awt.Color](https://docs.oracle.com/javase/7/docs/api/java/awt/Color.html#Color(int)) constructor that sets alpha = 255, even for four-channel images (which may have different alpha values). This PR fixes this issue & adds a unit test to verify correctness of reading four-channel images. ## How was this patch tested? Updates an existing unit test ("readImages pixel values test" in `ImageSchemaSuite`) to also verify correctness when reading a four-channel image. Author: Sid Murching Closes #20389 from smurching/image-schema-bugfix. --- data/mllib/images/multi-channel/BGRA_alpha_60.png | Bin 0 -> 747 bytes .../org/apache/spark/ml/image/ImageSchema.scala | 5 ++--- .../apache/spark/ml/image/ImageSchemaSuite.scala | 9 ++++++--- 3 files changed, 8 insertions(+), 6 deletions(-) create mode 100644 data/mllib/images/multi-channel/BGRA_alpha_60.png diff --git a/data/mllib/images/multi-channel/BGRA_alpha_60.png b/data/mllib/images/multi-channel/BGRA_alpha_60.png new file mode 100644 index 0000000000000000000000000000000000000000..913637cd2828ab4e2ff4b2bbd92c4cf362f871c4 GIT binary patch literal 747 zcmV zL3V>M390}vf5ExS_g|iB()io}T zdzc-a9(QkC83JQkyn5Lt1Z5CrRi#xHpD{9oI-@&0jtqe@HUD3>kZs0McGO5TM~1+d z7FE_NXiaZ3z?maMU@%u%R=mvsm?J}AJhD?K_8)UCLtrp7B&y$7129;Iz!*D2yjN8< z0y9X4z_>+*xc3{0=Ex8j*D|Cx*=8j4K{5o!c7|A?%)le4CMiSsY+o_Vo>AV}VFh50 z41v-2iea`H1DYg5U<}czr}rCyox2Qyfqu6aXGTB<$q*RG3>nT0#|)AoFmf{_nrt%+ z=R=0T=#wE~ZFN$PgH=pIpR#r$}~v0vQ6s<%(gm8QC*8vEQiGG6cq@D~4&`dke3xtTJT?jHV1R zn*p1-WHaVkhQK(LAu?mT_I%GyhQKgo$ZgD^p$y@(n<2xRQG=Ep8^{nCn;A09ds8(A zT2-xU83JRGAy_kN4BT(jY8e8f?Uz2S`+1phqfY#&mLV|C{nDp(Kbg^7%Mci;cTmZU z&sv7SNV$V*STh2UAuvMkprV>#Myssnk&_`_<2W4$=?*U$0wXp)$=dj2RgM;~p7O*^V>AfDC~V@@+tF<5118 dqE*&-`~khx#S9AHa*F@}002ovPDHLkV1lO_Oc($F literal 0 HcmV?d00001 diff --git a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala index f7850b238465b..dcc40b6668c7a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala @@ -169,12 +169,11 @@ object ImageSchema { var offset = 0 for (h <- 0 until height) { for (w <- 0 until width) { - val color = new Color(img.getRGB(w, h)) - + val color = new Color(img.getRGB(w, h), hasAlpha) decoded(offset) = color.getBlue.toByte decoded(offset + 1) = color.getGreen.toByte decoded(offset + 2) = color.getRed.toByte - if (nChannels == 4) { + if (hasAlpha) { decoded(offset + 3) = color.getAlpha.toByte } offset += nChannels diff --git a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala index dba61cd1eb1cc..a8833c615865d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala @@ -53,11 +53,11 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(df.count === 1) df = readImages(imagePath, null, true, -1, false, 1.0, 0) - assert(df.count === 9) + assert(df.count === 10) df = readImages(imagePath, null, true, -1, true, 1.0, 0) val countTotal = df.count - assert(countTotal === 7) + assert(countTotal === 8) df = readImages(imagePath, null, true, -1, true, 0.5, 0) // Random number about half of the size of the original dataset @@ -103,6 +103,9 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { -71, -58, -56, -73, -64))), "BGRA.png" -> (("CV_8UC4", Array[Byte](-128, -128, -8, -1, -128, -128, -8, -1, -128, - -128, -8, -1, 127, 127, -9, -1, 127, 127, -9, -1))) + -128, -8, -1, 127, 127, -9, -1, 127, 127, -9, -1))), + "BGRA_alpha_60.png" -> (("CV_8UC4", + Array[Byte](-128, -128, -8, 60, -128, -128, -8, 60, -128, + -128, -8, 60, 127, 127, -9, 60, 127, 127, -9, 60))) ) } From 70a68b328b856c17eb22cc86fee0ebe8d64f8825 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 26 Jan 2018 11:58:20 +0800 Subject: [PATCH 207/774] [SPARK-23020][CORE] Fix race in SparkAppHandle cleanup, again. Third time is the charm? There was still a race that was left in previous attempts. If the handle closes the connection, the close() implementation would clean up state that would prevent the thread from waiting on the connection thread to finish. That could cause the race causing the test flakiness reported in the bug. The fix is to move the "wait for connection thread" code to a separate close method that is used by the handle; that also simplifies the code a bit and makes it also easier to follow. I included an unrelated, but correct, change to a YARN test so that it triggers when the PR is built. Tested by inserting a sleep in the connection thread to mimic the race; test failed reliably with the sleep, passes now. (Sleep not included in the patch.) Also ran YARN tests to make sure. Author: Marcelo Vanzin Closes #20388 from vanzin/SPARK-23020. --- .../spark/launcher/AbstractAppHandle.java | 42 ++++++++------ .../spark/launcher/ChildProcAppHandle.java | 11 +--- .../spark/launcher/InProcessAppHandle.java | 9 +-- .../apache/spark/launcher/LauncherServer.java | 55 +++++++++---------- .../spark/deploy/yarn/YarnClusterSuite.scala | 5 +- 5 files changed, 55 insertions(+), 67 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java index daf0972f824dd..84a25a5254151 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Level; import java.util.logging.Logger; @@ -29,15 +30,15 @@ abstract class AbstractAppHandle implements SparkAppHandle { private final LauncherServer server; - private LauncherConnection connection; + private LauncherServer.ServerConnection connection; private List listeners; - private State state; + private AtomicReference state; private String appId; private volatile boolean disposed; protected AbstractAppHandle(LauncherServer server) { this.server = server; - this.state = State.UNKNOWN; + this.state = new AtomicReference<>(State.UNKNOWN); } @Override @@ -50,7 +51,7 @@ public synchronized void addListener(Listener l) { @Override public State getState() { - return state; + return state.get(); } @Override @@ -73,7 +74,7 @@ public synchronized void disconnect() { if (!isDisposed()) { if (connection != null) { try { - connection.close(); + connection.closeAndWait(); } catch (IOException ioe) { // no-op. } @@ -82,7 +83,7 @@ public synchronized void disconnect() { } } - void setConnection(LauncherConnection connection) { + void setConnection(LauncherServer.ServerConnection connection) { this.connection = connection; } @@ -99,12 +100,9 @@ boolean isDisposed() { */ synchronized void dispose() { if (!isDisposed()) { - // Unregister first to make sure that the connection with the app has been really - // terminated. server.unregister(this); - if (!getState().isFinal()) { - setState(State.LOST); - } + // Set state to LOST if not yet final. + setState(State.LOST, false); this.disposed = true; } } @@ -113,14 +111,24 @@ void setState(State s) { setState(s, false); } - synchronized void setState(State s, boolean force) { - if (force || !state.isFinal()) { - state = s; + void setState(State s, boolean force) { + if (force) { + state.set(s); fireEvent(false); - } else { - LOG.log(Level.WARNING, "Backend requested transition from final state {0} to {1}.", - new Object[] { state, s }); + return; } + + State current = state.get(); + while (!current.isFinal()) { + if (state.compareAndSet(current, s)) { + fireEvent(false); + return; + } + current = state.get(); + } + + LOG.log(Level.WARNING, "Backend requested transition from final state {0} to {1}.", + new Object[] { current, s }); } synchronized void setAppId(String appId) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index 2b99461652e1f..5e3c95676ecbe 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -104,19 +104,12 @@ void monitorChild() { ec = 1; } - State currState = getState(); - State newState = null; if (ec != 0) { + State currState = getState(); // Override state with failure if the current state is not final, or is success. if (!currState.isFinal() || currState == State.FINISHED) { - newState = State.FAILED; + setState(State.FAILED, true); } - } else if (!currState.isFinal()) { - newState = State.LOST; - } - - if (newState != null) { - setState(newState, true); } disconnect(); diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index f04263cb74a58..b8030e0063a37 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -66,14 +66,7 @@ synchronized void start(String appName, Method main, String[] args) { setState(State.FAILED); } - synchronized (InProcessAppHandle.this) { - if (!isDisposed()) { - disconnect(); - if (!getState().isFinal()) { - setState(State.LOST, true); - } - } - } + disconnect(); }); app.setName(String.format(THREAD_NAME_FMT, THREAD_IDS.incrementAndGet(), appName)); diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index 8091885c4f562..f4ecd52fdeab8 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -218,32 +218,6 @@ void unregister(AbstractAppHandle handle) { } } - // If there is a live connection for this handle, we need to wait for it to finish before - // returning, otherwise there might be a race between the connection thread processing - // buffered data and the handle cleaning up after itself, leading to potentially the wrong - // state being reported for the handle. - ServerConnection conn = null; - synchronized (clients) { - for (ServerConnection c : clients) { - if (c.handle == handle) { - conn = c; - break; - } - } - } - - if (conn != null) { - synchronized (conn) { - if (conn.isOpen()) { - try { - conn.wait(); - } catch (InterruptedException ie) { - // Ignore. - } - } - } - } - unref(); } @@ -312,9 +286,10 @@ private String createSecret() { } } - private class ServerConnection extends LauncherConnection { + class ServerConnection extends LauncherConnection { private TimerTask timeout; + private volatile Thread connectionThread; volatile AbstractAppHandle handle; ServerConnection(Socket socket, TimerTask timeout) throws IOException { @@ -322,6 +297,12 @@ private class ServerConnection extends LauncherConnection { this.timeout = timeout; } + @Override + public void run() { + this.connectionThread = Thread.currentThread(); + super.run(); + } + @Override protected void handle(Message msg) throws IOException { try { @@ -376,9 +357,23 @@ public void close() throws IOException { clients.remove(this); } - synchronized (this) { - super.close(); - notifyAll(); + super.close(); + } + + /** + * Close the connection and wait for any buffered data to be processed before returning. + * This ensures any changes reported by the child application take effect. + */ + public void closeAndWait() throws IOException { + close(); + + Thread connThread = this.connectionThread; + if (Thread.currentThread() != connThread) { + try { + connThread.join(); + } catch (InterruptedException ie) { + // Ignore. + } } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index e9dcfaf6ba4f0..5003326b440bf 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -45,8 +45,7 @@ import org.apache.spark.util.Utils /** * Integration tests for YARN; these tests use a mini Yarn cluster to run Spark-on-YARN - * applications, and require the Spark assembly to be built before they can be successfully - * run. + * applications. */ @ExtendedYarnTest class YarnClusterSuite extends BaseYarnClusterSuite { @@ -152,7 +151,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { } test("run Python application in yarn-cluster mode using " + - " spark.yarn.appMasterEnv to override local envvar") { + "spark.yarn.appMasterEnv to override local envvar") { testPySpark( clientMode = false, extraConf = Map( From d1721816d26bedee3c72eeb75db49da500568376 Mon Sep 17 00:00:00 2001 From: Santiago Saavedra Date: Fri, 26 Jan 2018 15:24:06 +0800 Subject: [PATCH 208/774] [SPARK-23200] Reset Kubernetes-specific config on Checkpoint restore ## What changes were proposed in this pull request? When using the Kubernetes cluster-manager and spawning a Streaming workload, it is important to reset many spark.kubernetes.* properties that are generated by spark-submit but which would get rewritten when restoring a Checkpoint. This is so, because the spark-submit codepath creates Kubernetes resources, such as a ConfigMap, a Secret and other variables, which have an autogenerated name and the previous one will not resolve anymore. In short, this change enables checkpoint restoration for streaming workloads, and thus enables Spark Streaming workloads in Kubernetes, which were not possible to restore from a checkpoint before if the workload went down. ## How was this patch tested? This patch was tested with the twitter-streaming example in AWS, using checkpoints in s3 with the s3a:// protocol, as supported by Hadoop. This is similar to the YARN related code for resetting a Spark Streaming workload, but for the Kubernetes scheduler. I'm adding the initcontainers properties because even if the discussion is not completely settled on the mailing list, my understanding is that at this moment they are going forward for the moment. For a previous discussion, see the non-rebased work at: https://github.com/apache-spark-on-k8s/spark/pull/516 Author: Santiago Saavedra Closes #20383 from ssaavedra/fix-k8s-checkpointing. --- .../org/apache/spark/streaming/Checkpoint.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index aed67a5027433..ed2a896033749 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -53,6 +53,21 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.driver.host", "spark.driver.bindAddress", "spark.driver.port", + "spark.kubernetes.driver.pod.name", + "spark.kubernetes.executor.podNamePrefix", + "spark.kubernetes.initcontainer.executor.configmapname", + "spark.kubernetes.initcontainer.executor.configmapkey", + "spark.kubernetes.initcontainer.downloadJarsResourceIdentifier", + "spark.kubernetes.initcontainer.downloadJarsSecretLocation", + "spark.kubernetes.initcontainer.downloadFilesResourceIdentifier", + "spark.kubernetes.initcontainer.downloadFilesSecretLocation", + "spark.kubernetes.initcontainer.remoteJars", + "spark.kubernetes.initcontainer.remoteFiles", + "spark.kubernetes.mountdependencies.jarsDownloadDir", + "spark.kubernetes.mountdependencies.filesDownloadDir", + "spark.kubernetes.initcontainer.executor.stagingServerSecret.name", + "spark.kubernetes.initcontainer.executor.stagingServerSecret.mountDir", + "spark.kubernetes.executor.limit.cores", "spark.master", "spark.yarn.jars", "spark.yarn.keytab", @@ -66,6 +81,7 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) .remove("spark.driver.host") .remove("spark.driver.bindAddress") + .remove("spark.kubernetes.driver.pod.name") .remove("spark.driver.port") val newReloadConf = new SparkConf(loadDefaults = true) propertiesToReload.foreach { prop => From cd3956df0f96dd416b6161bf7ce2962e06d0a62e Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 26 Jan 2018 12:23:14 +0200 Subject: [PATCH 209/774] [SPARK-22799][ML] Bucketizer should throw exception if single- and multi-column params are both set ## What changes were proposed in this pull request? Currently there is a mixed situation when both single- and multi-column are supported. In some cases exceptions are thrown, in others only a warning log is emitted. In this discussion https://issues.apache.org/jira/browse/SPARK-8418?focusedCommentId=16275049&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-16275049, the decision was to throw an exception. The PR throws an exception in `Bucketizer`, instead of logging a warning. ## How was this patch tested? modified UT Author: Marco Gaido Author: Joseph K. Bradley Closes #19993 from mgaido91/SPARK-22799. --- .../apache/spark/ml/feature/Bucketizer.scala | 44 +++++------- .../org/apache/spark/ml/param/params.scala | 69 +++++++++++++++++++ .../spark/ml/feature/BucketizerSuite.scala | 41 +++++------ .../apache/spark/ml/param/ParamsSuite.scala | 22 ++++++ 4 files changed, 131 insertions(+), 45 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 8299a3e95d822..c13bf47eacb94 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -32,11 +32,13 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** - * `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0, + * `Bucketizer` maps a column of continuous features to a column of feature buckets. + * + * Since 2.3.0, * `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that - * when both the `inputCol` and `inputCols` parameters are set, a log warning will be printed and - * only `inputCol` will take effect, while `inputCols` will be ignored. The `splits` parameter is - * only used for single column usage, and `splitsArray` is for multiple columns. + * when both the `inputCol` and `inputCols` parameters are set, an Exception will be thrown. The + * `splits` parameter is only used for single column usage, and `splitsArray` is for multiple + * columns. */ @Since("1.4.0") final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) @@ -134,28 +136,11 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("2.3.0") def setOutputCols(value: Array[String]): this.type = set(outputCols, value) - /** - * Determines whether this `Bucketizer` is going to map multiple columns. If and only if - * `inputCols` is set, it will map multiple columns. Otherwise, it just maps a column specified - * by `inputCol`. A warning will be printed if both are set. - */ - private[feature] def isBucketizeMultipleColumns(): Boolean = { - if (isSet(inputCols) && isSet(inputCol)) { - logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " + - "`Bucketizer` only map one column specified by `inputCol`") - false - } else if (isSet(inputCols)) { - true - } else { - false - } - } - @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema) - val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) { + val (inputColumns, outputColumns) = if (isSet(inputCols)) { ($(inputCols).toSeq, $(outputCols).toSeq) } else { (Seq($(inputCol)), Seq($(outputCol))) @@ -170,7 +155,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String } } - val seqOfSplits = if (isBucketizeMultipleColumns()) { + val seqOfSplits = if (isSet(inputCols)) { $(splitsArray).toSeq } else { Seq($(splits)) @@ -201,9 +186,18 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - if (isBucketizeMultipleColumns()) { + ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol, splits), + Seq(outputCols, splitsArray)) + + if (isSet(inputCols)) { + require(getInputCols.length == getOutputCols.length && + getInputCols.length == getSplitsArray.length, s"Bucketizer $this has mismatched Params " + + s"for multi-column transform. Params (inputCols, outputCols, splitsArray) should have " + + s"equal lengths, but they have different lengths: " + + s"(${getInputCols.length}, ${getOutputCols.length}, ${getSplitsArray.length}).") + var transformedSchema = schema - $(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx) => + $(inputCols).zip($(outputCols)).zipWithIndex.foreach { case ((inputCol, outputCol), idx) => SchemaUtils.checkNumericType(transformedSchema, inputCol) transformedSchema = SchemaUtils.appendColumn(transformedSchema, prepOutputField($(splitsArray)(idx), outputCol)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 1b4b401ac4aa0..9a83a5882ce29 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -249,6 +249,75 @@ object ParamValidators { def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) => value.length > lowerBound } + + /** + * Utility for Param validity checks for Transformers which have both single- and multi-column + * support. This utility assumes that `inputCol` indicates single-column usage and + * that `inputCols` indicates multi-column usage. + * + * This checks to ensure that exactly one set of Params has been set, and it + * raises an `IllegalArgumentException` if not. + * + * @param singleColumnParams Params which should be set (or have defaults) if `inputCol` has been + * set. This does not need to include `inputCol`. + * @param multiColumnParams Params which should be set (or have defaults) if `inputCols` has been + * set. This does not need to include `inputCols`. + */ + def checkSingleVsMultiColumnParams( + model: Params, + singleColumnParams: Seq[Param[_]], + multiColumnParams: Seq[Param[_]]): Unit = { + val name = s"${model.getClass.getSimpleName} $model" + + def checkExclusiveParams( + isSingleCol: Boolean, + requiredParams: Seq[Param[_]], + excludedParams: Seq[Param[_]]): Unit = { + val badParamsMsgBuilder = new mutable.StringBuilder() + + val mustUnsetParams = excludedParams.filter(p => model.isSet(p)) + .map(_.name).mkString(", ") + if (mustUnsetParams.nonEmpty) { + badParamsMsgBuilder ++= + s"The following Params are not applicable and should not be set: $mustUnsetParams." + } + + val mustSetParams = requiredParams.filter(p => !model.isDefined(p)) + .map(_.name).mkString(", ") + if (mustSetParams.nonEmpty) { + badParamsMsgBuilder ++= + s"The following Params must be defined but are not set: $mustSetParams." + } + + val badParamsMsg = badParamsMsgBuilder.toString() + + if (badParamsMsg.nonEmpty) { + val errPrefix = if (isSingleCol) { + s"$name has the inputCol Param set for single-column transform." + } else { + s"$name has the inputCols Param set for multi-column transform." + } + throw new IllegalArgumentException(s"$errPrefix $badParamsMsg") + } + } + + val inputCol = model.getParam("inputCol") + val inputCols = model.getParam("inputCols") + + if (model.isSet(inputCol)) { + require(!model.isSet(inputCols), s"$name requires " + + s"exactly one of inputCol, inputCols Params to be set, but both are set.") + + checkExclusiveParams(isSingleCol = true, requiredParams = singleColumnParams, + excludedParams = multiColumnParams) + } else if (model.isSet(inputCols)) { + checkExclusiveParams(isSingleCol = false, requiredParams = multiColumnParams, + excludedParams = singleColumnParams) + } else { + throw new IllegalArgumentException(s"$name requires " + + s"exactly one of inputCol, inputCols Params to be set, but neither is set.") + } + } } // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index d9c97ae8067d3..7403680ae3fdc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -216,8 +216,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result1", "result2")) .setSplitsArray(splits) - assert(bucketizer1.isBucketizeMultipleColumns()) - bucketizer1.transform(dataFrame).select("result1", "expected1", "result2", "expected2") BucketizerSuite.checkBucketResults(bucketizer1.transform(dataFrame), Seq("result1", "result2"), @@ -233,8 +231,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result")) .setSplitsArray(Array(splits(0))) - assert(bucketizer2.isBucketizeMultipleColumns()) - withClue("Invalid feature value -0.9 was not caught as an invalid feature!") { intercept[SparkException] { bucketizer2.transform(badDF1).collect() @@ -268,8 +264,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result1", "result2")) .setSplitsArray(splits) - assert(bucketizer.isBucketizeMultipleColumns()) - BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame), Seq("result1", "result2"), Seq("expected1", "expected2")) @@ -295,8 +289,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result1", "result2")) .setSplitsArray(splits) - assert(bucketizer.isBucketizeMultipleColumns()) - bucketizer.setHandleInvalid("keep") BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame), Seq("result1", "result2"), @@ -335,7 +327,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setInputCols(Array("myInputCol")) .setOutputCols(Array("myOutputCol")) .setSplitsArray(Array(Array(0.1, 0.8, 0.9))) - assert(t.isBucketizeMultipleColumns()) testDefaultReadWrite(t) } @@ -348,8 +339,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result1", "result2")) .setSplitsArray(Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5))) - assert(bucket.isBucketizeMultipleColumns()) - val pl = new Pipeline() .setStages(Array(bucket)) .fit(df) @@ -401,15 +390,27 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } } - test("Both inputCol and inputCols are set") { - val bucket = new Bucketizer() - .setInputCol("feature1") - .setOutputCol("result") - .setSplits(Array(-0.5, 0.0, 0.5)) - .setInputCols(Array("feature1", "feature2")) - - // When both are set, we ignore `inputCols` and just map the column specified by `inputCol`. - assert(bucket.isBucketizeMultipleColumns() == false) + test("assert exception is thrown if both multi-column and single-column params are set") { + val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2") + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), + ("inputCols", Array("feature1", "feature2"))) + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), + ("outputCol", "result1"), ("splits", Array(-0.5, 0.0, 0.5)), + ("outputCols", Array("result1", "result2"))) + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), + ("outputCol", "result1"), ("splits", Array(-0.5, 0.0, 0.5)), + ("splitsArray", Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5)))) + + // this should fail because at least one of inputCol and inputCols must be set + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("outputCol", "feature1"), + ("splits", Array(-0.5, 0.0, 0.5))) + + // the following should fail because not all the params are set + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), + ("outputCol", "result1")) + ParamsSuite.testExclusiveParams(new Bucketizer, df, + ("inputCols", Array("feature1", "feature2")), + ("outputCols", Array("result1", "result2"))) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 85198ad4c913a..36e06091d24de 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -20,8 +20,10 @@ package org.apache.spark.ml.param import java.io.{ByteArrayOutputStream, ObjectOutputStream} import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.{Estimator, Transformer} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.util.MyParams +import org.apache.spark.sql.Dataset class ParamsSuite extends SparkFunSuite { @@ -430,4 +432,24 @@ object ParamsSuite extends SparkFunSuite { require(copyReturnType === obj.getClass, s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.") } + + /** + * Checks that the class throws an exception in case multiple exclusive params are set. + * The params to be checked are passed as arguments with their value. + */ + def testExclusiveParams( + model: Params, + dataset: Dataset[_], + paramsAndValues: (String, Any)*): Unit = { + val m = model.copy(ParamMap.empty) + paramsAndValues.foreach { case (paramName, paramValue) => + m.set(m.getParam(paramName), paramValue) + } + intercept[IllegalArgumentException] { + m match { + case t: Transformer => t.transform(dataset) + case e: Estimator[_] => e.fit(dataset) + } + } + } } From c22eaa94e85aaac649566495dcf763a5de3c8d06 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Fri, 26 Jan 2018 12:28:27 +0200 Subject: [PATCH 210/774] [SPARK-22797][PYSPARK] Bucketizer support multi-column ## What changes were proposed in this pull request? Bucketizer support multi-column in the python side ## How was this patch tested? existing tests and added tests Author: Zheng RuiFeng Closes #19892 from zhengruifeng/20542_py. --- python/pyspark/ml/feature.py | 105 +++++++++++++++++++++------- python/pyspark/ml/param/__init__.py | 10 +++ python/pyspark/ml/tests.py | 9 +++ 3 files changed, 99 insertions(+), 25 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index da85ba761a145..fdc7787140490 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -317,26 +317,33 @@ class BucketedRandomProjectionLSHModel(LSHModel, JavaMLReadable, JavaMLWritable) @inherit_doc -class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid, - JavaMLReadable, JavaMLWritable): - """ - Maps a column of continuous features to a column of feature buckets. - - >>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)] - >>> df = spark.createDataFrame(values, ["values"]) +class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols, + HasHandleInvalid, JavaMLReadable, JavaMLWritable): + """ + Maps a column of continuous features to a column of feature buckets. Since 2.3.0, + :py:class:`Bucketizer` can map multiple columns at once by setting the :py:attr:`inputCols` + parameter. Note that when both the :py:attr:`inputCol` and :py:attr:`inputCols` parameters + are set, an Exception will be thrown. The :py:attr:`splits` parameter is only used for single + column usage, and :py:attr:`splitsArray` is for multiple columns. + + >>> values = [(0.1, 0.0), (0.4, 1.0), (1.2, 1.3), (1.5, float("nan")), + ... (float("nan"), 1.0), (float("nan"), 0.0)] + >>> df = spark.createDataFrame(values, ["values1", "values2"]) >>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")], - ... inputCol="values", outputCol="buckets") - >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df).collect() - >>> len(bucketed) - 6 - >>> bucketed[0].buckets - 0.0 - >>> bucketed[1].buckets - 0.0 - >>> bucketed[2].buckets - 1.0 - >>> bucketed[3].buckets - 2.0 + ... inputCol="values1", outputCol="buckets") + >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df.select("values1")) + >>> bucketed.show(truncate=False) + +-------+-------+ + |values1|buckets| + +-------+-------+ + |0.1 |0.0 | + |0.4 |0.0 | + |1.2 |1.0 | + |1.5 |2.0 | + |NaN |3.0 | + |NaN |3.0 | + +-------+-------+ + ... >>> bucketizer.setParams(outputCol="b").transform(df).head().b 0.0 >>> bucketizerPath = temp_path + "/bucketizer" @@ -347,6 +354,22 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid, >>> bucketed = bucketizer.setHandleInvalid("skip").transform(df).collect() >>> len(bucketed) 4 + >>> bucketizer2 = Bucketizer(splitsArray= + ... [[-float("inf"), 0.5, 1.4, float("inf")], [-float("inf"), 0.5, float("inf")]], + ... inputCols=["values1", "values2"], outputCols=["buckets1", "buckets2"]) + >>> bucketed2 = bucketizer2.setHandleInvalid("keep").transform(df) + >>> bucketed2.show(truncate=False) + +-------+-------+--------+--------+ + |values1|values2|buckets1|buckets2| + +-------+-------+--------+--------+ + |0.1 |0.0 |0.0 |0.0 | + |0.4 |1.0 |0.0 |1.0 | + |1.2 |1.3 |1.0 |1.0 | + |1.5 |NaN |2.0 |2.0 | + |NaN |1.0 |3.0 |1.0 | + |NaN |0.0 |3.0 |0.0 | + +-------+-------+--------+--------+ + ... .. versionadded:: 1.4.0 """ @@ -363,14 +386,30 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid, handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " + "Options are 'skip' (filter out rows with invalid values), " + - "'error' (throw an error), or 'keep' (keep invalid values in a special " + - "additional bucket).", + "'error' (throw an error), or 'keep' (keep invalid values in a " + + "special additional bucket). Note that in the multiple column " + + "case, the invalid handling is applied to all columns. That said " + + "for 'error' it will throw an error if any invalids are found in " + + "any column, for 'skip' it will skip rows with any invalids in " + + "any columns, etc.", typeConverter=TypeConverters.toString) + splitsArray = Param(Params._dummy(), "splitsArray", "The array of split points for mapping " + + "continuous features into buckets for multiple columns. For each input " + + "column, with n+1 splits, there are n buckets. A bucket defined by " + + "splits x,y holds values in the range [x,y) except the last bucket, " + + "which also includes y. The splits should be of length >= 3 and " + + "strictly increasing. Values at -inf, inf must be explicitly provided " + + "to cover all Double values; otherwise, values outside the splits " + + "specified will be treated as errors.", + typeConverter=TypeConverters.toListListFloat) + @keyword_only - def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"): + def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", + splitsArray=None, inputCols=None, outputCols=None): """ - __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error") + __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \ + splitsArray=None, inputCols=None, outputCols=None) """ super(Bucketizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid) @@ -380,9 +419,11 @@ def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="er @keyword_only @since("1.4.0") - def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"): + def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", + splitsArray=None, inputCols=None, outputCols=None): """ - setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error") + setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \ + splitsArray=None, inputCols=None, outputCols=None) Sets params for this Bucketizer. """ kwargs = self._input_kwargs @@ -402,6 +443,20 @@ def getSplits(self): """ return self.getOrDefault(self.splits) + @since("2.3.0") + def setSplitsArray(self, value): + """ + Sets the value of :py:attr:`splitsArray`. + """ + return self._set(splitsArray=value) + + @since("2.3.0") + def getSplitsArray(self): + """ + Gets the array of split points or its default value. + """ + return self.getOrDefault(self.splitsArray) + @inherit_doc class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 043c25cf9feb4..5b6b70292f099 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -134,6 +134,16 @@ def toListFloat(value): return [float(v) for v in value] raise TypeError("Could not convert %s to list of floats" % value) + @staticmethod + def toListListFloat(value): + """ + Convert a value to list of list of floats, if possible. + """ + if TypeConverters._can_convert_to_list(value): + value = TypeConverters.toList(value) + return [TypeConverters.toListFloat(v) for v in value] + raise TypeError("Could not convert %s to list of list of floats" % value) + @staticmethod def toListInt(value): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 1af2b91da900d..b8bddbd06f165 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -238,6 +238,15 @@ def test_bool(self): self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept=1)) self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept="false")) + def test_list_list_float(self): + b = Bucketizer(splitsArray=[[-0.1, 0.5, 3], [-5, 1.5]]) + self.assertEqual(b.getSplitsArray(), [[-0.1, 0.5, 3.0], [-5.0, 1.5]]) + self.assertTrue(all([type(v) == list for v in b.getSplitsArray()])) + self.assertTrue(all([type(v) == float for v in b.getSplitsArray()[0]])) + self.assertTrue(all([type(v) == float for v in b.getSplitsArray()[1]])) + self.assertRaises(TypeError, lambda: Bucketizer(splitsArray=["a", 1.0])) + self.assertRaises(TypeError, lambda: Bucketizer(splitsArray=[[-5, 1.5], ["a", 1.0]])) + class PipelineTests(PySparkTestCase): From 3e252514741447004f3c18ddd77c617b4e37cfaa Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Fri, 26 Jan 2018 19:18:18 +0800 Subject: [PATCH 211/774] [SPARK-22068][CORE] Reduce the duplicate code between putIteratorAsValues and putIteratorAsBytes ## What changes were proposed in this pull request? The code logic between `MemoryStore.putIteratorAsValues` and `Memory.putIteratorAsBytes` are almost same, so we should reduce the duplicate code between them. ## How was this patch tested? Existing UT. Author: Xianyang Liu Closes #19285 from ConeyLiu/rmemorystore. --- .../spark/storage/memory/MemoryStore.scala | 336 ++++++++++-------- 1 file changed, 178 insertions(+), 158 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 17f7a69ad6ba1..4cc5bcb7f9baf 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -162,7 +162,7 @@ private[spark] class MemoryStore( } /** - * Attempt to put the given block in memory store as values. + * Attempt to put the given block in memory store as values or bytes. * * It's possible that the iterator is too large to materialize and store in memory. To avoid * OOM exceptions, this method will gradually unroll the iterator while periodically checking @@ -170,18 +170,24 @@ private[spark] class MemoryStore( * temporary unroll memory used during the materialization is "transferred" to storage memory, * so we won't acquire more memory than is actually needed to store the block. * - * @return in case of success, the estimated size of the stored data. In case of failure, return - * an iterator containing the values of the block. The returned iterator will be backed - * by the combination of the partially-unrolled block and the remaining elements of the - * original input iterator. The caller must either fully consume this iterator or call - * `close()` on it in order to free the storage memory consumed by the partially-unrolled - * block. + * @param blockId The block id. + * @param values The values which need be stored. + * @param classTag the [[ClassTag]] for the block. + * @param memoryMode The values saved memory mode(ON_HEAP or OFF_HEAP). + * @param valuesHolder A holder that supports storing record of values into memory store as + * values or bytes. + * @return if the block is stored successfully, return the stored data size. Else return the + * memory has reserved for unrolling the block (There are two reasons for store failed: + * First, the block is partially-unrolled; second, the block is entirely unrolled and + * the actual stored data size is larger than reserved, but we can't request extra + * memory). */ - private[storage] def putIteratorAsValues[T]( + private def putIterator[T]( blockId: BlockId, values: Iterator[T], - classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = { - + classTag: ClassTag[T], + memoryMode: MemoryMode, + valuesHolder: ValuesHolder[T]): Either[Long, Long] = { require(!contains(blockId), s"Block $blockId is already present in the MemoryStore") // Number of elements unrolled so far @@ -198,12 +204,10 @@ private[spark] class MemoryStore( val memoryGrowthFactor = conf.get(UNROLL_MEMORY_GROWTH_FACTOR) // Keep track of unroll memory used by this particular block / putIterator() operation var unrollMemoryUsedByThisBlock = 0L - // Underlying vector for unrolling the block - var vector = new SizeTrackingVector[T]()(classTag) // Request enough memory to begin unrolling keepUnrolling = - reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, MemoryMode.ON_HEAP) + reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, memoryMode) if (!keepUnrolling) { logWarning(s"Failed to reserve initial memory threshold of " + @@ -214,14 +218,14 @@ private[spark] class MemoryStore( // Unroll this block safely, checking whether we have exceeded our threshold periodically while (values.hasNext && keepUnrolling) { - vector += values.next() + valuesHolder.storeValue(values.next()) if (elementsUnrolled % memoryCheckPeriod == 0) { + val currentSize = valuesHolder.estimatedSize() // If our vector's size has exceeded the threshold, request more memory - val currentSize = vector.estimateSize() if (currentSize >= memoryThreshold) { val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong keepUnrolling = - reserveUnrollMemoryForThisTask(blockId, amountToRequest, MemoryMode.ON_HEAP) + reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) if (keepUnrolling) { unrollMemoryUsedByThisBlock += amountToRequest } @@ -232,78 +236,86 @@ private[spark] class MemoryStore( elementsUnrolled += 1 } + // Make sure that we have enough memory to store the block. By this point, it is possible that + // the block's actual memory usage has exceeded the unroll memory by a small amount, so we + // perform one final call to attempt to allocate additional memory if necessary. if (keepUnrolling) { - // We successfully unrolled the entirety of this block - val arrayValues = vector.toArray - vector = null - val entry = - new DeserializedMemoryEntry[T](arrayValues, SizeEstimator.estimate(arrayValues), classTag) - val size = entry.size - def transferUnrollToStorage(amount: Long): Unit = { + val entryBuilder = valuesHolder.getBuilder() + val size = entryBuilder.preciseSize + if (size > unrollMemoryUsedByThisBlock) { + val amountToRequest = size - unrollMemoryUsedByThisBlock + keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) + if (keepUnrolling) { + unrollMemoryUsedByThisBlock += amountToRequest + } + } + + if (keepUnrolling) { + val entry = entryBuilder.build() // Synchronize so that transfer is atomic memoryManager.synchronized { - releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, amount) - val success = memoryManager.acquireStorageMemory(blockId, amount, MemoryMode.ON_HEAP) + releaseUnrollMemoryForThisTask(memoryMode, unrollMemoryUsedByThisBlock) + val success = memoryManager.acquireStorageMemory(blockId, entry.size, memoryMode) assert(success, "transferring unroll memory to storage memory failed") } - } - // Acquire storage memory if necessary to store this block in memory. - val enoughStorageMemory = { - if (unrollMemoryUsedByThisBlock <= size) { - val acquiredExtra = - memoryManager.acquireStorageMemory( - blockId, size - unrollMemoryUsedByThisBlock, MemoryMode.ON_HEAP) - if (acquiredExtra) { - transferUnrollToStorage(unrollMemoryUsedByThisBlock) - } - acquiredExtra - } else { // unrollMemoryUsedByThisBlock > size - // If this task attempt already owns more unroll memory than is necessary to store the - // block, then release the extra memory that will not be used. - val excessUnrollMemory = unrollMemoryUsedByThisBlock - size - releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, excessUnrollMemory) - transferUnrollToStorage(size) - true - } - } - if (enoughStorageMemory) { + entries.synchronized { entries.put(blockId, entry) } - logInfo("Block %s stored as values in memory (estimated size %s, free %s)".format( - blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed))) - Right(size) + + logInfo("Block %s stored as values in memory (estimated size %s, free %s)".format(blockId, + Utils.bytesToString(entry.size), Utils.bytesToString(maxMemory - blocksMemoryUsed))) + Right(entry.size) } else { - assert(currentUnrollMemoryForThisTask >= unrollMemoryUsedByThisBlock, - "released too much unroll memory") + // We ran out of space while unrolling the values for this block + logUnrollFailureMessage(blockId, entryBuilder.preciseSize) + Left(unrollMemoryUsedByThisBlock) + } + } else { + // We ran out of space while unrolling the values for this block + logUnrollFailureMessage(blockId, valuesHolder.estimatedSize()) + Left(unrollMemoryUsedByThisBlock) + } + } + + /** + * Attempt to put the given block in memory store as values. + * + * @return in case of success, the estimated size of the stored data. In case of failure, return + * an iterator containing the values of the block. The returned iterator will be backed + * by the combination of the partially-unrolled block and the remaining elements of the + * original input iterator. The caller must either fully consume this iterator or call + * `close()` on it in order to free the storage memory consumed by the partially-unrolled + * block. + */ + private[storage] def putIteratorAsValues[T]( + blockId: BlockId, + values: Iterator[T], + classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = { + + val valuesHolder = new DeserializedValuesHolder[T](classTag) + + putIterator(blockId, values, classTag, MemoryMode.ON_HEAP, valuesHolder) match { + case Right(storedSize) => Right(storedSize) + case Left(unrollMemoryUsedByThisBlock) => + val unrolledIterator = if (valuesHolder.vector != null) { + valuesHolder.vector.iterator + } else { + valuesHolder.arrayValues.toIterator + } + Left(new PartiallyUnrolledIterator( this, MemoryMode.ON_HEAP, unrollMemoryUsedByThisBlock, - unrolled = arrayValues.toIterator, - rest = Iterator.empty)) - } - } else { - // We ran out of space while unrolling the values for this block - logUnrollFailureMessage(blockId, vector.estimateSize()) - Left(new PartiallyUnrolledIterator( - this, - MemoryMode.ON_HEAP, - unrollMemoryUsedByThisBlock, - unrolled = vector.iterator, - rest = values)) + unrolled = unrolledIterator, + rest = values)) } } /** * Attempt to put the given block in memory store as bytes. * - * It's possible that the iterator is too large to materialize and store in memory. To avoid - * OOM exceptions, this method will gradually unroll the iterator while periodically checking - * whether there is enough free memory. If the block is successfully materialized, then the - * temporary unroll memory used during the materialization is "transferred" to storage memory, - * so we won't acquire more memory than is actually needed to store the block. - * * @return in case of success, the estimated size of the stored data. In case of failure, * return a handle which allows the caller to either finish the serialization by * spilling to disk or to deserialize the partially-serialized block and reconstruct @@ -319,25 +331,8 @@ private[spark] class MemoryStore( require(!contains(blockId), s"Block $blockId is already present in the MemoryStore") - val allocator = memoryMode match { - case MemoryMode.ON_HEAP => ByteBuffer.allocate _ - case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _ - } - - // Whether there is still enough memory for us to continue unrolling this block - var keepUnrolling = true - // Number of elements unrolled so far - var elementsUnrolled = 0L - // How often to check whether we need to request more memory - val memoryCheckPeriod = conf.get(UNROLL_MEMORY_CHECK_PERIOD) - // Memory to request as a multiple of current bbos size - val memoryGrowthFactor = conf.get(UNROLL_MEMORY_GROWTH_FACTOR) // Initial per-task memory to request for unrolling blocks (bytes). val initialMemoryThreshold = unrollMemoryThreshold - // Keep track of unroll memory used by this particular block / putIterator() operation - var unrollMemoryUsedByThisBlock = 0L - // Underlying buffer for unrolling the block - val redirectableStream = new RedirectableOutputStream val chunkSize = if (initialMemoryThreshold > Int.MaxValue) { logWarning(s"Initial memory threshold of ${Utils.bytesToString(initialMemoryThreshold)} " + s"is too large to be set as chunk size. Chunk size has been capped to " + @@ -346,85 +341,22 @@ private[spark] class MemoryStore( } else { initialMemoryThreshold.toInt } - val bbos = new ChunkedByteBufferOutputStream(chunkSize, allocator) - redirectableStream.setOutputStream(bbos) - val serializationStream: SerializationStream = { - val autoPick = !blockId.isInstanceOf[StreamBlockId] - val ser = serializerManager.getSerializer(classTag, autoPick).newInstance() - ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream)) - } - // Request enough memory to begin unrolling - keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, memoryMode) + val valuesHolder = new SerializedValuesHolder[T](blockId, chunkSize, classTag, + memoryMode, serializerManager) - if (!keepUnrolling) { - logWarning(s"Failed to reserve initial memory threshold of " + - s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.") - } else { - unrollMemoryUsedByThisBlock += initialMemoryThreshold - } - - def reserveAdditionalMemoryIfNecessary(): Unit = { - if (bbos.size > unrollMemoryUsedByThisBlock) { - val amountToRequest = (bbos.size * memoryGrowthFactor - unrollMemoryUsedByThisBlock).toLong - keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) - if (keepUnrolling) { - unrollMemoryUsedByThisBlock += amountToRequest - } - } - } - - // Unroll this block safely, checking whether we have exceeded our threshold - while (values.hasNext && keepUnrolling) { - serializationStream.writeObject(values.next())(classTag) - elementsUnrolled += 1 - if (elementsUnrolled % memoryCheckPeriod == 0) { - reserveAdditionalMemoryIfNecessary() - } - } - - // Make sure that we have enough memory to store the block. By this point, it is possible that - // the block's actual memory usage has exceeded the unroll memory by a small amount, so we - // perform one final call to attempt to allocate additional memory if necessary. - if (keepUnrolling) { - serializationStream.close() - if (bbos.size > unrollMemoryUsedByThisBlock) { - val amountToRequest = bbos.size - unrollMemoryUsedByThisBlock - keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) - if (keepUnrolling) { - unrollMemoryUsedByThisBlock += amountToRequest - } - } - } - - if (keepUnrolling) { - val entry = SerializedMemoryEntry[T](bbos.toChunkedByteBuffer, memoryMode, classTag) - // Synchronize so that transfer is atomic - memoryManager.synchronized { - releaseUnrollMemoryForThisTask(memoryMode, unrollMemoryUsedByThisBlock) - val success = memoryManager.acquireStorageMemory(blockId, entry.size, memoryMode) - assert(success, "transferring unroll memory to storage memory failed") - } - entries.synchronized { - entries.put(blockId, entry) - } - logInfo("Block %s stored as bytes in memory (estimated size %s, free %s)".format( - blockId, Utils.bytesToString(entry.size), - Utils.bytesToString(maxMemory - blocksMemoryUsed))) - Right(entry.size) - } else { - // We ran out of space while unrolling the values for this block - logUnrollFailureMessage(blockId, bbos.size) - Left( - new PartiallySerializedBlock( + putIterator(blockId, values, classTag, memoryMode, valuesHolder) match { + case Right(storedSize) => Right(storedSize) + case Left(unrollMemoryUsedByThisBlock) => + Left(new PartiallySerializedBlock( this, serializerManager, blockId, - serializationStream, - redirectableStream, + valuesHolder.serializationStream, + valuesHolder.redirectableStream, unrollMemoryUsedByThisBlock, memoryMode, - bbos, + valuesHolder.bbos, values, classTag)) } @@ -702,6 +634,94 @@ private[spark] class MemoryStore( } } +private trait MemoryEntryBuilder[T] { + def preciseSize: Long + def build(): MemoryEntry[T] +} + +private trait ValuesHolder[T] { + def storeValue(value: T): Unit + def estimatedSize(): Long + + /** + * Note: After this method is called, the ValuesHolder is invalid, we can't store data and + * get estimate size again. + * @return a MemoryEntryBuilder which is used to build a memory entry and get the stored data + * size. + */ + def getBuilder(): MemoryEntryBuilder[T] +} + +/** + * A holder for storing the deserialized values. + */ +private class DeserializedValuesHolder[T] (classTag: ClassTag[T]) extends ValuesHolder[T] { + // Underlying vector for unrolling the block + var vector = new SizeTrackingVector[T]()(classTag) + var arrayValues: Array[T] = null + + override def storeValue(value: T): Unit = { + vector += value + } + + override def estimatedSize(): Long = { + vector.estimateSize() + } + + override def getBuilder(): MemoryEntryBuilder[T] = new MemoryEntryBuilder[T] { + // We successfully unrolled the entirety of this block + arrayValues = vector.toArray + vector = null + + override val preciseSize: Long = SizeEstimator.estimate(arrayValues) + + override def build(): MemoryEntry[T] = + DeserializedMemoryEntry[T](arrayValues, preciseSize, classTag) + } +} + +/** + * A holder for storing the serialized values. + */ +private class SerializedValuesHolder[T]( + blockId: BlockId, + chunkSize: Int, + classTag: ClassTag[T], + memoryMode: MemoryMode, + serializerManager: SerializerManager) extends ValuesHolder[T] { + val allocator = memoryMode match { + case MemoryMode.ON_HEAP => ByteBuffer.allocate _ + case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _ + } + + val redirectableStream = new RedirectableOutputStream + val bbos = new ChunkedByteBufferOutputStream(chunkSize, allocator) + redirectableStream.setOutputStream(bbos) + val serializationStream: SerializationStream = { + val autoPick = !blockId.isInstanceOf[StreamBlockId] + val ser = serializerManager.getSerializer(classTag, autoPick).newInstance() + ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream)) + } + + override def storeValue(value: T): Unit = { + serializationStream.writeObject(value)(classTag) + } + + override def estimatedSize(): Long = { + bbos.size + } + + override def getBuilder(): MemoryEntryBuilder[T] = new MemoryEntryBuilder[T] { + // We successfully unrolled the entirety of this block + serializationStream.close() + + override def preciseSize(): Long = bbos.size + + override def build(): MemoryEntry[T] = + SerializedMemoryEntry[T](bbos.toChunkedByteBuffer, memoryMode, classTag) + } +} + /** * The result of a failed [[MemoryStore.putIteratorAsValues()]] call. * From dd8e257d1ccf20f4383dd7f30d634010b176f0d3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 26 Jan 2018 09:17:05 -0800 Subject: [PATCH 212/774] [SPARK-23218][SQL] simplify ColumnVector.getArray ## What changes were proposed in this pull request? `ColumnVector` is very flexible about how to implement array type. As a result `ColumnVector` has 3 abstract methods for array type: `arrayData`, `getArrayOffset`, `getArrayLength`. For example, in `WritableColumnVector` we use the first child vector as the array data vector, and store offsets and lengths in 2 arrays in the parent vector. `ArrowColumnVector` has a different implementation. This PR simplifies `ColumnVector` by using only one abstract method for array type: `getArray`. ## How was this patch tested? existing tests. rerun `ColumnarBatchBenchmark`, there is no performance regression. Author: Wenchen Fan Closes #20395 from cloud-fan/vector. --- .../datasources/orc/OrcColumnVector.java | 13 +-- .../vectorized/WritableColumnVector.java | 13 ++- .../sql/vectorized/ArrowColumnVector.java | 48 ++++------ .../spark/sql/vectorized/ColumnVector.java | 88 ++++++++++--------- .../spark/sql/vectorized/ColumnarArray.java | 2 + .../spark/sql/vectorized/ColumnarBatch.java | 2 + .../spark/sql/vectorized/ColumnarRow.java | 2 + .../vectorized/ColumnarBatchBenchmark.scala | 14 ++- 8 files changed, 87 insertions(+), 95 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index aaf2a380034a9..5078bc7922ee2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -24,6 +24,7 @@ import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.sql.vectorized.ColumnarArray; import org.apache.spark.unsafe.types.UTF8String; /** @@ -145,16 +146,6 @@ public double getDouble(int rowId) { return doubleData.vector[getRowIndex(rowId)]; } - @Override - public int getArrayLength(int rowId) { - throw new UnsupportedOperationException(); - } - - @Override - public int getArrayOffset(int rowId) { - throw new UnsupportedOperationException(); - } - @Override public Decimal getDecimal(int rowId, int precision, int scale) { BigDecimal data = decimalData.vector[getRowIndex(rowId)].getHiveDecimal().bigDecimalValue(); @@ -177,7 +168,7 @@ public byte[] getBinary(int rowId) { } @Override - public org.apache.spark.sql.vectorized.ColumnVector arrayData() { + public ColumnarArray getArray(int rowId) { throw new UnsupportedOperationException(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index ca4f00985c2a3..a8ec8ef2aadf8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -24,6 +24,7 @@ import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.*; import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; @@ -602,7 +603,17 @@ public final int appendStruct(boolean isNull) { // `WritableColumnVector` puts the data of array in the first child column vector, and puts the // array offsets and lengths in the current column vector. @Override - public WritableColumnVector arrayData() { return childColumns[0]; } + public final ColumnarArray getArray(int rowId) { + return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId)); + } + + public WritableColumnVector arrayData() { + return childColumns[0]; + } + + public abstract int getArrayLength(int rowId); + + public abstract int getArrayOffset(int rowId); @Override public WritableColumnVector getChild(int ordinal) { return childColumns[ordinal]; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index ca7a4751450d4..9803c3dec6de2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -17,17 +17,21 @@ package org.apache.spark.sql.vectorized; +import io.netty.buffer.ArrowBuf; import org.apache.arrow.vector.*; import org.apache.arrow.vector.complex.*; import org.apache.arrow.vector.holders.NullableVarCharHolder; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.arrow.ArrowUtils; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.UTF8String; /** - * A column vector backed by Apache Arrow. + * A column vector backed by Apache Arrow. Currently time interval type and map type are not + * supported. */ +@InterfaceStability.Evolving public final class ArrowColumnVector extends ColumnVector { private final ArrowVectorAccessor accessor; @@ -90,16 +94,6 @@ public double getDouble(int rowId) { return accessor.getDouble(rowId); } - @Override - public int getArrayLength(int rowId) { - return accessor.getArrayLength(rowId); - } - - @Override - public int getArrayOffset(int rowId) { - return accessor.getArrayOffset(rowId); - } - @Override public Decimal getDecimal(int rowId, int precision, int scale) { return accessor.getDecimal(rowId, precision, scale); @@ -116,7 +110,9 @@ public byte[] getBinary(int rowId) { } @Override - public ArrowColumnVector arrayData() { return childColumns[0]; } + public ColumnarArray getArray(int rowId) { + return accessor.getArray(rowId); + } @Override public ArrowColumnVector getChild(int ordinal) { return childColumns[ordinal]; } @@ -151,9 +147,6 @@ public ArrowColumnVector(ValueVector vector) { } else if (vector instanceof ListVector) { ListVector listVector = (ListVector) vector; accessor = new ArrayAccessor(listVector); - - childColumns = new ArrowColumnVector[1]; - childColumns[0] = new ArrowColumnVector(listVector.getDataVector()); } else if (vector instanceof NullableMapVector) { NullableMapVector mapVector = (NullableMapVector) vector; accessor = new StructAccessor(mapVector); @@ -180,10 +173,6 @@ boolean isNullAt(int rowId) { return vector.isNull(rowId); } - final int getValueCount() { - return vector.getValueCount(); - } - final int getNullCount() { return vector.getNullCount(); } @@ -232,11 +221,7 @@ byte[] getBinary(int rowId) { throw new UnsupportedOperationException(); } - int getArrayLength(int rowId) { - throw new UnsupportedOperationException(); - } - - int getArrayOffset(int rowId) { + ColumnarArray getArray(int rowId) { throw new UnsupportedOperationException(); } } @@ -433,10 +418,12 @@ final long getLong(int rowId) { private static class ArrayAccessor extends ArrowVectorAccessor { private final ListVector accessor; + private final ArrowColumnVector arrayData; ArrayAccessor(ListVector vector) { super(vector); this.accessor = vector; + this.arrayData = new ArrowColumnVector(vector.getDataVector()); } @Override @@ -450,13 +437,12 @@ final boolean isNullAt(int rowId) { } @Override - final int getArrayLength(int rowId) { - return accessor.getInnerValueCountAt(rowId); - } - - @Override - final int getArrayOffset(int rowId) { - return accessor.getOffsetBuffer().getInt(rowId * accessor.OFFSET_WIDTH); + final ColumnarArray getArray(int rowId) { + ArrowBuf offsets = accessor.getOffsetBuffer(); + int index = rowId * accessor.OFFSET_WIDTH; + int start = offsets.getInt(index); + int end = offsets.getInt(index + accessor.OFFSET_WIDTH); + return new ColumnarArray(arrayData, start, end - start); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index f9936214035b6..4b955ceddd0f2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.vectorized; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; @@ -29,11 +30,14 @@ * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values * in this ColumnVector. * + * Spark only calls specific `get` method according to the data type of this {@link ColumnVector}, + * e.g. if it's int type, Spark is guaranteed to only call {@link #getInt(int)} or + * {@link #getInts(int, int)}. + * * ColumnVector supports all the data types including nested types. To handle nested types, - * ColumnVector can have children and is a tree structure. For struct type, it stores the actual - * data of each field in the corresponding child ColumnVector, and only stores null information in - * the parent ColumnVector. For array type, it stores the actual array elements in the child - * ColumnVector, and stores null information, array offsets and lengths in the parent ColumnVector. + * ColumnVector can have children and is a tree structure. Please refer to {@link #getStruct(int)}, + * {@link #getArray(int)} and {@link #getMap(int)} for the details about how to implement nested + * types. * * ColumnVector is expected to be reused during the entire data loading process, to avoid allocating * memory again and again. @@ -43,6 +47,7 @@ * format. Since it is expected to reuse the ColumnVector instance while loading data, the storage * footprint is negligible. */ +@InterfaceStability.Evolving public abstract class ColumnVector implements AutoCloseable { /** @@ -70,12 +75,12 @@ public abstract class ColumnVector implements AutoCloseable { public abstract boolean isNullAt(int rowId); /** - * Returns the value for rowId. + * Returns the boolean type value for rowId. */ public abstract boolean getBoolean(int rowId); /** - * Gets values from [rowId, rowId + count) + * Gets boolean type values from [rowId, rowId + count) */ public boolean[] getBooleans(int rowId, int count) { boolean[] res = new boolean[count]; @@ -86,12 +91,12 @@ public boolean[] getBooleans(int rowId, int count) { } /** - * Returns the value for rowId. + * Returns the byte type value for rowId. */ public abstract byte getByte(int rowId); /** - * Gets values from [rowId, rowId + count) + * Gets byte type values from [rowId, rowId + count) */ public byte[] getBytes(int rowId, int count) { byte[] res = new byte[count]; @@ -102,12 +107,12 @@ public byte[] getBytes(int rowId, int count) { } /** - * Returns the value for rowId. + * Returns the short type value for rowId. */ public abstract short getShort(int rowId); /** - * Gets values from [rowId, rowId + count) + * Gets short type values from [rowId, rowId + count) */ public short[] getShorts(int rowId, int count) { short[] res = new short[count]; @@ -118,12 +123,12 @@ public short[] getShorts(int rowId, int count) { } /** - * Returns the value for rowId. + * Returns the int type value for rowId. */ public abstract int getInt(int rowId); /** - * Gets values from [rowId, rowId + count) + * Gets int type values from [rowId, rowId + count) */ public int[] getInts(int rowId, int count) { int[] res = new int[count]; @@ -134,12 +139,12 @@ public int[] getInts(int rowId, int count) { } /** - * Returns the value for rowId. + * Returns the long type value for rowId. */ public abstract long getLong(int rowId); /** - * Gets values from [rowId, rowId + count) + * Gets long type values from [rowId, rowId + count) */ public long[] getLongs(int rowId, int count) { long[] res = new long[count]; @@ -150,12 +155,12 @@ public long[] getLongs(int rowId, int count) { } /** - * Returns the value for rowId. + * Returns the float type value for rowId. */ public abstract float getFloat(int rowId); /** - * Gets values from [rowId, rowId + count) + * Gets float type values from [rowId, rowId + count) */ public float[] getFloats(int rowId, int count) { float[] res = new float[count]; @@ -166,12 +171,12 @@ public float[] getFloats(int rowId, int count) { } /** - * Returns the value for rowId. + * Returns the double type value for rowId. */ public abstract double getDouble(int rowId); /** - * Gets values from [rowId, rowId + count) + * Gets double type values from [rowId, rowId + count) */ public double[] getDoubles(int rowId, int count) { double[] res = new double[count]; @@ -182,57 +187,54 @@ public double[] getDoubles(int rowId, int count) { } /** - * Returns the length of the array for rowId. - */ - public abstract int getArrayLength(int rowId); - - /** - * Returns the offset of the array for rowId. - */ - public abstract int getArrayOffset(int rowId); - - /** - * Returns the struct for rowId. + * Returns the struct type value for rowId. + * + * To support struct type, implementations must implement {@link #getChild(int)} and make this + * vector a tree structure. The number of child vectors must be same as the number of fields of + * the struct type, and each child vector is responsible to store the data for its corresponding + * struct field. */ public final ColumnarRow getStruct(int rowId) { return new ColumnarRow(this, rowId); } /** - * Returns the array for rowId. + * Returns the array type value for rowId. + * + * To support array type, implementations must construct an {@link ColumnarArray} and return it in + * this method. {@link ColumnarArray} requires a {@link ColumnVector} that stores the data of all + * the elements of all the arrays in this vector, and an offset and length which points to a range + * in that {@link ColumnVector}, and the range represents the array for rowId. Implementations + * are free to decide where to put the data vector and offsets and lengths. For example, we can + * use the first child vector as the data vector, and store offsets and lengths in 2 int arrays in + * this vector. */ - public final ColumnarArray getArray(int rowId) { - return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId)); - } + public abstract ColumnarArray getArray(int rowId); /** - * Returns the map for rowId. + * Returns the map type value for rowId. */ public MapData getMap(int ordinal) { throw new UnsupportedOperationException(); } /** - * Returns the decimal for rowId. + * Returns the decimal type value for rowId. */ public abstract Decimal getDecimal(int rowId, int precision, int scale); /** - * Returns the UTF8String for rowId. Note that the returned UTF8String may point to the data of - * this column vector, please copy it if you want to keep it after this column vector is freed. + * Returns the string type value for rowId. Note that the returned UTF8String may point to the + * data of this column vector, please copy it if you want to keep it after this column vector is + * freed. */ public abstract UTF8String getUTF8String(int rowId); /** - * Returns the byte array for rowId. + * Returns the binary type value for rowId. */ public abstract byte[] getBinary(int rowId); - /** - * Returns the data for the underlying array. - */ - public abstract ColumnVector arrayData(); - /** * Returns the ordinal's child column vector. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 522c39580389f..0d2c3ec8648d3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.vectorized; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; @@ -25,6 +26,7 @@ /** * Array abstraction in {@link ColumnVector}. */ +@InterfaceStability.Evolving public final class ColumnarArray extends ArrayData { // The data for this array. This array contains elements from // data[offset] to data[offset + length). diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java index 4dc826cf60c15..d206c1df42abb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java @@ -18,6 +18,7 @@ import java.util.*; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.MutableColumnarRow; @@ -26,6 +27,7 @@ * batch so that Spark can access the data row by row. Instance of it is meant to be reused during * the entire data loading process. */ +@InterfaceStability.Evolving public final class ColumnarBatch { private int numRows; private final ColumnVector[] columns; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 2e59085a82768..25db7e09d20d0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.vectorized; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.util.MapData; @@ -26,6 +27,7 @@ /** * Row abstraction in {@link ColumnVector}. */ +@InterfaceStability.Evolving public final class ColumnarRow extends InternalRow { // The data for this row. // E.g. the value of 3rd int field is `data.getChild(3).getInt(rowId)`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index ad74fb99b0c73..1f31aa45a1220 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.parquet +package org.apache.spark.sql.execution.vectorized import java.nio.ByteBuffer import java.nio.charset.StandardCharsets @@ -23,8 +23,6 @@ import scala.util.Random import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector -import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType} import org.apache.spark.unsafe.Platform import org.apache.spark.util.Benchmark @@ -434,7 +432,6 @@ object ColumnarBatchBenchmark { } def readArrays(onHeap: Boolean): Unit = { - System.gc() val vector = if (onHeap) onHeapVector else offHeapVector var sum = 0L @@ -448,7 +445,6 @@ object ColumnarBatchBenchmark { } def readArrayElements(onHeap: Boolean): Unit = { - System.gc() val vector = if (onHeap) onHeapVector else offHeapVector var sum = 0L @@ -479,10 +475,10 @@ object ColumnarBatchBenchmark { Array Vector Read: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - On Heap Read Size Only 416 / 423 393.5 2.5 1.0X - Off Heap Read Size Only 396 / 404 413.6 2.4 1.1X - On Heap Read Elements 2569 / 2590 63.8 15.7 0.2X - Off Heap Read Elements 3302 / 3333 49.6 20.2 0.1X + On Heap Read Size Only 426 / 437 384.9 2.6 1.0X + Off Heap Read Size Only 406 / 421 404.0 2.5 1.0X + On Heap Read Elements 2636 / 2642 62.2 16.1 0.2X + Off Heap Read Elements 3770 / 3774 43.5 23.0 0.1X */ benchmark.run } From a8a3e9b7cf7b9346c43cfbbf7b26fd2fd28dd521 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Fri, 26 Jan 2018 23:48:02 +0200 Subject: [PATCH 213/774] Revert "[SPARK-22797][PYSPARK] Bucketizer support multi-column" This reverts commit c22eaa94e85aaac649566495dcf763a5de3c8d06. --- python/pyspark/ml/feature.py | 105 +++++++--------------------- python/pyspark/ml/param/__init__.py | 10 --- python/pyspark/ml/tests.py | 9 --- 3 files changed, 25 insertions(+), 99 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index fdc7787140490..da85ba761a145 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -317,33 +317,26 @@ class BucketedRandomProjectionLSHModel(LSHModel, JavaMLReadable, JavaMLWritable) @inherit_doc -class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols, - HasHandleInvalid, JavaMLReadable, JavaMLWritable): - """ - Maps a column of continuous features to a column of feature buckets. Since 2.3.0, - :py:class:`Bucketizer` can map multiple columns at once by setting the :py:attr:`inputCols` - parameter. Note that when both the :py:attr:`inputCol` and :py:attr:`inputCols` parameters - are set, an Exception will be thrown. The :py:attr:`splits` parameter is only used for single - column usage, and :py:attr:`splitsArray` is for multiple columns. - - >>> values = [(0.1, 0.0), (0.4, 1.0), (1.2, 1.3), (1.5, float("nan")), - ... (float("nan"), 1.0), (float("nan"), 0.0)] - >>> df = spark.createDataFrame(values, ["values1", "values2"]) +class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid, + JavaMLReadable, JavaMLWritable): + """ + Maps a column of continuous features to a column of feature buckets. + + >>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)] + >>> df = spark.createDataFrame(values, ["values"]) >>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")], - ... inputCol="values1", outputCol="buckets") - >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df.select("values1")) - >>> bucketed.show(truncate=False) - +-------+-------+ - |values1|buckets| - +-------+-------+ - |0.1 |0.0 | - |0.4 |0.0 | - |1.2 |1.0 | - |1.5 |2.0 | - |NaN |3.0 | - |NaN |3.0 | - +-------+-------+ - ... + ... inputCol="values", outputCol="buckets") + >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df).collect() + >>> len(bucketed) + 6 + >>> bucketed[0].buckets + 0.0 + >>> bucketed[1].buckets + 0.0 + >>> bucketed[2].buckets + 1.0 + >>> bucketed[3].buckets + 2.0 >>> bucketizer.setParams(outputCol="b").transform(df).head().b 0.0 >>> bucketizerPath = temp_path + "/bucketizer" @@ -354,22 +347,6 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOu >>> bucketed = bucketizer.setHandleInvalid("skip").transform(df).collect() >>> len(bucketed) 4 - >>> bucketizer2 = Bucketizer(splitsArray= - ... [[-float("inf"), 0.5, 1.4, float("inf")], [-float("inf"), 0.5, float("inf")]], - ... inputCols=["values1", "values2"], outputCols=["buckets1", "buckets2"]) - >>> bucketed2 = bucketizer2.setHandleInvalid("keep").transform(df) - >>> bucketed2.show(truncate=False) - +-------+-------+--------+--------+ - |values1|values2|buckets1|buckets2| - +-------+-------+--------+--------+ - |0.1 |0.0 |0.0 |0.0 | - |0.4 |1.0 |0.0 |1.0 | - |1.2 |1.3 |1.0 |1.0 | - |1.5 |NaN |2.0 |2.0 | - |NaN |1.0 |3.0 |1.0 | - |NaN |0.0 |3.0 |0.0 | - +-------+-------+--------+--------+ - ... .. versionadded:: 1.4.0 """ @@ -386,30 +363,14 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOu handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " + "Options are 'skip' (filter out rows with invalid values), " + - "'error' (throw an error), or 'keep' (keep invalid values in a " + - "special additional bucket). Note that in the multiple column " + - "case, the invalid handling is applied to all columns. That said " + - "for 'error' it will throw an error if any invalids are found in " + - "any column, for 'skip' it will skip rows with any invalids in " + - "any columns, etc.", + "'error' (throw an error), or 'keep' (keep invalid values in a special " + + "additional bucket).", typeConverter=TypeConverters.toString) - splitsArray = Param(Params._dummy(), "splitsArray", "The array of split points for mapping " + - "continuous features into buckets for multiple columns. For each input " + - "column, with n+1 splits, there are n buckets. A bucket defined by " + - "splits x,y holds values in the range [x,y) except the last bucket, " + - "which also includes y. The splits should be of length >= 3 and " + - "strictly increasing. Values at -inf, inf must be explicitly provided " + - "to cover all Double values; otherwise, values outside the splits " + - "specified will be treated as errors.", - typeConverter=TypeConverters.toListListFloat) - @keyword_only - def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", - splitsArray=None, inputCols=None, outputCols=None): + def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"): """ - __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \ - splitsArray=None, inputCols=None, outputCols=None) + __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error") """ super(Bucketizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid) @@ -419,11 +380,9 @@ def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="er @keyword_only @since("1.4.0") - def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", - splitsArray=None, inputCols=None, outputCols=None): + def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"): """ - setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \ - splitsArray=None, inputCols=None, outputCols=None) + setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error") Sets params for this Bucketizer. """ kwargs = self._input_kwargs @@ -443,20 +402,6 @@ def getSplits(self): """ return self.getOrDefault(self.splits) - @since("2.3.0") - def setSplitsArray(self, value): - """ - Sets the value of :py:attr:`splitsArray`. - """ - return self._set(splitsArray=value) - - @since("2.3.0") - def getSplitsArray(self): - """ - Gets the array of split points or its default value. - """ - return self.getOrDefault(self.splitsArray) - @inherit_doc class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 5b6b70292f099..043c25cf9feb4 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -134,16 +134,6 @@ def toListFloat(value): return [float(v) for v in value] raise TypeError("Could not convert %s to list of floats" % value) - @staticmethod - def toListListFloat(value): - """ - Convert a value to list of list of floats, if possible. - """ - if TypeConverters._can_convert_to_list(value): - value = TypeConverters.toList(value) - return [TypeConverters.toListFloat(v) for v in value] - raise TypeError("Could not convert %s to list of list of floats" % value) - @staticmethod def toListInt(value): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index b8bddbd06f165..1af2b91da900d 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -238,15 +238,6 @@ def test_bool(self): self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept=1)) self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept="false")) - def test_list_list_float(self): - b = Bucketizer(splitsArray=[[-0.1, 0.5, 3], [-5, 1.5]]) - self.assertEqual(b.getSplitsArray(), [[-0.1, 0.5, 3.0], [-5.0, 1.5]]) - self.assertTrue(all([type(v) == list for v in b.getSplitsArray()])) - self.assertTrue(all([type(v) == float for v in b.getSplitsArray()[0]])) - self.assertTrue(all([type(v) == float for v in b.getSplitsArray()[1]])) - self.assertRaises(TypeError, lambda: Bucketizer(splitsArray=["a", 1.0])) - self.assertRaises(TypeError, lambda: Bucketizer(splitsArray=[[-5, 1.5], ["a", 1.0]])) - class PipelineTests(PySparkTestCase): From 94c67a76ec1fda908a671a47a2a1fa63b3ab1b06 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 26 Jan 2018 15:01:03 -0800 Subject: [PATCH 214/774] [SPARK-23207][SQL] Shuffle+Repartition on a DataFrame could lead to incorrect answers ## What changes were proposed in this pull request? Currently shuffle repartition uses RoundRobinPartitioning, the generated result is nondeterministic since the sequence of input rows are not determined. The bug can be triggered when there is a repartition call following a shuffle (which would lead to non-deterministic row ordering), as the pattern shows below: upstream stage -> repartition stage -> result stage (-> indicate a shuffle) When one of the executors process goes down, some tasks on the repartition stage will be retried and generate inconsistent ordering, and some tasks of the result stage will be retried generating different data. The following code returns 931532, instead of 1000000: ``` import scala.sys.process._ import org.apache.spark.TaskContext val res = spark.range(0, 1000 * 1000, 1).repartition(200).map { x => x }.repartition(200).map { x => if (TaskContext.get.attemptNumber == 0 && TaskContext.get.partitionId < 2) { throw new Exception("pkill -f java".!!) } x } res.distinct().count() ``` In this PR, we propose a most straight-forward way to fix this problem by performing a local sort before partitioning, after we make the input row ordering deterministic, the function from rows to partitions is fully deterministic too. The downside of the approach is that with extra local sort inserted, the performance of repartition() will go down, so we add a new config named `spark.sql.execution.sortBeforeRepartition` to control whether this patch is applied. The patch is default enabled to be safe-by-default, but user may choose to manually turn it off to avoid performance regression. This patch also changes the output rows ordering of repartition(), that leads to a bunch of test cases failure because they are comparing the results directly. ## How was this patch tested? Add unit test in ExchangeSuite. With this patch(and `spark.sql.execution.sortBeforeRepartition` set to true), the following query returns 1000000: ``` import scala.sys.process._ import org.apache.spark.TaskContext spark.conf.set("spark.sql.execution.sortBeforeRepartition", "true") val res = spark.range(0, 1000 * 1000, 1).repartition(200).map { x => x }.repartition(200).map { x => if (TaskContext.get.attemptNumber == 0 && TaskContext.get.partitionId < 2) { throw new Exception("pkill -f java".!!) } x } res.distinct().count() res7: Long = 1000000 ``` Author: Xingbo Jiang Closes #20393 from jiangxb1987/shuffle-repartition. --- .../unsafe/sort/RecordComparator.java | 4 +- .../unsafe/sort/UnsafeInMemorySorter.java | 7 +- .../unsafe/sort/UnsafeSorterSpillMerger.java | 4 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 2 + .../sort/UnsafeExternalSorterSuite.java | 4 +- .../sort/UnsafeInMemorySorterSuite.java | 8 ++- .../spark/ml/feature/Word2VecSuite.scala | 3 +- .../sql/execution/RecordBinaryComparator.java | 70 +++++++++++++++++++ .../execution/UnsafeExternalRowSorter.java | 44 ++++++++++-- .../apache/spark/sql/internal/SQLConf.scala | 14 ++++ .../sql/execution/UnsafeKVExternalSorter.java | 8 ++- .../apache/spark/sql/execution/SortExec.scala | 2 +- .../exchange/ShuffleExchangeExec.scala | 52 +++++++++++++- .../spark/sql/execution/ExchangeSuite.scala | 26 ++++++- .../datasources/parquet/ParquetIOSuite.scala | 6 +- .../datasources/text/WholeTextFileSuite.scala | 2 +- .../streaming/ForeachSinkSuite.scala | 6 +- 17 files changed, 233 insertions(+), 29 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java index 09e4258792204..02b5de8e128c9 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java @@ -32,6 +32,8 @@ public abstract class RecordComparator { public abstract int compare( Object leftBaseObject, long leftBaseOffset, + int leftBaseLength, Object rightBaseObject, - long rightBaseOffset); + long rightBaseOffset, + int rightBaseLength); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 951d076420ee6..b3c27d83da172 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -62,12 +62,13 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { int uaoSize = UnsafeAlignedOffset.getUaoSize(); if (prefixComparisonResult == 0) { final Object baseObject1 = memoryManager.getPage(r1.recordPointer); - // skip length final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer) + uaoSize; + final int baseLength1 = UnsafeAlignedOffset.getSize(baseObject1, baseOffset1 - uaoSize); final Object baseObject2 = memoryManager.getPage(r2.recordPointer); - // skip length final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer) + uaoSize; - return recordComparator.compare(baseObject1, baseOffset1, baseObject2, baseOffset2); + final int baseLength2 = UnsafeAlignedOffset.getSize(baseObject2, baseOffset2 - uaoSize); + return recordComparator.compare(baseObject1, baseOffset1, baseLength1, baseObject2, + baseOffset2, baseLength2); } else { return prefixComparisonResult; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java index cf4dfde86ca91..ff0dcc259a4ad 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -35,8 +35,8 @@ final class UnsafeSorterSpillMerger { prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix()); if (prefixComparisonResult == 0) { return recordComparator.compare( - left.getBaseObject(), left.getBaseOffset(), - right.getBaseObject(), right.getBaseOffset()); + left.getBaseObject(), left.getBaseOffset(), left.getRecordLength(), + right.getBaseObject(), right.getBaseOffset(), right.getRecordLength()); } else { return prefixComparisonResult; } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 7859781e98223..0574abdca32ac 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -414,6 +414,8 @@ abstract class RDD[T: ClassTag]( * * If you are decreasing the number of partitions in this RDD, consider using `coalesce`, * which can avoid performing a shuffle. + * + * TODO Fix the Shuffle+Repartition data loss issue described in SPARK-23207. */ def repartition(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] = withScope { coalesce(numPartitions, shuffle = true) diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index af4975c888d65..411cd5cb57331 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -72,8 +72,10 @@ public class UnsafeExternalSorterSuite { public int compare( Object leftBaseObject, long leftBaseOffset, + int leftBaseLength, Object rightBaseObject, - long rightBaseOffset) { + long rightBaseOffset, + int rightBaseLength) { return 0; } }; diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 594f07dd780f9..c145532328514 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -98,8 +98,10 @@ public void testSortingOnlyByIntegerPrefix() throws Exception { public int compare( Object leftBaseObject, long leftBaseOffset, + int leftBaseLength, Object rightBaseObject, - long rightBaseOffset) { + long rightBaseOffset, + int rightBaseLength) { return 0; } }; @@ -164,8 +166,10 @@ public void freeAfterOOM() { public int compare( Object leftBaseObject, long leftBaseOffset, + int leftBaseLength, Object rightBaseObject, - long rightBaseOffset) { + long rightBaseOffset, + int rightBaseLength) { return 0; } }; diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 6183606a7b2ac..10682ba176aca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -222,7 +222,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val oldModel = new OldWord2VecModel(word2VecMap) val instance = new Word2VecModel("myWord2VecModel", oldModel) val newInstance = testDefaultReadWrite(instance) - assert(newInstance.getVectors.collect() === instance.getVectors.collect()) + assert(newInstance.getVectors.collect().sortBy(_.getString(0)) === + instance.getVectors.collect().sortBy(_.getString(0))) } test("Word2Vec works with input that is non-nullable (NGram)") { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java new file mode 100644 index 0000000000000..bb77b5bf6de2a --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.util.collection.unsafe.sort.RecordComparator; + +public final class RecordBinaryComparator extends RecordComparator { + + // TODO(jiangxb) Add test suite for this. + @Override + public int compare( + Object leftObj, long leftOff, int leftLen, Object rightObj, long rightOff, int rightLen) { + int i = 0; + int res = 0; + + // If the arrays have different length, the longer one is larger. + if (leftLen != rightLen) { + return leftLen - rightLen; + } + + // The following logic uses `leftLen` as the length for both `leftObj` and `rightObj`, since + // we have guaranteed `leftLen` == `rightLen`. + + // check if stars align and we can get both offsets to be aligned + if ((leftOff % 8) == (rightOff % 8)) { + while ((leftOff + i) % 8 != 0 && i < leftLen) { + res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - + (Platform.getByte(rightObj, rightOff + i) & 0xff); + if (res != 0) return res; + i += 1; + } + } + // for architectures that support unaligned accesses, chew it up 8 bytes at a time + if (Platform.unaligned() || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 8 == 0))) { + while (i <= leftLen - 8) { + res = (int) ((Platform.getLong(leftObj, leftOff + i) - + Platform.getLong(rightObj, rightOff + i)) % Integer.MAX_VALUE); + if (res != 0) return res; + i += 8; + } + } + // this will finish off the unaligned comparisons, or do the entire aligned comparison + // whichever is needed. + while (i < leftLen) { + res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - + (Platform.getByte(rightObj, rightOff + i) & 0xff); + if (res != 0) return res; + i += 1; + } + + // The two arrays are equal. + return 0; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 6b002f0d3f8e8..78647b56d621f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -18,7 +18,9 @@ package org.apache.spark.sql.execution; import java.io.IOException; +import java.util.function.Supplier; +import org.apache.spark.sql.catalyst.util.TypeUtils; import scala.collection.AbstractIterator; import scala.collection.Iterator; import scala.math.Ordering; @@ -56,26 +58,50 @@ public abstract static class PrefixComputer { public static class Prefix { /** Key prefix value, or the null prefix value if isNull = true. **/ - long value; + public long value; /** Whether the key is null. */ - boolean isNull; + public boolean isNull; } /** * Computes prefix for the given row. For efficiency, the returned object may be reused in * further calls to a given PrefixComputer. */ - abstract Prefix computePrefix(InternalRow row); + public abstract Prefix computePrefix(InternalRow row); } - public UnsafeExternalRowSorter( + public static UnsafeExternalRowSorter createWithRecordComparator( + StructType schema, + Supplier recordComparatorSupplier, + PrefixComparator prefixComparator, + PrefixComputer prefixComputer, + long pageSizeBytes, + boolean canUseRadixSort) throws IOException { + return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator, + prefixComputer, pageSizeBytes, canUseRadixSort); + } + + public static UnsafeExternalRowSorter create( StructType schema, Ordering ordering, PrefixComparator prefixComparator, PrefixComputer prefixComputer, long pageSizeBytes, boolean canUseRadixSort) throws IOException { + Supplier recordComparatorSupplier = + () -> new RowComparator(ordering, schema.length()); + return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator, + prefixComputer, pageSizeBytes, canUseRadixSort); + } + + private UnsafeExternalRowSorter( + StructType schema, + Supplier recordComparatorSupplier, + PrefixComparator prefixComparator, + PrefixComputer prefixComputer, + long pageSizeBytes, + boolean canUseRadixSort) throws IOException { this.schema = schema; this.prefixComputer = prefixComputer; final SparkEnv sparkEnv = SparkEnv.get(); @@ -85,7 +111,7 @@ public UnsafeExternalRowSorter( sparkEnv.blockManager(), sparkEnv.serializerManager(), taskContext, - () -> new RowComparator(ordering, schema.length()), + recordComparatorSupplier, prefixComparator, sparkEnv.conf().getInt("spark.shuffle.sort.initialBufferSize", DEFAULT_INITIAL_SORT_BUFFER_SIZE), @@ -206,7 +232,13 @@ private static final class RowComparator extends RecordComparator { } @Override - public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { + public int compare( + Object baseObj1, + long baseOff1, + int baseLen1, + Object baseObj2, + long baseOff2, + int baseLen2) { // Note that since ordering doesn't need the total length of the record, we just pass 0 // into the row. row1.pointTo(baseObj1, baseOff1, 0); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b0d18b6dced76..76b9d6f6f33bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1145,6 +1145,18 @@ object SQLConf { .checkValues(PartitionOverwriteMode.values.map(_.toString)) .createWithDefault(PartitionOverwriteMode.STATIC.toString) + val SORT_BEFORE_REPARTITION = + buildConf("spark.sql.execution.sortBeforeRepartition") + .internal() + .doc("When perform a repartition following a shuffle, the output row ordering would be " + + "nondeterministic. If some downstream stages fail and some tasks of the repartition " + + "stage retry, these tasks may generate different data, and that can lead to correctness " + + "issues. Turn on this config to insert a local sort before actually doing repartition " + + "to generate consistent repartition results. The performance of repartition() may go " + + "down since we insert extra local sort before it.") + .booleanConf + .createWithDefault(true) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1300,6 +1312,8 @@ class SQLConf extends Serializable with Logging { def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader) + def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index eb2fe82007af3..b0b5383a081a0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -241,7 +241,13 @@ private static final class KVComparator extends RecordComparator { } @Override - public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { + public int compare( + Object baseObj1, + long baseOff1, + int baseLen1, + Object baseObj2, + long baseOff2, + int baseLen2) { // Note that since ordering doesn't need the total length of the record, we just pass 0 // into the row. row1.pointTo(baseObj1, baseOff1 + 4, 0); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index ef1bb1c2a4468..ac1c34d41c4f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -84,7 +84,7 @@ case class SortExec( } val pageSize = SparkEnv.get.memoryManager.pageSizeBytes - val sorter = new UnsafeExternalRowSorter( + val sorter = UnsafeExternalRowSorter.create( schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort) if (testSpillFrequency > 0) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 5a1e217082bc2..76c1fa65f924b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.exchange import java.util.Random +import java.util.function.Supplier import org.apache.spark._ import org.apache.spark.rdd.RDD @@ -25,13 +26,15 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType import org.apache.spark.util.MutablePair +import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator} /** * Performs a shuffle that will result in the desired `newPartitioning`. @@ -247,14 +250,57 @@ object ShuffleExchangeExec { case RangePartitioning(_, _) | SinglePartition => identity case _ => sys.error(s"Exchange not implemented for $newPartitioning") } + val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = { - if (needToCopyObjectsBeforeShuffle(part, serializer)) { + // [SPARK-23207] Have to make sure the generated RoundRobinPartitioning is deterministic, + // otherwise a retry task may output different rows and thus lead to data loss. + // + // Currently we following the most straight-forward way that perform a local sort before + // partitioning. + val newRdd = if (SQLConf.get.sortBeforeRepartition && + newPartitioning.isInstanceOf[RoundRobinPartitioning]) { rdd.mapPartitionsInternal { iter => + val recordComparatorSupplier = new Supplier[RecordComparator] { + override def get: RecordComparator = new RecordBinaryComparator() + } + // The comparator for comparing row hashcode, which should always be Integer. + val prefixComparator = PrefixComparators.LONG + val canUseRadixSort = SparkEnv.get.conf.get(SQLConf.RADIX_SORT_ENABLED) + // The prefix computer generates row hashcode as the prefix, so we may decrease the + // probability that the prefixes are equal when input rows choose column values from a + // limited range. + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix + override def computePrefix(row: InternalRow): + UnsafeExternalRowSorter.PrefixComputer.Prefix = { + // The hashcode generated from the binary form of a [[UnsafeRow]] should not be null. + result.isNull = false + result.value = row.hashCode() + result + } + } + val pageSize = SparkEnv.get.memoryManager.pageSizeBytes + + val sorter = UnsafeExternalRowSorter.createWithRecordComparator( + StructType.fromAttributes(outputAttributes), + recordComparatorSupplier, + prefixComparator, + prefixComputer, + pageSize, + canUseRadixSort) + sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + } + } else { + rdd + } + + if (needToCopyObjectsBeforeShuffle(part, serializer)) { + newRdd.mapPartitionsInternal { iter => val getPartitionKey = getPartitionKeyExtractor() iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } } } else { - rdd.mapPartitionsInternal { iter => + newRdd.mapPartitionsInternal { iter => val getPartitionKey = getPartitionKeyExtractor() val mutablePair = new MutablePair[Int, InternalRow]() iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index aac8d56ba6201..697d7e6520713 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.Row +import scala.util.Random + +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext class ExchangeSuite extends SparkPlanTest with SharedSQLContext { @@ -101,4 +104,25 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(exchange4.sameResult(exchange5)) assert(exchange5 sameResult exchange4) } + + test("SPARK-23207: Make repartition() generate consistent output") { + def assertConsistency(ds: Dataset[java.lang.Long]): Unit = { + ds.persist() + + val exchange = ds.mapPartitions { iter => + Random.shuffle(iter) + }.repartition(111) + val exchange2 = ds.repartition(111) + + assert(exchange.rdd.collectPartitions() === exchange2.rdd.collectPartitions()) + } + + withSQLConf(SQLConf.SORT_BEFORE_REPARTITION.key -> "true") { + // repartition() should generate consistent output. + assertConsistency(spark.range(10000)) + + // case when input contains duplicated rows. + assertConsistency(spark.range(10000).map(i => Random.nextInt(1000).toLong)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 44a8b25c61dfb..f3ece5b15e26a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -662,7 +662,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val v = (row.getInt(0), row.getString(1)) result += v } - assert(data == result) + assert(data.toSet == result.toSet) } finally { reader.close() } @@ -678,7 +678,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val row = reader.getCurrentValue.asInstanceOf[InternalRow] result += row.getString(0) } - assert(data.map(_._2) == result) + assert(data.map(_._2).toSet == result.toSet) } finally { reader.close() } @@ -695,7 +695,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val v = (row.getString(0), row.getInt(1)) result += v } - assert(data.map { x => (x._2, x._1) } == result) + assert(data.map { x => (x._2, x._1) }.toSet == result.toSet) } finally { reader.close() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala index 8bd736bee69de..fff0f82f9bc2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala @@ -95,7 +95,7 @@ class WholeTextFileSuite extends QueryTest with SharedSQLContext { df1.write.option("compression", "gzip").mode("overwrite").text(path) // On reading through wholetext mode, one file will be read as a single row, i.e. not // delimited by "next line" character. - val expected = Row(Range(0, 1000).mkString("", "\n", "\n")) + val expected = Row(df1.collect().map(_.getString(0)).mkString("", "\n", "\n")) Seq(10, 100, 1000).foreach { bytes => withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> bytes.toString) { val df2 = spark.read.option("wholetext", "true").format("text").load(path) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala index 9137d650e906b..1248c670df45c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala @@ -52,13 +52,13 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf var expectedEventsForPartition0 = Seq( ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Process(value = 1), + ForeachSinkSuite.Process(value = 2), ForeachSinkSuite.Process(value = 3), ForeachSinkSuite.Close(None) ) var expectedEventsForPartition1 = Seq( ForeachSinkSuite.Open(partition = 1, version = 0), - ForeachSinkSuite.Process(value = 2), + ForeachSinkSuite.Process(value = 1), ForeachSinkSuite.Process(value = 4), ForeachSinkSuite.Close(None) ) @@ -162,7 +162,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf val allEvents = ForeachSinkSuite.allEvents() assert(allEvents.size === 1) assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0)) - assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1)) + assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 2)) // `close` should be called with the error val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close] From 073744985f439ca90afb9bd0bbc1332c53f7b4bb Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 26 Jan 2018 16:09:57 -0800 Subject: [PATCH 215/774] [SPARK-23242][SS][TESTS] Don't run tests in KafkaSourceSuiteBase twice ## What changes were proposed in this pull request? KafkaSourceSuiteBase should be abstract class, otherwise KafkaSourceSuiteBase will also run. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #20412 from zsxwing/SPARK-23242. --- .../scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 27dbb3f7a8f31..c4cb1bc4a2e18 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -546,7 +546,7 @@ class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { } } -class KafkaSourceSuiteBase extends KafkaSourceTest { +abstract class KafkaSourceSuiteBase extends KafkaSourceTest { import testImplicits._ From 5b5447c68ac79715e2256e487e1212861cdab1fc Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 26 Jan 2018 16:46:51 -0800 Subject: [PATCH 216/774] [SPARK-23214][SQL] cached data should not carry extra hint info ## What changes were proposed in this pull request? This is a regression introduced by https://github.com/apache/spark/pull/19864 When we lookup cache, we should not carry the hint info, as this cache entry might be added by a plan having hint info, while the input plan for this lookup may not have hint info, or have different hint info. ## How was this patch tested? a new test. Author: Wenchen Fan Closes #20394 from cloud-fan/cache. --- .../spark/sql/execution/CacheManager.scala | 17 +-- .../execution/columnar/InMemoryRelation.scala | 27 +++-- .../apache/spark/sql/CachedTableSuite.scala | 4 +- .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../execution/joins/BroadcastJoinSuite.scala | 103 +++++++++++------- 5 files changed, 94 insertions(+), 59 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 432eb59d6fe57..d68aeb275afda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -169,14 +169,17 @@ class CacheManager extends Logging { /** Replaces segments of the given logical plan with cached versions where possible. */ def useCachedData(plan: LogicalPlan): LogicalPlan = { val newPlan = plan transformDown { + // Do not lookup the cache by hint node. Hint node is special, we should ignore it when + // canonicalizing plans, so that plans which are same except hint can hit the same cache. + // However, we also want to keep the hint info after cache lookup. Here we skip the hint + // node, so that the returned caching plan won't replace the hint node and drop the hint info + // from the original plan. + case hint: ResolvedHint => hint + case currentFragment => - lookupCachedData(currentFragment).map { cached => - val cachedPlan = cached.cachedRepresentation.withOutput(currentFragment.output) - currentFragment match { - case hint: ResolvedHint => ResolvedHint(cachedPlan, hint.hints) - case _ => cachedPlan - } - }.getOrElse(currentFragment) + lookupCachedData(currentFragment) + .map(_.cachedRepresentation.withOutput(currentFragment.output)) + .getOrElse(currentFragment) } newPlan transformAllExpressions { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 51928d914841e..22e16913d4da9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator @@ -62,8 +62,8 @@ case class InMemoryRelation( @transient child: SparkPlan, tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, - val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics = null) + val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, + statsOfPlanToCache: Statistics) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[SparkPlan] = Seq(child) @@ -73,11 +73,16 @@ case class InMemoryRelation( @transient val partitionStatistics = new PartitionStatistics(output) override def computeStats(): Statistics = { - if (batchStats.value == 0L) { - // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache - statsOfPlanToCache + if (sizeInBytesStats.value == 0L) { + // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. + // Note that we should drop the hint info here. We may cache a plan whose root node is a hint + // node. When we lookup the cache with a semantically same plan without hint info, the plan + // returned by cache lookup should not have hint info. If we lookup the cache with a + // semantically same plan with a different hint info, `CacheManager.useCachedData` will take + // care of it and retain the hint info in the lookup input plan. + statsOfPlanToCache.copy(hints = HintInfo()) } else { - Statistics(sizeInBytes = batchStats.value.longValue) + Statistics(sizeInBytes = sizeInBytesStats.value.longValue) } } @@ -122,7 +127,7 @@ case class InMemoryRelation( rowCount += 1 } - batchStats.add(totalSize) + sizeInBytesStats.add(totalSize) val stats = InternalRow.fromSeq( columnBuilders.flatMap(_.columnStats.collectedStatistics)) @@ -144,7 +149,7 @@ case class InMemoryRelation( def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { InMemoryRelation( newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, batchStats, statsOfPlanToCache) + _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache) } override def newInstance(): this.type = { @@ -156,12 +161,12 @@ case class InMemoryRelation( child, tableName)( _cachedColumnBuffers, - batchStats, + sizeInBytesStats, statsOfPlanToCache).asInstanceOf[this.type] } def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers override protected def otherCopyArgs: Seq[AnyRef] = - Seq(_cachedColumnBuffers, batchStats, statsOfPlanToCache) + Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 1e52445f28fc1..72fe0f42801f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -368,12 +368,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val toBeCleanedAccIds = new HashSet[Long] val accId1 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.batchStats.id + case i: InMemoryRelation => i.sizeInBytesStats.id }.head toBeCleanedAccIds += accId1 val accId2 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.batchStats.id + case i: InMemoryRelation => i.sizeInBytesStats.id }.head toBeCleanedAccIds += accId2 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 2280da927cf70..dc1766fb9a785 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -336,7 +336,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(cached, expectedAnswer) // Check that the right size was calculated. - assert(cached.batchStats.value === expectedAnswer.size * INT.defaultSize) + assert(cached.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize) } test("access primitive-type columns in CachedBatch without whole stage codegen") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 1704bc8376f0d..bcdee792f4c70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,7 +22,8 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} -import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -70,8 +71,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { private def testBroadcastJoin[T: ClassTag]( joinType: String, forceBroadcast: Boolean = false): SparkPlan = { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") // Comparison at the end is for broadcast left semi join val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") @@ -109,30 +110,58 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } - test("broadcast hint is retained after using the cached data") { + test("SPARK-23192: broadcast hint should be retained after using the cached data") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") - df2.cache() - val df3 = df1.join(broadcast(df2), Seq("key"), "inner") - val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { - case b: BroadcastHashJoinExec => b - }.size - assert(numBroadCastHashJoin === 1) + try { + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") + df2.cache() + val df3 = df1.join(broadcast(df2), Seq("key"), "inner") + val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { + case b: BroadcastHashJoinExec => b + }.size + assert(numBroadCastHashJoin === 1) + } finally { + spark.catalog.clearCache() + } + } + } + + test("SPARK-23214: cached data should not carry extra hint info") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + try { + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") + broadcast(df2).cache() + + val df3 = df1.join(df2, Seq("key"), "inner") + val numCachedPlan = df3.queryExecution.executedPlan.collect { + case i: InMemoryTableScanExec => i + }.size + // df2 should be cached. + assert(numCachedPlan === 1) + + val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { + case b: BroadcastHashJoinExec => b + }.size + // df2 should not be broadcasted. + assert(numBroadCastHashJoin === 0) + } finally { + spark.catalog.clearCache() + } } } test("broadcast hint isn't propagated after a join") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df3 = df1.join(broadcast(df2), Seq("key"), "inner").drop(df2("key")) - val df4 = spark.createDataFrame(Seq((1, "5"), (2, "5"))).toDF("key", "value") + val df4 = Seq((1, "5"), (2, "5")).toDF("key", "value") val df5 = df4.join(df3, Seq("key"), "inner") - val plan = - EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan) + val plan = EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan) assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1) assert(plan.collect { case p: SortMergeJoinExec => p }.size === 1) @@ -140,30 +169,30 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } private def assertBroadcastJoin(df : Dataset[Row]) : Unit = { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") val joined = df1.join(df, Seq("key"), "inner") - val plan = - EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan) + val plan = EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan) assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1) } test("broadcast hint programming API") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"), (3, "2"))).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "2")).toDF("key", "value") val broadcasted = broadcast(df2) - val df3 = spark.createDataFrame(Seq((2, "2"), (3, "3"))).toDF("key", "value") - - val cases = Seq(broadcasted.limit(2), - broadcasted.filter("value < 10"), - broadcasted.sample(true, 0.5), - broadcasted.distinct(), - broadcasted.groupBy("value").agg(min($"key").as("key")), - // except and intersect are semi/anti-joins which won't return more data then - // their left argument, so the broadcast hint should be propagated here - broadcasted.except(df3), - broadcasted.intersect(df3)) + val df3 = Seq((2, "2"), (3, "3")).toDF("key", "value") + + val cases = Seq( + broadcasted.limit(2), + broadcasted.filter("value < 10"), + broadcasted.sample(true, 0.5), + broadcasted.distinct(), + broadcasted.groupBy("value").agg(min($"key").as("key")), + // except and intersect are semi/anti-joins which won't return more data then + // their left argument, so the broadcast hint should be propagated here + broadcasted.except(df3), + broadcasted.intersect(df3)) cases.foreach(assertBroadcastJoin) } @@ -240,9 +269,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { test("Shouldn't change broadcast join buildSide if user clearly specified") { withTempView("t1", "t2") { - spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1") - spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value") - .createTempView("t2") + Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1") + Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2") val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes @@ -292,9 +320,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { test("Shouldn't bias towards build right if user didn't specify") { withTempView("t1", "t2") { - spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1") - spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value") - .createTempView("t2") + Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1") + Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2") val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes From e7bc9f0524822a08d857c3a5ba57119644ceae85 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 26 Jan 2018 18:57:32 -0600 Subject: [PATCH 217/774] [MINOR][SS][DOC] Fix `Trigger` Scala/Java doc examples ## What changes were proposed in this pull request? This PR fixes Scala/Java doc examples in `Trigger.java`. ## How was this patch tested? N/A. Author: Dongjoon Hyun Closes #20401 from dongjoon-hyun/SPARK-TRIGGER. --- .../src/main/java/org/apache/spark/sql/streaming/Trigger.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java index 33ae9a9e87668..5371a23230c98 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java +++ b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java @@ -50,7 +50,7 @@ public static Trigger ProcessingTime(long intervalMs) { * * {{{ * import java.util.concurrent.TimeUnit - * df.writeStream.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * df.writeStream().trigger(Trigger.ProcessingTime(10, TimeUnit.SECONDS)) * }}} * * @since 2.2.0 @@ -66,7 +66,7 @@ public static Trigger ProcessingTime(long interval, TimeUnit timeUnit) { * * {{{ * import scala.concurrent.duration._ - * df.writeStream.trigger(ProcessingTime(10.seconds)) + * df.writeStream.trigger(Trigger.ProcessingTime(10.seconds)) * }}} * @since 2.2.0 */ From 6328868e524121bd00595959d6d059f74e038a6b Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 26 Jan 2018 23:06:03 -0800 Subject: [PATCH 218/774] [SPARK-23245][SS][TESTS] Don't access `lastExecution.executedPlan` in StreamTest ## What changes were proposed in this pull request? `lastExecution.executedPlan` is lazy val so accessing it in StreamTest may need to acquire the lock of `lastExecution`. It may be waiting forever when the streaming thread is holding it and running a continuous Spark job. This PR changes to check if `s.lastExecution` is null to avoid accessing `lastExecution.executedPlan`. ## How was this patch tested? Jenkins Author: Jose Torres Closes #20413 from zsxwing/SPARK-23245. --- .../test/scala/org/apache/spark/sql/streaming/StreamTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index efdb0e0e7cf1c..d6433562fb29b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -472,7 +472,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be currentStream.awaitInitialization(streamingTimeout.toMillis) currentStream match { case s: ContinuousExecution => eventually("IncrementalExecution was not created") { - s.lastExecution.executedPlan // will fail if lastExecution is null + assert(s.lastExecution != null) } case _ => } From 3227d14feb1a65e95a2bf326cff6ac95615cc5ac Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 27 Jan 2018 11:26:09 -0800 Subject: [PATCH 219/774] [SPARK-23233][PYTHON] Reset the cache in asNondeterministic to set deterministic properly ## What changes were proposed in this pull request? Reproducer: ```python from pyspark.sql.functions import udf f = udf(lambda x: x) spark.range(1).select(f("id")) # cache JVM UDF instance. f = f.asNondeterministic() spark.range(1).select(f("id"))._jdf.logicalPlan().projectList().head().deterministic() ``` It should return `False` but the current master returns `True`. Seems it's because we cache the JVM UDF instance and then we reuse it even after setting `deterministic` disabled once it's called. ## How was this patch tested? Manually tested. I am not sure if I should add the test with a lot of JVM accesses with the intetnal stuff .. Let me know if anyone feels so. I will add. Author: hyukjinkwon Closes #20409 from HyukjinKwon/SPARK-23233. --- python/pyspark/sql/tests.py | 13 +++++++++++++ python/pyspark/sql/udf.py | 3 +++ 2 files changed, 16 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a466ab87d882d..ca7bbf8ffe71c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -441,6 +441,19 @@ def test_nondeterministic_udf2(self): pydoc.render_doc(random_udf1) pydoc.render_doc(udf(lambda x: x).asNondeterministic) + def test_nondeterministic_udf3(self): + # regression test for SPARK-23233 + from pyspark.sql.functions import udf + f = udf(lambda x: x) + # Here we cache the JVM UDF instance. + self.spark.range(1).select(f("id")) + # This should reset the cache to set the deterministic status correctly. + f = f.asNondeterministic() + # Check the deterministic status of udf. + df = self.spark.range(1).select(f("id")) + deterministic = df._jdf.logicalPlan().projectList().head().deterministic() + self.assertFalse(deterministic) + def test_nondeterministic_udf_in_aggregate(self): from pyspark.sql.functions import udf, sum import random diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index de96846c5c774..4f303304e5600 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -188,6 +188,9 @@ def asNondeterministic(self): .. versionadded:: 2.3 """ + # Here, we explicitly clean the cache to create a JVM UDF instance + # with 'deterministic' updated. See SPARK-23233. + self._judf_placeholder = None self.deterministic = False return self From b8c32dc57368e49baaacf660b7e8836eedab2df7 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 28 Jan 2018 10:33:06 +0900 Subject: [PATCH 220/774] [SPARK-23248][PYTHON][EXAMPLES] Relocate module docstrings to the top in PySpark examples ## What changes were proposed in this pull request? This PR proposes to relocate the docstrings in modules of examples to the top. Seems these are mistakes. So, for example, the below codes ```python >>> help(aft_survival_regression) ``` shows the module docstrings for examples as below: **Before** ``` Help on module aft_survival_regression: NAME aft_survival_regression ... DESCRIPTION # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ... (END) ``` **After** ``` Help on module aft_survival_regression: NAME aft_survival_regression ... DESCRIPTION An example demonstrating aft survival regression. Run with: bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py (END) ``` ## How was this patch tested? Manually checked. Author: hyukjinkwon Closes #20416 from HyukjinKwon/module-docstring-example. --- examples/src/main/python/avro_inputformat.py | 14 +++++++------- .../src/main/python/ml/aft_survival_regression.py | 11 +++++------ .../main/python/ml/bisecting_k_means_example.py | 11 +++++------ .../ml/bucketed_random_projection_lsh_example.py | 12 +++++------- .../src/main/python/ml/chi_square_test_example.py | 10 +++++----- .../src/main/python/ml/correlation_example.py | 10 +++++----- examples/src/main/python/ml/cross_validator.py | 15 +++++++-------- examples/src/main/python/ml/fpgrowth_example.py | 9 ++++----- .../main/python/ml/gaussian_mixture_example.py | 11 +++++------ .../ml/generalized_linear_regression_example.py | 11 +++++------ examples/src/main/python/ml/imputer_example.py | 9 ++++----- .../main/python/ml/isotonic_regression_example.py | 9 +++------ examples/src/main/python/ml/kmeans_example.py | 15 +++++++-------- examples/src/main/python/ml/lda_example.py | 12 +++++------- .../ml/logistic_regression_summary_example.py | 11 +++++------ .../src/main/python/ml/min_hash_lsh_example.py | 12 +++++------- .../src/main/python/ml/one_vs_rest_example.py | 13 ++++++------- .../src/main/python/ml/train_validation_split.py | 13 ++++++------- examples/src/main/python/parquet_inputformat.py | 12 ++++++------ examples/src/main/python/sql/basic.py | 11 +++++------ examples/src/main/python/sql/datasource.py | 11 +++++------ examples/src/main/python/sql/hive.py | 11 +++++------ 22 files changed, 115 insertions(+), 138 deletions(-) diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py index 4422f9e7a9589..6286ba6541fbd 100644 --- a/examples/src/main/python/avro_inputformat.py +++ b/examples/src/main/python/avro_inputformat.py @@ -15,13 +15,6 @@ # limitations under the License. # -from __future__ import print_function - -import sys - -from functools import reduce -from pyspark.sql import SparkSession - """ Read data file users.avro in local Spark distro: @@ -50,6 +43,13 @@ {u'favorite_color': None, u'name': u'Alyssa'} {u'favorite_color': u'red', u'name': u'Ben'} """ +from __future__ import print_function + +import sys + +from functools import reduce +from pyspark.sql import SparkSession + if __name__ == "__main__": if len(sys.argv) != 2 and len(sys.argv) != 3: print(""" diff --git a/examples/src/main/python/ml/aft_survival_regression.py b/examples/src/main/python/ml/aft_survival_regression.py index 2f0ca995e55c7..0a71f76418ea6 100644 --- a/examples/src/main/python/ml/aft_survival_regression.py +++ b/examples/src/main/python/ml/aft_survival_regression.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example demonstrating aft survival regression. +Run with: + bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py +""" from __future__ import print_function # $example on$ @@ -23,12 +28,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating aft survival regression. -Run with: - bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/bisecting_k_means_example.py b/examples/src/main/python/ml/bisecting_k_means_example.py index 1263cb5d177a8..7842d2009e238 100644 --- a/examples/src/main/python/ml/bisecting_k_means_example.py +++ b/examples/src/main/python/ml/bisecting_k_means_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example demonstrating bisecting k-means clustering. +Run with: + bin/spark-submit examples/src/main/python/ml/bisecting_k_means_example.py +""" from __future__ import print_function # $example on$ @@ -22,12 +27,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating bisecting k-means clustering. -Run with: - bin/spark-submit examples/src/main/python/ml/bisecting_k_means_example.py -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/bucketed_random_projection_lsh_example.py b/examples/src/main/python/ml/bucketed_random_projection_lsh_example.py index 1b7a458125cef..610176ea596ca 100644 --- a/examples/src/main/python/ml/bucketed_random_projection_lsh_example.py +++ b/examples/src/main/python/ml/bucketed_random_projection_lsh_example.py @@ -15,7 +15,11 @@ # limitations under the License. # - +""" +An example demonstrating BucketedRandomProjectionLSH. +Run with: + bin/spark-submit examples/src/main/python/ml/bucketed_random_projection_lsh_example.py +""" from __future__ import print_function # $example on$ @@ -25,12 +29,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating BucketedRandomProjectionLSH. -Run with: - bin/spark-submit examples/src/main/python/ml/bucketed_random_projection_lsh_example.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/chi_square_test_example.py b/examples/src/main/python/ml/chi_square_test_example.py index 8f25318ded00a..2af7e683cdb72 100644 --- a/examples/src/main/python/ml/chi_square_test_example.py +++ b/examples/src/main/python/ml/chi_square_test_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example for Chi-square hypothesis testing. +Run with: + bin/spark-submit examples/src/main/python/ml/chi_square_test_example.py +""" from __future__ import print_function from pyspark.sql import SparkSession @@ -23,11 +28,6 @@ from pyspark.ml.stat import ChiSquareTest # $example off$ -""" -An example for Chi-square hypothesis testing. -Run with: - bin/spark-submit examples/src/main/python/ml/chi_square_test_example.py -""" if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/correlation_example.py b/examples/src/main/python/ml/correlation_example.py index 0a9d30da5a42e..1f4e402ac1a51 100644 --- a/examples/src/main/python/ml/correlation_example.py +++ b/examples/src/main/python/ml/correlation_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example for computing correlation matrix. +Run with: + bin/spark-submit examples/src/main/python/ml/correlation_example.py +""" from __future__ import print_function # $example on$ @@ -23,11 +28,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example for computing correlation matrix. -Run with: - bin/spark-submit examples/src/main/python/ml/correlation_example.py -""" if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/cross_validator.py b/examples/src/main/python/ml/cross_validator.py index db7054307c2e3..6256d11504afb 100644 --- a/examples/src/main/python/ml/cross_validator.py +++ b/examples/src/main/python/ml/cross_validator.py @@ -15,6 +15,13 @@ # limitations under the License. # +""" +A simple example demonstrating model selection using CrossValidator. +This example also demonstrates how Pipelines are Estimators. +Run with: + + bin/spark-submit examples/src/main/python/ml/cross_validator.py +""" from __future__ import print_function # $example on$ @@ -26,14 +33,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -A simple example demonstrating model selection using CrossValidator. -This example also demonstrates how Pipelines are Estimators. -Run with: - - bin/spark-submit examples/src/main/python/ml/cross_validator.py -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/fpgrowth_example.py b/examples/src/main/python/ml/fpgrowth_example.py index c92c3c27abb21..39092e616d429 100644 --- a/examples/src/main/python/ml/fpgrowth_example.py +++ b/examples/src/main/python/ml/fpgrowth_example.py @@ -15,16 +15,15 @@ # limitations under the License. # -# $example on$ -from pyspark.ml.fpm import FPGrowth -# $example off$ -from pyspark.sql import SparkSession - """ An example demonstrating FPGrowth. Run with: bin/spark-submit examples/src/main/python/ml/fpgrowth_example.py """ +# $example on$ +from pyspark.ml.fpm import FPGrowth +# $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": spark = SparkSession\ diff --git a/examples/src/main/python/ml/gaussian_mixture_example.py b/examples/src/main/python/ml/gaussian_mixture_example.py index e4a0d314e9d91..4938a904189f9 100644 --- a/examples/src/main/python/ml/gaussian_mixture_example.py +++ b/examples/src/main/python/ml/gaussian_mixture_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +A simple example demonstrating Gaussian Mixture Model (GMM). +Run with: + bin/spark-submit examples/src/main/python/ml/gaussian_mixture_example.py +""" from __future__ import print_function # $example on$ @@ -22,12 +27,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -A simple example demonstrating Gaussian Mixture Model (GMM). -Run with: - bin/spark-submit examples/src/main/python/ml/gaussian_mixture_example.py -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/generalized_linear_regression_example.py b/examples/src/main/python/ml/generalized_linear_regression_example.py index 796752a60f3ab..a52f4650c1c6f 100644 --- a/examples/src/main/python/ml/generalized_linear_regression_example.py +++ b/examples/src/main/python/ml/generalized_linear_regression_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example demonstrating generalized linear regression. +Run with: + bin/spark-submit examples/src/main/python/ml/generalized_linear_regression_example.py +""" from __future__ import print_function from pyspark.sql import SparkSession @@ -22,12 +27,6 @@ from pyspark.ml.regression import GeneralizedLinearRegression # $example off$ -""" -An example demonstrating generalized linear regression. -Run with: - bin/spark-submit examples/src/main/python/ml/generalized_linear_regression_example.py -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/imputer_example.py b/examples/src/main/python/ml/imputer_example.py index b8437f827e56d..9ba0147763618 100644 --- a/examples/src/main/python/ml/imputer_example.py +++ b/examples/src/main/python/ml/imputer_example.py @@ -15,16 +15,15 @@ # limitations under the License. # -# $example on$ -from pyspark.ml.feature import Imputer -# $example off$ -from pyspark.sql import SparkSession - """ An example demonstrating Imputer. Run with: bin/spark-submit examples/src/main/python/ml/imputer_example.py """ +# $example on$ +from pyspark.ml.feature import Imputer +# $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": spark = SparkSession\ diff --git a/examples/src/main/python/ml/isotonic_regression_example.py b/examples/src/main/python/ml/isotonic_regression_example.py index 6ae15f1b4b0dd..89cba9dfc7e8f 100644 --- a/examples/src/main/python/ml/isotonic_regression_example.py +++ b/examples/src/main/python/ml/isotonic_regression_example.py @@ -17,6 +17,9 @@ """ Isotonic Regression Example. + +Run with: + bin/spark-submit examples/src/main/python/ml/isotonic_regression_example.py """ from __future__ import print_function @@ -25,12 +28,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating isotonic regression. -Run with: - bin/spark-submit examples/src/main/python/ml/isotonic_regression_example.py -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/kmeans_example.py b/examples/src/main/python/ml/kmeans_example.py index 5f77843e3743a..80a878af679f4 100644 --- a/examples/src/main/python/ml/kmeans_example.py +++ b/examples/src/main/python/ml/kmeans_example.py @@ -15,6 +15,13 @@ # limitations under the License. # +""" +An example demonstrating k-means clustering. +Run with: + bin/spark-submit examples/src/main/python/ml/kmeans_example.py + +This example requires NumPy (http://www.numpy.org/). +""" from __future__ import print_function # $example on$ @@ -24,14 +31,6 @@ from pyspark.sql import SparkSession -""" -An example demonstrating k-means clustering. -Run with: - bin/spark-submit examples/src/main/python/ml/kmeans_example.py - -This example requires NumPy (http://www.numpy.org/). -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/lda_example.py b/examples/src/main/python/ml/lda_example.py index a8b346f72cd6f..97d1a042d1479 100644 --- a/examples/src/main/python/ml/lda_example.py +++ b/examples/src/main/python/ml/lda_example.py @@ -15,7 +15,11 @@ # limitations under the License. # - +""" +An example demonstrating LDA. +Run with: + bin/spark-submit examples/src/main/python/ml/lda_example.py +""" from __future__ import print_function # $example on$ @@ -23,12 +27,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating LDA. -Run with: - bin/spark-submit examples/src/main/python/ml/lda_example.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/logistic_regression_summary_example.py b/examples/src/main/python/ml/logistic_regression_summary_example.py index bd440a1fbe8df..2274ff707b2a3 100644 --- a/examples/src/main/python/ml/logistic_regression_summary_example.py +++ b/examples/src/main/python/ml/logistic_regression_summary_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example demonstrating Logistic Regression Summary. +Run with: + bin/spark-submit examples/src/main/python/ml/logistic_regression_summary_example.py +""" from __future__ import print_function # $example on$ @@ -22,12 +27,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating Logistic Regression Summary. -Run with: - bin/spark-submit examples/src/main/python/ml/logistic_regression_summary_example.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/min_hash_lsh_example.py b/examples/src/main/python/ml/min_hash_lsh_example.py index 7b1dd611a865b..93136e6ae3cae 100644 --- a/examples/src/main/python/ml/min_hash_lsh_example.py +++ b/examples/src/main/python/ml/min_hash_lsh_example.py @@ -15,7 +15,11 @@ # limitations under the License. # - +""" +An example demonstrating MinHashLSH. +Run with: + bin/spark-submit examples/src/main/python/ml/min_hash_lsh_example.py +""" from __future__ import print_function # $example on$ @@ -25,12 +29,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating MinHashLSH. -Run with: - bin/spark-submit examples/src/main/python/ml/min_hash_lsh_example.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/one_vs_rest_example.py b/examples/src/main/python/ml/one_vs_rest_example.py index 8e00c25d9342e..956e94ae4ab62 100644 --- a/examples/src/main/python/ml/one_vs_rest_example.py +++ b/examples/src/main/python/ml/one_vs_rest_example.py @@ -15,6 +15,12 @@ # limitations under the License. # +""" +An example of Multiclass to Binary Reduction with One Vs Rest, +using Logistic Regression as the base classifier. +Run with: + bin/spark-submit examples/src/main/python/ml/one_vs_rest_example.py +""" from __future__ import print_function # $example on$ @@ -23,13 +29,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example of Multiclass to Binary Reduction with One Vs Rest, -using Logistic Regression as the base classifier. -Run with: - bin/spark-submit examples/src/main/python/ml/one_vs_rest_example.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/train_validation_split.py b/examples/src/main/python/ml/train_validation_split.py index d104f7d30a1bf..d4f9184bf576e 100644 --- a/examples/src/main/python/ml/train_validation_split.py +++ b/examples/src/main/python/ml/train_validation_split.py @@ -15,13 +15,6 @@ # limitations under the License. # -# $example on$ -from pyspark.ml.evaluation import RegressionEvaluator -from pyspark.ml.regression import LinearRegression -from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit -# $example off$ -from pyspark.sql import SparkSession - """ This example demonstrates applying TrainValidationSplit to split data and preform model selection. @@ -29,6 +22,12 @@ bin/spark-submit examples/src/main/python/ml/train_validation_split.py """ +# $example on$ +from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.ml.regression import LinearRegression +from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit +# $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": spark = SparkSession\ diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index 52e9662d528d8..a3f86cf8999cf 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -15,12 +15,6 @@ # limitations under the License. # -from __future__ import print_function - -import sys - -from pyspark.sql import SparkSession - """ Read data file users.parquet in local Spark distro: @@ -35,6 +29,12 @@ {u'favorite_color': u'red', u'name': u'Ben', u'favorite_numbers': []} <...more log output...> """ +from __future__ import print_function + +import sys + +from pyspark.sql import SparkSession + if __name__ == "__main__": if len(sys.argv) != 2: print(""" diff --git a/examples/src/main/python/sql/basic.py b/examples/src/main/python/sql/basic.py index c07fa8f2752b3..c8fb25d0533b5 100644 --- a/examples/src/main/python/sql/basic.py +++ b/examples/src/main/python/sql/basic.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +A simple example demonstrating basic Spark SQL features. +Run with: + ./bin/spark-submit examples/src/main/python/sql/basic.py +""" from __future__ import print_function # $example on:init_session$ @@ -30,12 +35,6 @@ from pyspark.sql.types import * # $example off:programmatic_schema$ -""" -A simple example demonstrating basic Spark SQL features. -Run with: - ./bin/spark-submit examples/src/main/python/sql/basic.py -""" - def basic_df_example(spark): # $example on:create_df$ diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py index b375fa775de39..d8c879dfe02ed 100644 --- a/examples/src/main/python/sql/datasource.py +++ b/examples/src/main/python/sql/datasource.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +A simple example demonstrating Spark SQL data sources. +Run with: + ./bin/spark-submit examples/src/main/python/sql/datasource.py +""" from __future__ import print_function from pyspark.sql import SparkSession @@ -22,12 +27,6 @@ from pyspark.sql import Row # $example off:schema_merging$ -""" -A simple example demonstrating Spark SQL data sources. -Run with: - ./bin/spark-submit examples/src/main/python/sql/datasource.py -""" - def basic_datasource_example(spark): # $example on:generic_load_save_functions$ diff --git a/examples/src/main/python/sql/hive.py b/examples/src/main/python/sql/hive.py index 1f83a6fb48b97..33fc2dfbeefa2 100644 --- a/examples/src/main/python/sql/hive.py +++ b/examples/src/main/python/sql/hive.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +A simple example demonstrating Spark SQL Hive integration. +Run with: + ./bin/spark-submit examples/src/main/python/sql/hive.py +""" from __future__ import print_function # $example on:spark_hive$ @@ -24,12 +29,6 @@ from pyspark.sql import Row # $example off:spark_hive$ -""" -A simple example demonstrating Spark SQL Hive integration. -Run with: - ./bin/spark-submit examples/src/main/python/sql/hive.py -""" - if __name__ == "__main__": # $example on:spark_hive$ From c40fda9e4cf32d6cd17af2ace959bbbbe7c782a4 Mon Sep 17 00:00:00 2001 From: Yacine Mazari Date: Sun, 28 Jan 2018 10:27:59 -0600 Subject: [PATCH 221/774] [SPARK-23166][ML] Add maxDF Parameter to CountVectorizer ## What changes were proposed in this pull request? Currently, the CountVectorizer has a minDF parameter. It might be useful to also have a maxDF parameter. It will be used as a threshold for filtering all the terms that occur very frequently in a text corpus, because they are not very informative or could even be stop-words. This is analogous to scikit-learn, CountVectorizer, max_df. Other changes: - Refactored code to invoke "filter()" conditioned on maxDF or minDF set. - Refactored code to unpersist input after counting is done. ## How was this patch tested? Unit tests. Author: Yacine Mazari Closes #20367 from ymazari/SPARK-23166. --- .../spark/ml/feature/CountVectorizer.scala | 67 ++++++++++++++--- .../ml/feature/CountVectorizerSuite.scala | 72 +++++++++++++++++++ 2 files changed, 131 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 1ebe29703bc47..60a4f918790a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -69,6 +69,25 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** @group getParam */ def getMinDF: Double = $(minDF) + /** + * Specifies the maximum number of different documents a term must appear in to be included + * in the vocabulary. + * If this is an integer greater than or equal to 1, this specifies the number of documents + * the term must appear in; if this is a double in [0,1), then this specifies the fraction of + * documents. + * + * Default: (2^64^) - 1 + * @group param + */ + val maxDF: DoubleParam = new DoubleParam(this, "maxDF", "Specifies the maximum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of documents.", + ParamValidators.gtEq(0.0)) + + /** @group getParam */ + def getMaxDF: Double = $(maxDF) + /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false)) @@ -113,7 +132,11 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** @group getParam */ def getBinary: Boolean = $(binary) - setDefault(vocabSize -> (1 << 18), minDF -> 1.0, minTF -> 1.0, binary -> false) + setDefault(vocabSize -> (1 << 18), + minDF -> 1.0, + maxDF -> Long.MaxValue, + minTF -> 1.0, + binary -> false) } /** @@ -142,6 +165,10 @@ class CountVectorizer @Since("1.5.0") (@Since("1.5.0") override val uid: String) @Since("1.5.0") def setMinDF(value: Double): this.type = set(minDF, value) + /** @group setParam */ + @Since("2.4.0") + def setMaxDF(value: Double): this.type = set(maxDF, value) + /** @group setParam */ @Since("1.5.0") def setMinTF(value: Double): this.type = set(minTF, value) @@ -155,12 +182,24 @@ class CountVectorizer @Since("1.5.0") (@Since("1.5.0") override val uid: String) transformSchema(dataset.schema, logging = true) val vocSize = $(vocabSize) val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) + val countingRequired = $(minDF) < 1.0 || $(maxDF) < 1.0 + val maybeInputSize = if (countingRequired) { + Some(input.cache().count()) + } else { + None + } val minDf = if ($(minDF) >= 1.0) { $(minDF) } else { - $(minDF) * input.cache().count() + $(minDF) * maybeInputSize.get } - val wordCounts: RDD[(String, Long)] = input.flatMap { case (tokens) => + val maxDf = if ($(maxDF) >= 1.0) { + $(maxDF) + } else { + $(maxDF) * maybeInputSize.get + } + require(maxDf >= minDf, "maxDF must be >= minDF.") + val allWordCounts = input.flatMap { case (tokens) => val wc = new OpenHashMap[String, Long] tokens.foreach { w => wc.changeValue(w, 1L, _ + 1L) @@ -168,11 +207,23 @@ class CountVectorizer @Since("1.5.0") (@Since("1.5.0") override val uid: String) wc.map { case (word, count) => (word, (count, 1)) } }.reduceByKey { case ((wc1, df1), (wc2, df2)) => (wc1 + wc2, df1 + df2) - }.filter { case (word, (wc, df)) => - df >= minDf - }.map { case (word, (count, dfCount)) => - (word, count) - }.cache() + } + + val filteringRequired = isSet(minDF) || isSet(maxDF) + val maybeFilteredWordCounts = if (filteringRequired) { + allWordCounts.filter { case (_, (_, df)) => df >= minDf && df <= maxDf } + } else { + allWordCounts + } + + val wordCounts = maybeFilteredWordCounts + .map { case (word, (count, _)) => (word, count) } + .cache() + + if (countingRequired) { + input.unpersist() + } + val fullVocabSize = wordCounts.count() val vocab = wordCounts diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index f213145f1ba0a..1784c07ca23e3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -119,6 +119,78 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } } + test("CountVectorizer maxDF") { + val df = Seq( + (0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0), (2, 1.0)))), + (1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), + (2, split("a b"), Vectors.sparse(3, Seq((0, 1.0)))), + (3, split("a"), Vectors.sparse(3, Seq())) + ).toDF("id", "words", "expected") + + // maxDF: ignore terms with count more than 3 + val cvModel = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMaxDF(3) + .fit(df) + assert(cvModel.vocabulary === Array("b", "c", "d")) + + cvModel.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + + // maxDF: ignore terms with freq > 0.75 + val cvModel2 = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMaxDF(0.75) + .fit(df) + assert(cvModel2.vocabulary === Array("b", "c", "d")) + + cvModel2.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizer using both minDF and maxDF") { + // Ignore terms with count more than 3 AND less than 2 + val df = Seq( + (0, split("a b c d"), Vectors.sparse(2, Seq((0, 1.0), (1, 1.0)))), + (1, split("a b c"), Vectors.sparse(2, Seq((0, 1.0), (1, 1.0)))), + (2, split("a b"), Vectors.sparse(2, Seq((0, 1.0)))), + (3, split("a"), Vectors.sparse(2, Seq())) + ).toDF("id", "words", "expected") + + val cvModel = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMinDF(2) + .setMaxDF(3) + .fit(df) + assert(cvModel.vocabulary === Array("b", "c")) + + cvModel.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + + // Ignore terms with frequency higher than 0.75 AND less than 0.5 + val cvModel2 = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMinDF(0.5) + .setMaxDF(0.75) + .fit(df) + assert(cvModel2.vocabulary === Array("b", "c")) + + cvModel2.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + test("CountVectorizer throws exception when vocab is empty") { intercept[IllegalArgumentException] { val df = Seq( From 686a622c93207564635569f054e1e6c921624e96 Mon Sep 17 00:00:00 2001 From: CCInCharge Date: Sun, 28 Jan 2018 14:55:43 -0600 Subject: [PATCH 222/774] [SPARK-23250][DOCS] Typo in JavaDoc/ScalaDoc for DataFrameWriter ## What changes were proposed in this pull request? Fix typo in ScalaDoc for DataFrameWriter - originally stated "This is applicable for all file-based data sources (e.g. Parquet, JSON) staring Spark 2.1.0", should be "starting with Spark 2.1.0". ## How was this patch tested? Check of correct spelling in ScalaDoc Please review http://spark.apache.org/contributing.html before opening a pull request. Author: CCInCharge Closes #20417 from CCInCharge/master. --- .../scala/org/apache/spark/sql/DataFrameWriter.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 5f3d4448e4e54..5c02eae05304b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -174,7 +174,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * predicates on the partitioned columns. In order for partitioning to work well, the number * of distinct values in each column should typically be less than tens of thousands. * - * This is applicable for all file-based data sources (e.g. Parquet, JSON) staring Spark 2.1.0. + * This is applicable for all file-based data sources (e.g. Parquet, JSON) starting with Spark + * 2.1.0. * * @since 1.4.0 */ @@ -188,7 +189,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * Buckets the output by the given columns. If specified, the output is laid out on the file * system similar to Hive's bucketing scheme. * - * This is applicable for all file-based data sources (e.g. Parquet, JSON) staring Spark 2.1.0. + * This is applicable for all file-based data sources (e.g. Parquet, JSON) starting with Spark + * 2.1.0. * * @since 2.0 */ @@ -202,7 +204,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { /** * Sorts the output in each bucket by the given columns. * - * This is applicable for all file-based data sources (e.g. Parquet, JSON) staring Spark 2.1.0. + * This is applicable for all file-based data sources (e.g. Parquet, JSON) starting with Spark + * 2.1.0. * * @since 2.0 */ From 49b0207dc9327989c72700b4d04d2a714c92e159 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 29 Jan 2018 13:10:38 +0800 Subject: [PATCH 223/774] [SPARK-23196] Unify continuous and microbatch V2 sinks ## What changes were proposed in this pull request? Replace streaming V2 sinks with a unified StreamWriteSupport interface, with a shim to use it with microbatch execution. Add a new SQL config to use for disabling V2 sinks, falling back to the V1 sink implementation. ## How was this patch tested? Existing tests, which in the case of Kafka (the only existing continuous V2 sink) now use V2 for microbatch. Author: Jose Torres Closes #20369 from jose-torres/streaming-sink. --- .../sql/kafka010/KafkaSourceProvider.scala | 16 +-- ...usWriter.scala => KafkaStreamWriter.scala} | 30 ++--- .../kafka010/KafkaContinuousSinkSuite.scala | 8 +- .../spark/sql/kafka010/KafkaSinkSuite.scala | 14 ++- .../spark/sql/kafka010/KafkaSourceSuite.scala | 8 +- .../apache/spark/sql/internal/SQLConf.scala | 9 ++ .../v2/streaming/MicroBatchWriteSupport.java | 60 ---------- ...teSupport.java => StreamWriteSupport.java} | 12 +- ...ontinuousWriter.java => StreamWriter.java} | 34 +++++- .../sources/v2/writer/DataSourceV2Writer.java | 4 +- .../datasources/v2/WriteToDataSourceV2.scala | 11 +- .../streaming/MicroBatchExecution.scala | 19 +-- .../sql/execution/streaming/console.scala | 27 ++--- .../continuous/ContinuousExecution.scala | 19 ++- .../continuous/EpochCoordinator.scala | 9 +- .../streaming/sources/ConsoleWriter.scala | 59 ++------- .../streaming/sources/MicroBatchWriter.scala | 54 +++++++++ .../streaming/sources/memoryV2.scala | 29 ++--- .../sql/streaming/DataStreamWriter.scala | 10 +- .../sql/streaming/StreamingQueryManager.scala | 9 +- ...pache.spark.sql.sources.DataSourceRegister | 7 +- .../streaming/MemorySinkV2Suite.scala | 2 +- .../sources/StreamingDataSourceV2Suite.scala | 112 +++++++++--------- 23 files changed, 265 insertions(+), 297 deletions(-) rename external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/{KafkaContinuousWriter.scala => KafkaStreamWriter.scala} (78%) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/{ContinuousWriteSupport.java => StreamWriteSupport.java} (85%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/{ContinuousWriter.java => StreamWriter.java} (50%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 62a998fbfb30b..2deb7fa2cdf1e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -28,11 +28,11 @@ import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySe import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} -import org.apache.spark.sql.execution.streaming.{Offset, Sink, Source} +import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -46,7 +46,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider - with ContinuousWriteSupport + with StreamWriteSupport with ContinuousReadSupport with Logging { import KafkaSourceProvider._ @@ -223,11 +223,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - override def createContinuousWriter( + override def createStreamWriter( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): Optional[ContinuousWriter] = { + options: DataSourceV2Options): StreamWriter = { import scala.collection.JavaConverters._ val spark = SparkSession.getActiveSession.get @@ -238,7 +238,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister KafkaWriter.validateQuery( schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic) - Optional.of(new KafkaContinuousWriter(topic, producerParams, schema)) + new KafkaStreamWriter(topic, producerParams, schema) } private def strategy(caseInsensitiveParams: Map[String, String]) = diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala similarity index 78% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index 9843f469c5b25..a24efdefa4464 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -17,19 +17,14 @@ package org.apache.spark.sql.kafka010 -import org.apache.kafka.clients.producer.{Callback, ProducerRecord, RecordMetadata} import scala.collection.JavaConverters._ -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} -import org.apache.spark.sql.kafka010.KafkaSourceProvider.{kafkaParamsForProducer, TOPIC_OPTION_KEY} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{BinaryType, StringType, StructType} +import org.apache.spark.sql.types.StructType /** * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we @@ -38,23 +33,24 @@ import org.apache.spark.sql.types.{BinaryType, StringType, StructType} case object KafkaWriterCommitMessage extends WriterCommitMessage /** - * A [[ContinuousWriter]] for Kafka writing. Responsible for generating the writer factory. + * A [[StreamWriter]] for Kafka writing. Responsible for generating the writer factory. + * * @param topic The topic this writer is responsible for. If None, topic will be inferred from * a `topic` field in the incoming data. * @param producerParams Parameters for Kafka producers in each task. * @param schema The schema of the input data. */ -class KafkaContinuousWriter( +class KafkaStreamWriter( topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends ContinuousWriter with SupportsWriteInternalRow { + extends StreamWriter with SupportsWriteInternalRow { validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) - override def createInternalRowWriterFactory(): KafkaContinuousWriterFactory = - KafkaContinuousWriterFactory(topic, producerParams, schema) + override def createInternalRowWriterFactory(): KafkaStreamWriterFactory = + KafkaStreamWriterFactory(topic, producerParams, schema) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def abort(messages: Array[WriterCommitMessage]): Unit = {} + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} } /** @@ -65,12 +61,12 @@ class KafkaContinuousWriter( * @param producerParams Parameters for Kafka producers in each task. * @param schema The schema of the input data. */ -case class KafkaContinuousWriterFactory( +case class KafkaStreamWriterFactory( topic: Option[String], producerParams: Map[String, String], schema: StructType) extends DataWriterFactory[InternalRow] { override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { - new KafkaContinuousDataWriter(topic, producerParams, schema.toAttributes) + new KafkaStreamDataWriter(topic, producerParams, schema.toAttributes) } } @@ -83,7 +79,7 @@ case class KafkaContinuousWriterFactory( * @param producerParams Parameters to use for the Kafka producer. * @param inputSchema The attributes in the input data. */ -class KafkaContinuousDataWriter( +class KafkaStreamDataWriter( targetTopic: Option[String], producerParams: Map[String, String], inputSchema: Seq[Attribute]) extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] { import scala.collection.JavaConverters._ diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala index 8487a69851237..fc890a0cfdac3 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -18,16 +18,14 @@ package org.apache.spark.sql.kafka010 import java.util.Locale -import java.util.concurrent.atomic.AtomicInteger import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.common.serialization.ByteArraySerializer import org.scalatest.time.SpanSugar._ import scala.collection.JavaConverters._ -import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} -import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{BinaryType, DataType} import org.apache.spark.util.Utils @@ -362,7 +360,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { } finally { writer.stop() } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) + assert(ex.getCause.getCause.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) } test("streaming - exception on config serializer") { @@ -424,7 +422,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) val inputSchema = Seq(AttributeReference("value", BinaryType)()) val data = new Array[Byte](15000) // large value - val writeTask = new KafkaContinuousDataWriter(Some(topic), options.asScala.toMap, inputSchema) + val writeTask = new KafkaStreamDataWriter(Some(topic), options.asScala.toMap, inputSchema) try { val fieldTypes: Array[DataType] = Array(BinaryType) val converter = UnsafeProjection.create(fieldTypes) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 2ab336c7ac476..42f8b4c7657e2 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -336,27 +336,31 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { } finally { writer.stop() } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) + assert(ex.getCause.getCause.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) } test("streaming - exception on config serializer") { val input = MemoryStream[String] var writer: StreamingQuery = null var ex: Exception = null - ex = intercept[IllegalArgumentException] { + ex = intercept[StreamingQueryException] { writer = createKafkaWriter( input.toDF(), withOptions = Map("kafka.key.serializer" -> "foo"))() + input.addData("1") + writer.processAllAvailable() } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + assert(ex.getCause.getMessage.toLowerCase(Locale.ROOT).contains( "kafka option 'key.serializer' is not supported")) - ex = intercept[IllegalArgumentException] { + ex = intercept[StreamingQueryException] { writer = createKafkaWriter( input.toDF(), withOptions = Map("kafka.value.serializer" -> "foo"))() + input.addData("1") + writer.processAllAvailable() } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + assert(ex.getCause.getMessage.toLowerCase(Locale.ROOT).contains( "kafka option 'value.serializer' is not supported")) } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index c4cb1bc4a2e18..02c87643568bd 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -29,19 +29,17 @@ import scala.util.Random import org.apache.kafka.clients.producer.RecordMetadata import org.apache.kafka.common.TopicPartition -import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2Exec} +import org.apache.spark.sql.{Dataset, ForeachWriter} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryWriter import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ -import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest, Trigger} +import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} import org.apache.spark.util.Utils diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 76b9d6f6f33bd..2c70b004bcff9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1127,6 +1127,13 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(100) + val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers") + .internal() + .doc("A comma-separated list of fully qualified data source register class names for which" + + " StreamWriteSupport is disabled. Writes to these sources will fail back to the V1 Sink.") + .stringConf + .createWithDefault("") + object PartitionOverwriteMode extends Enumeration { val STATIC, DYNAMIC = Value } @@ -1494,6 +1501,8 @@ class SQLConf extends Serializable with Logging { def continuousStreamingExecutorPollIntervalMs: Long = getConf(CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS) + def disabledV2StreamingWriters: String = getConf(DISABLED_V2_STREAMING_WRITERS) + def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java deleted file mode 100644 index 53ffa95ae0f4c..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.streaming; - -import java.util.Optional; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.streaming.BaseStreamingSink; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; -import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; -import org.apache.spark.sql.streaming.OutputMode; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data writing ability and save the data from a microbatch to the data source. - */ -@InterfaceStability.Evolving -public interface MicroBatchWriteSupport extends BaseStreamingSink { - - /** - * Creates an optional {@link DataSourceV2Writer} to save the data to this data source. Data - * sources can return None if there is no writing needed to be done. - * - * @param queryId A unique string for the writing query. It's possible that there are many writing - * queries running at the same time, and the returned {@link DataSourceV2Writer} - * can use this id to distinguish itself from others. - * @param epochId The unique numeric ID of the batch within this writing query. This is an - * incrementing counter representing a consistent set of data; the same batch may - * be started multiple times in failure recovery scenarios, but it will always - * contain the same records. - * @param schema the schema of the data to be written. - * @param mode the output mode which determines what successive batch output means to this - * sink, please refer to {@link OutputMode} for more details. - * @param options the options for the returned data source writer, which is an immutable - * case-insensitive string-to-string map. - */ - Optional createMicroBatchWriter( - String queryId, - long epochId, - StructType schema, - OutputMode mode, - DataSourceV2Options options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java similarity index 85% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java index dee493cadb71e..6cd219c67109a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java @@ -17,26 +17,24 @@ package org.apache.spark.sql.sources.v2.streaming; -import java.util.Optional; - import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceV2Options; -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter; +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter; import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data writing ability for continuous stream processing. + * provide data writing ability for structured streaming. */ @InterfaceStability.Evolving -public interface ContinuousWriteSupport extends BaseStreamingSink { +public interface StreamWriteSupport extends BaseStreamingSink { /** - * Creates an optional {@link ContinuousWriter} to save the data to this data source. Data + * Creates an optional {@link StreamWriter} to save the data to this data source. Data * sources can return None if there is no writing needed to be done. * * @param queryId A unique string for the writing query. It's possible that there are many @@ -48,7 +46,7 @@ public interface ContinuousWriteSupport extends BaseStreamingSink { * @param options the options for the returned data source writer, which is an immutable * case-insensitive string-to-string map. */ - Optional createContinuousWriter( + StreamWriter createStreamWriter( String queryId, StructType schema, OutputMode mode, diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java similarity index 50% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java index 723395bd1e963..3156c88933e5e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java @@ -23,10 +23,14 @@ import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; /** - * A {@link DataSourceV2Writer} for use with continuous stream processing. + * A {@link DataSourceV2Writer} for use with structured streaming. This writer handles commits and + * aborts relative to an epoch ID determined by the execution engine. + * + * {@link DataWriter} implementations generated by a StreamWriter may be reused for multiple epochs, + * and so must reset any internal state after a successful commit. */ @InterfaceStability.Evolving -public interface ContinuousWriter extends DataSourceV2Writer { +public interface StreamWriter extends DataSourceV2Writer { /** * Commits this writing job for the specified epoch with a list of commit messages. The commit * messages are collected from successful data writers and are produced by @@ -34,11 +38,35 @@ public interface ContinuousWriter extends DataSourceV2Writer { * * If this method fails (by throwing an exception), this writing job is considered to have been * failed, and the execution engine will attempt to call {@link #abort(WriterCommitMessage[])}. + * + * To support exactly-once processing, writer implementations should ensure that this method is + * idempotent. The execution engine may call commit() multiple times for the same epoch + * in some circumstances. */ void commit(long epochId, WriterCommitMessage[] messages); + /** + * Aborts this writing job because some data writers are failed and keep failing when retry, or + * the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} fails. + * + * If this method fails (by throwing an exception), the underlying data source may require manual + * cleanup. + * + * Unless the abort is triggered by the failure of commit, the given messages should have some + * null slots as there maybe only a few data writers that are committed before the abort + * happens, or some data writers were committed but their commit messages haven't reached the + * driver when the abort is triggered. So this is just a "best effort" for data sources to + * clean up the data left by data writers. + */ + void abort(long epochId, WriterCommitMessage[] messages); + default void commit(WriterCommitMessage[] messages) { throw new UnsupportedOperationException( - "Commit without epoch should not be called with ContinuousWriter"); + "Commit without epoch should not be called with StreamWriter"); + } + + default void abort(WriterCommitMessage[] messages) { + throw new UnsupportedOperationException( + "Abort without epoch should not be called with StreamWriter"); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java index f1ef411423162..8048f507a1dca 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java @@ -28,9 +28,7 @@ /** * A data source writer that is returned by * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}/ - * {@link org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport#createMicroBatchWriter( - * String, long, StructType, OutputMode, DataSourceV2Options)}/ - * {@link org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport#createContinuousWriter( + * {@link org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport#createStreamWriter( * String, StructType, OutputMode, DataSourceV2Options)}. * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 3dbdae7b4df9f..cd6b3e99b6bcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -26,9 +26,8 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -62,7 +61,9 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) try { val runTask = writer match { - case w: ContinuousWriter => + // This case means that we're doing continuous processing. In microbatch streaming, the + // StreamWriter is wrapped in a MicroBatchWriter, which is executed as a normal batch. + case w: StreamWriter => EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) @@ -82,13 +83,13 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) (index, message: WriterCommitMessage) => messages(index) = message ) - if (!writer.isInstanceOf[ContinuousWriter]) { + if (!writer.isInstanceOf[StreamWriter]) { logInfo(s"Data source writer $writer is committing.") writer.commit(messages) logInfo(s"Data source writer $writer committed.") } } catch { - case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] => + case _: InterruptedException if writer.isInstanceOf[StreamWriter] => // Interruption is how continuous queries are ended, so accept and ignore the exception. case cause: Throwable => logError(s"Data source writer $writer is aborting.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 7c3804547b736..975975243a3d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -28,9 +28,11 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Curre import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.streaming.{MicroBatchReadSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.{MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} @@ -440,15 +442,18 @@ class MicroBatchExecution( val triggerLogicalPlan = sink match { case _: Sink => newAttributePlan - case s: MicroBatchWriteSupport => - val writer = s.createMicroBatchWriter( + case s: StreamWriteSupport => + val writer = s.createStreamWriter( s"$runId", - currentBatchId, newAttributePlan.schema, outputMode, new DataSourceV2Options(extraOptions.asJava)) - assert(writer.isPresent, "microbatch writer must always be present") - WriteToDataSourceV2(writer.get, newAttributePlan) + if (writer.isInstanceOf[SupportsWriteInternalRow]) { + WriteToDataSourceV2( + new InternalRowMicroBatchWriter(currentBatchId, writer), newAttributePlan) + } else { + WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan) + } case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } @@ -471,7 +476,7 @@ class MicroBatchExecution( SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { sink match { case s: Sink => s.addBatch(currentBatchId, nextBatch) - case s: MicroBatchWriteSupport => + case _: StreamWriteSupport => // This doesn't accumulate any data - it just forces execution of the microbatch writer. nextBatch.collect() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index f2aa3259731d1..d5ac0bd1df52b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.execution.streaming -import java.util.Optional - import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming.sources.{ConsoleContinuousWriter, ConsoleMicroBatchWriter, ConsoleWriter} +import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} -import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter -import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer +import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -35,26 +32,16 @@ case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) } class ConsoleSinkProvider extends DataSourceV2 - with MicroBatchWriteSupport - with ContinuousWriteSupport + with StreamWriteSupport with DataSourceRegister with CreatableRelationProvider { - override def createMicroBatchWriter( - queryId: String, - batchId: Long, - schema: StructType, - mode: OutputMode, - options: DataSourceV2Options): Optional[DataSourceV2Writer] = { - Optional.of(new ConsoleMicroBatchWriter(batchId, schema, options)) - } - - override def createContinuousWriter( + override def createStreamWriter( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): Optional[ContinuousWriter] = { - Optional.of(new ConsoleContinuousWriter(schema, options)) + options: DataSourceV2Options): StreamWriter = { + new ConsoleWriter(schema, options) } def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 462e7d9721d28..60f880f9c73b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -24,17 +24,16 @@ import java.util.function.UnaryOperator import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} -import org.apache.spark.{SparkEnv, SparkException} -import org.apache.spark.sql.{AnalysisException, SparkSession} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} +import org.apache.spark.SparkEnv +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, Utils} @@ -44,7 +43,7 @@ class ContinuousExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: ContinuousWriteSupport, + sink: StreamWriteSupport, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, @@ -195,12 +194,12 @@ class ContinuousExecution( "CurrentTimestamp and CurrentDate not yet supported for continuous processing") } - val writer = sink.createContinuousWriter( + val writer = sink.createStreamWriter( s"$runId", triggerLogicalPlan.schema, outputMode, new DataSourceV2Options(extraOptions.asJava)) - val withSink = WriteToDataSourceV2(writer.get(), triggerLogicalPlan) + val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { case DataSourceV2Relation(_, r: ContinuousReader) => r @@ -230,7 +229,7 @@ class ContinuousExecution( // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = EpochCoordinatorRef.create( - writer.get(), reader, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) + writer, reader, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) val epochUpdateThread = new Thread(new Runnable { override def run: Unit = { try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 90b3584aa0436..84d262116cb46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -17,17 +17,14 @@ package org.apache.spark.sql.execution.streaming.continuous -import java.util.concurrent.atomic.AtomicLong - import scala.collection.mutable import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, PartitionOffset} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage import org.apache.spark.util.RpcUtils @@ -85,7 +82,7 @@ private[sql] object EpochCoordinatorRef extends Logging { * Create a reference to a new [[EpochCoordinator]]. */ def create( - writer: ContinuousWriter, + writer: StreamWriter, reader: ContinuousReader, query: ContinuousExecution, epochCoordinatorId: String, @@ -118,7 +115,7 @@ private[sql] object EpochCoordinatorRef extends Logging { * have both committed and reported an end offset for a given epoch. */ private[continuous] class EpochCoordinator( - writer: ContinuousWriter, + writer: StreamWriter, reader: ContinuousReader, query: ContinuousExecution, startEpoch: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index 6fb61dff60045..7c1700f1de48c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -20,14 +20,13 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter -import org.apache.spark.sql.sources.v2.writer.{DataSourceV2Writer, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.types.StructType /** Common methods used to create writes for the the console sink */ -trait ConsoleWriter extends Logging { - - def options: DataSourceV2Options +class ConsoleWriter(schema: StructType, options: DataSourceV2Options) + extends StreamWriter with Logging { // Number of rows to display, by default 20 rows protected val numRowsToShow = options.getInt("numRows", 20) @@ -40,14 +39,20 @@ trait ConsoleWriter extends Logging { def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory - def abort(messages: Array[WriterCommitMessage]): Unit = {} + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2 + // behavior. + printRows(messages, schema, s"Batch: $epochId") + } + + def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} protected def printRows( commitMessages: Array[WriterCommitMessage], schema: StructType, printMessage: String): Unit = { val rows = commitMessages.collect { - case PackedRowCommitMessage(rows) => rows + case PackedRowCommitMessage(rs) => rs }.flatten // scalastyle:off println @@ -59,46 +64,8 @@ trait ConsoleWriter extends Logging { .createDataFrame(spark.sparkContext.parallelize(rows), schema) .show(numRowsToShow, isTruncated) } -} - - -/** - * A [[DataSourceV2Writer]] that collects results from a micro-batch query to the driver and - * prints them in the console. Created by - * [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. - * - * This sink should not be used for production, as it requires sending all rows to the driver - * and does not support recovery. - */ -class ConsoleMicroBatchWriter(batchId: Long, schema: StructType, val options: DataSourceV2Options) - extends DataSourceV2Writer with ConsoleWriter { - - override def commit(messages: Array[WriterCommitMessage]): Unit = { - printRows(messages, schema, s"Batch: $batchId") - } - - override def toString(): String = { - s"ConsoleMicroBatchWriter[numRows=$numRowsToShow, truncate=$isTruncated]" - } -} - - -/** - * A [[DataSourceV2Writer]] that collects results from a continuous query to the driver and - * prints them in the console. Created by - * [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. - * - * This sink should not be used for production, as it requires sending all rows to the driver - * and does not support recovery. - */ -class ConsoleContinuousWriter(schema: StructType, val options: DataSourceV2Options) - extends ContinuousWriter with ConsoleWriter { - - override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { - printRows(messages, schema, s"Continuous processing epoch $epochId") - } override def toString(): String = { - s"ConsoleContinuousWriter[numRows=$numRowsToShow, truncate=$isTruncated]" + s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala new file mode 100644 index 0000000000000..d7f3ba8856982 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.sources.v2.writer.{DataSourceV2Writer, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} + +/** + * A [[DataSourceV2Writer]] used to hook V2 stream writers into a microbatch plan. It implements + * the non-streaming interface, forwarding the batch ID determined at construction to a wrapped + * streaming writer. + */ +class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceV2Writer { + override def commit(messages: Array[WriterCommitMessage]): Unit = { + writer.commit(batchId, messages) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) + + override def createWriterFactory(): DataWriterFactory[Row] = writer.createWriterFactory() +} + +class InternalRowMicroBatchWriter(batchId: Long, writer: StreamWriter) + extends DataSourceV2Writer with SupportsWriteInternalRow { + override def commit(messages: Array[WriterCommitMessage]): Unit = { + writer.commit(batchId, messages) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) + + override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = + writer match { + case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() + case _ => throw new IllegalStateException( + "InternalRowMicroBatchWriter should only be created with base writer support") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index da7c31cf62428..ce55e44d932bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -30,8 +30,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.Sink import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} -import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -40,24 +40,13 @@ import org.apache.spark.sql.types.StructType * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySinkV2 extends DataSourceV2 - with MicroBatchWriteSupport with ContinuousWriteSupport with Logging { - - override def createMicroBatchWriter( +class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { + override def createStreamWriter( queryId: String, - batchId: Long, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): java.util.Optional[DataSourceV2Writer] = { - java.util.Optional.of(new MemoryWriter(this, batchId, mode)) - } - - override def createContinuousWriter( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceV2Options): java.util.Optional[ContinuousWriter] = { - java.util.Optional.of(new ContinuousMemoryWriter(this, mode)) + options: DataSourceV2Options): StreamWriter = { + new MemoryStreamWriter(this, mode) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -141,8 +130,8 @@ class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) } } -class ContinuousMemoryWriter(val sink: MemorySinkV2, outputMode: OutputMode) - extends ContinuousWriter { +class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) + extends StreamWriter { override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) @@ -153,7 +142,7 @@ class ContinuousMemoryWriter(val sink: MemorySinkV2, outputMode: OutputMode) sink.write(epochId, outputMode, newRows) } - override def abort(messages: Array[WriterCommitMessage]): Unit = { + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { // Don't accept any of the new input. } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index d24f0ddeab4de..3b5b30d77945c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -281,11 +281,9 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { trigger = trigger) } else { val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) - val sink = (ds.newInstance(), trigger) match { - case (w: ContinuousWriteSupport, _: ContinuousTrigger) => w - case (_, _: ContinuousTrigger) => throw new UnsupportedOperationException( - s"Data source $source does not support continuous writing") - case (w: MicroBatchWriteSupport, _) => w + val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") + val sink = ds.newInstance() match { + case w: StreamWriteSupport if !disabledSources.contains(w.getClass.getCanonicalName) => w case _ => val ds = DataSource( df.sparkSession, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 4b27e0d4ef47b..fdd709cdb1f38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger} import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -241,7 +241,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } (sink, trigger) match { - case (v2Sink: ContinuousWriteSupport, trigger: ContinuousTrigger) => + case (v2Sink: StreamWriteSupport, trigger: ContinuousTrigger) => UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) new StreamingQueryWrapper(new ContinuousExecution( sparkSession, @@ -254,7 +254,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo outputMode, extraOptions, deleteCheckpointOnStop)) - case (_: MicroBatchWriteSupport, _) | (_: Sink, _) => + case _ => new StreamingQueryWrapper(new MicroBatchExecution( sparkSession, userSpecifiedName.orNull, @@ -266,9 +266,6 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo outputMode, extraOptions, deleteCheckpointOnStop)) - case (_: ContinuousWriteSupport, t) if !t.isInstanceOf[ContinuousTrigger] => - throw new AnalysisException( - "Sink only supports continuous writes, but a continuous trigger was not specified.") } } diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index a0b25b4e82364..46b38bed1c0fb 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -9,7 +9,6 @@ org.apache.spark.sql.streaming.sources.FakeReadMicroBatchOnly org.apache.spark.sql.streaming.sources.FakeReadContinuousOnly org.apache.spark.sql.streaming.sources.FakeReadBothModes org.apache.spark.sql.streaming.sources.FakeReadNeitherMode -org.apache.spark.sql.streaming.sources.FakeWriteMicroBatchOnly -org.apache.spark.sql.streaming.sources.FakeWriteContinuousOnly -org.apache.spark.sql.streaming.sources.FakeWriteBothModes -org.apache.spark.sql.streaming.sources.FakeWriteNeitherMode +org.apache.spark.sql.streaming.sources.FakeWrite +org.apache.spark.sql.streaming.sources.FakeNoWrite +org.apache.spark.sql.streaming.sources.FakeWriteV1Fallback diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 00d4f0b8503d8..9be22d94b5654 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -40,7 +40,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("continuous writer") { val sink = new MemorySinkV2 - val writer = new ContinuousMemoryWriter(sink, OutputMode.Append()) + val writer = new MemoryStreamWriter(sink, OutputMode.Append()) writer.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index f152174b0a7f0..d4f8bae96695d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -19,18 +19,18 @@ package org.apache.spark.sql.streaming.sources import java.util.Optional -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming.{LongOffset, RateStreamOffset} +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, StreamingQueryWrapper} import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} +import org.apache.spark.sql.sources.v2.DataSourceV2Options import org.apache.spark.sql.sources.v2.reader.ReadTask -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport, MicroBatchReadSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming._ import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter -import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer -import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest, Trigger} +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -64,23 +64,12 @@ trait FakeContinuousReadSupport extends ContinuousReadSupport { options: DataSourceV2Options): ContinuousReader = FakeReader() } -trait FakeMicroBatchWriteSupport extends MicroBatchWriteSupport { - def createMicroBatchWriter( +trait FakeStreamWriteSupport extends StreamWriteSupport { + override def createStreamWriter( queryId: String, - epochId: Long, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): Optional[DataSourceV2Writer] = { - throw new IllegalStateException("fake sink - cannot actually write") - } -} - -trait FakeContinuousWriteSupport extends ContinuousWriteSupport { - def createContinuousWriter( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceV2Options): Optional[ContinuousWriter] = { + options: DataSourceV2Options): StreamWriter = { throw new IllegalStateException("fake sink - cannot actually write") } } @@ -102,23 +91,36 @@ class FakeReadNeitherMode extends DataSourceRegister { override def shortName(): String = "fake-read-neither-mode" } -class FakeWriteMicroBatchOnly extends DataSourceRegister with FakeMicroBatchWriteSupport { - override def shortName(): String = "fake-write-microbatch-only" +class FakeWrite extends DataSourceRegister with FakeStreamWriteSupport { + override def shortName(): String = "fake-write-microbatch-continuous" } -class FakeWriteContinuousOnly extends DataSourceRegister with FakeContinuousWriteSupport { - override def shortName(): String = "fake-write-continuous-only" +class FakeNoWrite extends DataSourceRegister { + override def shortName(): String = "fake-write-neither-mode" } -class FakeWriteBothModes extends DataSourceRegister - with FakeMicroBatchWriteSupport with FakeContinuousWriteSupport { - override def shortName(): String = "fake-write-microbatch-continuous" + +case class FakeWriteV1FallbackException() extends Exception + +class FakeSink extends Sink { + override def addBatch(batchId: Long, data: DataFrame): Unit = {} } -class FakeWriteNeitherMode extends DataSourceRegister { - override def shortName(): String = "fake-write-neither-mode" +class FakeWriteV1Fallback extends DataSourceRegister + with FakeStreamWriteSupport with StreamSinkProvider { + + override def createSink( + sqlContext: SQLContext, + parameters: Map[String, String], + partitionColumns: Seq[String], + outputMode: OutputMode): Sink = { + new FakeSink() + } + + override def shortName(): String = "fake-write-v1-fallback" } + class StreamingDataSourceV2Suite extends StreamTest { override def beforeAll(): Unit = { @@ -133,8 +135,6 @@ class StreamingDataSourceV2Suite extends StreamTest { "fake-read-microbatch-continuous", "fake-read-neither-mode") val writeFormats = Seq( - "fake-write-microbatch-only", - "fake-write-continuous-only", "fake-write-microbatch-continuous", "fake-write-neither-mode") val triggers = Seq( @@ -151,6 +151,7 @@ class StreamingDataSourceV2Suite extends StreamTest { .trigger(trigger) .start() query.stop() + query } private def testNegativeCase( @@ -184,6 +185,24 @@ class StreamingDataSourceV2Suite extends StreamTest { } } + test("disabled v2 write") { + // Ensure the V2 path works normally and generates a V2 sink.. + val v2Query = testPositiveCase( + "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) + assert(v2Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink + .isInstanceOf[FakeWriteV1Fallback]) + + // Ensure we create a V1 sink with the config. Note the config is a comma separated + // list, including other fake entries. + val fullSinkName = "org.apache.spark.sql.streaming.sources.FakeWriteV1Fallback" + withSQLConf(SQLConf.DISABLED_V2_STREAMING_WRITERS.key -> s"a,b,c,test,$fullSinkName,d,e") { + val v1Query = testPositiveCase( + "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) + assert(v1Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink + .isInstanceOf[FakeSink]) + } + } + // Get a list of (read, write, trigger) tuples for test cases. val cases = readFormats.flatMap { read => writeFormats.flatMap { write => @@ -199,12 +218,12 @@ class StreamingDataSourceV2Suite extends StreamTest { val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf).newInstance() (readSource, writeSource, trigger) match { // Valid microbatch queries. - case (_: MicroBatchReadSupport, _: MicroBatchWriteSupport, t) + case (_: MicroBatchReadSupport, _: StreamWriteSupport, t) if !t.isInstanceOf[ContinuousTrigger] => testPositiveCase(read, write, trigger) // Valid continuous queries. - case (_: ContinuousReadSupport, _: ContinuousWriteSupport, _: ContinuousTrigger) => + case (_: ContinuousReadSupport, _: StreamWriteSupport, _: ContinuousTrigger) => testPositiveCase(read, write, trigger) // Invalid - can't read at all @@ -214,31 +233,18 @@ class StreamingDataSourceV2Suite extends StreamTest { testNegativeCase(read, write, trigger, s"Data source $read does not support streamed reading") - // Invalid - trigger is continuous but writer is not - case (_, w, _: ContinuousTrigger) if !w.isInstanceOf[ContinuousWriteSupport] => - testNegativeCase(read, write, trigger, - s"Data source $write does not support continuous writing") - - // Invalid - can't write at all - case (_, w, _) - if !w.isInstanceOf[MicroBatchWriteSupport] - && !w.isInstanceOf[ContinuousWriteSupport] => + // Invalid - can't write + case (_, w, _) if !w.isInstanceOf[StreamWriteSupport] => testNegativeCase(read, write, trigger, s"Data source $write does not support streamed writing") - // Invalid - trigger and writer are continuous but reader is not - case (r, _: ContinuousWriteSupport, _: ContinuousTrigger) + // Invalid - trigger is continuous but reader is not + case (r, _: StreamWriteSupport, _: ContinuousTrigger) if !r.isInstanceOf[ContinuousReadSupport] => testNegativeCase(read, write, trigger, s"Data source $read does not support continuous processing") - // Invalid - trigger is microbatch but writer is not - case (_, w, t) - if !w.isInstanceOf[MicroBatchWriteSupport] && !t.isInstanceOf[ContinuousTrigger] => - testNegativeCase(read, write, trigger, - s"Data source $write does not support streamed writing") - - // Invalid - trigger and writer are microbatch but reader is not + // Invalid - trigger is microbatch but reader is not case (r, _, t) if !r.isInstanceOf[MicroBatchReadSupport] && !t.isInstanceOf[ContinuousTrigger] => testPostCreationNegativeCase(read, write, trigger, From 39d2c6b03488895a0acb1dd3c46329db00fdd357 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 29 Jan 2018 21:09:05 +0900 Subject: [PATCH 224/774] [SPARK-23238][SQL] Externalize SQLConf configurations exposed in documentation ## What changes were proposed in this pull request? This PR proposes to expose few internal configurations found in the documentation. Also it fixes the description for `spark.sql.execution.arrow.enabled`. It's quite self-explanatory. ## How was this patch tested? N/A Author: hyukjinkwon Closes #20403 from HyukjinKwon/minor-doc-arrow. --- .../org/apache/spark/sql/internal/SQLConf.scala | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2c70b004bcff9..61ea03d395afc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -123,14 +123,12 @@ object SQLConf { .createWithDefault(10) val COMPRESS_CACHED = buildConf("spark.sql.inMemoryColumnarStorage.compressed") - .internal() .doc("When set to true Spark SQL will automatically select a compression codec for each " + "column based on statistics of the data.") .booleanConf .createWithDefault(true) val COLUMN_BATCH_SIZE = buildConf("spark.sql.inMemoryColumnarStorage.batchSize") - .internal() .doc("Controls the size of batches for columnar caching. Larger batch sizes can improve " + "memory utilization and compression, but risk OOMs when caching data.") .intConf @@ -1043,11 +1041,11 @@ object SQLConf { val ARROW_EXECUTION_ENABLE = buildConf("spark.sql.execution.arrow.enabled") - .internal() - .doc("Make use of Apache Arrow for columnar data transfers. Currently available " + - "for use with pyspark.sql.DataFrame.toPandas with the following data types: " + - "StringType, BinaryType, BooleanType, DoubleType, FloatType, ByteType, IntegerType, " + - "LongType, ShortType") + .doc("When true, make use of Apache Arrow for columnar data transfers. Currently available " + + "for use with pyspark.sql.DataFrame.toPandas, and " + + "pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame. " + + "The following data types are unsupported: " + + "MapType, ArrayType of TimestampType, and nested StructType.") .booleanConf .createWithDefault(false) From badf0d0e0d1d9aa169ed655176ce9ae684d3905d Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Tue, 30 Jan 2018 00:50:49 +0800 Subject: [PATCH 225/774] [SPARK-23219][SQL] Rename ReadTask to DataReaderFactory in data source v2 ## What changes were proposed in this pull request? Currently we have `ReadTask` in data source v2 reader, while in writer we have `DataWriterFactory`. To make the naming consistent and better, renaming `ReadTask` to `DataReaderFactory`. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #20397 from gengliangwang/rename. --- .../sql/kafka010/KafkaContinuousReader.scala | 16 ++--- .../execution/UnsafeExternalRowSorter.java | 1 - .../v2/reader/ClusteredDistribution.java | 2 +- .../sql/sources/v2/reader/DataReader.java | 2 +- .../{ReadTask.java => DataReaderFactory.java} | 22 +++---- .../sources/v2/reader/DataSourceV2Reader.java | 11 ++-- .../sql/sources/v2/reader/Distribution.java | 6 +- .../sql/sources/v2/reader/Partitioning.java | 2 +- .../v2/reader/SupportsScanColumnarBatch.java | 11 ++-- .../v2/reader/SupportsScanUnsafeRow.java | 9 +-- .../v2/streaming/MicroBatchReadSupport.java | 4 +- .../v2/streaming/reader/ContinuousReader.java | 14 ++--- .../v2/streaming/reader/MicroBatchReader.java | 6 +- .../datasources/v2/DataSourceRDD.scala | 14 ++--- .../datasources/v2/DataSourceV2ScanExec.scala | 25 ++++---- .../ContinuousDataSourceRDDIter.scala | 11 ++-- .../ContinuousRateStreamSource.scala | 10 ++-- .../sources/RateStreamSourceV2.scala | 6 +- .../sources/v2/JavaAdvancedDataSourceV2.java | 20 +++---- .../sql/sources/v2/JavaBatchDataSourceV2.java | 10 ++-- .../v2/JavaPartitionAwareDataSource.java | 10 ++-- .../v2/JavaSchemaRequiredDataSource.java | 4 +- .../sources/v2/JavaSimpleDataSourceV2.java | 14 ++--- .../sources/v2/JavaUnsafeRowDataSourceV2.java | 13 ++-- .../streaming/RateSourceV2Suite.scala | 10 ++-- .../sql/sources/v2/DataSourceV2Suite.scala | 59 ++++++++++--------- .../sources/v2/SimpleWritableDataSource.scala | 12 ++-- .../sources/StreamingDataSourceV2Suite.scala | 4 +- 28 files changed, 172 insertions(+), 156 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{ReadTask.java => DataReaderFactory.java} (65%) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index fc977977504f7..9125cf5799d74 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -63,7 +63,7 @@ class KafkaContinuousReader( private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong - // Initialized when creating read tasks. If this diverges from the partitions at the latest + // Initialized when creating reader factories. If this diverges from the partitions at the latest // offsets, we need to reconfigure. // Exposed outside this object only for unit tests. private[sql] var knownPartitions: Set[TopicPartition] = _ @@ -89,7 +89,7 @@ class KafkaContinuousReader( KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } - override def createUnsafeRowReadTasks(): ju.List[ReadTask[UnsafeRow]] = { + override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { import scala.collection.JavaConverters._ val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) @@ -109,9 +109,9 @@ class KafkaContinuousReader( startOffsets.toSeq.map { case (topicPartition, start) => - KafkaContinuousReadTask( + KafkaContinuousDataReaderFactory( topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) - .asInstanceOf[ReadTask[UnsafeRow]] + .asInstanceOf[DataReaderFactory[UnsafeRow]] }.asJava } @@ -149,8 +149,8 @@ class KafkaContinuousReader( } /** - * A read task for continuous Kafka processing. This will be serialized and transformed into a - * full reader on executors. + * A data reader factory for continuous Kafka processing. This will be serialized and transformed + * into a full reader on executors. * * @param topicPartition The (topic, partition) pair this task is responsible for. * @param startOffset The offset to start reading from within the partition. @@ -159,12 +159,12 @@ class KafkaContinuousReader( * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets * are skipped. */ -case class KafkaContinuousReadTask( +case class KafkaContinuousDataReaderFactory( topicPartition: TopicPartition, startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ReadTask[UnsafeRow] { + failOnDataLoss: Boolean) extends DataReaderFactory[UnsafeRow] { override def createDataReader(): KafkaContinuousDataReader = { new KafkaContinuousDataReader( topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 78647b56d621f..1b2f5eee5ccdd 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -20,7 +20,6 @@ import java.io.IOException; import java.util.function.Supplier; -import org.apache.spark.sql.catalyst.util.TypeUtils; import scala.collection.AbstractIterator; import scala.collection.Iterator; import scala.math.Ordering; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java index 7346500de45b6..27905e325df87 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java @@ -22,7 +22,7 @@ /** * A concrete implementation of {@link Distribution}. Represents a distribution where records that * share the same values for the {@link #clusteredColumns} will be produced by the same - * {@link ReadTask}. + * {@link DataReader}. */ @InterfaceStability.Evolving public class ClusteredDistribution implements Distribution { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java index 8f58c865b6201..bb9790a1c819e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java @@ -23,7 +23,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data reader returned by {@link ReadTask#createDataReader()} and is responsible for + * A data reader returned by {@link DataReaderFactory#createDataReader()} and is responsible for * outputting data for a RDD partition. * * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java similarity index 65% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java index fa161cdb8b347..077b95b837964 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java @@ -22,21 +22,23 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A read task returned by {@link DataSourceV2Reader#createReadTasks()} and is responsible for - * creating the actual data reader. The relationship between {@link ReadTask} and {@link DataReader} + * A reader factory returned by {@link DataSourceV2Reader#createDataReaderFactories()} and is + * responsible for creating the actual data reader. The relationship between + * {@link DataReaderFactory} and {@link DataReader} * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. * - * Note that, the read task will be serialized and sent to executors, then the data reader will be - * created on executors and do the actual reading. So {@link ReadTask} must be serializable and - * {@link DataReader} doesn't need to be. + * Note that, the reader factory will be serialized and sent to executors, then the data reader + * will be created on executors and do the actual reading. So {@link DataReaderFactory} must be + * serializable and {@link DataReader} doesn't need to be. */ @InterfaceStability.Evolving -public interface ReadTask extends Serializable { +public interface DataReaderFactory extends Serializable { /** - * The preferred locations where this read task can run faster, but Spark does not guarantee that - * this task will always run on these locations. The implementations should make sure that it can - * be run on any location. The location is a string representing the host name. + * The preferred locations where the data reader returned by this reader factory can run faster, + * but Spark does not guarantee to run the data reader on these locations. + * The implementations should make sure that it can be run on any location. + * The location is a string representing the host name. * * Note that if a host name cannot be recognized by Spark, it will be ignored as it was not in * the returned locations. By default this method returns empty string array, which means this @@ -50,7 +52,7 @@ default String[] preferredLocations() { } /** - * Returns a data reader to do the actual reading work for this read task. + * Returns a data reader to do the actual reading work. * * If this method fails (by throwing an exception), the corresponding Spark task would fail and * get retried until hitting the maximum retry times. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java index f23c3842bf1b1..0180cd9ea47f8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java @@ -30,7 +30,8 @@ * {@link org.apache.spark.sql.sources.v2.ReadSupportWithSchema#createReader( * StructType, org.apache.spark.sql.sources.v2.DataSourceV2Options)}. * It can mix in various query optimization interfaces to speed up the data scan. The actual scan - * logic is delegated to {@link ReadTask}s that are returned by {@link #createReadTasks()}. + * logic is delegated to {@link DataReaderFactory}s that are returned by + * {@link #createDataReaderFactories()}. * * There are mainly 3 kinds of query optimizations: * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column @@ -63,9 +64,9 @@ public interface DataSourceV2Reader { StructType readSchema(); /** - * Returns a list of read tasks. Each task is responsible for outputting data for one RDD - * partition. That means the number of tasks returned here is same as the number of RDD - * partitions this scan outputs. + * Returns a list of reader factories. Each factory is responsible for creating a data reader to + * output data for one RDD partition. That means the number of factories returned here is same as + * the number of RDD partitions this scan outputs. * * Note that, this may not be a full scan if the data source reader mixes in other optimization * interfaces like column pruning, filter push-down, etc. These optimizations are applied before @@ -74,5 +75,5 @@ public interface DataSourceV2Reader { * If this method fails (by throwing an exception), the action would fail and no Spark job was * submitted. */ - List> createReadTasks(); + List> createDataReaderFactories(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java index a6201a222f541..b37562167d9ef 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java @@ -21,9 +21,9 @@ /** * An interface to represent data distribution requirement, which specifies how the records should - * be distributed among the {@link ReadTask}s that are returned by - * {@link DataSourceV2Reader#createReadTasks()}. Note that this interface has nothing to do with - * the data ordering inside one partition(the output records of a single {@link ReadTask}). + * be distributed among the data partitions(one {@link DataReader} outputs data for one partition). + * Note that this interface has nothing to do with the data ordering inside one + * partition(the output records of a single {@link DataReader}). * * The instance of this interface is created and provided by Spark, then consumed by * {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java index 199e45d4a02ab..5e334d13a1215 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java @@ -29,7 +29,7 @@ public interface Partitioning { /** - * Returns the number of partitions(i.e., {@link ReadTask}s) the data source outputs. + * Returns the number of partitions(i.e., {@link DataReaderFactory}s) the data source outputs. */ int numPartitions(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java index 27cf3a77724f0..67da55554bbf3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java @@ -30,21 +30,22 @@ @InterfaceStability.Evolving public interface SupportsScanColumnarBatch extends DataSourceV2Reader { @Override - default List> createReadTasks() { + default List> createDataReaderFactories() { throw new IllegalStateException( - "createReadTasks not supported by default within SupportsScanColumnarBatch."); + "createDataReaderFactories not supported by default within SupportsScanColumnarBatch."); } /** - * Similar to {@link DataSourceV2Reader#createReadTasks()}, but returns columnar data in batches. + * Similar to {@link DataSourceV2Reader#createDataReaderFactories()}, but returns columnar data + * in batches. */ - List> createBatchReadTasks(); + List> createBatchDataReaderFactories(); /** * Returns true if the concrete data source reader can read data in batch according to the scan * properties like required columns, pushes filters, etc. It's possible that the implementation * can only support some certain columns with certain types. Users can overwrite this method and - * {@link #createReadTasks()} to fallback to normal read path under some conditions. + * {@link #createDataReaderFactories()} to fallback to normal read path under some conditions. */ default boolean enableBatchRead() { return true; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java index 2d3ad0eee65ff..156af69520f77 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java @@ -33,13 +33,14 @@ public interface SupportsScanUnsafeRow extends DataSourceV2Reader { @Override - default List> createReadTasks() { + default List> createDataReaderFactories() { throw new IllegalStateException( - "createReadTasks not supported by default within SupportsScanUnsafeRow"); + "createDataReaderFactories not supported by default within SupportsScanUnsafeRow"); } /** - * Similar to {@link DataSourceV2Reader#createReadTasks()}, but returns data in unsafe row format. + * Similar to {@link DataSourceV2Reader#createDataReaderFactories()}, + * but returns data in unsafe row format. */ - List> createUnsafeRowReadTasks(); + List> createUnsafeRowReaderFactories(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java index 3c87a3db68243..3b357c01a29fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java @@ -36,8 +36,8 @@ public interface MicroBatchReadSupport extends DataSourceV2 { * streaming query. * * The execution engine will create a micro-batch reader at the start of a streaming query, - * alternate calls to setOffsetRange and createReadTasks for each batch to process, and then - * call stop() when the execution is complete. Note that a single query may have multiple + * alternate calls to setOffsetRange and createDataReaderFactories for each batch to process, and + * then call stop() when the execution is complete. Note that a single query may have multiple * executions due to restart or failure recovery. * * @param schema the user provided schema, or empty() if none was provided diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java index 745f1ce502443..3ac979cb0b7b4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java @@ -27,7 +27,7 @@ * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this * interface to allow reading in a continuous processing mode stream. * - * Implementations must ensure each read task output is a {@link ContinuousDataReader}. + * Implementations must ensure each reader factory output is a {@link ContinuousDataReader}. * * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. @@ -47,9 +47,9 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceV2Reade Offset deserializeOffset(String json); /** - * Set the desired start offset for read tasks created from this reader. The scan will start - * from the first record after the provided offset, or from an implementation-defined inferred - * starting point if no offset is provided. + * Set the desired start offset for reader factories created from this reader. The scan will + * start from the first record after the provided offset, or from an implementation-defined + * inferred starting point if no offset is provided. */ void setOffset(Optional start); @@ -61,9 +61,9 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceV2Reade Offset getStartOffset(); /** - * The execution engine will call this method in every epoch to determine if new read tasks need - * to be generated, which may be required if for example the underlying source system has had - * partitions added or removed. + * The execution engine will call this method in every epoch to determine if new reader + * factories need to be generated, which may be required if for example the underlying + * source system has had partitions added or removed. * * If true, the query will be shut down and restarted with a new reader. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java index 02f37cebc7484..68887e569fc1d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java @@ -33,9 +33,9 @@ @InterfaceStability.Evolving public interface MicroBatchReader extends DataSourceV2Reader, BaseStreamingSource { /** - * Set the desired offset range for read tasks created from this reader. Read tasks will - * generate only data within (`start`, `end`]; that is, from the first record after `start` to - * the record with offset `end`. + * Set the desired offset range for reader factories created from this reader. Reader factories + * will generate only data within (`start`, `end`]; that is, from the first record after `start` + * to the record with offset `end`. * * @param start The initial offset to scan from. If not specified, scan from an * implementation-specified start point, such as the earliest available record. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index ac104d7cd0cb3..5ed0ba71e94c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -22,24 +22,24 @@ import scala.reflect.ClassTag import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.sources.v2.reader.ReadTask +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory -class DataSourceRDDPartition[T : ClassTag](val index: Int, val readTask: ReadTask[T]) +class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: DataReaderFactory[T]) extends Partition with Serializable class DataSourceRDD[T: ClassTag]( sc: SparkContext, - @transient private val readTasks: java.util.List[ReadTask[T]]) + @transient private val readerFactories: java.util.List[DataReaderFactory[T]]) extends RDD[T](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readTasks.asScala.zipWithIndex.map { - case (readTask, index) => new DataSourceRDDPartition(index, readTask) + readerFactories.asScala.zipWithIndex.map { + case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) }.toArray } override def compute(split: Partition, context: TaskContext): Iterator[T] = { - val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readTask.createDataReader() + val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.createDataReader() context.addTaskCompletionListener(_ => reader.close()) val iter = new Iterator[T] { private[this] var valuePrepared = false @@ -63,6 +63,6 @@ class DataSourceRDD[T: ClassTag]( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition[T]].readTask.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 2c22239e81869..3f808fbb40932 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -51,11 +51,11 @@ case class DataSourceV2ScanExec( case _ => super.outputPartitioning } - private lazy val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks() + private lazy val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]] = reader match { + case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories() case _ => - reader.createReadTasks().asScala.map { - new RowToUnsafeRowReadTask(_, reader.readSchema()): ReadTask[UnsafeRow] + reader.createDataReaderFactories().asScala.map { + new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow] }.asJava } @@ -63,18 +63,19 @@ case class DataSourceV2ScanExec( case r: SupportsScanColumnarBatch if r.enableBatchRead() => assert(!reader.isInstanceOf[ContinuousReader], "continuous stream reader does not support columnar read yet.") - new DataSourceRDD(sparkContext, r.createBatchReadTasks()).asInstanceOf[RDD[InternalRow]] + new DataSourceRDD(sparkContext, r.createBatchDataReaderFactories()) + .asInstanceOf[RDD[InternalRow]] case _: ContinuousReader => EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) - .askSync[Unit](SetReaderPartitions(readTasks.size())) - new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks) + .askSync[Unit](SetReaderPartitions(readerFactories.size())) + new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories) .asInstanceOf[RDD[InternalRow]] case _ => - new DataSourceRDD(sparkContext, readTasks).asInstanceOf[RDD[InternalRow]] + new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]] } override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) @@ -99,14 +100,14 @@ case class DataSourceV2ScanExec( } } -class RowToUnsafeRowReadTask(rowReadTask: ReadTask[Row], schema: StructType) - extends ReadTask[UnsafeRow] { +class RowToUnsafeRowDataReaderFactory(rowReaderFactory: DataReaderFactory[Row], schema: StructType) + extends DataReaderFactory[UnsafeRow] { - override def preferredLocations: Array[String] = rowReadTask.preferredLocations + override def preferredLocations: Array[String] = rowReaderFactory.preferredLocations override def createDataReader: DataReader[UnsafeRow] = { new RowToUnsafeDataReader( - rowReadTask.createDataReader, RowEncoder.apply(schema).resolveAndBind()) + rowReaderFactory.createDataReader, RowEncoder.apply(schema).resolveAndBind()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index cd7065f5e6601..8a7a38b22caca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -39,15 +39,15 @@ import org.apache.spark.util.{SystemClock, ThreadUtils} class ContinuousDataSourceRDD( sc: SparkContext, sqlContext: SQLContext, - @transient private val readTasks: java.util.List[ReadTask[UnsafeRow]]) + @transient private val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]]) extends RDD[UnsafeRow](sc, Nil) { private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs override protected def getPartitions: Array[Partition] = { - readTasks.asScala.zipWithIndex.map { - case (readTask, index) => new DataSourceRDDPartition(index, readTask) + readerFactories.asScala.zipWithIndex.map { + case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) }.toArray } @@ -57,7 +57,8 @@ class ContinuousDataSourceRDD( throw new ContinuousTaskRetryException() } - val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader() + val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]] + .readerFactory.createDataReader() val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) @@ -136,7 +137,7 @@ class ContinuousDataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readerFactory.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index b4b21e7d2052f..61304480f4721 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -68,7 +68,7 @@ class RateStreamContinuousReader(options: DataSourceV2Options) override def getStartOffset(): Offset = offset - override def createReadTasks(): java.util.List[ReadTask[Row]] = { + override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { val partitionStartMap = offset match { case off: RateStreamOffset => off.partitionToValueAndRunTimeMs case off => @@ -86,13 +86,13 @@ class RateStreamContinuousReader(options: DataSourceV2Options) val start = partitionStartMap(i) // Have each partition advance by numPartitions each row, with starting points staggered // by their partition index. - RateStreamContinuousReadTask( + RateStreamContinuousDataReaderFactory( start.value, start.runTimeMs, i, numPartitions, perPartitionRate) - .asInstanceOf[ReadTask[Row]] + .asInstanceOf[DataReaderFactory[Row]] }.asJava } @@ -101,13 +101,13 @@ class RateStreamContinuousReader(options: DataSourceV2Options) } -case class RateStreamContinuousReadTask( +case class RateStreamContinuousDataReaderFactory( startValue: Long, startTimeMs: Long, partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ReadTask[Row] { + extends DataReaderFactory[Row] { override def createDataReader(): DataReader[Row] = new RateStreamContinuousDataReader( startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index c0ed12cec25ef..a25cc4f3b06f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -123,7 +123,7 @@ class RateStreamMicroBatchReader(options: DataSourceV2Options) RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def createReadTasks(): java.util.List[ReadTask[Row]] = { + override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { val startMap = start.partitionToValueAndRunTimeMs val endMap = end.partitionToValueAndRunTimeMs endMap.keys.toSeq.map { part => @@ -139,7 +139,7 @@ class RateStreamMicroBatchReader(options: DataSourceV2Options) outTimeMs += msPerPartitionBetweenRows } - RateStreamBatchTask(packedRows).asInstanceOf[ReadTask[Row]] + RateStreamBatchTask(packedRows).asInstanceOf[DataReaderFactory[Row]] }.toList.asJava } @@ -147,7 +147,7 @@ class RateStreamMicroBatchReader(options: DataSourceV2Options) override def stop(): Unit = {} } -case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends ReadTask[Row] { +case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactory[Row] { override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 1cfdc08217e6e..4026ee44bfdb7 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -60,8 +60,8 @@ public Filter[] pushedFilters() { } @Override - public List> createReadTasks() { - List> res = new ArrayList<>(); + public List> createDataReaderFactories() { + List> res = new ArrayList<>(); Integer lowerBound = null; for (Filter filter : filters) { @@ -75,25 +75,25 @@ public List> createReadTasks() { } if (lowerBound == null) { - res.add(new JavaAdvancedReadTask(0, 5, requiredSchema)); - res.add(new JavaAdvancedReadTask(5, 10, requiredSchema)); + res.add(new JavaAdvancedDataReaderFactory(0, 5, requiredSchema)); + res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema)); } else if (lowerBound < 4) { - res.add(new JavaAdvancedReadTask(lowerBound + 1, 5, requiredSchema)); - res.add(new JavaAdvancedReadTask(5, 10, requiredSchema)); + res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 5, requiredSchema)); + res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema)); } else if (lowerBound < 9) { - res.add(new JavaAdvancedReadTask(lowerBound + 1, 10, requiredSchema)); + res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 10, requiredSchema)); } return res; } } - static class JavaAdvancedReadTask implements ReadTask, DataReader { + static class JavaAdvancedDataReaderFactory implements DataReaderFactory, DataReader { private int start; private int end; private StructType requiredSchema; - JavaAdvancedReadTask(int start, int end, StructType requiredSchema) { + JavaAdvancedDataReaderFactory(int start, int end, StructType requiredSchema) { this.start = start; this.end = end; this.requiredSchema = requiredSchema; @@ -101,7 +101,7 @@ static class JavaAdvancedReadTask implements ReadTask, DataReader { @Override public DataReader createDataReader() { - return new JavaAdvancedReadTask(start - 1, end, requiredSchema); + return new JavaAdvancedDataReaderFactory(start - 1, end, requiredSchema); } @Override diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java index a5d77a90ece42..34e6c63801064 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -42,12 +42,14 @@ public StructType readSchema() { } @Override - public List> createBatchReadTasks() { - return java.util.Arrays.asList(new JavaBatchReadTask(0, 50), new JavaBatchReadTask(50, 90)); + public List> createBatchDataReaderFactories() { + return java.util.Arrays.asList( + new JavaBatchDataReaderFactory(0, 50), new JavaBatchDataReaderFactory(50, 90)); } } - static class JavaBatchReadTask implements ReadTask, DataReader { + static class JavaBatchDataReaderFactory + implements DataReaderFactory, DataReader { private int start; private int end; @@ -57,7 +59,7 @@ static class JavaBatchReadTask implements ReadTask, DataReader> createReadTasks() { + public List> createDataReaderFactories() { return java.util.Arrays.asList( - new SpecificReadTask(new int[]{1, 1, 3}, new int[]{4, 4, 6}), - new SpecificReadTask(new int[]{2, 4, 4}, new int[]{6, 2, 2})); + new SpecificDataReaderFactory(new int[]{1, 1, 3}, new int[]{4, 4, 6}), + new SpecificDataReaderFactory(new int[]{2, 4, 4}, new int[]{6, 2, 2})); } @Override @@ -70,12 +70,12 @@ public boolean satisfy(Distribution distribution) { } } - static class SpecificReadTask implements ReadTask, DataReader { + static class SpecificDataReaderFactory implements DataReaderFactory, DataReader { private int[] i; private int[] j; private int current = -1; - SpecificReadTask(int[] i, int[] j) { + SpecificDataReaderFactory(int[] i, int[] j) { assert i.length == j.length; this.i = i; this.j = j; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index a174bd8092cbd..f997366af1a64 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -24,7 +24,7 @@ import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; -import org.apache.spark.sql.sources.v2.reader.ReadTask; +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; import org.apache.spark.sql.types.StructType; public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema { @@ -42,7 +42,7 @@ public StructType readSchema() { } @Override - public List> createReadTasks() { + public List> createDataReaderFactories() { return java.util.Collections.emptyList(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 2d458b7f7e906..2beed431d301f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -26,7 +26,7 @@ import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.DataReader; -import org.apache.spark.sql.sources.v2.reader.ReadTask; +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; import org.apache.spark.sql.types.StructType; @@ -41,25 +41,25 @@ public StructType readSchema() { } @Override - public List> createReadTasks() { + public List> createDataReaderFactories() { return java.util.Arrays.asList( - new JavaSimpleReadTask(0, 5), - new JavaSimpleReadTask(5, 10)); + new JavaSimpleDataReaderFactory(0, 5), + new JavaSimpleDataReaderFactory(5, 10)); } } - static class JavaSimpleReadTask implements ReadTask, DataReader { + static class JavaSimpleDataReaderFactory implements DataReaderFactory, DataReader { private int start; private int end; - JavaSimpleReadTask(int start, int end) { + JavaSimpleDataReaderFactory(int start, int end) { this.start = start; this.end = end; } @Override public DataReader createDataReader() { - return new JavaSimpleReadTask(start - 1, end); + return new JavaSimpleDataReaderFactory(start - 1, end); } @Override diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java index f6aa00869a681..e8187524ea871 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java @@ -38,19 +38,20 @@ public StructType readSchema() { } @Override - public List> createUnsafeRowReadTasks() { + public List> createUnsafeRowReaderFactories() { return java.util.Arrays.asList( - new JavaUnsafeRowReadTask(0, 5), - new JavaUnsafeRowReadTask(5, 10)); + new JavaUnsafeRowDataReaderFactory(0, 5), + new JavaUnsafeRowDataReaderFactory(5, 10)); } } - static class JavaUnsafeRowReadTask implements ReadTask, DataReader { + static class JavaUnsafeRowDataReaderFactory + implements DataReaderFactory, DataReader { private int start; private int end; private UnsafeRow row; - JavaUnsafeRowReadTask(int start, int end) { + JavaUnsafeRowDataReaderFactory(int start, int end) { this.start = start; this.end = end; this.row = new UnsafeRow(2); @@ -59,7 +60,7 @@ static class JavaUnsafeRowReadTask implements ReadTask, DataReader createDataReader() { - return new JavaUnsafeRowReadTask(start - 1, end); + return new JavaUnsafeRowDataReaderFactory(start - 1, end); } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index 85085d43061bd..d2cfe7905f6fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -78,7 +78,7 @@ class RateSourceV2Suite extends StreamTest { val reader = new RateStreamMicroBatchReader( new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) reader.setOffsetRange(Optional.empty(), Optional.empty()) - val tasks = reader.createReadTasks() + val tasks = reader.createDataReaderFactories() assert(tasks.size == 11) } @@ -118,7 +118,7 @@ class RateSourceV2Suite extends StreamTest { val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createReadTasks() + val tasks = reader.createDataReaderFactories() assert(tasks.size == 1) assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20) } @@ -133,7 +133,7 @@ class RateSourceV2Suite extends StreamTest { }.toMap) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createReadTasks() + val tasks = reader.createDataReaderFactories() assert(tasks.size == 11) val readData = tasks.asScala @@ -161,12 +161,12 @@ class RateSourceV2Suite extends StreamTest { val reader = new RateStreamContinuousReader( new DataSourceV2Options(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) reader.setOffset(Optional.empty()) - val tasks = reader.createReadTasks() + val tasks = reader.createDataReaderFactories() assert(tasks.size == 2) val data = scala.collection.mutable.ListBuffer[Row]() tasks.asScala.foreach { - case t: RateStreamContinuousReadTask => + case t: RateStreamContinuousDataReaderFactory => val startTimeMs = reader.getStartOffset() .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 0620693b35d16..42c5d3bcea44b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -204,18 +204,20 @@ class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceV2Reader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createReadTasks(): JList[ReadTask[Row]] = { - java.util.Arrays.asList(new SimpleReadTask(0, 5), new SimpleReadTask(5, 10)) + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5), new SimpleDataReaderFactory(5, 10)) } } override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader } -class SimpleReadTask(start: Int, end: Int) extends ReadTask[Row] with DataReader[Row] { +class SimpleDataReaderFactory(start: Int, end: Int) + extends DataReaderFactory[Row] + with DataReader[Row] { private var current = start - 1 - override def createDataReader(): DataReader[Row] = new SimpleReadTask(start, end) + override def createDataReader(): DataReader[Row] = new SimpleDataReaderFactory(start, end) override def next(): Boolean = { current += 1 @@ -252,21 +254,21 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { requiredSchema } - override def createReadTasks(): JList[ReadTask[Row]] = { + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { val lowerBound = filters.collect { case GreaterThan("i", v: Int) => v }.headOption - val res = new ArrayList[ReadTask[Row]] + val res = new ArrayList[DataReaderFactory[Row]] if (lowerBound.isEmpty) { - res.add(new AdvancedReadTask(0, 5, requiredSchema)) - res.add(new AdvancedReadTask(5, 10, requiredSchema)) + res.add(new AdvancedDataReaderFactory(0, 5, requiredSchema)) + res.add(new AdvancedDataReaderFactory(5, 10, requiredSchema)) } else if (lowerBound.get < 4) { - res.add(new AdvancedReadTask(lowerBound.get + 1, 5, requiredSchema)) - res.add(new AdvancedReadTask(5, 10, requiredSchema)) + res.add(new AdvancedDataReaderFactory(lowerBound.get + 1, 5, requiredSchema)) + res.add(new AdvancedDataReaderFactory(5, 10, requiredSchema)) } else if (lowerBound.get < 9) { - res.add(new AdvancedReadTask(lowerBound.get + 1, 10, requiredSchema)) + res.add(new AdvancedDataReaderFactory(lowerBound.get + 1, 10, requiredSchema)) } res @@ -276,13 +278,13 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader } -class AdvancedReadTask(start: Int, end: Int, requiredSchema: StructType) - extends ReadTask[Row] with DataReader[Row] { +class AdvancedDataReaderFactory(start: Int, end: Int, requiredSchema: StructType) + extends DataReaderFactory[Row] with DataReader[Row] { private var current = start - 1 override def createDataReader(): DataReader[Row] = { - new AdvancedReadTask(start, end, requiredSchema) + new AdvancedDataReaderFactory(start, end, requiredSchema) } override def close(): Unit = {} @@ -307,16 +309,17 @@ class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceV2Reader with SupportsScanUnsafeRow { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createUnsafeRowReadTasks(): JList[ReadTask[UnsafeRow]] = { - java.util.Arrays.asList(new UnsafeRowReadTask(0, 5), new UnsafeRowReadTask(5, 10)) + override def createUnsafeRowReaderFactories(): JList[DataReaderFactory[UnsafeRow]] = { + java.util.Arrays.asList(new UnsafeRowDataReaderFactory(0, 5), + new UnsafeRowDataReaderFactory(5, 10)) } } override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader } -class UnsafeRowReadTask(start: Int, end: Int) - extends ReadTask[UnsafeRow] with DataReader[UnsafeRow] { +class UnsafeRowDataReaderFactory(start: Int, end: Int) + extends DataReaderFactory[UnsafeRow] with DataReader[UnsafeRow] { private val row = new UnsafeRow(2) row.pointTo(new Array[Byte](8 * 3), 8 * 3) @@ -341,7 +344,7 @@ class UnsafeRowReadTask(start: Int, end: Int) class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { class Reader(val readSchema: StructType) extends DataSourceV2Reader { - override def createReadTasks(): JList[ReadTask[Row]] = + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = java.util.Collections.emptyList() } @@ -354,16 +357,16 @@ class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceV2Reader with SupportsScanColumnarBatch { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createBatchReadTasks(): JList[ReadTask[ColumnarBatch]] = { - java.util.Arrays.asList(new BatchReadTask(0, 50), new BatchReadTask(50, 90)) + override def createBatchDataReaderFactories(): JList[DataReaderFactory[ColumnarBatch]] = { + java.util.Arrays.asList(new BatchDataReaderFactory(0, 50), new BatchDataReaderFactory(50, 90)) } } override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader } -class BatchReadTask(start: Int, end: Int) - extends ReadTask[ColumnarBatch] with DataReader[ColumnarBatch] { +class BatchDataReaderFactory(start: Int, end: Int) + extends DataReaderFactory[ColumnarBatch] with DataReader[ColumnarBatch] { private final val BATCH_SIZE = 20 private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) @@ -406,11 +409,11 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { class Reader extends DataSourceV2Reader with SupportsReportPartitioning { override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") - override def createReadTasks(): JList[ReadTask[Row]] = { + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { // Note that we don't have same value of column `a` across partitions. java.util.Arrays.asList( - new SpecificReadTask(Array(1, 1, 3), Array(4, 4, 6)), - new SpecificReadTask(Array(2, 4, 4), Array(6, 2, 2))) + new SpecificDataReaderFactory(Array(1, 1, 3), Array(4, 4, 6)), + new SpecificDataReaderFactory(Array(2, 4, 4), Array(6, 2, 2))) } override def outputPartitioning(): Partitioning = new MyPartitioning @@ -428,7 +431,9 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader } -class SpecificReadTask(i: Array[Int], j: Array[Int]) extends ReadTask[Row] with DataReader[Row] { +class SpecificDataReaderFactory(i: Array[Int], j: Array[Int]) + extends DataReaderFactory[Row] + with DataReader[Row] { assert(i.length == j.length) private var current = -1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index cd7252eb2e3d6..3310d6dd199d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataSourceV2Reader, ReadTask} +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, DataSourceV2Reader} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -45,7 +45,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS class Reader(path: String, conf: Configuration) extends DataSourceV2Reader { override def readSchema(): StructType = schema - override def createReadTasks(): JList[ReadTask[Row]] = { + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { @@ -54,7 +54,9 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS name.startsWith("_") || name.startsWith(".") }.map { f => val serializableConf = new SerializableConfiguration(conf) - new SimpleCSVReadTask(f.getPath.toUri.toString, serializableConf): ReadTask[Row] + new SimpleCSVDataReaderFactory( + f.getPath.toUri.toString, + serializableConf): DataReaderFactory[Row] }.toList.asJava } else { Collections.emptyList() @@ -149,8 +151,8 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } } -class SimpleCSVReadTask(path: String, conf: SerializableConfiguration) - extends ReadTask[Row] with DataReader[Row] { +class SimpleCSVDataReaderFactory(path: String, conf: SerializableConfiguration) + extends DataReaderFactory[Row] with DataReader[Row] { @transient private var lines: Iterator[String] = _ @transient private var currentLine: String = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index d4f8bae96695d..dc8c857018457 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.reader.ReadTask +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory import org.apache.spark.sql.sources.v2.streaming._ import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter @@ -45,7 +45,7 @@ case class FakeReader() extends MicroBatchReader with ContinuousReader { def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) def setOffset(start: Optional[Offset]): Unit = {} - def createReadTasks(): java.util.ArrayList[ReadTask[Row]] = { + def createDataReaderFactories(): java.util.ArrayList[DataReaderFactory[Row]] = { throw new IllegalStateException("fake source - cannot actually read") } } From 54dd7cf4ef921bc9dc12f99cfb90d1da57939901 Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Mon, 29 Jan 2018 08:56:42 -0800 Subject: [PATCH 226/774] [SPARK-23199][SQL] improved Removes repetition from group expressions in Aggregate ## What changes were proposed in this pull request? Currently, all Aggregate operations will go into RemoveRepetitionFromGroupExpressions, but there is no group expression or there is no duplicate group expression in group expression, we not need copy for logic plan. ## How was this patch tested? the existed test case. Author: caoxuewen Closes #20375 from heary-cao/RepetitionGroupExpressions. --- .../apache/spark/sql/catalyst/optimizer/Optimizer.scala | 8 ++++++-- .../sql/catalyst/optimizer/AggregateOptimizeSuite.scala | 5 ++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 8d207708c12ad..a28b6a0feb8f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1302,8 +1302,12 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { */ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a @ Aggregate(grouping, _, _) => + case a @ Aggregate(grouping, _, _) if grouping.size > 1 => val newGrouping = ExpressionSet(grouping).toSeq - a.copy(groupingExpressions = newGrouping) + if (newGrouping.size == grouping.size) { + a + } else { + a.copy(groupingExpressions = newGrouping) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index a3184a4266c7c..f8ddc93597070 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -67,10 +67,9 @@ class AggregateOptimizeSuite extends PlanTest { } test("remove repetition in grouping expression") { - val input = LocalRelation('a.int, 'b.int, 'c.int) - val query = input.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c)) + val query = testRelation.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c)) val optimized = Optimize.execute(analyzer.execute(query)) - val correctAnswer = input.groupBy('a + 1, 'b + 2)(sum('c)).analyze + val correctAnswer = testRelation.groupBy('a + 1, 'b + 2)(sum('c)).analyze comparePlans(optimized, correctAnswer) } From fbce2ed0fa5c3e9fb2bdf9d9741eb3ff0760f88c Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Mon, 29 Jan 2018 08:58:14 -0800 Subject: [PATCH 227/774] [SPARK-23059][SQL][TEST] Correct some improper with view related method usage ## What changes were proposed in this pull request? Correct some improper with view related method usage Only change test cases like: ``` test("list global temp views") { try { sql("CREATE GLOBAL TEMP VIEW v1 AS SELECT 3, 4") sql("CREATE TEMP VIEW v2 AS SELECT 1, 2") checkAnswer(sql(s"SHOW TABLES IN $globalTempDB"), Row(globalTempDB, "v1", true) :: Row("", "v2", true) :: Nil) assert(spark.catalog.listTables(globalTempDB).collect().toSeq.map(_.name) == Seq("v1", "v2")) } finally { spark.catalog.dropTempView("v1") spark.catalog.dropGlobalTempView("v2") } } ``` other change please review the code. ## How was this patch tested? See test case. Author: xubo245 <601450868@qq.com> Closes #20250 from xubo245/DropTempViewError. --- .../org/apache/spark/sql/SQLQuerySuite.scala | 48 ++++++++++--------- .../sql/execution/GlobalTempViewSuite.scala | 4 +- .../spark/sql/execution/SQLViewSuite.scala | 36 ++++++++------ .../sql/execution/command/DDLSuite.scala | 2 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 2 +- .../sql/hive/execution/HiveSQLViewSuite.scala | 26 +++++----- .../sql/hive/execution/SQLQuerySuite.scala | 44 +++++++++-------- 7 files changed, 88 insertions(+), 74 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a79ab47f0197e..ffd736d2ebbb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1565,36 +1565,38 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("specifying database name for a temporary view is not allowed") { withTempPath { dir => - val path = dir.toURI.toString - val df = - sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") - df - .write - .format("parquet") - .save(path) - - // We don't support creating a temporary table while specifying a database - intercept[AnalysisException] { + withTempView("db.t") { + val path = dir.toURI.toString + val df = + sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + df + .write + .format("parquet") + .save(path) + + // We don't support creating a temporary table while specifying a database + intercept[AnalysisException] { + spark.sql( + s""" + |CREATE TEMPORARY VIEW db.t + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + + // If you use backticks to quote the name then it's OK. spark.sql( s""" - |CREATE TEMPORARY VIEW db.t + |CREATE TEMPORARY VIEW `db.t` |USING parquet |OPTIONS ( | path '$path' |) """.stripMargin) - }.getMessage - - // If you use backticks to quote the name then it's OK. - spark.sql( - s""" - |CREATE TEMPORARY VIEW `db.t` - |USING parquet - |OPTIONS ( - | path '$path' - |) - """.stripMargin) - checkAnswer(spark.table("`db.t`"), df) + checkAnswer(spark.table("`db.t`"), df) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala index dcc6fa6403f31..972b47e96fe06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -134,8 +134,8 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { assert(spark.catalog.listTables(globalTempDB).collect().toSeq.map(_.name) == Seq("v1", "v2")) } finally { - spark.catalog.dropTempView("v1") - spark.catalog.dropGlobalTempView("v2") + spark.catalog.dropGlobalTempView("v1") + spark.catalog.dropTempView("v2") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index ce8fde28a941c..8269d4d3a285d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -53,15 +53,17 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } test("create a temp view on a permanent view") { - withView("jtv1", "temp_jtv1") { - sql("CREATE VIEW jtv1 AS SELECT * FROM jt WHERE id > 3") - sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jtv1 WHERE id < 6") - checkAnswer(sql("select count(*) FROM temp_jtv1"), Row(2)) + withView("jtv1") { + withTempView("temp_jtv1") { + sql("CREATE VIEW jtv1 AS SELECT * FROM jt WHERE id > 3") + sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jtv1 WHERE id < 6") + checkAnswer(sql("select count(*) FROM temp_jtv1"), Row(2)) + } } } test("create a temp view on a temp view") { - withView("temp_jtv1", "temp_jtv2") { + withTempView("temp_jtv1", "temp_jtv2") { sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jt WHERE id > 3") sql("CREATE TEMPORARY VIEW temp_jtv2 AS SELECT * FROM temp_jtv1 WHERE id < 6") checkAnswer(sql("select count(*) FROM temp_jtv2"), Row(2)) @@ -222,10 +224,12 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } test("error handling: disallow IF NOT EXISTS for CREATE TEMPORARY VIEW") { - val e = intercept[AnalysisException] { - sql("CREATE TEMPORARY VIEW IF NOT EXISTS myabcdview AS SELECT * FROM jt") + withTempView("myabcdview") { + val e = intercept[AnalysisException] { + sql("CREATE TEMPORARY VIEW IF NOT EXISTS myabcdview AS SELECT * FROM jt") + } + assert(e.message.contains("It is not allowed to define a TEMPORARY view with IF NOT EXISTS")) } - assert(e.message.contains("It is not allowed to define a TEMPORARY view with IF NOT EXISTS")) } test("error handling: fail if the temp view sql itself is invalid") { @@ -274,7 +278,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } test("correctly parse CREATE TEMPORARY VIEW statement") { - withView("testView") { + withTempView("testView") { sql( """CREATE TEMPORARY VIEW |testView (c1 COMMENT 'blabla', c2 COMMENT 'blabla') @@ -286,7 +290,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } test("should NOT allow CREATE TEMPORARY VIEW when TEMPORARY VIEW with same name exists") { - withView("testView") { + withTempView("testView") { sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") val e = intercept[AnalysisException] { @@ -299,15 +303,19 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { test("should allow CREATE TEMPORARY VIEW when a permanent VIEW with same name exists") { withView("testView", "default.testView") { - sql("CREATE VIEW testView AS SELECT id FROM jt") - sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") + withTempView("testView") { + sql("CREATE VIEW testView AS SELECT id FROM jt") + sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") + } } } test("should allow CREATE permanent VIEW when a TEMPORARY VIEW with same name exists") { withView("testView", "default.testView") { - sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") - sql("CREATE VIEW testView AS SELECT id FROM jt") + withTempView("testView") { + sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") + sql("CREATE VIEW testView AS SELECT id FROM jt") + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 6ca21b5aa1595..ee3674ba17821 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -739,7 +739,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { // starts with 'jar:', and it is an illegal parameter for Path, so here we copy it // to a temp file by withResourceTempPath withResourceTempPath("test-data/cars.csv") { tmpFile => - withView("testview") { + withTempView("testview") { sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1 String, c2 String) USING " + "org.apache.spark.sql.execution.datasources.csv.CSVFileFormat " + s"OPTIONS (PATH '${tmpFile.toURI}')") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index fade143a1755e..859099a321bf7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -1151,7 +1151,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv test("create a temp view using hive") { val tableName = "tab1" - withTable(tableName) { + withTempView(tableName) { val e = intercept[AnalysisException] { sql( s""" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala index 97e4c2b6b2db8..5e6e114fc3fdc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala @@ -67,20 +67,22 @@ class HiveSQLViewSuite extends SQLViewSuite with TestHiveSingleton { classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDFUpper].getCanonicalName withUserDefinedFunction(tempFunctionName -> true) { sql(s"CREATE TEMPORARY FUNCTION $tempFunctionName AS '$functionClass'") - withView("view1", "tempView1") { - withTable("tab1") { - (1 to 10).map(i => s"$i").toDF("id").write.saveAsTable("tab1") + withView("view1") { + withTempView("tempView1") { + withTable("tab1") { + (1 to 10).map(i => s"$i").toDF("id").write.saveAsTable("tab1") - // temporary view - sql(s"CREATE TEMPORARY VIEW tempView1 AS SELECT $tempFunctionName(id) from tab1") - checkAnswer(sql("select count(*) FROM tempView1"), Row(10)) + // temporary view + sql(s"CREATE TEMPORARY VIEW tempView1 AS SELECT $tempFunctionName(id) from tab1") + checkAnswer(sql("select count(*) FROM tempView1"), Row(10)) - // permanent view - val e = intercept[AnalysisException] { - sql(s"CREATE VIEW view1 AS SELECT $tempFunctionName(id) from tab1") - }.getMessage - assert(e.contains("Not allowed to create a permanent view `view1` by referencing " + - s"a temporary function `$tempFunctionName`")) + // permanent view + val e = intercept[AnalysisException] { + sql(s"CREATE VIEW view1 AS SELECT $tempFunctionName(id) from tab1") + }.getMessage + assert(e.contains("Not allowed to create a permanent view `view1` by referencing " + + s"a temporary function `$tempFunctionName`")) + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 33bcae91fdaf4..baabc4a3bca2c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1203,35 +1203,37 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("specifying database name for a temporary view is not allowed") { withTempPath { dir => - val path = dir.toURI.toString - val df = sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") - df - .write - .format("parquet") - .save(path) - - // We don't support creating a temporary table while specifying a database - intercept[AnalysisException] { + withTempView("db.t") { + val path = dir.toURI.toString + val df = sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + df + .write + .format("parquet") + .save(path) + + // We don't support creating a temporary table while specifying a database + intercept[AnalysisException] { + spark.sql( + s""" + |CREATE TEMPORARY VIEW db.t + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + } + + // If you use backticks to quote the name then it's OK. spark.sql( s""" - |CREATE TEMPORARY VIEW db.t + |CREATE TEMPORARY VIEW `db.t` |USING parquet |OPTIONS ( | path '$path' |) """.stripMargin) + checkAnswer(spark.table("`db.t`"), df) } - - // If you use backticks to quote the name then it's OK. - spark.sql( - s""" - |CREATE TEMPORARY VIEW `db.t` - |USING parquet - |OPTIONS ( - | path '$path' - |) - """.stripMargin) - checkAnswer(spark.table("`db.t`"), df) } } From 2d903cf9d3a827e54217dfc9f1e4be99d8204387 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 29 Jan 2018 09:00:54 -0800 Subject: [PATCH 228/774] [SPARK-23223][SQL] Make stacking dataset transforms more performant ## What changes were proposed in this pull request? It is a common pattern to apply multiple transforms to a `Dataset` (using `Dataset.withColumn` for example. This is currently quite expensive because we run `CheckAnalysis` on the full plan and create an encoder for each intermediate `Dataset`. This PR extends the usage of the `AnalysisBarrier` to include `CheckAnalysis`. By doing this we hide the already analyzed plan from `CheckAnalysis` because barrier is a `LeafNode`. The `AnalysisBarrier` is in the `FinishAnalysis` phase of the optimizer. We also make binding the `Dataset` encoder lazy. The bound encoder is only needed when we materialize the dataset. ## How was this patch tested? Existing test should cover this. Author: Herman van Hovell Closes #20402 from hvanhovell/SPARK-23223. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 16 ++++++++++++++-- .../sql/catalyst/analysis/CheckAnalysis.scala | 1 + .../sql/catalyst/analysis/AnalysisTest.scala | 3 +-- .../scala/org/apache/spark/sql/Dataset.scala | 8 ++++++-- .../spark/sql/execution/QueryExecution.scala | 16 ++-------------- .../apache/spark/sql/hive/test/TestHive.scala | 2 +- 6 files changed, 25 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2b14c8220d43b..91cb0365a0856 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -98,6 +98,19 @@ class Analyzer( this(catalog, conf, conf.optimizerMaxIterations) } + def executeAndCheck(plan: LogicalPlan): LogicalPlan = { + val analyzed = execute(plan) + try { + checkAnalysis(analyzed) + EliminateBarriers(analyzed) + } catch { + case e: AnalysisException => + val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed)) + ae.setStackTrace(e.getStackTrace) + throw ae + } + } + override def execute(plan: LogicalPlan): LogicalPlan = { AnalysisContext.reset() try { @@ -178,8 +191,7 @@ class Analyzer( Batch("Subquery", Once, UpdateOuterReferences), Batch("Cleanup", fixedPoint, - CleanupAliases, - EliminateBarriers) + CleanupAliases) ) /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index ef91d79f3302c..90bda2a72ad82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -356,6 +356,7 @@ trait CheckAnalysis extends PredicateHelper { } extendedCheckRules.foreach(_(plan)) plan.foreachUp { + case AnalysisBarrier(child) if !child.resolved => checkAnalysis(child) case o if !o.resolved => failAnalysis(s"unresolved operator ${o.simpleString}") case _ => } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 549a4355dfba3..3d7c91870133b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -54,8 +54,7 @@ trait AnalysisTest extends PlanTest { expectedPlan: LogicalPlan, caseSensitive: Boolean = true): Unit = { val analyzer = getAnalyzer(caseSensitive) - val actualPlan = analyzer.execute(inputPlan) - analyzer.checkAnalysis(actualPlan) + val actualPlan = analyzer.executeAndCheck(inputPlan) comparePlans(actualPlan, expectedPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index edb6644ed5ac0..cc5b647b3f037 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -62,7 +62,11 @@ import org.apache.spark.util.Utils private[sql] object Dataset { def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = { - new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) + val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) + // Eagerly bind the encoder so we verify that the encoder matches the underlying + // schema. The user will get an error if this is not the case. + dataset.deserializer + dataset } def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = { @@ -204,7 +208,7 @@ class Dataset[T] private[sql]( // The deserializer expression which can be used to build a projection and turn rows to objects // of type T, after collecting rows to the driver side. - private val deserializer = + private lazy val deserializer = exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer).deserializer private implicit def classTag = exprEnc.clsTag diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 8bfe3eff0c3b3..7cae24bf5976c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -44,19 +44,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { // TODO: Move the planner an optimizer into here from SessionState. protected def planner = sparkSession.sessionState.planner - def assertAnalyzed(): Unit = { - // Analyzer is invoked outside the try block to avoid calling it again from within the - // catch block below. - analyzed - try { - sparkSession.sessionState.analyzer.checkAnalysis(analyzed) - } catch { - case e: AnalysisException => - val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed)) - ae.setStackTrace(e.getStackTrace) - throw ae - } - } + def assertAnalyzed(): Unit = analyzed def assertSupported(): Unit = { if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { @@ -66,7 +54,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { lazy val analyzed: LogicalPlan = { SparkSession.setActiveSession(sparkSession) - sparkSession.sessionState.analyzer.execute(logical) + sparkSession.sessionState.analyzer.executeAndCheck(logical) } lazy val withCachedData: LogicalPlan = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 7287e20d55bbe..59708e7a0f2ff 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -575,7 +575,7 @@ private[hive] class TestHiveQueryExecution( logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(sparkSession.loadTestTable) // Proceed with analysis. - sparkSession.sessionState.analyzer.execute(logical) + sparkSession.sessionState.analyzer.executeAndCheck(logical) } } From 0d60b3213fe9a7ae5e9b208639f92011fdb2ca32 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 29 Jan 2018 10:25:25 -0800 Subject: [PATCH 229/774] [SPARK-22221][DOCS] Adding User Documentation for Arrow ## What changes were proposed in this pull request? Adding user facing documentation for working with Arrow in Spark Author: Bryan Cutler Author: Li Jin Author: hyukjinkwon Closes #19575 from BryanCutler/arrow-user-docs-SPARK-2221. --- docs/sql-programming-guide.md | 134 +++++++++++++++++++++++++- examples/src/main/python/sql/arrow.py | 129 +++++++++++++++++++++++++ 2 files changed, 262 insertions(+), 1 deletion(-) create mode 100644 examples/src/main/python/sql/arrow.py diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 502c0a8c37e01..d49c8d869cba6 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1640,6 +1640,138 @@ Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` a You may run `./bin/spark-sql --help` for a complete list of all available options. +# PySpark Usage Guide for Pandas with Apache Arrow + +## Apache Arrow in Spark + +Apache Arrow is an in-memory columnar data format that is used in Spark to efficiently transfer +data between JVM and Python processes. This currently is most beneficial to Python users that +work with Pandas/NumPy data. Its usage is not automatic and might require some minor +changes to configuration or code to take full advantage and ensure compatibility. This guide will +give a high-level description of how to use Arrow in Spark and highlight any differences when +working with Arrow-enabled data. + +### Ensure PyArrow Installed + +If you install PySpark using pip, then PyArrow can be brought in as an extra dependency of the +SQL module with the command `pip install pyspark[sql]`. Otherwise, you must ensure that PyArrow +is installed and available on all cluster nodes. The current supported version is 0.8.0. +You can install using pip or conda from the conda-forge channel. See PyArrow +[installation](https://arrow.apache.org/docs/python/install.html) for details. + +## Enabling for Conversion to/from Pandas + +Arrow is available as an optimization when converting a Spark DataFrame to a Pandas DataFrame +using the call `toPandas()` and when creating a Spark DataFrame from a Pandas DataFrame with +`createDataFrame(pandas_df)`. To use Arrow when executing these calls, users need to first set +the Spark configuration 'spark.sql.execution.arrow.enabled' to 'true'. This is disabled by default. + +
+
+{% include_example dataframe_with_arrow python/sql/arrow.py %} +
+
+ +Using the above optimizations with Arrow will produce the same results as when Arrow is not +enabled. Note that even with Arrow, `toPandas()` results in the collection of all records in the +DataFrame to the driver program and should be done on a small subset of the data. Not all Spark +data types are currently supported and an error can be raised if a column has an unsupported type, +see [Supported Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`, +Spark will fall back to create the DataFrame without Arrow. + +## Pandas UDFs (a.k.a. Vectorized UDFs) + +Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and +Pandas to work with the data. A Pandas UDF is defined using the keyword `pandas_udf` as a decorator +or to wrap the function, no additional configuration is required. Currently, there are two types of +Pandas UDF: Scalar and Group Map. + +### Scalar + +Scalar Pandas UDFs are used for vectorizing scalar operations. They can be used with functions such +as `select` and `withColumn`. The Python function should take `pandas.Series` as inputs and return +a `pandas.Series` of the same length. Internally, Spark will execute a Pandas UDF by splitting +columns into batches and calling the function for each batch as a subset of the data, then +concatenating the results together. + +The following example shows how to create a scalar Pandas UDF that computes the product of 2 columns. + +
+
+{% include_example scalar_pandas_udf python/sql/arrow.py %} +
+
+ +### Group Map +Group map Pandas UDFs are used with `groupBy().apply()` which implements the "split-apply-combine" pattern. +Split-apply-combine consists of three steps: +* Split the data into groups by using `DataFrame.groupBy`. +* Apply a function on each group. The input and output of the function are both `pandas.DataFrame`. The + input data contains all the rows and columns for each group. +* Combine the results into a new `DataFrame`. + +To use `groupBy().apply()`, the user needs to define the following: +* A Python function that defines the computation for each group. +* A `StructType` object or a string that defines the schema of the output `DataFrame`. + +Note that all data for a group will be loaded into memory before the function is applied. This can +lead to out of memory exceptons, especially if the group sizes are skewed. The configuration for +[maxRecordsPerBatch](#setting-arrow-batch-size) is not applied on groups and it is up to the user +to ensure that the grouped data will fit into the available memory. + +The following example shows how to use `groupby().apply()` to subtract the mean from each value in the group. + +
+
+{% include_example group_map_pandas_udf python/sql/arrow.py %} +
+
+ +For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) and +[`pyspark.sql.GroupedData.apply`](api/python/pyspark.sql.html#pyspark.sql.GroupedData.apply). + +## Usage Notes + +### Supported SQL Types + +Currently, all Spark SQL data types are supported by Arrow-based conversion except `MapType`, +`ArrayType` of `TimestampType`, and nested `StructType`. + +### Setting Arrow Batch Size + +Data partitions in Spark are converted into Arrow record batches, which can temporarily lead to +high memory usage in the JVM. To avoid possible out of memory exceptions, the size of the Arrow +record batches can be adjusted by setting the conf "spark.sql.execution.arrow.maxRecordsPerBatch" +to an integer that will determine the maximum number of rows for each batch. The default value is +10,000 records per batch. If the number of columns is large, the value should be adjusted +accordingly. Using this limit, each data partition will be made into 1 or more record batches for +processing. + +### Timestamp with Time Zone Semantics + +Spark internally stores timestamps as UTC values, and timestamp data that is brought in without +a specified time zone is converted as local time to UTC with microsecond resolution. When timestamp +data is exported or displayed in Spark, the session time zone is used to localize the timestamp +values. The session time zone is set with the configuration 'spark.sql.session.timeZone' and will +default to the JVM system local time zone if not set. Pandas uses a `datetime64` type with nanosecond +resolution, `datetime64[ns]`, with optional time zone on a per-column basis. + +When timestamp data is transferred from Spark to Pandas it will be converted to nanoseconds +and each column will be converted to the Spark session time zone then localized to that time +zone, which removes the time zone and displays values as local time. This will occur +when calling `toPandas()` or `pandas_udf` with timestamp columns. + +When timestamp data is transferred from Pandas to Spark, it will be converted to UTC microseconds. This +occurs when calling `createDataFrame` with a Pandas DataFrame or when returning a timestamp from a +`pandas_udf`. These conversions are done automatically to ensure Spark will have data in the +expected format, so it is not necessary to do any of these conversions yourself. Any nanosecond +values will be truncated. + +Note that a standard UDF (non-Pandas) will load timestamp data as Python datetime objects, which is +different than a Pandas timestamp. It is recommended to use Pandas time series functionality when +working with timestamps in `pandas_udf`s to get the best performance, see +[here](https://pandas.pydata.org/pandas-docs/stable/timeseries.html) for details. + # Migration Guide ## Upgrading From Spark SQL 2.2 to 2.3 @@ -1788,7 +1920,7 @@ options. Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other combinations of scales and precisions because currently we only infer decimal type like `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type. - In PySpark, now we need Pandas 0.19.2 or upper if you want to use Pandas related functionalities, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc. - In PySpark, the behavior of timestamp values for Pandas related functionalities was changed to respect session timezone. If you want to use the old behavior, you need to set a configuration `spark.sql.execution.pandas.respectSessionTimeZone` to `False`. See [SPARK-22395](https://issues.apache.org/jira/browse/SPARK-22395) for details. - - In PySpark, `na.fill()` or `fillna` also accepts boolean and replaces nulls with booleans. In prior Spark versions, PySpark just ignores it and returns the original Dataset/DataFrame. + - In PySpark, `na.fill()` or `fillna` also accepts boolean and replaces nulls with booleans. In prior Spark versions, PySpark just ignores it and returns the original Dataset/DataFrame. - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489). - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py new file mode 100644 index 0000000000000..6c0028b3f1c1f --- /dev/null +++ b/examples/src/main/python/sql/arrow.py @@ -0,0 +1,129 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A simple example demonstrating Arrow in Spark. +Run with: + ./bin/spark-submit examples/src/main/python/sql/arrow.py +""" + +from __future__ import print_function + +from pyspark.sql import SparkSession +from pyspark.sql.utils import require_minimum_pandas_version, require_minimum_pyarrow_version + +require_minimum_pandas_version() +require_minimum_pyarrow_version() + + +def dataframe_with_arrow_example(spark): + # $example on:dataframe_with_arrow$ + import numpy as np + import pandas as pd + + # Enable Arrow-based columnar data transfers + spark.conf.set("spark.sql.execution.arrow.enabled", "true") + + # Generate a Pandas DataFrame + pdf = pd.DataFrame(np.random.rand(100, 3)) + + # Create a Spark DataFrame from a Pandas DataFrame using Arrow + df = spark.createDataFrame(pdf) + + # Convert the Spark DataFrame back to a Pandas DataFrame using Arrow + result_pdf = df.select("*").toPandas() + # $example off:dataframe_with_arrow$ + print("Pandas DataFrame result statistics:\n%s\n" % str(result_pdf.describe())) + + +def scalar_pandas_udf_example(spark): + # $example on:scalar_pandas_udf$ + import pandas as pd + + from pyspark.sql.functions import col, pandas_udf + from pyspark.sql.types import LongType + + # Declare the function and create the UDF + def multiply_func(a, b): + return a * b + + multiply = pandas_udf(multiply_func, returnType=LongType()) + + # The function for a pandas_udf should be able to execute with local Pandas data + x = pd.Series([1, 2, 3]) + print(multiply_func(x, x)) + # 0 1 + # 1 4 + # 2 9 + # dtype: int64 + + # Create a Spark DataFrame, 'spark' is an existing SparkSession + df = spark.createDataFrame(pd.DataFrame(x, columns=["x"])) + + # Execute function as a Spark vectorized UDF + df.select(multiply(col("x"), col("x"))).show() + # +-------------------+ + # |multiply_func(x, x)| + # +-------------------+ + # | 1| + # | 4| + # | 9| + # +-------------------+ + # $example off:scalar_pandas_udf$ + + +def group_map_pandas_udf_example(spark): + # $example on:group_map_pandas_udf$ + from pyspark.sql.functions import pandas_udf, PandasUDFType + + df = spark.createDataFrame( + [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ("id", "v")) + + @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) + def substract_mean(pdf): + # pdf is a pandas.DataFrame + v = pdf.v + return pdf.assign(v=v - v.mean()) + + df.groupby("id").apply(substract_mean).show() + # +---+----+ + # | id| v| + # +---+----+ + # | 1|-0.5| + # | 1| 0.5| + # | 2|-3.0| + # | 2|-1.0| + # | 2| 4.0| + # +---+----+ + # $example off:group_map_pandas_udf$ + + +if __name__ == "__main__": + spark = SparkSession \ + .builder \ + .appName("Python Arrow-in-Spark example") \ + .getOrCreate() + + print("Running Pandas to/from conversion example") + dataframe_with_arrow_example(spark) + print("Running pandas_udf scalar example") + scalar_pandas_udf_example(spark) + print("Running pandas_udf group map example") + group_map_pandas_udf_example(spark) + + spark.stop() From e30b34f7bd9a687eb43d636fffeb98fe235fcbf4 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 29 Jan 2018 10:29:42 -0800 Subject: [PATCH 230/774] [SPARK-22916][SQL][FOLLOW-UP] Update the Description of Join Selection ## What changes were proposed in this pull request? This PR is to update the description of the join algorithm changes. ## How was this patch tested? N/A Author: gatorsmile Closes #20420 from gatorsmile/followUp22916. --- .../spark/sql/execution/SparkStrategies.scala | 60 +++++++++++++++---- 1 file changed, 47 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ce512bc46563a..82b4eb9fba242 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -91,23 +91,58 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Select the proper physical plan for join based on joining keys and size of logical plan. * * At first, uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the - * predicates can be evaluated by matching join keys. If found, Join implementations are chosen + * predicates can be evaluated by matching join keys. If found, join implementations are chosen * with the following precedence: * - * - Broadcast: We prefer to broadcast the join side with an explicit broadcast hint(e.g. the - * user applied the [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame). - * If both sides have the broadcast hint, we prefer to broadcast the side with a smaller - * estimated physical size. If neither one of the sides has the broadcast hint, - * we only broadcast the join side if its estimated physical size that is smaller than - * the user-configurable [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold. + * - Broadcast hash join (BHJ): + * BHJ is not supported for full outer join. For right outer join, we only can broadcast the + * left side. For left outer, left semi, left anti and the internal join type ExistenceJoin, + * we only can broadcast the right side. For inner like join, we can broadcast both sides. + * Normally, BHJ can perform faster than the other join algorithms when the broadcast side is + * small. However, broadcasting tables is a network-intensive operation. It could cause OOM + * or perform worse than the other join algorithms, especially when the build/broadcast side + * is big. + * + * For the supported cases, users can specify the broadcast hint (e.g. the user applied the + * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame) and session-based + * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold to adjust whether BHJ is used and + * which join side is broadcast. + * + * 1) Broadcast the join side with the broadcast hint, even if the size is larger than + * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]]. If both sides have the hint (only when the type + * is inner like join), the side with a smaller estimated physical size will be broadcast. + * 2) Respect the [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold and broadcast the side + * whose estimated physical size is smaller than the threshold. If both sides are below the + * threshold, broadcast the smaller side. If neither is smaller, BHJ is not used. + * * - Shuffle hash join: if the average size of a single partition is small enough to build a hash * table. + * * - Sort merge: if the matching join keys are sortable. * * If there is no joining keys, Join implementations are chosen with the following precedence: - * - BroadcastNestedLoopJoin: if one side of the join could be broadcasted - * - CartesianProduct: for Inner join - * - BroadcastNestedLoopJoin + * - BroadcastNestedLoopJoin (BNLJ): + * BNLJ supports all the join types but the impl is OPTIMIZED for the following scenarios: + * For right outer join, the left side is broadcast. For left outer, left semi, left anti + * and the internal join type ExistenceJoin, the right side is broadcast. For inner like + * joins, either side is broadcast. + * + * Like BHJ, users still can specify the broadcast hint and session-based + * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold to impact which side is broadcast. + * + * 1) Broadcast the join side with the broadcast hint, even if the size is larger than + * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]]. If both sides have the hint (i.e., just for + * inner-like join), the side with a smaller estimated physical size will be broadcast. + * 2) Respect the [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold and broadcast the side + * whose estimated physical size is smaller than the threshold. If both sides are below the + * threshold, broadcast the smaller side. If neither is smaller, BNLJ is not used. + * + * - CartesianProduct: for inner like join, CartesianProduct is the fallback option. + * + * - BroadcastNestedLoopJoin (BNLJ): + * For the other join types, BNLJ is the fallback option. Here, we just pick the broadcast + * side with the broadcast hint. If neither side has a hint, we broadcast the side with + * the smaller estimated physical size. */ object JoinSelection extends Strategy with PredicateHelper { @@ -140,8 +175,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } private def canBuildRight(joinType: JoinType): Boolean = joinType match { - case _: InnerLike | LeftOuter | LeftSemi | LeftAnti => true - case j: ExistenceJoin => true + case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true case _ => false } @@ -244,7 +278,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- Without joining keys ------------------------------------------------------------ - // Pick BroadcastNestedLoopJoin if one side could be broadcasted + // Pick BroadcastNestedLoopJoin if one side could be broadcast case j @ logical.Join(left, right, joinType, condition) if canBroadcastByHints(joinType, left, right) => val buildSide = broadcastSideByHints(joinType, left, right) From b834446ec1338349f6d974afd96f677db3e8fd1a Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 29 Jan 2018 16:09:14 -0600 Subject: [PATCH 231/774] [SPARK-23209][core] Allow credential manager to work when Hive not available. The JVM seems to be doing early binding of classes that the Hive provider depends on, causing an error to be thrown before it was caught by the code in the class. The fix wraps the creation of the provider in a try..catch so that the provider can be ignored when dependencies are missing. Added a unit test (which fails without the fix), and also tested that getting tokens still works in a real cluster. Author: Marcelo Vanzin Closes #20399 from vanzin/SPARK-23209. --- .../HadoopDelegationTokenManager.scala | 17 +++++- .../HadoopDelegationTokenManagerSuite.scala | 58 +++++++++++++++++++ 2 files changed, 72 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index 116a686fe1480..5151df00476f9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -64,9 +64,9 @@ private[spark] class HadoopDelegationTokenManager( } private def getDelegationTokenProviders: Map[String, HadoopDelegationTokenProvider] = { - val providers = List(new HadoopFSDelegationTokenProvider(fileSystems), - new HiveDelegationTokenProvider, - new HBaseDelegationTokenProvider) + val providers = Seq(new HadoopFSDelegationTokenProvider(fileSystems)) ++ + safeCreateProvider(new HiveDelegationTokenProvider) ++ + safeCreateProvider(new HBaseDelegationTokenProvider) // Filter out providers for which spark.security.credentials.{service}.enabled is false. providers @@ -75,6 +75,17 @@ private[spark] class HadoopDelegationTokenManager( .toMap } + private def safeCreateProvider( + createFn: => HadoopDelegationTokenProvider): Option[HadoopDelegationTokenProvider] = { + try { + Some(createFn) + } catch { + case t: Throwable => + logDebug(s"Failed to load built in provider.", t) + None + } + } + def isServiceEnabled(serviceName: String): Boolean = { val key = providerEnabledConfig.format(serviceName) diff --git a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala index eeffc36070b44..2849a10a2c81e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy.security +import org.apache.commons.io.IOUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.security.Credentials @@ -110,7 +111,64 @@ class HadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers { creds.getAllTokens.size should be (0) } + test("SPARK-23209: obtain tokens when Hive classes are not available") { + // This test needs a custom class loader to hide Hive classes which are in the classpath. + // Because the manager code loads the Hive provider directly instead of using reflection, we + // need to drive the test through the custom class loader so a new copy that cannot find + // Hive classes is loaded. + val currentLoader = Thread.currentThread().getContextClassLoader() + val noHive = new ClassLoader() { + override def loadClass(name: String, resolve: Boolean): Class[_] = { + if (name.startsWith("org.apache.hive") || name.startsWith("org.apache.hadoop.hive")) { + throw new ClassNotFoundException(name) + } + + if (name.startsWith("java") || name.startsWith("scala")) { + currentLoader.loadClass(name) + } else { + val classFileName = name.replaceAll("\\.", "/") + ".class" + val in = currentLoader.getResourceAsStream(classFileName) + if (in != null) { + val bytes = IOUtils.toByteArray(in) + defineClass(name, bytes, 0, bytes.length) + } else { + throw new ClassNotFoundException(name) + } + } + } + } + + try { + Thread.currentThread().setContextClassLoader(noHive) + val test = noHive.loadClass(NoHiveTest.getClass.getName().stripSuffix("$")) + test.getMethod("runTest").invoke(null) + } finally { + Thread.currentThread().setContextClassLoader(currentLoader) + } + } + private[spark] def hadoopFSsToAccess(hadoopConf: Configuration): Set[FileSystem] = { Set(FileSystem.get(hadoopConf)) } } + +/** Test code for SPARK-23209 to avoid using too much reflection above. */ +private object NoHiveTest extends Matchers { + + def runTest(): Unit = { + try { + val manager = new HadoopDelegationTokenManager(new SparkConf(), new Configuration(), + _ => Set()) + manager.getServiceDelegationTokenProvider("hive") should be (None) + } catch { + case e: Throwable => + // Throw a better exception in case the test fails, since there may be a lot of nesting. + var cause = e + while (cause.getCause() != null) { + cause = cause.getCause() + } + throw cause + } + } + +} From f235df66a4754cbb64d5b7b5cfd5a52bdd243b8a Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 29 Jan 2018 17:37:55 -0800 Subject: [PATCH 232/774] [SPARK-22221][SQL][FOLLOWUP] Externalize spark.sql.execution.arrow.maxRecordsPerBatch ## What changes were proposed in this pull request? This is a followup to #19575 which added a section on setting max Arrow record batches and this will externalize the conf that was referenced in the docs. ## How was this patch tested? NA Author: Bryan Cutler Closes #20423 from BryanCutler/arrow-user-doc-externalize-maxRecordsPerBatch-SPARK-22221. --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 61ea03d395afc..54a35594f505e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1051,7 +1051,6 @@ object SQLConf { val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH = buildConf("spark.sql.execution.arrow.maxRecordsPerBatch") - .internal() .doc("When using Apache Arrow, limit the maximum number of records that can be written " + "to a single ArrowRecordBatch in memory. If set to zero or negative there is no limit.") .intConf From 31bd1dab1301d27a16c9d5d1b0b3301d618b0516 Mon Sep 17 00:00:00 2001 From: Paul Mackles Date: Tue, 30 Jan 2018 11:15:27 +0800 Subject: [PATCH 233/774] [SPARK-23088][CORE] History server not showing incomplete/running applications ## What changes were proposed in this pull request? History server not showing incomplete/running applications when spark.history.ui.maxApplications property is set to a value that is smaller than the total number of applications. ## How was this patch tested? Verified manually against master and 2.2.2 branch. Author: Paul Mackles Closes #20335 from pmackles/SPARK-23088. --- .../resources/org/apache/spark/ui/static/historypage.js | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 2cde66b081a1c..f0b2a5a833a99 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -108,7 +108,12 @@ $(document).ready(function() { requestedIncomplete = getParameterByName("showIncomplete", searchString); requestedIncomplete = (requestedIncomplete == "true" ? true : false); - $.getJSON("api/v1/applications?limit=" + appLimit, function(response,status,jqXHR) { + appParams = { + limit: appLimit, + status: (requestedIncomplete ? "running" : "completed") + }; + + $.getJSON("api/v1/applications", appParams, function(response,status,jqXHR) { var array = []; var hasMultipleAttempts = false; for (i in response) { From b375397b1678b7fe20a0b7f87a7e8b37ae5646ef Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 30 Jan 2018 11:40:42 +0800 Subject: [PATCH 234/774] [SPARK-23207][SQL][FOLLOW-UP] Don't perform local sort for DataFrame.repartition(1) ## What changes were proposed in this pull request? In `ShuffleExchangeExec`, we don't need to insert extra local sort before round-robin partitioning, if the new partitioning has only 1 partition, because under that case all output rows go to the same partition. ## How was this patch tested? The existing test cases. Author: Xingbo Jiang Closes #20426 from jiangxb1987/repartition1. --- .../spark/sql/execution/exchange/ShuffleExchangeExec.scala | 4 ++++ .../spark/sql/execution/streaming/ForeachSinkSuite.scala | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 76c1fa65f924b..4d95ee34f30de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -257,7 +257,11 @@ object ShuffleExchangeExec { // // Currently we following the most straight-forward way that perform a local sort before // partitioning. + // + // Note that we don't perform local sort if the new partitioning has only 1 partition, under + // that case all output rows go to the same partition. val newRdd = if (SQLConf.get.sortBeforeRepartition && + newPartitioning.numPartitions > 1 && newPartitioning.isInstanceOf[RoundRobinPartitioning]) { rdd.mapPartitionsInternal { iter => val recordComparatorSupplier = new Supplier[RecordComparator] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala index 1248c670df45c..41434e6d8b974 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala @@ -162,7 +162,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf val allEvents = ForeachSinkSuite.allEvents() assert(allEvents.size === 1) assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0)) - assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 2)) + assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1)) // `close` should be called with the error val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close] From 8b983243e45dfe2617c043a3229a7d87f4c4b44b Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Mon, 29 Jan 2018 22:19:59 -0800 Subject: [PATCH 235/774] [SPARK-23157][SQL] Explain restriction on column expression in withColumn() ## What changes were proposed in this pull request? It's not obvious from the comments that any added column must be a function of the dataset that we are adding it to. Add a comment to that effect to Scala, Python and R Data* methods. Author: Henry Robinson Closes #20429 from henryr/SPARK-23157. --- R/pkg/R/DataFrame.R | 3 ++- python/pyspark/sql/dataframe.py | 4 ++++ sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 3 +++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 29f3e986eaab6..547b5ea48a555 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2090,7 +2090,8 @@ setMethod("selectExpr", #' #' @param x a SparkDataFrame. #' @param colName a column name. -#' @param col a Column expression, or an atomic vector in the length of 1 as literal value. +#' @param col a Column expression (which must refer only to this DataFrame), or an atomic vector in +#' the length of 1 as literal value. #' @return A SparkDataFrame with the new column added or the existing column replaced. #' @family SparkDataFrame functions #' @aliases withColumn,SparkDataFrame,character-method diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ac403080acfdf..055b2c4a0ffec 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1829,11 +1829,15 @@ def withColumn(self, colName, col): Returns a new :class:`DataFrame` by adding a column or replacing the existing column that has the same name. + The column expression must be an expression over this dataframe; attempting to add + a column from some other dataframe will raise an error. + :param colName: string, name of the new column. :param col: a :class:`Column` expression for the new column. >>> df.withColumn('age2', df.age + 2).collect() [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] + """ assert isinstance(col, Column), "col should be Column" return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index cc5b647b3f037..d47cd0aecf56a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2150,6 +2150,9 @@ class Dataset[T] private[sql]( * Returns a new Dataset by adding a column or replacing the existing column that has * the same name. * + * `column`'s expression must only refer to attributes supplied by this Dataset. It is an + * error to add a column that refers to some other Dataset. + * * @group untypedrel * @since 2.0.0 */ From 5056877e8bea56dd0f4dc9e3385669e1e78b2925 Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 30 Jan 2018 09:02:16 +0200 Subject: [PATCH 236/774] [SPARK-23138][ML][DOC] Multiclass logistic regression summary example and user guide ## What changes were proposed in this pull request? User guide and examples are updated to reflect multiclass logistic regression summary which was added in [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139). I did not make a separate summary example, but added the summary code to the multiclass example that already existed. I don't see the need for a separate example for the summary. ## How was this patch tested? Docs and examples only. Ran all examples locally using spark-submit. Author: sethah Closes #20332 from sethah/multiclass_summary_example. --- docs/ml-classification-regression.md | 22 +++---- .../JavaLogisticRegressionSummaryExample.java | 17 ++--- ...gisticRegressionWithElasticNetExample.java | 62 +++++++++++++++++++ ...ss_logistic_regression_with_elastic_net.py | 38 ++++++++++++ .../ml/LogisticRegressionSummaryExample.scala | 15 ++--- ...isticRegressionWithElasticNetExample.scala | 43 +++++++++++++ 6 files changed, 164 insertions(+), 33 deletions(-) diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index bf979f3c73a52..ddd2f4b49ca07 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -87,7 +87,7 @@ More details on parameters can be found in the [R API documentation](api/R/spark The `spark.ml` implementation of logistic regression also supports extracting a summary of the model over the training set. Note that the predictions and metrics which are stored as `DataFrame` in -`BinaryLogisticRegressionSummary` are annotated `@transient` and hence +`LogisticRegressionSummary` are annotated `@transient` and hence only available on the driver.
@@ -97,10 +97,9 @@ only available on the driver. [`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary) provides a summary for a [`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel). -Currently, only binary classification is supported and the -summary must be explicitly cast to -[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary). -This will likely change when multiclass classification is supported. +In the case of binary classification, certain additional metrics are +available, e.g. ROC curve. The binary summary can be accessed via the +`binarySummary` method. See [`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary). Continuing the earlier example: @@ -111,10 +110,9 @@ Continuing the earlier example: [`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html) provides a summary for a [`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html). -Currently, only binary classification is supported and the -summary must be explicitly cast to -[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). -Support for multiclass model summaries will be added in the future. +In the case of binary classification, certain additional metrics are +available, e.g. ROC curve. The binary summary can be accessed via the +`binarySummary` method. See [`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). Continuing the earlier example: @@ -125,7 +123,8 @@ Continuing the earlier example: [`LogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionSummary) provides a summary for a [`LogisticRegressionModel`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionModel). -Currently, only binary classification is supported. Support for multiclass model summaries will be added in the future. +In the case of binary classification, certain additional metrics are +available, e.g. ROC curve. See [`BinaryLogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.BinaryLogisticRegressionTrainingSummary). Continuing the earlier example: @@ -162,7 +161,8 @@ For a detailed derivation please see [here](https://en.wikipedia.org/wiki/Multin **Examples** The following example shows how to train a multiclass logistic regression -model with elastic net regularization. +model with elastic net regularization, as well as extract the multiclass +training summary for evaluating the model.
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java index dee56799d8aee..1529da16f051f 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java @@ -18,10 +18,9 @@ package org.apache.spark.examples.ml; // $example on$ -import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary; +import org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.LogisticRegressionModel; -import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; @@ -50,7 +49,7 @@ public static void main(String[] args) { // $example on$ // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier // example - LogisticRegressionTrainingSummary trainingSummary = lrModel.summary(); + BinaryLogisticRegressionTrainingSummary trainingSummary = lrModel.binarySummary(); // Obtain the loss per iteration. double[] objectiveHistory = trainingSummary.objectiveHistory(); @@ -58,21 +57,15 @@ public static void main(String[] args) { System.out.println(lossPerIteration); } - // Obtain the metrics useful to judge performance on test data. - // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a binary - // classification problem. - BinaryLogisticRegressionSummary binarySummary = - (BinaryLogisticRegressionSummary) trainingSummary; - // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. - Dataset roc = binarySummary.roc(); + Dataset roc = trainingSummary.roc(); roc.show(); roc.select("FPR").show(); - System.out.println(binarySummary.areaUnderROC()); + System.out.println(trainingSummary.areaUnderROC()); // Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with // this selected threshold. - Dataset fMeasure = binarySummary.fMeasureByThreshold(); + Dataset fMeasure = trainingSummary.fMeasureByThreshold(); double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0); double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)) .select("threshold").head().getDouble(0); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java index da410cba2b3f1..801a82cd2f24f 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java @@ -20,6 +20,7 @@ // $example on$ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; @@ -48,6 +49,67 @@ public static void main(String[] args) { // Print the coefficients and intercept for multinomial logistic regression System.out.println("Coefficients: \n" + lrModel.coefficientMatrix() + " \nIntercept: " + lrModel.interceptVector()); + LogisticRegressionTrainingSummary trainingSummary = lrModel.summary(); + + // Obtain the loss per iteration. + double[] objectiveHistory = trainingSummary.objectiveHistory(); + for (double lossPerIteration : objectiveHistory) { + System.out.println(lossPerIteration); + } + + // for multiclass, we can inspect metrics on a per-label basis + System.out.println("False positive rate by label:"); + int i = 0; + double[] fprLabel = trainingSummary.falsePositiveRateByLabel(); + for (double fpr : fprLabel) { + System.out.println("label " + i + ": " + fpr); + i++; + } + + System.out.println("True positive rate by label:"); + i = 0; + double[] tprLabel = trainingSummary.truePositiveRateByLabel(); + for (double tpr : tprLabel) { + System.out.println("label " + i + ": " + tpr); + i++; + } + + System.out.println("Precision by label:"); + i = 0; + double[] precLabel = trainingSummary.precisionByLabel(); + for (double prec : precLabel) { + System.out.println("label " + i + ": " + prec); + i++; + } + + System.out.println("Recall by label:"); + i = 0; + double[] recLabel = trainingSummary.recallByLabel(); + for (double rec : recLabel) { + System.out.println("label " + i + ": " + rec); + i++; + } + + System.out.println("F-measure by label:"); + i = 0; + double[] fLabel = trainingSummary.fMeasureByLabel(); + for (double f : fLabel) { + System.out.println("label " + i + ": " + f); + i++; + } + + double accuracy = trainingSummary.accuracy(); + double falsePositiveRate = trainingSummary.weightedFalsePositiveRate(); + double truePositiveRate = trainingSummary.weightedTruePositiveRate(); + double fMeasure = trainingSummary.weightedFMeasure(); + double precision = trainingSummary.weightedPrecision(); + double recall = trainingSummary.weightedRecall(); + System.out.println("Accuracy: " + accuracy); + System.out.println("FPR: " + falsePositiveRate); + System.out.println("TPR: " + truePositiveRate); + System.out.println("F-measure: " + fMeasure); + System.out.println("Precision: " + precision); + System.out.println("Recall: " + recall); // $example off$ spark.stop(); diff --git a/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py b/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py index bb9cd82d6ba27..bec9860c79a2d 100644 --- a/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py +++ b/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py @@ -43,6 +43,44 @@ # Print the coefficients and intercept for multinomial logistic regression print("Coefficients: \n" + str(lrModel.coefficientMatrix)) print("Intercept: " + str(lrModel.interceptVector)) + + trainingSummary = lrModel.summary + + # Obtain the objective per iteration + objectiveHistory = trainingSummary.objectiveHistory + print("objectiveHistory:") + for objective in objectiveHistory: + print(objective) + + # for multiclass, we can inspect metrics on a per-label basis + print("False positive rate by label:") + for i, rate in enumerate(trainingSummary.falsePositiveRateByLabel): + print("label %d: %s" % (i, rate)) + + print("True positive rate by label:") + for i, rate in enumerate(trainingSummary.truePositiveRateByLabel): + print("label %d: %s" % (i, rate)) + + print("Precision by label:") + for i, prec in enumerate(trainingSummary.precisionByLabel): + print("label %d: %s" % (i, prec)) + + print("Recall by label:") + for i, rec in enumerate(trainingSummary.recallByLabel): + print("label %d: %s" % (i, rec)) + + print("F-measure by label:") + for i, f in enumerate(trainingSummary.fMeasureByLabel()): + print("label %d: %s" % (i, f)) + + accuracy = trainingSummary.accuracy + falsePositiveRate = trainingSummary.weightedFalsePositiveRate + truePositiveRate = trainingSummary.weightedTruePositiveRate + fMeasure = trainingSummary.weightedFMeasure() + precision = trainingSummary.weightedPrecision + recall = trainingSummary.weightedRecall + print("Accuracy: %s\nFPR: %s\nTPR: %s\nF-measure: %s\nPrecision: %s\nRecall: %s" + % (accuracy, falsePositiveRate, truePositiveRate, fMeasure, precision, recall)) # $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala index 1740a0d3f9d12..0368dcba460b5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala @@ -19,7 +19,7 @@ package org.apache.spark.examples.ml // $example on$ -import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression} +import org.apache.spark.ml.classification.LogisticRegression // $example off$ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.max @@ -47,25 +47,20 @@ object LogisticRegressionSummaryExample { // $example on$ // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier // example - val trainingSummary = lrModel.summary + val trainingSummary = lrModel.binarySummary // Obtain the objective per iteration. val objectiveHistory = trainingSummary.objectiveHistory println("objectiveHistory:") objectiveHistory.foreach(loss => println(loss)) - // Obtain the metrics useful to judge performance on test data. - // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a - // binary classification problem. - val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary] - // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. - val roc = binarySummary.roc + val roc = trainingSummary.roc roc.show() - println(s"areaUnderROC: ${binarySummary.areaUnderROC}") + println(s"areaUnderROC: ${trainingSummary.areaUnderROC}") // Set the model threshold to maximize F-Measure - val fMeasure = binarySummary.fMeasureByThreshold + val fMeasure = trainingSummary.fMeasureByThreshold val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0) val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure) .select("threshold").head().getDouble(0) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala index 3e61dbe628c20..1f7dbddd454e8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala @@ -49,6 +49,49 @@ object MulticlassLogisticRegressionWithElasticNetExample { // Print the coefficients and intercept for multinomial logistic regression println(s"Coefficients: \n${lrModel.coefficientMatrix}") println(s"Intercepts: \n${lrModel.interceptVector}") + + val trainingSummary = lrModel.summary + + // Obtain the objective per iteration + val objectiveHistory = trainingSummary.objectiveHistory + println("objectiveHistory:") + objectiveHistory.foreach(println) + + // for multiclass, we can inspect metrics on a per-label basis + println("False positive rate by label:") + trainingSummary.falsePositiveRateByLabel.zipWithIndex.foreach { case (rate, label) => + println(s"label $label: $rate") + } + + println("True positive rate by label:") + trainingSummary.truePositiveRateByLabel.zipWithIndex.foreach { case (rate, label) => + println(s"label $label: $rate") + } + + println("Precision by label:") + trainingSummary.precisionByLabel.zipWithIndex.foreach { case (prec, label) => + println(s"label $label: $prec") + } + + println("Recall by label:") + trainingSummary.recallByLabel.zipWithIndex.foreach { case (rec, label) => + println(s"label $label: $rec") + } + + + println("F-measure by label:") + trainingSummary.fMeasureByLabel.zipWithIndex.foreach { case (f, label) => + println(s"label $label: $f") + } + + val accuracy = trainingSummary.accuracy + val falsePositiveRate = trainingSummary.weightedFalsePositiveRate + val truePositiveRate = trainingSummary.weightedTruePositiveRate + val fMeasure = trainingSummary.weightedFMeasure + val precision = trainingSummary.weightedPrecision + val recall = trainingSummary.weightedRecall + println(s"Accuracy: $accuracy\nFPR: $falsePositiveRate\nTPR: $truePositiveRate\n" + + s"F-measure: $fMeasure\nPrecision: $precision\nRecall: $recall") // $example off$ spark.stop() From 0a9ac0248b6514a1e83ff7e4c522424f01b8b78d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 30 Jan 2018 19:43:17 +0800 Subject: [PATCH 237/774] [SPARK-23260][SPARK-23262][SQL] several data source v2 naming cleanup ## What changes were proposed in this pull request? All other classes in the reader/writer package doesn't have `V2` in their names, and the streaming reader/writer don't have `V2` either. It's more consistent to remove `V2` from `DataSourceV2Reader` and `DataSourceVWriter`. Also rename `DataSourceV2Option` to remote the `V2`, we should only have `V2` in the root interface: `DataSourceV2`. This PR also fixes some places that the mix-in interface doesn't extend the interface it aimed to mix in. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #20427 from cloud-fan/ds-v2. --- .../sql/kafka010/KafkaContinuousReader.scala | 2 +- .../sql/kafka010/KafkaSourceProvider.scala | 6 ++--- ...eV2Options.java => DataSourceOptions.java} | 8 +++---- .../spark/sql/sources/v2/ReadSupport.java | 8 +++---- .../sql/sources/v2/ReadSupportWithSchema.java | 8 +++---- .../sql/sources/v2/SessionConfigSupport.java | 2 +- .../spark/sql/sources/v2/WriteSupport.java | 12 +++++----- .../sources/v2/reader/DataReaderFactory.java | 2 +- ...rceV2Reader.java => DataSourceReader.java} | 11 +++++---- .../SupportsPushDownCatalystFilters.java | 4 ++-- .../v2/reader/SupportsPushDownFilters.java | 4 ++-- .../SupportsPushDownRequiredColumns.java | 6 ++--- .../v2/reader/SupportsReportPartitioning.java | 4 ++-- .../v2/reader/SupportsReportStatistics.java | 4 ++-- .../v2/reader/SupportsScanColumnarBatch.java | 6 ++--- .../v2/reader/SupportsScanUnsafeRow.java | 6 ++--- .../v2/streaming/ContinuousReadSupport.java | 4 ++-- .../v2/streaming/MicroBatchReadSupport.java | 4 ++-- .../v2/streaming/StreamWriteSupport.java | 10 ++++---- .../v2/streaming/reader/ContinuousReader.java | 6 ++--- .../v2/streaming/reader/MicroBatchReader.java | 6 ++--- .../v2/streaming/writer/StreamWriter.java | 6 ++--- ...rceV2Writer.java => DataSourceWriter.java} | 8 +++---- .../sql/sources/v2/writer/DataWriter.java | 12 +++++----- .../sources/v2/writer/DataWriterFactory.java | 2 +- .../v2/writer/SupportsWriteInternalRow.java | 4 ++-- .../v2/writer/WriterCommitMessage.java | 4 ++-- .../apache/spark/sql/DataFrameReader.scala | 2 +- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../v2/DataSourceReaderHolder.scala | 2 +- .../datasources/v2/DataSourceV2Relation.scala | 6 ++--- .../datasources/v2/DataSourceV2ScanExec.scala | 2 +- .../datasources/v2/WriteToDataSourceV2.scala | 4 ++-- .../streaming/MicroBatchExecution.scala | 6 ++--- .../streaming/RateSourceProvider.scala | 2 +- .../sql/execution/streaming/console.scala | 4 ++-- .../continuous/ContinuousExecution.scala | 6 ++--- .../ContinuousRateStreamSource.scala | 7 +++--- .../streaming/sources/ConsoleWriter.scala | 4 ++-- .../streaming/sources/MicroBatchWriter.scala | 8 +++---- .../sources/PackedRowWriterFactory.scala | 4 ++-- .../sources/RateStreamSourceV2.scala | 6 ++--- .../streaming/sources/memoryV2.scala | 6 ++--- .../sql/streaming/DataStreamReader.scala | 4 ++-- .../sources/v2/JavaAdvancedDataSourceV2.java | 6 ++--- .../sql/sources/v2/JavaBatchDataSourceV2.java | 6 ++--- .../v2/JavaPartitionAwareDataSource.java | 6 ++--- .../v2/JavaSchemaRequiredDataSource.java | 8 +++---- .../sources/v2/JavaSimpleDataSourceV2.java | 8 +++---- .../sources/v2/JavaUnsafeRowDataSourceV2.java | 6 ++--- .../streaming/RateSourceV2Suite.scala | 18 +++++++------- ...ite.scala => DataSourceOptionsSuite.scala} | 16 ++++++------- .../sql/sources/v2/DataSourceV2Suite.scala | 24 +++++++++---------- .../sources/v2/SimpleWritableDataSource.scala | 12 +++++----- .../sources/StreamingDataSourceV2Suite.scala | 8 +++---- 55 files changed, 176 insertions(+), 176 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{DataSourceV2Options.java => DataSourceOptions.java} (94%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{DataSourceV2Reader.java => DataSourceReader.java} (91%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/{DataSourceV2Writer.java => DataSourceWriter.java} (96%) rename sql/core/src/test/scala/org/apache/spark/sql/sources/v2/{DataSourceV2OptionsSuite.scala => DataSourceOptionsSuite.scala} (80%) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 9125cf5799d74..8c733426b256f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -41,7 +41,7 @@ import org.apache.spark.unsafe.types.UTF8String * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be * read by per-task consumers generated later. * @param kafkaParams String params for per-task Kafka consumers. - * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceV2Options]] params which + * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceOptions]] params which * are not Kafka consumer params. * @param metadataPath Path to a directory this reader can use for writing metadata. * @param initialOffsets The Kafka offsets to start reading data at. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 2deb7fa2cdf1e..85e96b6783327 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -30,7 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.streaming.OutputMode @@ -109,7 +109,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister override def createContinuousReader( schema: Optional[StructType], metadataPath: String, - options: DataSourceV2Options): KafkaContinuousReader = { + options: DataSourceOptions): KafkaContinuousReader = { val parameters = options.asMap().asScala.toMap validateStreamOptions(parameters) // Each running query should use its own group id. Otherwise, the query may be only assigned @@ -227,7 +227,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): StreamWriter = { + options: DataSourceOptions): StreamWriter = { import scala.collection.JavaConverters._ val spark = SparkSession.getActiveSession.get diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java similarity index 94% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java index ddc2acca693ac..c32053580f016 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java @@ -29,18 +29,18 @@ * data source options. */ @InterfaceStability.Evolving -public class DataSourceV2Options { +public class DataSourceOptions { private final Map keyLowerCasedMap; private String toLowerCase(String key) { return key.toLowerCase(Locale.ROOT); } - public static DataSourceV2Options empty() { - return new DataSourceV2Options(new HashMap<>()); + public static DataSourceOptions empty() { + return new DataSourceOptions(new HashMap<>()); } - public DataSourceV2Options(Map originalMap) { + public DataSourceOptions(Map originalMap) { keyLowerCasedMap = new HashMap<>(originalMap.size()); for (Map.Entry entry : originalMap.entrySet()) { keyLowerCasedMap.put(toLowerCase(entry.getKey()), entry.getValue()); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java index 948e20bacf4a2..0ea4dc6b5def3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java @@ -18,17 +18,17 @@ package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to * provide data reading ability and scan the data from the data source. */ @InterfaceStability.Evolving -public interface ReadSupport { +public interface ReadSupport extends DataSourceV2 { /** - * Creates a {@link DataSourceV2Reader} to scan the data from this data source. + * Creates a {@link DataSourceReader} to scan the data from this data source. * * If this method fails (by throwing an exception), the action would fail and no Spark job was * submitted. @@ -36,5 +36,5 @@ public interface ReadSupport { * @param options the options for the returned data source reader, which is an immutable * case-insensitive string-to-string map. */ - DataSourceV2Reader createReader(DataSourceV2Options options); + DataSourceReader createReader(DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java index b69c6bed8d1b5..3801402268af1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.types.StructType; /** @@ -30,10 +30,10 @@ * supports both schema inference and user-specified schema. */ @InterfaceStability.Evolving -public interface ReadSupportWithSchema { +public interface ReadSupportWithSchema extends DataSourceV2 { /** - * Create a {@link DataSourceV2Reader} to scan the data from this data source. + * Create a {@link DataSourceReader} to scan the data from this data source. * * If this method fails (by throwing an exception), the action would fail and no Spark job was * submitted. @@ -45,5 +45,5 @@ public interface ReadSupportWithSchema { * @param options the options for the returned data source reader, which is an immutable * case-insensitive string-to-string map. */ - DataSourceV2Reader createReader(StructType schema, DataSourceV2Options options); + DataSourceReader createReader(StructType schema, DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java index 3cb020d2e0836..9d66805d79b9e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java @@ -25,7 +25,7 @@ * session. */ @InterfaceStability.Evolving -public interface SessionConfigSupport { +public interface SessionConfigSupport extends DataSourceV2 { /** * Key prefix of the session configs to propagate. Spark will extract all session configs that diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java index 1e3b644d8c4ae..cab56453816cc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java @@ -21,7 +21,7 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.SaveMode; -import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; import org.apache.spark.sql.types.StructType; /** @@ -29,17 +29,17 @@ * provide data writing ability and save the data to the data source. */ @InterfaceStability.Evolving -public interface WriteSupport { +public interface WriteSupport extends DataSourceV2 { /** - * Creates an optional {@link DataSourceV2Writer} to save the data to this data source. Data + * Creates an optional {@link DataSourceWriter} to save the data to this data source. Data * sources can return None if there is no writing needed to be done according to the save mode. * * If this method fails (by throwing an exception), the action would fail and no Spark job was * submitted. * * @param jobId A unique string for the writing job. It's possible that there are many writing - * jobs running at the same time, and the returned {@link DataSourceV2Writer} can + * jobs running at the same time, and the returned {@link DataSourceWriter} can * use this job id to distinguish itself from other jobs. * @param schema the schema of the data to be written. * @param mode the save mode which determines what to do when the data are already in this data @@ -47,6 +47,6 @@ public interface WriteSupport { * @param options the options for the returned data source writer, which is an immutable * case-insensitive string-to-string map. */ - Optional createWriter( - String jobId, StructType schema, SaveMode mode, DataSourceV2Options options); + Optional createWriter( + String jobId, StructType schema, SaveMode mode, DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java index 077b95b837964..32e98e8f5d8bd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A reader factory returned by {@link DataSourceV2Reader#createDataReaderFactories()} and is + * A reader factory returned by {@link DataSourceReader#createDataReaderFactories()} and is * responsible for creating the actual data reader. The relationship between * {@link DataReaderFactory} and {@link DataReader} * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java similarity index 91% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java index 0180cd9ea47f8..a470bccc5aad2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java @@ -21,14 +21,15 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.Row; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; import org.apache.spark.sql.types.StructType; /** * A data source reader that is returned by - * {@link org.apache.spark.sql.sources.v2.ReadSupport#createReader( - * org.apache.spark.sql.sources.v2.DataSourceV2Options)} or - * {@link org.apache.spark.sql.sources.v2.ReadSupportWithSchema#createReader( - * StructType, org.apache.spark.sql.sources.v2.DataSourceV2Options)}. + * {@link ReadSupport#createReader(DataSourceOptions)} or + * {@link ReadSupportWithSchema#createReader(StructType, DataSourceOptions)}. * It can mix in various query optimization interfaces to speed up the data scan. The actual scan * logic is delegated to {@link DataReaderFactory}s that are returned by * {@link #createDataReaderFactories()}. @@ -52,7 +53,7 @@ * issues the scan request and does the actual data reading. */ @InterfaceStability.Evolving -public interface DataSourceV2Reader { +public interface DataSourceReader { /** * Returns the actual schema of this data source reader, which may be different from the physical diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index f76c687f450c8..98224102374aa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to push down arbitrary expressions as predicates to the data source. * This is an experimental and unstable interface as {@link Expression} is not public and may get * changed in the future Spark versions. @@ -31,7 +31,7 @@ * process this interface. */ @InterfaceStability.Unstable -public interface SupportsPushDownCatalystFilters { +public interface SupportsPushDownCatalystFilters extends DataSourceReader { /** * Pushes down filters, and returns unsupported filters. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index 6b0c9d417eeae..f35c711b0387a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -21,7 +21,7 @@ import org.apache.spark.sql.sources.Filter; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to push down filters to the data source and reduce the size of the data to be read. * * Note that, if data source readers implement both this interface and @@ -29,7 +29,7 @@ * {@link SupportsPushDownCatalystFilters}. */ @InterfaceStability.Evolving -public interface SupportsPushDownFilters { +public interface SupportsPushDownFilters extends DataSourceReader { /** * Pushes down filters, and returns unsupported filters. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java index fe0ac8ee0ee32..427b4d00a1128 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java @@ -21,12 +21,12 @@ import org.apache.spark.sql.types.StructType; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to push down required columns to the data source and only read these columns during * scan to reduce the size of the data to be read. */ @InterfaceStability.Evolving -public interface SupportsPushDownRequiredColumns { +public interface SupportsPushDownRequiredColumns extends DataSourceReader { /** * Applies column pruning w.r.t. the given requiredSchema. @@ -35,7 +35,7 @@ public interface SupportsPushDownRequiredColumns { * also OK to do the pruning partially, e.g., a data source may not be able to prune nested * fields, and only prune top-level columns. * - * Note that, data source readers should update {@link DataSourceV2Reader#readSchema()} after + * Note that, data source readers should update {@link DataSourceReader#readSchema()} after * applying column pruning. */ void pruneColumns(StructType requiredSchema); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index f786472ccf345..a2383a9d7d680 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -20,11 +20,11 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A mix in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report data partitioning and try to avoid shuffle at Spark side. */ @InterfaceStability.Evolving -public interface SupportsReportPartitioning { +public interface SupportsReportPartitioning extends DataSourceReader { /** * Returns the output data partitioning that this reader guarantees. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index c019d2f819ab7..11bb13fd3b211 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -20,11 +20,11 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A mix in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report statistics to Spark. */ @InterfaceStability.Evolving -public interface SupportsReportStatistics { +public interface SupportsReportStatistics extends DataSourceReader { /** * Returns the basic statistics of this data source. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java index 67da55554bbf3..2e5cfa78511f0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java @@ -24,11 +24,11 @@ import org.apache.spark.sql.vectorized.ColumnarBatch; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to output {@link ColumnarBatch} and make the scan faster. */ @InterfaceStability.Evolving -public interface SupportsScanColumnarBatch extends DataSourceV2Reader { +public interface SupportsScanColumnarBatch extends DataSourceReader { @Override default List> createDataReaderFactories() { throw new IllegalStateException( @@ -36,7 +36,7 @@ default List> createDataReaderFactories() { } /** - * Similar to {@link DataSourceV2Reader#createDataReaderFactories()}, but returns columnar data + * Similar to {@link DataSourceReader#createDataReaderFactories()}, but returns columnar data * in batches. */ List> createBatchDataReaderFactories(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java index 156af69520f77..9cd749e8e4ce9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java @@ -24,13 +24,13 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to output {@link UnsafeRow} directly and avoid the row copy at Spark side. * This is an experimental and unstable interface, as {@link UnsafeRow} is not public and may get * changed in the future Spark versions. */ @InterfaceStability.Unstable -public interface SupportsScanUnsafeRow extends DataSourceV2Reader { +public interface SupportsScanUnsafeRow extends DataSourceReader { @Override default List> createDataReaderFactories() { @@ -39,7 +39,7 @@ default List> createDataReaderFactories() { } /** - * Similar to {@link DataSourceV2Reader#createDataReaderFactories()}, + * Similar to {@link DataSourceReader#createDataReaderFactories()}, * but returns data in unsafe row format. */ List> createUnsafeRowReaderFactories(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java index 9a93a806b0efc..f79424e036a52 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java @@ -21,7 +21,7 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader; import org.apache.spark.sql.types.StructType; @@ -44,5 +44,5 @@ public interface ContinuousReadSupport extends DataSourceV2 { ContinuousReader createContinuousReader( Optional schema, String checkpointLocation, - DataSourceV2Options options); + DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java index 3b357c01a29fe..22660e42ad850 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java @@ -20,8 +20,8 @@ import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.streaming.reader.MicroBatchReader; import org.apache.spark.sql.types.StructType; @@ -50,5 +50,5 @@ public interface MicroBatchReadSupport extends DataSourceV2 { MicroBatchReader createMicroBatchReader( Optional schema, String checkpointLocation, - DataSourceV2Options options); + DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java index 6cd219c67109a..7c5f304425093 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java @@ -19,10 +19,10 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter; -import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; @@ -31,7 +31,7 @@ * provide data writing ability for structured streaming. */ @InterfaceStability.Evolving -public interface StreamWriteSupport extends BaseStreamingSink { +public interface StreamWriteSupport extends DataSourceV2, BaseStreamingSink { /** * Creates an optional {@link StreamWriter} to save the data to this data source. Data @@ -39,7 +39,7 @@ public interface StreamWriteSupport extends BaseStreamingSink { * * @param queryId A unique string for the writing query. It's possible that there are many * writing queries running at the same time, and the returned - * {@link DataSourceV2Writer} can use this id to distinguish itself from others. + * {@link DataSourceWriter} can use this id to distinguish itself from others. * @param schema the schema of the data to be written. * @param mode the output mode which determines what successive epoch output means to this * sink, please refer to {@link OutputMode} for more details. @@ -50,5 +50,5 @@ StreamWriter createStreamWriter( String queryId, StructType schema, OutputMode mode, - DataSourceV2Options options); + DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java index 3ac979cb0b7b4..6e5177ee83a62 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java @@ -19,12 +19,12 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import java.util.Optional; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to allow reading in a continuous processing mode stream. * * Implementations must ensure each reader factory output is a {@link ContinuousDataReader}. @@ -33,7 +33,7 @@ * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. */ @InterfaceStability.Evolving -public interface ContinuousReader extends BaseStreamingSource, DataSourceV2Reader { +public interface ContinuousReader extends BaseStreamingSource, DataSourceReader { /** * Merge partitioned offsets coming from {@link ContinuousDataReader} instances for each * partition to a single global offset. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java index 68887e569fc1d..fcec446d892f5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java @@ -18,20 +18,20 @@ package org.apache.spark.sql.sources.v2.streaming.reader; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; import java.util.Optional; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to indicate they allow micro-batch streaming reads. * * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. */ @InterfaceStability.Evolving -public interface MicroBatchReader extends DataSourceV2Reader, BaseStreamingSource { +public interface MicroBatchReader extends DataSourceReader, BaseStreamingSource { /** * Set the desired offset range for reader factories created from this reader. Reader factories * will generate only data within (`start`, `end`]; that is, from the first record after `start` diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java index 3156c88933e5e..915ee6c4fb390 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java @@ -18,19 +18,19 @@ package org.apache.spark.sql.sources.v2.streaming.writer; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; import org.apache.spark.sql.sources.v2.writer.DataWriter; import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; /** - * A {@link DataSourceV2Writer} for use with structured streaming. This writer handles commits and + * A {@link DataSourceWriter} for use with structured streaming. This writer handles commits and * aborts relative to an epoch ID determined by the execution engine. * * {@link DataWriter} implementations generated by a StreamWriter may be reused for multiple epochs, * and so must reset any internal state after a successful commit. */ @InterfaceStability.Evolving -public interface StreamWriter extends DataSourceV2Writer { +public interface StreamWriter extends DataSourceWriter { /** * Commits this writing job for the specified epoch with a list of commit messages. The commit * messages are collected from successful data writers and are produced by diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java similarity index 96% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 8048f507a1dca..d89d27d0e5b1b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -20,16 +20,16 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.Row; import org.apache.spark.sql.SaveMode; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.WriteSupport; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; /** * A data source writer that is returned by - * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}/ + * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceOptions)}/ * {@link org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport#createStreamWriter( - * String, StructType, OutputMode, DataSourceV2Options)}. + * String, StructType, OutputMode, DataSourceOptions)}. * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. * @@ -52,7 +52,7 @@ * Please refer to the documentation of commit/abort methods for detailed specifications. */ @InterfaceStability.Evolving -public interface DataSourceV2Writer { +public interface DataSourceWriter { /** * Creates a writer factory which will be serialized and sent to executors. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 04b03e63de500..53941a89ba94e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -33,11 +33,11 @@ * * If this data writer succeeds(all records are successfully written and {@link #commit()} * succeeds), a {@link WriterCommitMessage} will be sent to the driver side and pass to - * {@link DataSourceV2Writer#commit(WriterCommitMessage[])} with commit messages from other data + * {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an * exception will be sent to the driver side, and Spark will retry this writing task for some times, * each time {@link DataWriterFactory#createDataWriter(int, int)} gets a different `attemptNumber`, - * and finally call {@link DataSourceV2Writer#abort(WriterCommitMessage[])} if all retry fail. + * and finally call {@link DataSourceWriter#abort(WriterCommitMessage[])} if all retry fail. * * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task * takes too long to finish. Different from retried tasks, which are launched one by one after the @@ -69,11 +69,11 @@ public interface DataWriter { /** * Commits this writer after all records are written successfully, returns a commit message which * will be sent back to driver side and passed to - * {@link DataSourceV2Writer#commit(WriterCommitMessage[])}. + * {@link DataSourceWriter#commit(WriterCommitMessage[])}. * * The written data should only be visible to data source readers after - * {@link DataSourceV2Writer#commit(WriterCommitMessage[])} succeeds, which means this method - * should still "hide" the written data and ask the {@link DataSourceV2Writer} at driver side to + * {@link DataSourceWriter#commit(WriterCommitMessage[])} succeeds, which means this method + * should still "hide" the written data and ask the {@link DataSourceWriter} at driver side to * do the final commit via {@link WriterCommitMessage}. * * If this method fails (by throwing an exception), {@link #abort()} will be called and this @@ -91,7 +91,7 @@ public interface DataWriter { * failed. * * If this method fails(by throwing an exception), the underlying data source may have garbage - * that need to be cleaned by {@link DataSourceV2Writer#abort(WriterCommitMessage[])} or manually, + * that need to be cleaned by {@link DataSourceWriter#abort(WriterCommitMessage[])} or manually, * but these garbage should not be visible to data source readers. * * @throws IOException if failure happens during disk/network IO like writing files. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index 18ec792f5a2c9..ea95442511ce5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A factory of {@link DataWriter} returned by {@link DataSourceV2Writer#createWriterFactory()}, + * A factory of {@link DataWriter} returned by {@link DataSourceWriter#createWriterFactory()}, * which is responsible for creating and initializing the actual data writer at executor side. * * Note that, the writer factory will be serialized and sent to executors, then the data writer diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java index 3e0518814f458..d2cf7e01c08c8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java @@ -22,14 +22,14 @@ import org.apache.spark.sql.catalyst.InternalRow; /** - * A mix-in interface for {@link DataSourceV2Writer}. Data source writers can implement this + * A mix-in interface for {@link DataSourceWriter}. Data source writers can implement this * interface to write {@link InternalRow} directly and avoid the row conversion at Spark side. * This is an experimental and unstable interface, as {@link InternalRow} is not public and may get * changed in the future Spark versions. */ @InterfaceStability.Unstable -public interface SupportsWriteInternalRow extends DataSourceV2Writer { +public interface SupportsWriteInternalRow extends DataSourceWriter { @Override default DataWriterFactory createWriterFactory() { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java index 082d6b5dc409f..9e38836c0edf9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java @@ -23,10 +23,10 @@ /** * A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side - * as the input parameter of {@link DataSourceV2Writer#commit(WriterCommitMessage[])}. + * as the input parameter of {@link DataSourceWriter#commit(WriterCommitMessage[])}. * * This is an empty interface, data sources should define their own message class and use it in - * their {@link DataWriter#commit()} and {@link DataSourceV2Writer#commit(WriterCommitMessage[])} + * their {@link DataWriter#commit()} and {@link DataSourceWriter#commit(WriterCommitMessage[])} * implementations. */ @InterfaceStability.Evolving diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b714a46b5f786..46b5f54a33f74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -186,7 +186,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { val ds = cls.newInstance() - val options = new DataSourceV2Options((extraOptions ++ + val options = new DataSourceOptions((extraOptions ++ DataSourceV2Utils.extractSessionConfigs( ds = ds.asInstanceOf[DataSourceV2], conf = sparkSession.sessionState.conf)).asJava) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 5c02eae05304b..ed7a9100cc7f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -243,7 +243,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val ds = cls.newInstance() ds match { case ws: WriteSupport => - val options = new DataSourceV2Options((extraOptions ++ + val options = new DataSourceOptions((extraOptions ++ DataSourceV2Utils.extractSessionConfigs( ds = ds.asInstanceOf[DataSourceV2], conf = df.sparkSession.sessionState.conf)).asJava) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala index 6093df26630cd..6460c97abe344 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala @@ -35,7 +35,7 @@ trait DataSourceReaderHolder { /** * The held data source reader. */ - def reader: DataSourceV2Reader + def reader: DataSourceReader /** * The metadata of this data source reader that can be used for equality test. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index cba20dd902007..3d4c64981373d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( fullOutput: Seq[AttributeReference], - reader: DataSourceV2Reader) extends LeafNode with DataSourceReaderHolder { + reader: DataSourceReader) extends LeafNode with DataSourceReaderHolder { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] @@ -41,12 +41,12 @@ case class DataSourceV2Relation( */ class StreamingDataSourceV2Relation( fullOutput: Seq[AttributeReference], - reader: DataSourceV2Reader) extends DataSourceV2Relation(fullOutput, reader) { + reader: DataSourceReader) extends DataSourceV2Relation(fullOutput, reader) { override def isStreaming: Boolean = true } object DataSourceV2Relation { - def apply(reader: DataSourceV2Reader): DataSourceV2Relation = { + def apply(reader: DataSourceReader): DataSourceV2Relation = { new DataSourceV2Relation(reader.readSchema().toAttributes, reader) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 3f808fbb40932..ee085820b0775 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.StructType */ case class DataSourceV2ScanExec( fullOutput: Seq[AttributeReference], - @transient reader: DataSourceV2Reader) + @transient reader: DataSourceReader) extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index cd6b3e99b6bcb..c544adbf32cdf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -35,7 +35,7 @@ import org.apache.spark.util.Utils /** * The logical plan for writing data into data source v2. */ -case class WriteToDataSourceV2(writer: DataSourceV2Writer, query: LogicalPlan) extends LogicalPlan { +case class WriteToDataSourceV2(writer: DataSourceWriter, query: LogicalPlan) extends LogicalPlan { override def children: Seq[LogicalPlan] = Seq(query) override def output: Seq[Attribute] = Nil } @@ -43,7 +43,7 @@ case class WriteToDataSourceV2(writer: DataSourceV2Writer, query: LogicalPlan) e /** * The physical plan for writing data into data source v2. */ -case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) extends SparkPlan { +case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) extends SparkPlan { override def children: Seq[SparkPlan] = Seq(query) override def output: Seq[Attribute] = Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 975975243a3d1..93572f7a63132 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.streaming.{MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow @@ -89,7 +89,7 @@ class MicroBatchExecution( val reader = source.createMicroBatchReader( Optional.empty(), // user specified schema metadataPath, - new DataSourceV2Options(options.asJava)) + new DataSourceOptions(options.asJava)) nextSourceId += 1 StreamingExecutionRelation(reader, output)(sparkSession) }) @@ -447,7 +447,7 @@ class MicroBatchExecution( s"$runId", newAttributePlan.schema, outputMode, - new DataSourceV2Options(extraOptions.asJava)) + new DataSourceOptions(extraOptions.asJava)) if (writer.isInstanceOf[SupportsWriteInternalRow]) { WriteToDataSourceV2( new InternalRowMicroBatchWriter(currentBatchId, writer), newAttributePlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index 66eb0169ac1ec..5e3fee633f591 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -111,7 +111,7 @@ class RateSourceProvider extends StreamSourceProvider with DataSourceRegister override def createContinuousReader( schema: Optional[StructType], checkpointLocation: String, - options: DataSourceV2Options): ContinuousReader = { + options: DataSourceOptions): ContinuousReader = { new RateStreamContinuousReader(options) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index d5ac0bd1df52b..3f5bb489d6528 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.streaming.OutputMode @@ -40,7 +40,7 @@ class ConsoleSinkProvider extends DataSourceV2 queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): StreamWriter = { + options: DataSourceOptions): StreamWriter = { new ConsoleWriter(schema, options) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 60f880f9c73b8..9402d7c1dcefd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} @@ -160,7 +160,7 @@ class ContinuousExecution( dataSource.createContinuousReader( java.util.Optional.empty[StructType](), metadataPath, - new DataSourceV2Options(extraReaderOptions.asJava)) + new DataSourceOptions(extraReaderOptions.asJava)) } uniqueSources = continuousSources.distinct @@ -198,7 +198,7 @@ class ContinuousExecution( s"$runId", triggerLogicalPlan.schema, outputMode, - new DataSourceV2Options(extraOptions.asJava)) + new DataSourceOptions(extraOptions.asJava)) val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 61304480f4721..ff028ebc4236a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -23,19 +23,18 @@ import org.json4s.DefaultFormats import org.json4s.jackson.Serialization import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} +import org.apache.spark.sql.types.StructType case class RateStreamPartitionOffset( partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset -class RateStreamContinuousReader(options: DataSourceV2Options) +class RateStreamContinuousReader(options: DataSourceOptions) extends ContinuousReader { implicit val defaultFormats: DefaultFormats = DefaultFormats diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index 7c1700f1de48c..d46f4d7b86360 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.types.StructType /** Common methods used to create writes for the the console sink */ -class ConsoleWriter(schema: StructType, options: DataSourceV2Options) +class ConsoleWriter(schema: StructType, options: DataSourceOptions) extends StreamWriter with Logging { // Number of rows to display, by default 20 rows diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala index d7f3ba8856982..d7ce9a7b84479 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter -import org.apache.spark.sql.sources.v2.writer.{DataSourceV2Writer, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} /** - * A [[DataSourceV2Writer]] used to hook V2 stream writers into a microbatch plan. It implements + * A [[DataSourceWriter]] used to hook V2 stream writers into a microbatch plan. It implements * the non-streaming interface, forwarding the batch ID determined at construction to a wrapped * streaming writer. */ -class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceV2Writer { +class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceWriter { override def commit(messages: Array[WriterCommitMessage]): Unit = { writer.commit(batchId, messages) } @@ -38,7 +38,7 @@ class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceV2 } class InternalRowMicroBatchWriter(batchId: Long, writer: StreamWriter) - extends DataSourceV2Writer with SupportsWriteInternalRow { + extends DataSourceWriter with SupportsWriteInternalRow { override def commit(messages: Array[WriterCommitMessage]): Unit = { writer.commit(batchId, messages) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index 9282ba05bdb7b..248295e401a0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -21,11 +21,11 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.Row -import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage} /** * A simple [[DataWriterFactory]] whose tasks just pack rows into the commit message for delivery - * to a [[org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer]] on the driver. + * to a [[DataSourceWriter]] on the driver. * * Note that, because it sends all rows to the driver, this factory will generally be unsuitable * for production-quality sinks. It's intended for use in tests. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index a25cc4f3b06f8..43949e6180aaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.streaming.MicroBatchReadSupport import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset} @@ -44,14 +44,14 @@ class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with override def createMicroBatchReader( schema: Optional[StructType], checkpointLocation: String, - options: DataSourceV2Options): MicroBatchReader = { + options: DataSourceOptions): MicroBatchReader = { new RateStreamMicroBatchReader(options) } override def shortName(): String = "ratev2" } -class RateStreamMicroBatchReader(options: DataSourceV2Options) +class RateStreamMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader { implicit val defaultFormats: DefaultFormats = DefaultFormats diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index ce55e44d932bd..58767261dc684 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.Sink -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ @@ -45,7 +45,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): StreamWriter = { + options: DataSourceOptions): StreamWriter = { new MemoryStreamWriter(this, mode) } @@ -114,7 +114,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) - extends DataSourceV2Writer with Logging { + extends DataSourceWriter with Logging { override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 9f5ca9f914284..f1b3f93c4e1fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -158,7 +158,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf).newInstance() - val options = new DataSourceV2Options(extraOptions.asJava) + val options = new DataSourceOptions(extraOptions.asJava) // We need to generate the V1 data source so we can pass it to the V2 relation as a shim. // We can't be sure at this point whether we'll actually want to use V2, since we don't know the // writer or whether the query is continuous. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 4026ee44bfdb7..d421f7d19563f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -24,15 +24,15 @@ import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceV2Reader, SupportsPushDownRequiredColumns, + class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, SupportsPushDownFilters { private StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); @@ -131,7 +131,7 @@ public void close() throws IOException { @Override - public DataSourceV2Reader createReader(DataSourceV2Options options) { + public DataSourceReader createReader(DataSourceOptions options) { return new Reader(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java index 34e6c63801064..c55093768105b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -21,8 +21,8 @@ import java.util.List; import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.DataTypes; @@ -33,7 +33,7 @@ public class JavaBatchDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceV2Reader, SupportsScanColumnarBatch { + class Reader implements DataSourceReader, SupportsScanColumnarBatch { private final StructType schema = new StructType().add("i", "int").add("j", "int"); @Override @@ -108,7 +108,7 @@ public void close() throws IOException { @Override - public DataSourceV2Reader createReader(DataSourceV2Options options) { + public DataSourceReader createReader(DataSourceOptions options) { return new Reader(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index d0c87503ab455..99cca0f6dd626 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -23,15 +23,15 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { - class Reader implements DataSourceV2Reader, SupportsReportPartitioning { + class Reader implements DataSourceReader, SupportsReportPartitioning { private final StructType schema = new StructType().add("a", "int").add("b", "int"); @Override @@ -104,7 +104,7 @@ public DataReader createDataReader() { } @Override - public DataSourceV2Reader createReader(DataSourceV2Options options) { + public DataSourceReader createReader(DataSourceOptions options) { return new Reader(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index f997366af1a64..048d078dfaac4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -20,16 +20,16 @@ import java.util.List; import org.apache.spark.sql.Row; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; import org.apache.spark.sql.types.StructType; public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema { - class Reader implements DataSourceV2Reader { + class Reader implements DataSourceReader { private final StructType schema; Reader(StructType schema) { @@ -48,7 +48,7 @@ public List> createDataReaderFactories() { } @Override - public DataSourceV2Reader createReader(StructType schema, DataSourceV2Options options) { + public DataSourceReader createReader(StructType schema, DataSourceOptions options) { return new Reader(schema); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 2beed431d301f..96f55b8a76811 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -23,16 +23,16 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.DataReader; import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.types.StructType; public class JavaSimpleDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceV2Reader { + class Reader implements DataSourceReader { private final StructType schema = new StructType().add("i", "int").add("j", "int"); @Override @@ -80,7 +80,7 @@ public void close() throws IOException { } @Override - public DataSourceV2Reader createReader(DataSourceV2Options options) { + public DataSourceReader createReader(DataSourceOptions options) { return new Reader(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java index e8187524ea871..c3916e0b370b5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java @@ -21,15 +21,15 @@ import java.util.List; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; public class JavaUnsafeRowDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceV2Reader, SupportsScanUnsafeRow { + class Reader implements DataSourceReader, SupportsScanUnsafeRow { private final StructType schema = new StructType().add("i", "int").add("j", "int"); @Override @@ -83,7 +83,7 @@ public void close() throws IOException { } @Override - public DataSourceV2Reader createReader(DataSourceV2Options options) { + public DataSourceReader createReader(DataSourceOptions options) { return new Reader(); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index d2cfe7905f6fa..b060aeeef811d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock @@ -49,7 +49,7 @@ class RateSourceV2Suite extends StreamTest { test("microbatch in registry") { DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match { case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceV2Options.empty()) + val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceOptions.empty()) assert(reader.isInstanceOf[RateStreamMicroBatchReader]) case _ => throw new IllegalStateException("Could not find v2 read support for rate") @@ -76,14 +76,14 @@ class RateSourceV2Suite extends StreamTest { test("microbatch - numPartitions propagated") { val reader = new RateStreamMicroBatchReader( - new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) reader.setOffsetRange(Optional.empty(), Optional.empty()) val tasks = reader.createDataReaderFactories() assert(tasks.size == 11) } test("microbatch - set offset") { - val reader = new RateStreamMicroBatchReader(DataSourceV2Options.empty()) + val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty()) val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000)))) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) @@ -93,7 +93,7 @@ class RateSourceV2Suite extends StreamTest { test("microbatch - infer offsets") { val reader = new RateStreamMicroBatchReader( - new DataSourceV2Options(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100) reader.setOffsetRange(Optional.empty(), Optional.empty()) reader.getStartOffset() match { @@ -114,7 +114,7 @@ class RateSourceV2Suite extends StreamTest { test("microbatch - predetermined batch size") { val reader = new RateStreamMicroBatchReader( - new DataSourceV2Options(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) @@ -125,7 +125,7 @@ class RateSourceV2Suite extends StreamTest { test("microbatch - data read") { val reader = new RateStreamMicroBatchReader( - new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => @@ -150,7 +150,7 @@ class RateSourceV2Suite extends StreamTest { test("continuous in registry") { DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { case ds: ContinuousReadSupport => - val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceV2Options.empty()) + val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) assert(reader.isInstanceOf[RateStreamContinuousReader]) case _ => throw new IllegalStateException("Could not find v2 read support for rate") @@ -159,7 +159,7 @@ class RateSourceV2Suite extends StreamTest { test("continuous data") { val reader = new RateStreamContinuousReader( - new DataSourceV2Options(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) + new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) reader.setOffset(Optional.empty()) val tasks = reader.createDataReaderFactories() assert(tasks.size == 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala similarity index 80% rename from sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala index 90d92864b26fa..31dfc55b23361 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala @@ -22,24 +22,24 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite /** - * A simple test suite to verify `DataSourceV2Options`. + * A simple test suite to verify `DataSourceOptions`. */ -class DataSourceV2OptionsSuite extends SparkFunSuite { +class DataSourceOptionsSuite extends SparkFunSuite { test("key is case-insensitive") { - val options = new DataSourceV2Options(Map("foo" -> "bar").asJava) + val options = new DataSourceOptions(Map("foo" -> "bar").asJava) assert(options.get("foo").get() == "bar") assert(options.get("FoO").get() == "bar") assert(!options.get("abc").isPresent) } test("value is case-sensitive") { - val options = new DataSourceV2Options(Map("foo" -> "bAr").asJava) + val options = new DataSourceOptions(Map("foo" -> "bAr").asJava) assert(options.get("foo").get == "bAr") } test("getInt") { - val options = new DataSourceV2Options(Map("numFOo" -> "1", "foo" -> "bar").asJava) + val options = new DataSourceOptions(Map("numFOo" -> "1", "foo" -> "bar").asJava) assert(options.getInt("numFOO", 10) == 1) assert(options.getInt("numFOO2", 10) == 10) @@ -49,7 +49,7 @@ class DataSourceV2OptionsSuite extends SparkFunSuite { } test("getBoolean") { - val options = new DataSourceV2Options( + val options = new DataSourceOptions( Map("isFoo" -> "true", "isFOO2" -> "false", "foo" -> "bar").asJava) assert(options.getBoolean("isFoo", false)) assert(!options.getBoolean("isFoo2", true)) @@ -59,7 +59,7 @@ class DataSourceV2OptionsSuite extends SparkFunSuite { } test("getLong") { - val options = new DataSourceV2Options(Map("numFoo" -> "9223372036854775807", + val options = new DataSourceOptions(Map("numFoo" -> "9223372036854775807", "foo" -> "bar").asJava) assert(options.getLong("numFOO", 0L) == 9223372036854775807L) assert(options.getLong("numFoo2", -1L) == -1L) @@ -70,7 +70,7 @@ class DataSourceV2OptionsSuite extends SparkFunSuite { } test("getDouble") { - val options = new DataSourceV2Options(Map("numFoo" -> "922337.1", + val options = new DataSourceOptions(Map("numFoo" -> "922337.1", "foo" -> "bar").asJava) assert(options.getDouble("numFOO", 0d) == 922337.1d) assert(options.getDouble("numFoo2", -1.02d) == -1.02d) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 42c5d3bcea44b..ee50e8a92270b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -201,7 +201,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceV2Reader { + class Reader extends DataSourceReader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { @@ -209,7 +209,7 @@ class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { } } - override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } class SimpleDataReaderFactory(start: Int, end: Int) @@ -233,7 +233,7 @@ class SimpleDataReaderFactory(start: Int, end: Int) class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceV2Reader + class Reader extends DataSourceReader with SupportsPushDownRequiredColumns with SupportsPushDownFilters { var requiredSchema = new StructType().add("i", "int").add("j", "int") @@ -275,7 +275,7 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { } } - override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } class AdvancedDataReaderFactory(start: Int, end: Int, requiredSchema: StructType) @@ -306,7 +306,7 @@ class AdvancedDataReaderFactory(start: Int, end: Int, requiredSchema: StructType class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceV2Reader with SupportsScanUnsafeRow { + class Reader extends DataSourceReader with SupportsScanUnsafeRow { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") override def createUnsafeRowReaderFactories(): JList[DataReaderFactory[UnsafeRow]] = { @@ -315,7 +315,7 @@ class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { } } - override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } class UnsafeRowDataReaderFactory(start: Int, end: Int) @@ -343,18 +343,18 @@ class UnsafeRowDataReaderFactory(start: Int, end: Int) class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { - class Reader(val readSchema: StructType) extends DataSourceV2Reader { + class Reader(val readSchema: StructType) extends DataSourceReader { override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = java.util.Collections.emptyList() } - override def createReader(schema: StructType, options: DataSourceV2Options): DataSourceV2Reader = + override def createReader(schema: StructType, options: DataSourceOptions): DataSourceReader = new Reader(schema) } class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceV2Reader with SupportsScanColumnarBatch { + class Reader extends DataSourceReader with SupportsScanColumnarBatch { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") override def createBatchDataReaderFactories(): JList[DataReaderFactory[ColumnarBatch]] = { @@ -362,7 +362,7 @@ class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { } } - override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } class BatchDataReaderFactory(start: Int, end: Int) @@ -406,7 +406,7 @@ class BatchDataReaderFactory(start: Int, end: Int) class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceV2Reader with SupportsReportPartitioning { + class Reader extends DataSourceReader with SupportsReportPartitioning { override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { @@ -428,7 +428,7 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { } } - override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } class SpecificDataReaderFactory(i: Array[Int], j: Array[Int]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 3310d6dd199d6..a131b16953e3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, DataSourceV2Reader} +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, DataSourceReader} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -42,7 +42,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS private val schema = new StructType().add("i", "long").add("j", "long") - class Reader(path: String, conf: Configuration) extends DataSourceV2Reader { + class Reader(path: String, conf: Configuration) extends DataSourceReader { override def readSchema(): StructType = schema override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { @@ -64,7 +64,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } } - class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceV2Writer { + class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter { override def createWriterFactory(): DataWriterFactory[Row] = { new SimpleCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) } @@ -104,7 +104,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } } - override def createReader(options: DataSourceV2Options): DataSourceV2Reader = { + override def createReader(options: DataSourceOptions): DataSourceReader = { val path = new Path(options.get("path").get()) val conf = SparkContext.getActive.get.hadoopConfiguration new Reader(path.toUri.toString, conf) @@ -114,7 +114,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS jobId: String, schema: StructType, mode: SaveMode, - options: DataSourceV2Options): Optional[DataSourceV2Writer] = { + options: DataSourceOptions): Optional[DataSourceWriter] = { assert(DataType.equalsStructurally(schema.asNullable, this.schema.asNullable)) assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false)) @@ -141,7 +141,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } private def createWriter( - jobId: String, path: Path, conf: Configuration, internal: Boolean): DataSourceV2Writer = { + jobId: String, path: Path, conf: Configuration, internal: Boolean): DataSourceWriter = { val pathStr = path.toUri.toString if (internal) { new InternalRowWriter(jobId, pathStr, conf) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index dc8c857018457..3127d664d32dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, Streami import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader.DataReaderFactory import org.apache.spark.sql.sources.v2.streaming._ import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} @@ -54,14 +54,14 @@ trait FakeMicroBatchReadSupport extends MicroBatchReadSupport { override def createMicroBatchReader( schema: Optional[StructType], checkpointLocation: String, - options: DataSourceV2Options): MicroBatchReader = FakeReader() + options: DataSourceOptions): MicroBatchReader = FakeReader() } trait FakeContinuousReadSupport extends ContinuousReadSupport { override def createContinuousReader( schema: Optional[StructType], checkpointLocation: String, - options: DataSourceV2Options): ContinuousReader = FakeReader() + options: DataSourceOptions): ContinuousReader = FakeReader() } trait FakeStreamWriteSupport extends StreamWriteSupport { @@ -69,7 +69,7 @@ trait FakeStreamWriteSupport extends StreamWriteSupport { queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): StreamWriter = { + options: DataSourceOptions): StreamWriter = { throw new IllegalStateException("fake sink - cannot actually write") } } From 7a2ada223e14d09271a76091be0338b2d375081e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 30 Jan 2018 21:55:55 +0900 Subject: [PATCH 238/774] [SPARK-23261][PYSPARK] Rename Pandas UDFs ## What changes were proposed in this pull request? Rename the public APIs and names of pandas udfs. - `PANDAS SCALAR UDF` -> `SCALAR PANDAS UDF` - `PANDAS GROUP MAP UDF` -> `GROUPED MAP PANDAS UDF` - `PANDAS GROUP AGG UDF` -> `GROUPED AGG PANDAS UDF` ## How was this patch tested? The existing tests Author: gatorsmile Closes #20428 from gatorsmile/renamePandasUDFs. --- .../spark/api/python/PythonRunner.scala | 12 +-- docs/sql-programming-guide.md | 8 +- examples/src/main/python/sql/arrow.py | 12 +-- python/pyspark/rdd.py | 6 +- python/pyspark/sql/functions.py | 34 +++---- python/pyspark/sql/group.py | 10 +- python/pyspark/sql/tests.py | 92 +++++++++---------- python/pyspark/sql/udf.py | 25 ++--- python/pyspark/worker.py | 24 ++--- .../sql/catalyst/expressions/PythonUDF.scala | 4 +- .../sql/catalyst/planning/patterns.scala | 1 - .../spark/sql/RelationalGroupedDataset.scala | 4 +- .../python/AggregateInPandasExec.scala | 2 +- .../python/ArrowEvalPythonExec.scala | 2 +- .../execution/python/ExtractPythonUDFs.scala | 2 +- .../python/FlatMapGroupsInPandasExec.scala | 2 +- 16 files changed, 120 insertions(+), 120 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 29148a7ee558b..f075a7e0eb0b4 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -37,16 +37,16 @@ private[spark] object PythonEvalType { val SQL_BATCHED_UDF = 100 - val SQL_PANDAS_SCALAR_UDF = 200 - val SQL_PANDAS_GROUP_MAP_UDF = 201 - val SQL_PANDAS_GROUP_AGG_UDF = 202 + val SQL_SCALAR_PANDAS_UDF = 200 + val SQL_GROUPED_MAP_PANDAS_UDF = 201 + val SQL_GROUPED_AGG_PANDAS_UDF = 202 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" case SQL_BATCHED_UDF => "SQL_BATCHED_UDF" - case SQL_PANDAS_SCALAR_UDF => "SQL_PANDAS_SCALAR_UDF" - case SQL_PANDAS_GROUP_MAP_UDF => "SQL_PANDAS_GROUP_MAP_UDF" - case SQL_PANDAS_GROUP_AGG_UDF => "SQL_PANDAS_GROUP_AGG_UDF" + case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF" + case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF" + case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF" } } diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d49c8d869cba6..a0e221b39cc34 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1684,7 +1684,7 @@ Spark will fall back to create the DataFrame without Arrow. Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and Pandas to work with the data. A Pandas UDF is defined using the keyword `pandas_udf` as a decorator or to wrap the function, no additional configuration is required. Currently, there are two types of -Pandas UDF: Scalar and Group Map. +Pandas UDF: Scalar and Grouped Map. ### Scalar @@ -1702,8 +1702,8 @@ The following example shows how to create a scalar Pandas UDF that computes the
-### Group Map -Group map Pandas UDFs are used with `groupBy().apply()` which implements the "split-apply-combine" pattern. +### Grouped Map +Grouped map Pandas UDFs are used with `groupBy().apply()` which implements the "split-apply-combine" pattern. Split-apply-combine consists of three steps: * Split the data into groups by using `DataFrame.groupBy`. * Apply a function on each group. The input and output of the function are both `pandas.DataFrame`. The @@ -1723,7 +1723,7 @@ The following example shows how to use `groupby().apply()` to subtract the mean
-{% include_example group_map_pandas_udf python/sql/arrow.py %} +{% include_example grouped_map_pandas_udf python/sql/arrow.py %}
diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py index 6c0028b3f1c1f..4c5aefb6ff4a6 100644 --- a/examples/src/main/python/sql/arrow.py +++ b/examples/src/main/python/sql/arrow.py @@ -86,15 +86,15 @@ def multiply_func(a, b): # $example off:scalar_pandas_udf$ -def group_map_pandas_udf_example(spark): - # $example on:group_map_pandas_udf$ +def grouped_map_pandas_udf_example(spark): + # $example on:grouped_map_pandas_udf$ from pyspark.sql.functions import pandas_udf, PandasUDFType df = spark.createDataFrame( [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")) - @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) + @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) def substract_mean(pdf): # pdf is a pandas.DataFrame v = pdf.v @@ -110,7 +110,7 @@ def substract_mean(pdf): # | 2|-1.0| # | 2| 4.0| # +---+----+ - # $example off:group_map_pandas_udf$ + # $example off:grouped_map_pandas_udf$ if __name__ == "__main__": @@ -123,7 +123,7 @@ def substract_mean(pdf): dataframe_with_arrow_example(spark) print("Running pandas_udf scalar example") scalar_pandas_udf_example(spark) - print("Running pandas_udf group map example") - group_map_pandas_udf_example(spark) + print("Running pandas_udf grouped map example") + grouped_map_pandas_udf_example(spark) spark.stop() diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 6b018c3a38444..93b8974a7e64a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -68,9 +68,9 @@ class PythonEvalType(object): SQL_BATCHED_UDF = 100 - SQL_PANDAS_SCALAR_UDF = 200 - SQL_PANDAS_GROUP_MAP_UDF = 201 - SQL_PANDAS_GROUP_AGG_UDF = 202 + SQL_SCALAR_PANDAS_UDF = 200 + SQL_GROUPED_MAP_PANDAS_UDF = 201 + SQL_GROUPED_AGG_PANDAS_UDF = 202 def portable_hash(x): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a291c9b71913f..3c8fb4c4d19e7 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1737,8 +1737,8 @@ def translate(srcCol, matching, replace): def create_map(*cols): """Creates a new map column. - :param cols: list of column names (string) or list of :class:`Column` expressions that grouped - as key-value pairs, e.g. (key1, value1, key2, value2, ...). + :param cols: list of column names (string) or list of :class:`Column` expressions that are + grouped as key-value pairs, e.g. (key1, value1, key2, value2, ...). >>> df.select(create_map('name', 'age').alias("map")).collect() [Row(map={u'Alice': 2}), Row(map={u'Bob': 5})] @@ -2085,11 +2085,11 @@ def map_values(col): class PandasUDFType(object): """Pandas UDF Types. See :meth:`pyspark.sql.functions.pandas_udf`. """ - SCALAR = PythonEvalType.SQL_PANDAS_SCALAR_UDF + SCALAR = PythonEvalType.SQL_SCALAR_PANDAS_UDF - GROUP_MAP = PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF + GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF - GROUP_AGG = PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF + GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF @since(1.3) @@ -2193,20 +2193,20 @@ def pandas_udf(f=None, returnType=None, functionType=None): Therefore, this can be used, for example, to ensure the length of each returned `pandas.Series`, and can not be used as the column length. - 2. GROUP_MAP + 2. GROUPED_MAP - A group map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` + A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` The returnType should be a :class:`StructType` describing the schema of the returned `pandas.DataFrame`. The length of the returned `pandas.DataFrame` can be arbitrary. - Group map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. + Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) # doctest: +SKIP - >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) # doctest: +SKIP + >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) @@ -2223,9 +2223,9 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. seealso:: :meth:`pyspark.sql.GroupedData.apply` - 3. GROUP_AGG + 3. GROUPED_AGG - A group aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar + A grouped aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar The `returnType` should be a primitive data type, e.g., :class:`DoubleType`. The returned scalar can be either a python primitive type, e.g., `int` or `float` or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. @@ -2239,7 +2239,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) - >>> @pandas_udf("double", PandasUDFType.GROUP_AGG) # doctest: +SKIP + >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP ... def mean_udf(v): ... return v.mean() >>> df.groupby("id").agg(mean_udf(df['v'])).show() # doctest: +SKIP @@ -2285,21 +2285,21 @@ def pandas_udf(f=None, returnType=None, functionType=None): eval_type = returnType else: # @pandas_udf(dataType) or @pandas_udf(returnType=dataType) - eval_type = PythonEvalType.SQL_PANDAS_SCALAR_UDF + eval_type = PythonEvalType.SQL_SCALAR_PANDAS_UDF else: return_type = returnType if functionType is not None: eval_type = functionType else: - eval_type = PythonEvalType.SQL_PANDAS_SCALAR_UDF + eval_type = PythonEvalType.SQL_SCALAR_PANDAS_UDF if return_type is None: raise ValueError("Invalid returnType: returnType can not be None") - if eval_type not in [PythonEvalType.SQL_PANDAS_SCALAR_UDF, - PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, - PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF]: + if eval_type not in [PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF]: raise ValueError("Invalid functionType: " "functionType must be one the values from PandasUDFType") diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index f90a909d7c2b1..ab646535c864c 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -98,7 +98,7 @@ def agg(self, *exprs): [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> @pandas_udf('int', PandasUDFType.GROUP_AGG) # doctest: +SKIP + >>> @pandas_udf('int', PandasUDFType.GROUPED_AGG) # doctest: +SKIP ... def min_udf(v): ... return v.min() >>> sorted(gdf.agg(min_udf(df.age)).collect()) # doctest: +SKIP @@ -235,14 +235,14 @@ def apply(self, udf): into memory, so the user should be aware of the potential OOM risk if data is skewed and certain groups are too large to fit in memory. - :param udf: a group map user-defined function returned by + :param udf: a grouped map user-defined function returned by :func:`pyspark.sql.functions.pandas_udf`. >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) - >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) # doctest: +SKIP + >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) @@ -262,9 +262,9 @@ def apply(self, udf): """ # Columns are special because hasattr always return True if isinstance(udf, Column) or not hasattr(udf, 'func') \ - or udf.evalType != PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + or udf.evalType != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type " - "GROUP_MAP.") + "GROUPED_MAP.") df = self._df udf_column = udf(*[df[col] for col in df.columns]) jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ca7bbf8ffe71c..dc80870d3cd9f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3621,34 +3621,34 @@ def test_pandas_udf_basic(self): udf = pandas_udf(lambda x: x, DoubleType()) self.assertEqual(udf.returnType, DoubleType()) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) udf = pandas_udf(lambda x: x, DoubleType(), PandasUDFType.SCALAR) self.assertEqual(udf.returnType, DoubleType()) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) udf = pandas_udf(lambda x: x, 'double', PandasUDFType.SCALAR) self.assertEqual(udf.returnType, DoubleType()) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) udf = pandas_udf(lambda x: x, StructType([StructField("v", DoubleType())]), - PandasUDFType.GROUP_MAP) + PandasUDFType.GROUPED_MAP) self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUP_MAP) + udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP) self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) udf = pandas_udf(lambda x: x, 'v double', - functionType=PandasUDFType.GROUP_MAP) + functionType=PandasUDFType.GROUPED_MAP) self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) udf = pandas_udf(lambda x: x, returnType='v double', - functionType=PandasUDFType.GROUP_MAP) + functionType=PandasUDFType.GROUPED_MAP) self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) def test_pandas_udf_decorator(self): from pyspark.rdd import PythonEvalType @@ -3659,45 +3659,45 @@ def test_pandas_udf_decorator(self): def foo(x): return x self.assertEqual(foo.returnType, DoubleType()) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) @pandas_udf(returnType=DoubleType()) def foo(x): return x self.assertEqual(foo.returnType, DoubleType()) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) schema = StructType([StructField("v", DoubleType())]) - @pandas_udf(schema, PandasUDFType.GROUP_MAP) + @pandas_udf(schema, PandasUDFType.GROUPED_MAP) def foo(x): return x self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - @pandas_udf('v double', PandasUDFType.GROUP_MAP) + @pandas_udf('v double', PandasUDFType.GROUPED_MAP) def foo(x): return x self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - @pandas_udf(schema, functionType=PandasUDFType.GROUP_MAP) + @pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP) def foo(x): return x self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) @pandas_udf(returnType='v double', functionType=PandasUDFType.SCALAR) def foo(x): return x self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUP_MAP) + @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP) def foo(x): return x self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) def test_udf_wrong_arg(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -3724,15 +3724,15 @@ def zero_with_type(): return 1 with self.assertRaisesRegexp(TypeError, 'Invalid returnType'): - @pandas_udf(returnType=PandasUDFType.GROUP_MAP) + @pandas_udf(returnType=PandasUDFType.GROUPED_MAP) def foo(df): return df with self.assertRaisesRegexp(ValueError, 'Invalid returnType'): - @pandas_udf(returnType='double', functionType=PandasUDFType.GROUP_MAP) + @pandas_udf(returnType='double', functionType=PandasUDFType.GROUPED_MAP) def foo(df): return df with self.assertRaisesRegexp(ValueError, 'Invalid function'): - @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUP_MAP) + @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP) def foo(k, v): return k @@ -3804,11 +3804,11 @@ def test_register_nondeterministic_vectorized_udf_basic(self): random_pandas_udf = pandas_udf( lambda x: random.randint(6, 6) + x, IntegerType()).asNondeterministic() self.assertEqual(random_pandas_udf.deterministic, False) - self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) nondeterministic_pandas_udf = self.spark.catalog.registerFunction( "randomPandasUDF", random_pandas_udf) self.assertEqual(nondeterministic_pandas_udf.deterministic, False) - self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) [row] = self.spark.sql("SELECT randomPandasUDF(1)").collect() self.assertEqual(row[0], 7) @@ -4206,7 +4206,7 @@ def test_register_vectorized_udf_basic(self): col('id').cast('int').alias('b')) original_add = pandas_udf(lambda x, y: x + y, IntegerType()) self.assertEqual(original_add.deterministic, True) - self.assertEqual(original_add.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(original_add.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) new_add = self.spark.catalog.registerFunction("add1", original_add) res1 = df.select(new_add(col('a'), col('b'))) res2 = self.spark.sql( @@ -4237,20 +4237,20 @@ def test_simple(self): StructField('v', IntegerType()), StructField('v1', DoubleType()), StructField('v2', LongType())]), - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) self.assertPandasEqual(expected, result) - def test_register_group_map_udf(self): + def test_register_grouped_map_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType - foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUP_MAP) + foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP) with QuietTest(self.sc): with self.assertRaisesRegexp(ValueError, 'f must be either SQL_BATCHED_UDF or ' - 'SQL_PANDAS_SCALAR_UDF'): + 'SQL_SCALAR_PANDAS_UDF'): self.spark.catalog.registerFunction("foo_udf", foo_udf) def test_decorator(self): @@ -4259,7 +4259,7 @@ def test_decorator(self): @pandas_udf( 'id long, v int, v1 double, v2 long', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) def foo(pdf): return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id) @@ -4275,7 +4275,7 @@ def test_coerce(self): foo = pandas_udf( lambda pdf: pdf, 'id long, v double', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) result = df.groupby('id').apply(foo).sort('id').toPandas() @@ -4289,7 +4289,7 @@ def test_complex_groupby(self): @pandas_udf( 'id long, v int, norm double', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) def normalize(pdf): v = pdf.v @@ -4308,7 +4308,7 @@ def test_empty_groupby(self): @pandas_udf( 'id long, v int, norm double', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) def normalize(pdf): v = pdf.v @@ -4328,7 +4328,7 @@ def test_datatype_string(self): foo_udf = pandas_udf( lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), 'id long, v int, v1 double, v2 long', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) result = df.groupby('id').apply(foo_udf).sort('id').toPandas() @@ -4342,7 +4342,7 @@ def test_wrong_return_type(self): foo = pandas_udf( lambda pdf: pdf, 'id long, v map', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) with QuietTest(self.sc): @@ -4368,7 +4368,7 @@ def test_wrong_args(self): with self.assertRaisesRegexp(ValueError, 'Invalid udf'): df.groupby('id').apply( pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]))) - with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUP_MAP'): + with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUPED_MAP'): df.groupby('id').apply( pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]), PandasUDFType.SCALAR)) @@ -4379,7 +4379,7 @@ def test_unsupported_types(self): [StructField("id", LongType(), True), StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(1, None,)], schema=schema) - f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUP_MAP) + f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUPED_MAP) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): df.groupby('id').apply(f).collect() @@ -4422,7 +4422,7 @@ def plus_two(v): def pandas_agg_mean_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType - @pandas_udf('double', PandasUDFType.GROUP_AGG) + @pandas_udf('double', PandasUDFType.GROUPED_AGG) def avg(v): return v.mean() return avg @@ -4431,7 +4431,7 @@ def avg(v): def pandas_agg_sum_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType - @pandas_udf('double', PandasUDFType.GROUP_AGG) + @pandas_udf('double', PandasUDFType.GROUPED_AGG) def sum(v): return v.sum() return sum @@ -4441,7 +4441,7 @@ def pandas_agg_weighted_mean_udf(self): import numpy as np from pyspark.sql.functions import pandas_udf, PandasUDFType - @pandas_udf('double', PandasUDFType.GROUP_AGG) + @pandas_udf('double', PandasUDFType.GROUPED_AGG) def weighted_mean(v, w): return np.average(v, weights=w) return weighted_mean @@ -4505,19 +4505,19 @@ def test_unsupported_types(self): with QuietTest(self.sc): with self.assertRaisesRegex(NotImplementedError, 'not supported'): - @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUP_AGG) + @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return [v.mean(), v.std()] with QuietTest(self.sc): with self.assertRaisesRegex(NotImplementedError, 'not supported'): - @pandas_udf('mean double, std double', PandasUDFType.GROUP_AGG) + @pandas_udf('mean double, std double', PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return v.mean(), v.std() with QuietTest(self.sc): with self.assertRaisesRegex(NotImplementedError, 'not supported'): - @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUP_AGG) + @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return {v.mean(): v.std()} diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 4f303304e5600..0f759c448b8a7 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -37,9 +37,9 @@ def _wrap_function(sc, func, returnType): def _create_udf(f, returnType, evalType): - if evalType in (PythonEvalType.SQL_PANDAS_SCALAR_UDF, - PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, - PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF): + if evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): import inspect from pyspark.sql.utils import require_minimum_pyarrow_version @@ -47,16 +47,16 @@ def _create_udf(f, returnType, evalType): require_minimum_pyarrow_version() argspec = inspect.getargspec(f) - if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF and len(argspec.args) == 0 and \ + if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \ argspec.varargs is None: raise ValueError( "Invalid function: 0-arg pandas_udfs are not supported. " "Instead, create a 1-arg pandas_udf and ignore the arg in your function." ) - if evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF and len(argspec.args) != 1: + if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF and len(argspec.args) != 1: raise ValueError( - "Invalid function: pandas_udfs with function type GROUP_MAP " + "Invalid function: pandas_udfs with function type GROUPED_MAP " "must take a single arg that is a pandas DataFrame." ) @@ -112,14 +112,15 @@ def returnType(self): else: self._returnType_placeholder = _parse_datatype_string(self._returnType) - if self.evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF \ + if self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \ and not isinstance(self._returnType_placeholder, StructType): raise ValueError("Invalid returnType: returnType must be a StructType for " - "pandas_udf with function type GROUP_MAP") - elif self.evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF \ + "pandas_udf with function type GROUPED_MAP") + elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF \ and isinstance(self._returnType_placeholder, (StructType, ArrayType, MapType)): raise NotImplementedError( - "ArrayType, StructType and MapType are not supported with PandasUDFType.GROUP_AGG") + "ArrayType, StructType and MapType are not supported with " + "PandasUDFType.GROUPED_AGG") return self._returnType_placeholder @@ -292,9 +293,9 @@ def register(self, name, f, returnType=None): "Invalid returnType: data type can not be specified when f is" "a user-defined function, but got %s." % returnType) if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, - PythonEvalType.SQL_PANDAS_SCALAR_UDF]: + PythonEvalType.SQL_SCALAR_PANDAS_UDF]: raise ValueError( - "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF") + "Invalid f: f must be either SQL_BATCHED_UDF or SQL_SCALAR_PANDAS_UDF") register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, evalType=f.evalType, deterministic=f.deterministic) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 173d8fb2856fa..121b3dd1aeec9 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -74,7 +74,7 @@ def wrap_udf(f, return_type): return lambda *a: f(*a) -def wrap_pandas_scalar_udf(f, return_type): +def wrap_scalar_pandas_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) def verify_result_length(*a): @@ -90,7 +90,7 @@ def verify_result_length(*a): return lambda *a: (verify_result_length(*a), arrow_return_type) -def wrap_pandas_group_map_udf(f, return_type): +def wrap_grouped_map_pandas_udf(f, return_type): def wrapped(*series): import pandas as pd @@ -110,7 +110,7 @@ def wrapped(*series): return wrapped -def wrap_pandas_group_agg_udf(f, return_type): +def wrap_grouped_agg_pandas_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) def wrapped(*series): @@ -133,12 +133,12 @@ def read_single_udf(pickleSer, infile, eval_type): row_func = chain(row_func, f) # the last returnType will be the return type of UDF - if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF: - return arg_offsets, wrap_pandas_scalar_udf(row_func, return_type) - elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: - return arg_offsets, wrap_pandas_group_map_udf(row_func, return_type) - elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF: - return arg_offsets, wrap_pandas_group_agg_udf(row_func, return_type) + if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: + return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: + return arg_offsets, wrap_grouped_agg_pandas_udf(row_func, return_type) elif eval_type == PythonEvalType.SQL_BATCHED_UDF: return arg_offsets, wrap_udf(row_func, return_type) else: @@ -163,9 +163,9 @@ def read_udfs(pickleSer, infile, eval_type): func = lambda _, it: map(mapper, it) - if eval_type in (PythonEvalType.SQL_PANDAS_SCALAR_UDF, - PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, - PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF): + if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): timezone = utf8_deserializer.loads(infile) ser = ArrowStreamPandasSerializer(timezone) else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index 4ba8ff6e3802f..efd664dde725a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.DataType object PythonUDF { private[this] val SCALAR_TYPES = Set( PythonEvalType.SQL_BATCHED_UDF, - PythonEvalType.SQL_PANDAS_SCALAR_UDF + PythonEvalType.SQL_SCALAR_PANDAS_UDF ) def isScalarPythonUDF(e: Expression): Boolean = { @@ -36,7 +36,7 @@ object PythonUDF { def isGroupAggPandasUDF(e: Expression): Boolean = { e.isInstanceOf[PythonUDF] && - e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF + e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 132241061d510..626f905707191 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.planning -import org.apache.spark.api.python.PythonEvalType import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index d320c1c359411..7147798d99533 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -449,8 +449,8 @@ class RelationalGroupedDataset protected[sql]( * workers. */ private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = { - require(expr.evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, - "Must pass a group map udf") + require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + "Must pass a grouped map udf") require(expr.dataType.isInstanceOf[StructType], "The returnType of the udf must be a StructType") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 18e5f8605c60d..8e01e8e56a5bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -136,7 +136,7 @@ case class AggregateInPandasExec( val columnarBatchIter = new ArrowPythonRunner( pyFuncs, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF, argOffsets, aggInputSchema, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, argOffsets, aggInputSchema, sessionLocalTimeZone, pandasRespectSessionTimeZone) .compute(projectedRowIter, context.partitionId(), context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 47b146f076b62..c4de214679ae4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -81,7 +81,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val columnarBatchIter = new ArrowPythonRunner( funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_SCALAR_UDF, argOffsets, schema, + PythonEvalType.SQL_SCALAR_PANDAS_UDF, argOffsets, schema, sessionLocalTimeZone, pandasRespectSessionTimeZone) .compute(batchIter, context.partitionId(), context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 4ae4e164830be..9d56f48249982 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -160,7 +160,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } val evaluation = validUdfs.partition( - _.evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF + _.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF ) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 59db66bd7adf1..c798fe5a92c54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -96,7 +96,7 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, argOffsets, schema, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, schema, sessionLocalTimeZone, pandasRespectSessionTimeZone) .compute(grouped, context.partitionId(), context) From 84bcf9dc88ffeae6fba4cfad9455ad75bed6e6f6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 30 Jan 2018 21:00:29 +0800 Subject: [PATCH 239/774] [SPARK-23222][SQL] Make DataFrameRangeSuite not flaky ## What changes were proposed in this pull request? It is reported that the test `Cancelling stage in a query with Range` in `DataFrameRangeSuite` fails a few times in unrelated PRs. I personally also saw it too in my PR. This test is not very flaky actually but only fails occasionally. Based on how the test works, I guess that is because `range` finishes before the listener calls `cancelStage`. I increase the range number from `1000000000L` to `100000000000L` and count the range in one partition. I also reduce the `interval` of checking stage id. Hopefully it can make the test not flaky anymore. ## How was this patch tested? The modified tests. Author: Liang-Chi Hsieh Closes #20431 from viirya/SPARK-23222. --- .../scala/org/apache/spark/sql/DataFrameRangeSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index 45afbd29d1907..57a930dfaf320 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -154,7 +154,7 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall test("Cancelling stage in a query with Range.") { val listener = new SparkListener { override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - eventually(timeout(10.seconds)) { + eventually(timeout(10.seconds), interval(1.millis)) { assert(DataFrameRangeSuite.stageToKill > 0) } sparkContext.cancelStage(DataFrameRangeSuite.stageToKill) @@ -166,7 +166,7 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { DataFrameRangeSuite.stageToKill = -1 val ex = intercept[SparkException] { - spark.range(1000000000L).map { x => + spark.range(0, 100000000000L, 1, 1).map { x => DataFrameRangeSuite.stageToKill = TaskContext.get().stageId() x }.toDF("id").agg(sum("id")).collect() @@ -184,6 +184,7 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) } } + sparkContext.removeSparkListener(listener) } test("SPARK-20430 Initialize Range parameters in a driver side") { From a23187f53037425c61f1180b5e7990a116f86a42 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 31 Jan 2018 00:51:00 +0900 Subject: [PATCH 240/774] [SPARK-23174][BUILD][PYTHON][FOLLOWUP] Add pycodestyle*.py to .gitignore file. ## What changes were proposed in this pull request? This is a follow-up pr of #20338 which changed the downloaded file name of the python code style checker but it's not contained in .gitignore file so the file remains as an untracked file for git after running the checker. This pr adds the file name to .gitignore file. ## How was this patch tested? Tested manually. Author: Takuya UESHIN Closes #20432 from ueshin/issues/SPARK-23174/fup1. --- dev/.gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/dev/.gitignore b/dev/.gitignore index 4a6027429e0d3..c673922f36d23 100644 --- a/dev/.gitignore +++ b/dev/.gitignore @@ -1 +1,2 @@ pep8*.py +pycodestyle*.py From 31c00ad8b090d7eddc4622e73dc4440cd32624de Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 30 Jan 2018 11:33:30 -0800 Subject: [PATCH 241/774] [SPARK-23267][SQL] Increase spark.sql.codegen.hugeMethodLimit to 65535 ## What changes were proposed in this pull request? Still saw the performance regression introduced by `spark.sql.codegen.hugeMethodLimit` in our internal workloads. There are two major issues in the current solution. - The size of the complied byte code is not identical to the bytecode size of the method. The detection is still not accurate. - The bytecode size of a single operator (e.g., `SerializeFromObject`) could still exceed 8K limit. We saw the performance regression in such scenario. Since it is close to the release of 2.3, we decide to increase it to 64K for avoiding the perf regression. ## How was this patch tested? N/A Author: gatorsmile Closes #20434 from gatorsmile/revertConf. --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 11 ++++++----- .../spark/sql/execution/WholeStageCodegenSuite.scala | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 54a35594f505e..7394a0d7cf983 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -660,12 +660,13 @@ object SQLConf { val WHOLESTAGE_HUGE_METHOD_LIMIT = buildConf("spark.sql.codegen.hugeMethodLimit") .internal() .doc("The maximum bytecode size of a single compiled Java function generated by whole-stage " + - "codegen. When the compiled function exceeds this threshold, " + - "the whole-stage codegen is deactivated for this subtree of the current query plan. " + - s"The default value is ${CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT} and " + - "this is a limit in the OpenJDK JVM implementation.") + "codegen. When the compiled function exceeds this threshold, the whole-stage codegen is " + + "deactivated for this subtree of the current query plan. The default value is 65535, which " + + "is the largest bytecode size possible for a valid Java method. When running on HotSpot, " + + s"it may be preferable to set the value to ${CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT} " + + "to match HotSpot's implementation.") .intConf - .createWithDefault(CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT) + .createWithDefault(65535) val WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR = buildConf("spark.sql.codegen.splitConsumeFuncByOperator") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 28ad712feaae6..6e8d5a70d5a8f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -202,7 +202,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 } - test("SPARK-21871 check if we can get large code size when compiling too long functions") { + ignore("SPARK-21871 check if we can get large code size when compiling too long functions") { val codeWithShortFunctions = genGroupByCode(3) val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions) assert(maxCodeSize1 < SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) @@ -211,7 +211,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) } - test("bytecode of batch file scan exceeds the limit of WHOLESTAGE_HUGE_METHOD_LIMIT") { + ignore("bytecode of batch file scan exceeds the limit of WHOLESTAGE_HUGE_METHOD_LIMIT") { import testImplicits._ withTempPath { dir => val path = dir.getCanonicalPath From 58fcb5a95ee0b91300138cd23f3ce2165fab597f Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 30 Jan 2018 14:11:06 -0800 Subject: [PATCH 242/774] [SPARK-23275][SQL] hive/tests have been failing when run locally on the laptop (Mac) with OOM ## What changes were proposed in this pull request? hive tests have been failing when they are run locally (Mac Os) after a recent change in the trunk. After running the tests for some time, the test fails with OOM with Error: unable to create new native thread. I noticed the thread count goes all the way up to 2000+ after which we start getting these OOM errors. Most of the threads seem to be related to the connection pool in hive metastore (BoneCP-xxxxx-xxxx ). This behaviour change is happening after we made the following change to HiveClientImpl.reset() ``` SQL def reset(): Unit = withHiveState { try { // code } finally { runSqlHive("USE default") ===> this is causing the issue } ``` I am proposing to temporarily back-out part of a fix made to address SPARK-23000 to resolve this issue while we work-out the exact reason for this sudden increase in thread counts. ## How was this patch tested? Ran hive/test multiple times in different machines. (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Dilip Biswal Closes #20441 from dilipbiswal/hive_tests. --- .../sql/hive/client/HiveClientImpl.scala | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 39d839059be75..6c0f4144992ae 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -825,23 +825,19 @@ private[hive] class HiveClientImpl( } def reset(): Unit = withHiveState { - try { - client.getAllTables("default").asScala.foreach { t => - logDebug(s"Deleting table $t") - val table = client.getTable("default", t) - client.getIndexes("default", t, 255).asScala.foreach { index => - shim.dropIndex(client, "default", t, index.getIndexName) - } - if (!table.isIndexTable) { - client.dropTable("default", t) - } + client.getAllTables("default").asScala.foreach { t => + logDebug(s"Deleting table $t") + val table = client.getTable("default", t) + client.getIndexes("default", t, 255).asScala.foreach { index => + shim.dropIndex(client, "default", t, index.getIndexName) } - client.getAllDatabases.asScala.filterNot(_ == "default").foreach { db => - logDebug(s"Dropping Database: $db") - client.dropDatabase(db, true, false, true) + if (!table.isIndexTable) { + client.dropTable("default", t) } - } finally { - runSqlHive("USE default") + } + client.getAllDatabases.asScala.filterNot(_ == "default").foreach { db => + logDebug(s"Dropping Database: $db") + client.dropDatabase(db, true, false, true) } } } From 9623a98248837da302ba4ec240335d1c4268ee21 Mon Sep 17 00:00:00 2001 From: Shashwat Anand Date: Wed, 31 Jan 2018 07:37:25 +0900 Subject: [PATCH 243/774] [MINOR] Fix typos in dev/* scripts. ## What changes were proposed in this pull request? Consistency in style, grammar and removal of extraneous characters. ## How was this patch tested? Manually as this is a doc change. Author: Shashwat Anand Closes #20436 from ashashwat/SPARK-23174. --- dev/appveyor-guide.md | 6 +++--- dev/lint-python | 12 ++++++------ dev/run-pip-tests | 4 ++-- dev/run-tests-jenkins | 2 +- dev/sparktestsupport/modules.py | 8 ++++---- dev/sparktestsupport/toposort.py | 6 +++--- dev/tests/pr_merge_ability.sh | 4 ++-- dev/tests/pr_public_classes.sh | 4 ++-- 8 files changed, 23 insertions(+), 23 deletions(-) diff --git a/dev/appveyor-guide.md b/dev/appveyor-guide.md index d2e00b484727d..a842f39b3049a 100644 --- a/dev/appveyor-guide.md +++ b/dev/appveyor-guide.md @@ -1,6 +1,6 @@ # AppVeyor Guides -Currently, SparkR on Windows is being tested with [AppVeyor](https://ci.appveyor.com). This page describes how to set up AppVeyor with Spark, how to run the build, check the status and stop the build via this tool. There is the documenation for AppVeyor [here](https://www.appveyor.com/docs). Please refer this for full details. +Currently, SparkR on Windows is being tested with [AppVeyor](https://ci.appveyor.com). This page describes how to set up AppVeyor with Spark, how to run the build, check the status and stop the build via this tool. There is the documentation for AppVeyor [here](https://www.appveyor.com/docs). Please refer this for full details. ### Setting up AppVeyor @@ -45,7 +45,7 @@ Currently, SparkR on Windows is being tested with [AppVeyor](https://ci.appveyor 2016-08-30 12 16 35 -- Since we will use Github here, click the "GITHUB" button and then click "Authorize Github" so that AppVeyor can access to the Github logs (e.g. commits). +- Since we will use Github here, click the "GITHUB" button and then click "Authorize Github" so that AppVeyor can access the Github logs (e.g. commits). 2016-09-04 11 10 22 @@ -87,7 +87,7 @@ Currently, SparkR on Windows is being tested with [AppVeyor](https://ci.appveyor 2016-08-30 12 29 41 -- If the build is running, "CANCEL BUILD" buttom appears. Click this button top cancel the current build. +- If the build is running, "CANCEL BUILD" button appears. Click this button to cancel the current build. 2016-08-30 1 11 13 diff --git a/dev/lint-python b/dev/lint-python index e069cafa1b8c6..f738af9c49763 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -34,8 +34,8 @@ python -B -m compileall -q -l $PATHS_TO_CHECK > "$PYCODESTYLE_REPORT_PATH" compile_status="${PIPESTATUS[0]}" # Get pycodestyle at runtime so that we don't rely on it being installed on the build server. -#+ See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 -# Updated to latest official version for pep8. pep8 is formally renamed to pycodestyle. +# See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 +# Updated to the latest official version of pep8. pep8 is formally renamed to pycodestyle. PYCODESTYLE_VERSION="2.3.1" PYCODESTYLE_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-$PYCODESTYLE_VERSION.py" PYCODESTYLE_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/PyCQA/pycodestyle/$PYCODESTYLE_VERSION/pycodestyle.py" @@ -60,9 +60,9 @@ export "PYLINT_HOME=$PYTHONPATH" export "PATH=$PYTHONPATH:$PATH" # There is no need to write this output to a file -#+ first, but we do so so that the check status can -#+ be output before the report, like with the -#+ scalastyle and RAT checks. +# first, but we do so so that the check status can +# be output before the report, like with the +# scalastyle and RAT checks. python "$PYCODESTYLE_SCRIPT_PATH" --config=dev/tox.ini $PATHS_TO_CHECK >> "$PYCODESTYLE_REPORT_PATH" pycodestyle_status="${PIPESTATUS[0]}" @@ -73,7 +73,7 @@ else fi if [ "$lint_status" -ne 0 ]; then - echo "PYCODESTYLE checks failed." + echo "pycodestyle checks failed." cat "$PYCODESTYLE_REPORT_PATH" rm "$PYCODESTYLE_REPORT_PATH" exit "$lint_status" diff --git a/dev/run-pip-tests b/dev/run-pip-tests index d51dde12a03c5..1321c2be4c192 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -25,10 +25,10 @@ shopt -s nullglob FWDIR="$(cd "$(dirname "$0")"/..; pwd)" cd "$FWDIR" -echo "Constucting virtual env for testing" +echo "Constructing virtual env for testing" VIRTUALENV_BASE=$(mktemp -d) -# Clean up the virtual env enviroment used if we created one. +# Clean up the virtual env environment used if we created one. function delete_virtualenv() { echo "Cleaning up temporary directory - $VIRTUALENV_BASE" rm -rf "$VIRTUALENV_BASE" diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 03fd6ff0fba40..5bc03e41d1f2d 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -20,7 +20,7 @@ # Wrapper script that runs the Spark tests then reports QA results # to github via its API. # Environment variables are populated by the code here: -#+ https://github.com/jenkinsci/ghprb-plugin/blob/master/src/main/java/org/jenkinsci/plugins/ghprb/GhprbTrigger.java#L139 +# https://github.com/jenkinsci/ghprb-plugin/blob/master/src/main/java/org/jenkinsci/plugins/ghprb/GhprbTrigger.java#L139 FWDIR="$( cd "$( dirname "$0" )/.." && pwd )" cd "$FWDIR" diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index b900f0bd913c3..dfea762db98c6 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -25,10 +25,10 @@ @total_ordering class Module(object): """ - A module is the basic abstraction in our test runner script. Each module consists of a set of - source files, a set of test commands, and a set of dependencies on other modules. We use modules - to define a dependency graph that lets determine which tests to run based on which files have - changed. + A module is the basic abstraction in our test runner script. Each module consists of a set + of source files, a set of test commands, and a set of dependencies on other modules. We use + modules to define a dependency graph that let us determine which tests to run based on which + files have changed. """ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, diff --git a/dev/sparktestsupport/toposort.py b/dev/sparktestsupport/toposort.py index 6c67b4504bc3b..8b2688d20039f 100644 --- a/dev/sparktestsupport/toposort.py +++ b/dev/sparktestsupport/toposort.py @@ -43,8 +43,8 @@ def toposort(data): """Dependencies are expressed as a dictionary whose keys are items and whose values are a set of dependent items. Output is a list of sets in topological order. The first set consists of items with no -dependences, each subsequent set consists of items that depend upon -items in the preceeding sets. +dependencies, each subsequent set consists of items that depend upon +items in the preceding sets. """ # Special case empty input. @@ -59,7 +59,7 @@ def toposort(data): v.discard(k) # Find all items that don't depend on anything. extra_items_in_deps = _reduce(set.union, data.values()) - set(data.keys()) - # Add empty dependences where needed. + # Add empty dependencies where needed. data.update({item: set() for item in extra_items_in_deps}) while True: ordered = set(item for item, dep in data.items() if len(dep) == 0) diff --git a/dev/tests/pr_merge_ability.sh b/dev/tests/pr_merge_ability.sh index d9a347fe24a8c..25fdbccac4dd8 100755 --- a/dev/tests/pr_merge_ability.sh +++ b/dev/tests/pr_merge_ability.sh @@ -23,9 +23,9 @@ # found at dev/run-tests-jenkins. # # Arg1: The Github Pull Request Actual Commit -#+ known as `ghprbActualCommit` in `run-tests-jenkins` +# known as `ghprbActualCommit` in `run-tests-jenkins` # Arg2: The SHA1 hash -#+ known as `sha1` in `run-tests-jenkins` +# known as `sha1` in `run-tests-jenkins` # ghprbActualCommit="$1" diff --git a/dev/tests/pr_public_classes.sh b/dev/tests/pr_public_classes.sh index 41c5d3ee8cb3c..479d1851fe0b8 100755 --- a/dev/tests/pr_public_classes.sh +++ b/dev/tests/pr_public_classes.sh @@ -23,7 +23,7 @@ # found at dev/run-tests-jenkins. # # Arg1: The Github Pull Request Actual Commit -#+ known as `ghprbActualCommit` in `run-tests-jenkins` +# known as `ghprbActualCommit` in `run-tests-jenkins` ghprbActualCommit="$1" @@ -31,7 +31,7 @@ ghprbActualCommit="$1" # master commit and the tip of the pull request branch. # By diffing$ghprbActualCommit^...$ghprbActualCommit and filtering to examine the diffs of only -# non-test files, we can gets us changes introduced in the PR and not anything else added to master +# non-test files, we can get changes introduced in the PR and not anything else added to master # since the PR was branched. # Handle differences between GNU and BSD sed From 77866167330a665e174ae08a2f8902ef9dc3438b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 30 Jan 2018 17:14:17 -0800 Subject: [PATCH 244/774] [SPARK-23276][SQL][TEST] Enable UDT tests in (Hive)OrcHadoopFsRelationSuite ## What changes were proposed in this pull request? Like Parquet, ORC test suites should enable UDT tests. ## How was this patch tested? Pass the Jenkins with newly enabled test cases. Author: Dongjoon Hyun Closes #20440 from dongjoon-hyun/SPARK-23276. --- .../apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index a1f054b8e3f44..3b82a6c458ce4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -34,11 +34,10 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName - // ORC does not play well with NullType and UDT. + // ORC does not play well with NullType. override protected def supportsDataType(dataType: DataType): Boolean = dataType match { case _: NullType => false case _: CalendarIntervalType => false - case _: UserDefinedType[_] => false case _ => true } From ca04c3ff2387bf0a4308a4b010154e6761827278 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 30 Jan 2018 20:05:57 -0800 Subject: [PATCH 245/774] [SPARK-23274][SQL] Fix ReplaceExceptWithFilter when the right's Filter contains the references that are not in the left output ## What changes were proposed in this pull request? This PR is to fix the `ReplaceExceptWithFilter` rule when the right's Filter contains the references that are not in the left output. Before this PR, we got the error like ``` java.util.NoSuchElementException: key not found: a at scala.collection.MapLike$class.default(MapLike.scala:228) at scala.collection.AbstractMap.default(Map.scala:59) at scala.collection.MapLike$class.apply(MapLike.scala:141) at scala.collection.AbstractMap.apply(Map.scala:59) ``` After this PR, `ReplaceExceptWithFilter ` will not take an effect in this case. ## How was this patch tested? Added tests Author: gatorsmile Closes #20444 from gatorsmile/fixReplaceExceptWithFilter. --- .../optimizer/ReplaceExceptWithFilter.scala | 17 +++++++++++++---- .../optimizer/ReplaceOperatorSuite.scala | 15 +++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 8 ++++++++ 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala index 89bfcee078fba..45edf266bbce4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala @@ -46,18 +46,27 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] { } plan.transform { - case Except(left, right) if isEligible(left, right) => - Distinct(Filter(Not(transformCondition(left, skipProject(right))), left)) + case e @ Except(left, right) if isEligible(left, right) => + val newCondition = transformCondition(left, skipProject(right)) + newCondition.map { c => + Distinct(Filter(Not(c), left)) + }.getOrElse { + e + } } } - private def transformCondition(left: LogicalPlan, right: LogicalPlan): Expression = { + private def transformCondition(left: LogicalPlan, right: LogicalPlan): Option[Expression] = { val filterCondition = InferFiltersFromConstraints(combineFilters(right)).asInstanceOf[Filter].condition val attributeNameMap: Map[String, Attribute] = left.output.map(x => (x.name, x)).toMap - filterCondition.transform { case a : AttributeReference => attributeNameMap(a.name) } + if (filterCondition.references.forall(r => attributeNameMap.contains(r.name))) { + Some(filterCondition.transform { case a: AttributeReference => attributeNameMap(a.name) }) + } else { + None + } } // TODO: This can be further extended in the future. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index e9701ffd2c54b..52dc2e9fb076c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -168,6 +168,21 @@ class ReplaceOperatorSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("replace Except with Filter when only right filter can be applied to the left") { + val table = LocalRelation(Seq('a.int, 'b.int)) + val left = table.where('b < 1).select('a).as("left") + val right = table.where('b < 3).select('a).as("right") + + val query = Except(left, right) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(left.output, right.output, + Join(left, right, LeftAnti, Option($"left.a" <=> $"right.a"))).analyze + + comparePlans(optimized, correctAnswer) + } + test("replace Distinct with Aggregate") { val input = LocalRelation('a.int, 'b.int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 33707080c1301..8b66f77b2f923 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -589,6 +589,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Nil) } + test("SPARK-23274: except between two projects without references used in filter") { + val df = Seq((1, 2, 4), (1, 3, 5), (2, 2, 3), (2, 4, 5)).toDF("a", "b", "c") + val df1 = df.filter($"a" === 1) + val df2 = df.filter($"a" === 2) + checkAnswer(df1.select("b").except(df2.select("b")), Row(3) :: Nil) + checkAnswer(df1.select("b").except(df2.select("c")), Row(2) :: Nil) + } + test("except distinct - SQL compliance") { val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") val df_right = Seq(1, 3).toDF("id") From 8c6a9c90a36a938372f28ee8be72178192fbc313 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 31 Jan 2018 13:59:21 +0800 Subject: [PATCH 246/774] [SPARK-23279][SS] Avoid triggering distributed job for Console sink ## What changes were proposed in this pull request? Console sink will redistribute collected local data and trigger a distributed job in each batch, this is not necessary, so here change to local job. ## How was this patch tested? Existing UT and manual verification. Author: jerryshao Closes #20447 from jerryshao/console-minor. --- .../spark/sql/execution/streaming/sources/ConsoleWriter.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index d46f4d7b86360..c57bdc4a28905 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming.sources +import scala.collection.JavaConverters._ + import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.sources.v2.DataSourceOptions @@ -61,7 +63,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions) println("-------------------------------------------") // scalastyle:off println spark - .createDataFrame(spark.sparkContext.parallelize(rows), schema) + .createDataFrame(rows.toList.asJava, schema) .show(numRowsToShow, isTruncated) } From 695f7146bca342a0ee192d8c7f5ec48d4d8577a8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 31 Jan 2018 15:13:15 +0800 Subject: [PATCH 247/774] [SPARK-23272][SQL] add calendar interval type support to ColumnVector ## What changes were proposed in this pull request? `ColumnVector` is aimed to support all the data types, but `CalendarIntervalType` is missing. Actually we do support interval type for inner fields, e.g. `ColumnarRow`, `ColumnarArray` both support interval type. It's weird if we don't support interval type at the top level. This PR adds the interval type support. This PR also makes `ColumnVector.getChild` protect. We need it public because `MutableColumnaRow.getInterval` needs it. Now the interval implementation is in `ColumnVector.getInterval`. ## How was this patch tested? a new test. Author: Wenchen Fan Closes #20438 from cloud-fan/interval. --- .../vectorized/MutableColumnarRow.java | 4 +- .../sql/vectorized/ArrowColumnVector.java | 2 +- .../spark/sql/vectorized/ColumnVector.java | 26 ++++++++++- .../spark/sql/vectorized/ColumnarArray.java | 4 +- .../spark/sql/vectorized/ColumnarRow.java | 4 +- .../vectorized/ColumnarBatchSuite.scala | 45 +++++++++++++++++-- 6 files changed, 70 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 2bab095d4d951..66668f3753604 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -146,9 +146,7 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { if (columns[ordinal].isNullAt(rowId)) return null; - final int months = columns[ordinal].getChild(0).getInt(rowId); - final long microseconds = columns[ordinal].getChild(1).getLong(rowId); - return new CalendarInterval(months, microseconds); + return columns[ordinal].getInterval(rowId); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 9803c3dec6de2..a75d76bd0f82e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -28,7 +28,7 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * A column vector backed by Apache Arrow. Currently time interval type and map type are not + * A column vector backed by Apache Arrow. Currently calendar interval type and map type are not * supported. */ @InterfaceStability.Evolving diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index 4b955ceddd0f2..111f5d9b358d4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -20,6 +20,7 @@ import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; /** @@ -195,6 +196,7 @@ public double[] getDoubles(int rowId, int count) { * struct field. */ public final ColumnarRow getStruct(int rowId) { + if (isNullAt(rowId)) return null; return new ColumnarRow(this, rowId); } @@ -236,9 +238,29 @@ public MapData getMap(int ordinal) { public abstract byte[] getBinary(int rowId); /** - * Returns the ordinal's child column vector. + * Returns the calendar interval type value for rowId. + * + * In Spark, calendar interval type value is basically an integer value representing the number of + * months in this interval, and a long value representing the number of microseconds in this + * interval. An interval type vector is the same as a struct type vector with 2 fields: `months` + * and `microseconds`. + * + * To support interval type, implementations must implement {@link #getChild(int)} and define 2 + * child vectors: the first child vector is an int type vector, containing all the month values of + * all the interval values in this vector. The second child vector is a long type vector, + * containing all the microsecond values of all the interval values in this vector. + */ + public final CalendarInterval getInterval(int rowId) { + if (isNullAt(rowId)) return null; + final int months = getChild(0).getInt(rowId); + final long microseconds = getChild(1).getLong(rowId); + return new CalendarInterval(months, microseconds); + } + + /** + * @return child [[ColumnVector]] at the given ordinal. */ - public abstract ColumnVector getChild(int ordinal); + protected abstract ColumnVector getChild(int ordinal); /** * Data type for this column. diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 0d2c3ec8648d3..72c07ee7cad3f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -135,9 +135,7 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { - int month = data.getChild(0).getInt(offset + ordinal); - long microseconds = data.getChild(1).getLong(offset + ordinal); - return new CalendarInterval(month, microseconds); + return data.getInterval(offset + ordinal); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 25db7e09d20d0..6ca749d7c6e85 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -139,9 +139,7 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { if (data.getChild(ordinal).isNullAt(rowId)) return null; - final int months = data.getChild(ordinal).getChild(0).getInt(rowId); - final long microseconds = data.getChild(ordinal).getChild(1).getLong(rowId); - return new CalendarInterval(months, microseconds); + return data.getChild(ordinal).getInterval(rowId); } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 1873c24ab063c..925c101fe1fee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -620,6 +620,39 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.arrayData().elementsAppended == 0) } + testVector("CalendarInterval APIs", 4, CalendarIntervalType) { + column => + val reference = mutable.ArrayBuffer.empty[CalendarInterval] + + val months = column.getChild(0) + val microseconds = column.getChild(1) + assert(months.dataType() == IntegerType) + assert(microseconds.dataType() == LongType) + + months.putInt(0, 1) + microseconds.putLong(0, 100) + reference += new CalendarInterval(1, 100) + + months.putInt(1, 0) + microseconds.putLong(1, 2000) + reference += new CalendarInterval(0, 2000) + + column.putNull(2) + reference += null + + months.putInt(3, 20) + microseconds.putLong(3, 0) + reference += new CalendarInterval(20, 0) + + reference.zipWithIndex.foreach { case (v, i) => + val errMsg = "VectorType=" + column.getClass.getSimpleName + assert(v == column.getInterval(i), errMsg) + if (v == null) assert(column.isNullAt(i), errMsg) + } + + column.close() + } + testVector("Int Array", 10, new ArrayType(IntegerType, true)) { column => @@ -739,14 +772,20 @@ class ColumnarBatchSuite extends SparkFunSuite { c1.putInt(0, 123) c2.putDouble(0, 3.45) - c1.putInt(1, 456) - c2.putDouble(1, 5.67) + + column.putNull(1) + + c1.putInt(2, 456) + c2.putDouble(2, 5.67) val s = column.getStruct(0) assert(s.getInt(0) == 123) assert(s.getDouble(1) == 3.45) - val s2 = column.getStruct(1) + assert(column.isNullAt(1)) + assert(column.getStruct(1) == null) + + val s2 = column.getStruct(2) assert(s2.getInt(0) == 456) assert(s2.getDouble(1) == 5.67) } From 161a3f2ae324271a601500e3d2900db9359ee2ef Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Wed, 31 Jan 2018 10:37:37 +0200 Subject: [PATCH 248/774] [SPARK-23112][DOC] Update ML migration guide with breaking and behavior changes. Add breaking changes, as well as update behavior changes, to `2.3` ML migration guide. ## How was this patch tested? Doc only Author: Nick Pentreath Closes #20421 from MLnick/SPARK-23112-ml-guide. --- docs/ml-guide.md | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index b957445579ffd..702bcf748fc74 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -108,7 +108,13 @@ and the migration guide below will explain all changes between releases. ### Breaking changes -There are no breaking changes. +* The class and trait hierarchy for logistic regression model summaries was changed to be cleaner +and better accommodate the addition of the multi-class summary. This is a breaking change for user +code that casts a `LogisticRegressionTrainingSummary` to a +` BinaryLogisticRegressionTrainingSummary`. Users should instead use the `model.binarySummary` +method. See [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139) for more detail +(_note_ this is an `Experimental` API). This _does not_ affect the Python `summary` method, which +will still work correctly for both multinomial and binary cases. ### Deprecations and changes of behavior @@ -123,8 +129,19 @@ new [`OneHotEncoderEstimator`](ml-features.html#onehotencoderestimator) **Changes of behavior** * [SPARK-21027](https://issues.apache.org/jira/browse/SPARK-21027): - We are now setting the default parallelism used in `OneVsRest` to be 1 (i.e. serial). In 2.2 and + The default parallelism used in `OneVsRest` is now set to 1 (i.e. serial). In `2.2` and earlier versions, the level of parallelism was set to the default threadpool size in Scala. +* [SPARK-22156](https://issues.apache.org/jira/browse/SPARK-22156): + The learning rate update for `Word2Vec` was incorrect when `numIterations` was set greater than + `1`. This will cause training results to be different between `2.3` and earlier versions. +* [SPARK-21681](https://issues.apache.org/jira/browse/SPARK-21681): + Fixed an edge case bug in multinomial logistic regression that resulted in incorrect coefficients + when some features had zero variance. +* [SPARK-16957](https://issues.apache.org/jira/browse/SPARK-16957): + Tree algorithms now use mid-points for split values. This may change results from model training. +* [SPARK-14657](https://issues.apache.org/jira/browse/SPARK-14657): + Fixed an issue where the features generated by `RFormula` without an intercept were inconsistent + with the output in R. This may change results from model training in this scenario. ## Previous Spark versions From 3d0911bbe47f76c341c090edad3737e88a67e3d7 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 31 Jan 2018 20:04:51 +0900 Subject: [PATCH 249/774] [SPARK-23228][PYSPARK] Add Python Created jsparkSession to JVM's defaultSession ## What changes were proposed in this pull request? In the current PySpark code, Python created `jsparkSession` doesn't add to JVM's defaultSession, this `SparkSession` object cannot be fetched from Java side, so the below scala code will be failed when loaded in PySpark application. ```scala class TestSparkSession extends SparkListener with Logging { override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { case CreateTableEvent(db, table) => val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) assert(session.isDefined) val tableInfo = session.get.sharedState.externalCatalog.getTable(db, table) logInfo(s"Table info ${tableInfo}") case e => logInfo(s"event $e") } } } ``` So here propose to add fresh create `jsparkSession` to `defaultSession`. ## How was this patch tested? Manual verification. Author: jerryshao Author: hyukjinkwon Author: Saisai Shao Closes #20404 from jerryshao/SPARK-23228. --- python/pyspark/sql/session.py | 10 +++++++++- python/pyspark/sql/tests.py | 28 +++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 6c84023c43fb6..1ed04298bc899 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -213,7 +213,12 @@ def __init__(self, sparkContext, jsparkSession=None): self._jsc = self._sc._jsc self._jvm = self._sc._jvm if jsparkSession is None: - jsparkSession = self._jvm.SparkSession(self._jsc.sc()) + if self._jvm.SparkSession.getDefaultSession().isDefined() \ + and not self._jvm.SparkSession.getDefaultSession().get() \ + .sparkContext().isStopped(): + jsparkSession = self._jvm.SparkSession.getDefaultSession().get() + else: + jsparkSession = self._jvm.SparkSession(self._jsc.sc()) self._jsparkSession = jsparkSession self._jwrapped = self._jsparkSession.sqlContext() self._wrapped = SQLContext(self._sc, self, self._jwrapped) @@ -225,6 +230,7 @@ def __init__(self, sparkContext, jsparkSession=None): if SparkSession._instantiatedSession is None \ or SparkSession._instantiatedSession._sc._jsc is None: SparkSession._instantiatedSession = self + self._jvm.SparkSession.setDefaultSession(self._jsparkSession) def _repr_html_(self): return """ @@ -759,6 +765,8 @@ def stop(self): """Stop the underlying :class:`SparkContext`. """ self._sc.stop() + # We should clean the default session up. See SPARK-23228. + self._jvm.SparkSession.clearDefaultSession() SparkSession._instantiatedSession = None @since(2.0) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index dc80870d3cd9f..dc26b96334c7a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -69,7 +69,7 @@ from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings from pyspark.sql.types import _merge_type -from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests +from pyspark.tests import QuietTest, ReusedPySparkTestCase, PySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException @@ -2925,6 +2925,32 @@ def test_sparksession_with_stopped_sparkcontext(self): sc.stop() +class SparkSessionTests(PySparkTestCase): + + # This test is separate because it's closely related with session's start and stop. + # See SPARK-23228. + def test_set_jvm_default_session(self): + spark = SparkSession.builder.getOrCreate() + try: + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) + finally: + spark.stop() + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isEmpty()) + + def test_jvm_default_session_already_set(self): + # Here, we assume there is the default session already set in JVM. + jsession = self.sc._jvm.SparkSession(self.sc._jsc.sc()) + self.sc._jvm.SparkSession.setDefaultSession(jsession) + + spark = SparkSession.builder.getOrCreate() + try: + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) + # The session should be the same with the exiting one. + self.assertTrue(jsession.equals(spark._jvm.SparkSession.getDefaultSession().get())) + finally: + spark.stop() + + class UDFInitializationTests(unittest.TestCase): def tearDown(self): if SparkSession._instantiatedSession is not None: From 48dd6a4c79e33a8f2dba8349b58aa07e4796a925 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 1 Feb 2018 00:24:42 +0800 Subject: [PATCH 250/774] revert [SPARK-22785][SQL] remove ColumnVector.anyNullsSet ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/19980 , we thought `anyNullsSet` can be simply implemented by `numNulls() > 0`. This is logically true, but may have performance problems. `OrcColumnVector` is an example. It doesn't have the `numNulls` property, only has a `noNulls` property. We will lose a lot of performance if we use `numNulls() > 0` to check null. This PR simply revert #19980, with a renaming to call it `hasNull`. Better name suggestions are welcome, e.g. `nullable`? ## How was this patch tested? existing test Author: Wenchen Fan Closes #20452 from cloud-fan/null. --- .../execution/datasources/orc/OrcColumnVector.java | 5 +++++ .../execution/vectorized/OffHeapColumnVector.java | 2 +- .../sql/execution/vectorized/OnHeapColumnVector.java | 2 +- .../execution/vectorized/WritableColumnVector.java | 7 ++++++- .../spark/sql/vectorized/ArrowColumnVector.java | 5 +++++ .../apache/spark/sql/vectorized/ColumnVector.java | 5 +++++ .../vectorized/ArrowColumnVectorSuite.scala | 12 ++++++++++++ .../execution/vectorized/ColumnarBatchSuite.scala | 9 +++++++++ 8 files changed, 44 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index 5078bc7922ee2..78203e3145c62 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -77,6 +77,11 @@ public void close() { } + @Override + public boolean hasNull() { + return !baseData.noNulls; + } + @Override public int numNulls() { if (baseData.isRepeating) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 1c45b846790b6..fa52e4a354786 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -123,7 +123,7 @@ public void putNulls(int rowId, int count) { @Override public void putNotNulls(int rowId, int count) { - if (numNulls == 0) return; + if (!hasNull()) return; long offset = nulls + rowId; for (int i = 0; i < count; ++i, ++offset) { Platform.putByte(null, offset, (byte) 0); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 1d538fe4181b7..cccef78aebdc8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -119,7 +119,7 @@ public void putNulls(int rowId, int count) { @Override public void putNotNulls(int rowId, int count) { - if (numNulls == 0) return; + if (!hasNull()) return; for (int i = 0; i < count; ++i) { nulls[rowId + i] = (byte)0; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index a8ec8ef2aadf8..8ebc1adf59c8b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -59,8 +59,8 @@ public void reset() { elementsAppended = 0; if (numNulls > 0) { putNotNulls(0, capacity); + numNulls = 0; } - numNulls = 0; } @Override @@ -102,6 +102,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { throw new RuntimeException(message, cause); } + @Override + public boolean hasNull() { + return numNulls > 0; + } + @Override public int numNulls() { return numNulls; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index a75d76bd0f82e..5ff6474c161f3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -37,6 +37,11 @@ public final class ArrowColumnVector extends ColumnVector { private final ArrowVectorAccessor accessor; private ArrowColumnVector[] childColumns; + @Override + public boolean hasNull() { + return accessor.getNullCount() > 0; + } + @Override public int numNulls() { return accessor.getNullCount(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index 111f5d9b358d4..d588956208047 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -65,6 +65,11 @@ public abstract class ColumnVector implements AutoCloseable { @Override public abstract void close(); + /** + * Returns true if this column vector contains any null values. + */ + public abstract boolean hasNull(); + /** * Returns the number of nulls in this column vector. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index e794f50781ff2..b55489cb2678a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -42,6 +42,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === BooleanType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -69,6 +70,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ByteType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -96,6 +98,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ShortType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -123,6 +126,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === IntegerType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -150,6 +154,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === LongType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -177,6 +182,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === FloatType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -204,6 +210,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === DoubleType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -232,6 +239,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === StringType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -258,6 +266,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === BinaryType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -300,6 +309,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ArrayType(IntegerType)) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) val array0 = columnVector.getArray(0) @@ -344,6 +354,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === schema) + assert(!columnVector.hasNull) assert(columnVector.numNulls === 0) val row0 = columnVector.getStruct(0) @@ -396,6 +407,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === schema) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) val row0 = columnVector.getStruct(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 925c101fe1fee..168bc5e3e480b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -66,22 +66,27 @@ class ColumnarBatchSuite extends SparkFunSuite { column => val reference = mutable.ArrayBuffer.empty[Boolean] var idx = 0 + assert(!column.hasNull) assert(column.numNulls() == 0) column.appendNotNull() reference += false + assert(!column.hasNull) assert(column.numNulls() == 0) column.appendNotNulls(3) (1 to 3).foreach(_ => reference += false) + assert(!column.hasNull) assert(column.numNulls() == 0) column.appendNull() reference += true + assert(column.hasNull) assert(column.numNulls() == 1) column.appendNulls(3) (1 to 3).foreach(_ => reference += true) + assert(column.hasNull) assert(column.numNulls() == 4) idx = column.elementsAppended @@ -89,11 +94,13 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putNotNull(idx) reference += false idx += 1 + assert(column.hasNull) assert(column.numNulls() == 4) column.putNull(idx) reference += true idx += 1 + assert(column.hasNull) assert(column.numNulls() == 5) column.putNulls(idx, 3) @@ -101,6 +108,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference += true reference += true idx += 3 + assert(column.hasNull) assert(column.numNulls() == 8) column.putNotNulls(idx, 4) @@ -109,6 +117,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference += false reference += false idx += 4 + assert(column.hasNull) assert(column.numNulls() == 8) reference.zipWithIndex.foreach { v => From 8c21170decfb9ca4d3233e1ea13bd1b6e3199ed9 Mon Sep 17 00:00:00 2001 From: Glen Takahashi Date: Thu, 1 Feb 2018 01:14:01 +0800 Subject: [PATCH 251/774] [SPARK-23249][SQL] Improved block merging logic for partitions ## What changes were proposed in this pull request? Change DataSourceScanExec so that when grouping blocks together into partitions, also checks the end of the sorted list of splits to more efficiently fill out partitions. ## How was this patch tested? Updated old test to reflect the new logic, which causes the # of partitions to drop from 4 -> 3 Also, a current test exists to test large non-splittable files at https://github.com/glentakahashi/spark/blob/c575977a5952bf50b605be8079c9be1e30f3bd36/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala#L346 ## Rationale The current bin-packing method of next-fit descending for blocks into partitions is sub-optimal in a lot of cases and will result in extra partitions, un-even distribution of block-counts across partitions, and un-even distribution of partition sizes. As an example, 128 files ranging from 1MB, 2MB,...127MB,128MB. will result in 82 partitions with the current algorithm, but only 64 using this algorithm. Also in this example, the max # of blocks per partition in NFD is 13, while in this algorithm is is 2. More generally, running a simulation of 1000 runs using 128MB blocksize, between 1-1000 normally distributed file sizes between 1-500Mb, you can see an improvement of approx 5% reduction of partition counts, and a large reduction in standard deviation of blocks per partition. This algorithm also runs in O(n) time as NFD does, and in every case is strictly better results than NFD. Overall, the more even distribution of blocks across partitions and therefore reduced partition counts should result in a small but significant performance increase across the board Author: Glen Takahashi Closes #20372 from glentakahashi/feature/improved-block-merging. --- .../sql/execution/DataSourceScanExec.scala | 29 ++++++++++++++----- .../datasources/FileSourceStrategySuite.scala | 15 ++++------ 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index aa66ee7e948ea..f7732e2098c29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -445,16 +445,29 @@ case class FileSourceScanExec( currentSize = 0 } - // Assign files to partitions using "Next Fit Decreasing" - splitFiles.foreach { file => - if (currentSize + file.length > maxSplitBytes) { - closePartition() + def addFile(file: PartitionedFile): Unit = { + currentFiles += file + currentSize += file.length + openCostInBytes + } + + var frontIndex = 0 + var backIndex = splitFiles.length - 1 + + while (frontIndex <= backIndex) { + addFile(splitFiles(frontIndex)) + frontIndex += 1 + while (frontIndex <= backIndex && + currentSize + splitFiles(frontIndex).length <= maxSplitBytes) { + addFile(splitFiles(frontIndex)) + frontIndex += 1 + } + while (backIndex > frontIndex && + currentSize + splitFiles(backIndex).length <= maxSplitBytes) { + addFile(splitFiles(backIndex)) + backIndex -= 1 } - // Add the given file to the current partition. - currentSize += file.length + openCostInBytes - currentFiles += file + closePartition() } - closePartition() new FileScanRDD(fsRelation.sparkSession, readFile, partitions) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index c1d61b843d899..bfccc9335b361 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -141,16 +141,17 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "4", SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") { checkScan(table.select('c1)) { partitions => - // Files should be laid out [(file1), (file2, file3), (file4, file5), (file6)] - assert(partitions.size == 4, "when checking partitions") - assert(partitions(0).files.size == 1, "when checking partition 1") + // Files should be laid out [(file1, file6), (file2, file3), (file4, file5)] + assert(partitions.size == 3, "when checking partitions") + assert(partitions(0).files.size == 2, "when checking partition 1") assert(partitions(1).files.size == 2, "when checking partition 2") assert(partitions(2).files.size == 2, "when checking partition 3") - assert(partitions(3).files.size == 1, "when checking partition 4") - // First partition reads (file1) + // First partition reads (file1, file6) assert(partitions(0).files(0).start == 0) assert(partitions(0).files(0).length == 2) + assert(partitions(0).files(1).start == 0) + assert(partitions(0).files(1).length == 1) // Second partition reads (file2, file3) assert(partitions(1).files(0).start == 0) @@ -163,10 +164,6 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi assert(partitions(2).files(0).length == 1) assert(partitions(2).files(1).start == 0) assert(partitions(2).files(1).length == 1) - - // Final partition reads (file6) - assert(partitions(3).files(0).start == 0) - assert(partitions(3).files(0).length == 1) } checkPartitionSchema(StructType(Nil)) From dd242bad39cc6df7ff6c6b16642bdc92dccca6ac Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 31 Jan 2018 11:48:19 -0800 Subject: [PATCH 252/774] [SPARK-21525][STREAMING] Check error code from supervisor RPC. The code was ignoring the error code from the AddBlock RPC, which means that a failure to write to the WAL was being ignored by the receiver, and would lead to the block being acked (in the case of the Flume receiver) and data potentially lost. Author: Marcelo Vanzin Closes #20161 from vanzin/SPARK-21525. --- .../spark/streaming/receiver/ReceiverSupervisorImpl.scala | 4 +++- .../apache/spark/streaming/scheduler/ReceiverTracker.scala | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 27644a645727c..5d38c56aa5873 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -159,7 +159,9 @@ private[streaming] class ReceiverSupervisorImpl( logDebug(s"Pushed block $blockId in ${(System.currentTimeMillis - time)} ms") val numRecords = blockStoreResult.numRecords val blockInfo = ReceivedBlockInfo(streamId, numRecords, metadataOption, blockStoreResult) - trackerEndpoint.askSync[Boolean](AddBlock(blockInfo)) + if (!trackerEndpoint.askSync[Boolean](AddBlock(blockInfo))) { + throw new SparkException("Failed to add block to receiver tracker.") + } logDebug(s"Reported block $blockId") } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 6f130c803f310..c74ca1918a81d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -521,7 +521,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false if (active) { context.reply(addBlock(receivedBlockInfo)) } else { - throw new IllegalStateException("ReceiverTracker RpcEndpoint shut down.") + context.sendFailure( + new IllegalStateException("ReceiverTracker RpcEndpoint already shut down.")) } } }) From 9ff1d96f01e2c89acfd248db917e068b93f519a6 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 31 Jan 2018 13:52:47 -0800 Subject: [PATCH 253/774] [SPARK-23281][SQL] Query produces results in incorrect order when a composite order by clause refers to both original columns and aliases ## What changes were proposed in this pull request? Here is the test snippet. ``` SQL scala> Seq[(Integer, Integer)]( | (1, 1), | (1, 3), | (2, 3), | (3, 3), | (4, null), | (5, null) | ).toDF("key", "value").createOrReplaceTempView("src") scala> sql( | """ | |SELECT MAX(value) as value, key as col2 | |FROM src | |GROUP BY key | |ORDER BY value desc, key | """.stripMargin).show +-----+----+ |value|col2| +-----+----+ | 3| 3| | 3| 2| | 3| 1| | null| 5| | null| 4| +-----+----+ ```SQL Here is the explain output : ```SQL == Parsed Logical Plan == 'Sort ['value DESC NULLS LAST, 'key ASC NULLS FIRST], true +- 'Aggregate ['key], ['MAX('value) AS value#9, 'key AS col2#10] +- 'UnresolvedRelation `src` == Analyzed Logical Plan == value: int, col2: int Project [value#9, col2#10] +- Sort [value#9 DESC NULLS LAST, col2#10 DESC NULLS LAST], true +- Aggregate [key#5], [max(value#6) AS value#9, key#5 AS col2#10] +- SubqueryAlias src +- Project [_1#2 AS key#5, _2#3 AS value#6] +- LocalRelation [_1#2, _2#3] ``` SQL The sort direction is being wrongly changed from ASC to DSC while resolving ```Sort``` in resolveAggregateFunctions. The above testcase models TPCDS-Q71 and thus we have the same issue in Q71 as well. ## How was this patch tested? A few tests are added in SQLQuerySuite. Author: Dilip Biswal Closes #20453 from dilipbiswal/local_spark. --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 41 ++++++++++++++++++- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 91cb0365a0856..251099f750cf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1493,7 +1493,7 @@ class Analyzer( // to push down this ordering expression and can reference the original aggregate // expression instead. val needsPushDown = ArrayBuffer.empty[NamedExpression] - val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map { + val evaluatedOrderings = resolvedAliasedOrdering.zip(unresolvedSortOrders).map { case (evaluated, order) => val index = originalAggExprs.indexWhere { case Alias(child, _) => child semanticEquals evaluated.child diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ffd736d2ebbb6..8f14575c3325f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import java.io.File -import java.math.MathContext import java.net.{MalformedURLException, URL} import java.sql.Timestamp import java.util.concurrent.atomic.AtomicBoolean @@ -1618,6 +1617,46 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("SPARK-23281: verify the correctness of sort direction on composite order by clause") { + withTempView("src") { + Seq[(Integer, Integer)]( + (1, 1), + (1, 3), + (2, 3), + (3, 3), + (4, null), + (5, null) + ).toDF("key", "value").createOrReplaceTempView("src") + + checkAnswer(sql( + """ + |SELECT MAX(value) as value, key as col2 + |FROM src + |GROUP BY key + |ORDER BY value desc, key + """.stripMargin), + Seq(Row(3, 1), Row(3, 2), Row(3, 3), Row(null, 4), Row(null, 5))) + + checkAnswer(sql( + """ + |SELECT MAX(value) as value, key as col2 + |FROM src + |GROUP BY key + |ORDER BY value desc, key desc + """.stripMargin), + Seq(Row(3, 3), Row(3, 2), Row(3, 1), Row(null, 5), Row(null, 4))) + + checkAnswer(sql( + """ + |SELECT MAX(value) as value, key as col2 + |FROM src + |GROUP BY key + |ORDER BY value asc, key desc + """.stripMargin), + Seq(Row(null, 5), Row(null, 4), Row(3, 3), Row(3, 2), Row(3, 1))) + } + } + test("run sql directly on files") { val df = spark.range(100).toDF() withTempPath(f => { From f470df2fcf14e6234c577dc1bdfac27d49b441f5 Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Thu, 1 Feb 2018 11:15:17 +0900 Subject: [PATCH 254/774] [SPARK-23157][SQL][FOLLOW-UP] DataFrame -> SparkDataFrame in R comment Author: Henry Robinson Closes #20443 from henryr/SPARK-23157. --- R/pkg/R/DataFrame.R | 4 ++-- python/pyspark/sql/dataframe.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 547b5ea48a555..41c3c3a89fa72 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2090,8 +2090,8 @@ setMethod("selectExpr", #' #' @param x a SparkDataFrame. #' @param colName a column name. -#' @param col a Column expression (which must refer only to this DataFrame), or an atomic vector in -#' the length of 1 as literal value. +#' @param col a Column expression (which must refer only to this SparkDataFrame), or an atomic +#' vector in the length of 1 as literal value. #' @return A SparkDataFrame with the new column added or the existing column replaced. #' @family SparkDataFrame functions #' @aliases withColumn,SparkDataFrame,character-method diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 055b2c4a0ffec..1496cba91b90e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1829,7 +1829,7 @@ def withColumn(self, colName, col): Returns a new :class:`DataFrame` by adding a column or replacing the existing column that has the same name. - The column expression must be an expression over this dataframe; attempting to add + The column expression must be an expression over this DataFrame; attempting to add a column from some other dataframe will raise an error. :param colName: string, name of the new column. From 52e00f70663a87b5837235bdf72a3e6f84e11411 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 1 Feb 2018 11:56:06 +0800 Subject: [PATCH 255/774] [SPARK-23280][SQL] add map type support to ColumnVector ## What changes were proposed in this pull request? Fill the last missing piece of `ColumnVector`: the map type support. The idea is similar to the array type support. A map is basically 2 arrays: keys and values. We ask the implementations to provide a key array, a value array, and an offset and length to specify the range of this map in the key/value array. In `WritableColumnVector`, we put the key array in first child vector, and value array in second child vector, and offsets and lengths in the current vector, which is very similar to how array type is implemented here. ## How was this patch tested? a new test Author: Wenchen Fan Closes #20450 from cloud-fan/map. --- .../datasources/orc/OrcColumnVector.java | 6 ++ .../vectorized/ColumnVectorUtils.java | 15 ++++ .../vectorized/OffHeapColumnVector.java | 4 +- .../vectorized/OnHeapColumnVector.java | 4 +- .../vectorized/WritableColumnVector.java | 13 ++++ .../sql/vectorized/ArrowColumnVector.java | 5 ++ .../spark/sql/vectorized/ColumnVector.java | 14 +++- .../spark/sql/vectorized/ColumnarArray.java | 4 +- .../spark/sql/vectorized/ColumnarMap.java | 53 ++++++++++++++ .../spark/sql/vectorized/ColumnarRow.java | 5 +- .../vectorized/ColumnarBatchSuite.scala | 70 ++++++++++++++----- 11 files changed, 166 insertions(+), 27 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index 78203e3145c62..c8add4c9f486c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -25,6 +25,7 @@ import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.types.TimestampType; import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.types.UTF8String; /** @@ -177,6 +178,11 @@ public ColumnarArray getArray(int rowId) { throw new UnsupportedOperationException(); } + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } + @Override public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) { throw new UnsupportedOperationException(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index a2853bbadc92b..829f3ce750fe6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -20,8 +20,10 @@ import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.sql.Date; +import java.util.HashMap; import java.util.Iterator; import java.util.List; +import java.util.Map; import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.Row; @@ -30,6 +32,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.sql.vectorized.ColumnarArray; import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -109,6 +112,18 @@ public static int[] toJavaIntArray(ColumnarArray array) { return array.toIntArray(); } + public static Map toJavaIntMap(ColumnarMap map) { + int[] keys = toJavaIntArray(map.keyArray()); + int[] values = toJavaIntArray(map.valueArray()); + assert keys.length == values.length; + + Map result = new HashMap<>(); + for (int i = 0; i < keys.length; i++) { + result.put(keys[i], values[i]); + } + return result; + } + private static void appendValue(WritableColumnVector dst, DataType t, Object o) { if (o == null) { if (t instanceof CalendarIntervalType) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index fa52e4a354786..754c26579ff08 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -60,7 +60,7 @@ public static OffHeapColumnVector[] allocateColumns(int capacity, StructField[] private long nulls; private long data; - // Set iff the type is array. + // Only set if type is Array or Map. private long lengthData; private long offsetData; @@ -530,7 +530,7 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) { @Override protected void reserveInternal(int newCapacity) { int oldCapacity = (nulls == 0L) ? 0 : capacity; - if (isArray()) { + if (isArray() || type instanceof MapType) { this.lengthData = Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4); this.offsetData = diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index cccef78aebdc8..23dcc104e67c4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -69,7 +69,7 @@ public static OnHeapColumnVector[] allocateColumns(int capacity, StructField[] f private float[] floatData; private double[] doubleData; - // Only set if type is Array. + // Only set if type is Array or Map. private int[] arrayLengths; private int[] arrayOffsets; @@ -503,7 +503,7 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) { // Spilt this function out since it is the slow path. @Override protected void reserveInternal(int newCapacity) { - if (isArray()) { + if (isArray() || type instanceof MapType) { int[] newLengths = new int[newCapacity]; int[] newOffsets = new int[newCapacity]; if (this.arrayLengths != null) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 8ebc1adf59c8b..c2e595455549c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -25,6 +25,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; @@ -612,6 +613,13 @@ public final ColumnarArray getArray(int rowId) { return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId)); } + // `WritableColumnVector` puts the key array in the first child column vector, value array in the + // second child column vector, and puts the offsets and lengths in the current column vector. + @Override + public final ColumnarMap getMap(int rowId) { + return new ColumnarMap(getChild(0), getChild(1), getArrayOffset(rowId), getArrayLength(rowId)); + } + public WritableColumnVector arrayData() { return childColumns[0]; } @@ -705,6 +713,11 @@ protected WritableColumnVector(int capacity, DataType type) { for (int i = 0; i < childColumns.length; ++i) { this.childColumns[i] = reserveNewColumn(capacity, st.fields()[i].dataType()); } + } else if (type instanceof MapType) { + MapType mapType = (MapType) type; + this.childColumns = new WritableColumnVector[2]; + this.childColumns[0] = reserveNewColumn(capacity, mapType.keyType()); + this.childColumns[1] = reserveNewColumn(capacity, mapType.valueType()); } else if (type instanceof CalendarIntervalType) { // Two columns. Months as int. Microseconds as Long. this.childColumns = new WritableColumnVector[2]; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 5ff6474c161f3..f3ece538c3b80 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -119,6 +119,11 @@ public ColumnarArray getArray(int rowId) { return accessor.getArray(rowId); } + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } + @Override public ArrowColumnVector getChild(int ordinal) { return childColumns[ordinal]; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index d588956208047..05271ec1f46ab 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -220,10 +220,18 @@ public final ColumnarRow getStruct(int rowId) { /** * Returns the map type value for rowId. + * + * In Spark, map type value is basically a key data array and a value data array. A key from the + * key array with a index and a value from the value array with the same index contribute to + * an entry of this map type value. + * + * To support map type, implementations must construct an {@link ColumnarMap} and return it in + * this method. {@link ColumnarMap} requires a {@link ColumnVector} that stores the data of all + * the keys of all the maps in this vector, and another {@link ColumnVector} that stores the data + * of all the values of all the maps in this vector, and a pair of offset and length which + * specify the range of the key/value array that belongs to the map type value at rowId. */ - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); - } + public abstract ColumnarMap getMap(int ordinal); /** * Returns the decimal type value for rowId. diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 72c07ee7cad3f..7c7a1c806a2b7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -149,8 +149,8 @@ public ColumnarArray getArray(int ordinal) { } @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); + public ColumnarMap getMap(int ordinal) { + return data.getMap(offset + ordinal); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java new file mode 100644 index 0000000000000..35648e386c4f1 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.vectorized; + +import org.apache.spark.sql.catalyst.util.MapData; + +/** + * Map abstraction in {@link ColumnVector}. + */ +public final class ColumnarMap extends MapData { + private final ColumnarArray keys; + private final ColumnarArray values; + private final int length; + + public ColumnarMap(ColumnVector keys, ColumnVector values, int offset, int length) { + this.length = length; + this.keys = new ColumnarArray(keys, offset, length); + this.values = new ColumnarArray(values, offset, length); + } + + @Override + public int numElements() { return length; } + + @Override + public ColumnarArray keyArray() { + return keys; + } + + @Override + public ColumnarArray valueArray() { + return values; + } + + @Override + public ColumnarMap copy() { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 6ca749d7c6e85..0c9e92ed11fbd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -155,8 +155,9 @@ public ColumnarArray getArray(int ordinal) { } @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); + public ColumnarMap getMap(int ordinal) { + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getMap(rowId); } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 168bc5e3e480b..8fe2985836f2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -673,35 +673,37 @@ class ColumnarBatchSuite extends SparkFunSuite { i += 1 } - // Populate it with arrays [0], [1, 2], [], [3, 4, 5] + // Populate it with arrays [0], [1, 2], null, [], [3, 4, 5] column.putArray(0, 0, 1) column.putArray(1, 1, 2) - column.putArray(2, 2, 0) - column.putArray(3, 3, 3) + column.putNull(2) + column.putArray(3, 3, 0) + column.putArray(4, 3, 3) + + assert(column.getArray(0).numElements == 1) + assert(column.getArray(1).numElements == 2) + assert(column.isNullAt(2)) + assert(column.getArray(3).numElements == 0) + assert(column.getArray(4).numElements == 3) val a1 = ColumnVectorUtils.toJavaIntArray(column.getArray(0)) val a2 = ColumnVectorUtils.toJavaIntArray(column.getArray(1)) - val a3 = ColumnVectorUtils.toJavaIntArray(column.getArray(2)) - val a4 = ColumnVectorUtils.toJavaIntArray(column.getArray(3)) + val a3 = ColumnVectorUtils.toJavaIntArray(column.getArray(3)) + val a4 = ColumnVectorUtils.toJavaIntArray(column.getArray(4)) assert(a1 === Array(0)) assert(a2 === Array(1, 2)) assert(a3 === Array.empty[Int]) assert(a4 === Array(3, 4, 5)) - // Verify the ArrayData APIs - assert(column.getArray(0).numElements() == 1) + // Verify the ArrayData get APIs assert(column.getArray(0).getInt(0) == 0) - assert(column.getArray(1).numElements() == 2) assert(column.getArray(1).getInt(0) == 1) assert(column.getArray(1).getInt(1) == 2) - assert(column.getArray(2).numElements() == 0) - - assert(column.getArray(3).numElements() == 3) - assert(column.getArray(3).getInt(0) == 3) - assert(column.getArray(3).getInt(1) == 4) - assert(column.getArray(3).getInt(2) == 5) + assert(column.getArray(4).getInt(0) == 3) + assert(column.getArray(4).getInt(1) == 4) + assert(column.getArray(4).getInt(2) == 5) // Add a longer array which requires resizing column.reset() @@ -711,8 +713,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(data.capacity == array.length * 2) data.putInts(0, array.length, array, 0) column.putArray(0, 0, array.length) - assert(ColumnVectorUtils.toJavaIntArray(column.getArray(0)) - === array) + assert(ColumnVectorUtils.toJavaIntArray(column.getArray(0)) === array) } test("toArray for primitive types") { @@ -770,6 +771,43 @@ class ColumnarBatchSuite extends SparkFunSuite { } } + test("Int Map") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => + val column = allocate(10, new MapType(IntegerType, IntegerType, false), memMode) + (0 to 1).foreach { colIndex => + val data = column.getChild(colIndex) + (0 to 5).foreach {i => + data.putInt(i, i * (colIndex + 1)) + } + } + + // Populate it with maps [0->0], [1->2, 2->4], null, [], [3->6, 4->8, 5->10] + column.putArray(0, 0, 1) + column.putArray(1, 1, 2) + column.putNull(2) + column.putArray(3, 3, 0) + column.putArray(4, 3, 3) + + assert(column.getMap(0).numElements == 1) + assert(column.getMap(1).numElements == 2) + assert(column.isNullAt(2)) + assert(column.getMap(3).numElements == 0) + assert(column.getMap(4).numElements == 3) + + val a1 = ColumnVectorUtils.toJavaIntMap(column.getMap(0)) + val a2 = ColumnVectorUtils.toJavaIntMap(column.getMap(1)) + val a4 = ColumnVectorUtils.toJavaIntMap(column.getMap(3)) + val a5 = ColumnVectorUtils.toJavaIntMap(column.getMap(4)) + + assert(a1.asScala == Map(0 -> 0)) + assert(a2.asScala == Map(1 -> 2, 2 -> 4)) + assert(a4.asScala == Map()) + assert(a5.asScala == Map(3 -> 6, 4 -> 8, 5 -> 10)) + + column.close() + } + } + testVector( "Struct Column", 10, From 2ac895be909de7e58e1051dc2a1bba98a25bf4be Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Thu, 1 Feb 2018 12:05:12 +0800 Subject: [PATCH 256/774] [SPARK-23247][SQL] combines Unsafe operations and statistics operations in Scan Data Source ## What changes were proposed in this pull request? Currently, we scan the execution plan of the data source, first the unsafe operation of each row of data, and then re traverse the data for the count of rows. In terms of performance, this is not necessary. this PR combines the two operations and makes statistics on the number of rows while performing the unsafe operation. Before modified, ``` val unsafeRow = rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) proj.initialize(index) iter.map(proj) } val numOutputRows = longMetric("numOutputRows") unsafeRow.map { r => numOutputRows += 1 r } ``` After modified, val numOutputRows = longMetric("numOutputRows") rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) proj.initialize(index) iter.map( r => { numOutputRows += 1 proj(r) }) } ## How was this patch tested? the existed test cases. Author: caoxuewen Closes #20415 from heary-cao/DataSourceScanExec. --- .../sql/execution/DataSourceScanExec.scala | 45 +++++++++---------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index f7732e2098c29..ba1157d5b6a49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -90,16 +90,15 @@ case class RowDataSourceScanExec( Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) protected override def doExecute(): RDD[InternalRow] = { - val unsafeRow = rdd.mapPartitionsWithIndexInternal { (index, iter) => + val numOutputRows = longMetric("numOutputRows") + + rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) proj.initialize(index) - iter.map(proj) - } - - val numOutputRows = longMetric("numOutputRows") - unsafeRow.map { r => - numOutputRows += 1 - r + iter.map( r => { + numOutputRows += 1 + proj(r) + }) } } @@ -326,22 +325,22 @@ case class FileSourceScanExec( // 2) the number of columns should be smaller than spark.sql.codegen.maxFields WholeStageCodegenExec(this)(codegenStageId = 0).execute() } else { - val unsafeRows = { - val scan = inputRDD - if (needsUnsafeRowConversion) { - scan.mapPartitionsWithIndexInternal { (index, iter) => - val proj = UnsafeProjection.create(schema) - proj.initialize(index) - iter.map(proj) - } - } else { - scan - } - } val numOutputRows = longMetric("numOutputRows") - unsafeRows.map { r => - numOutputRows += 1 - r + + if (needsUnsafeRowConversion) { + inputRDD.mapPartitionsWithIndexInternal { (index, iter) => + val proj = UnsafeProjection.create(schema) + proj.initialize(index) + iter.map( r => { + numOutputRows += 1 + proj(r) + }) + } + } else { + inputRDD.map { r => + numOutputRows += 1 + r + } } } } From 56ae32657e9e5d1e30b62afe77d9e14eb07cf4fb Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Wed, 31 Jan 2018 20:33:51 -0800 Subject: [PATCH 257/774] [SPARK-23268][SQL] Reorganize packages in data source V2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? 1. create a new package for partitioning/distribution related classes. As Spark will add new concrete implementations of `Distribution` in new releases, it is good to have a new package for partitioning/distribution related classes. 2. move streaming related class to package `org.apache.spark.sql.sources.v2.reader/writer.streaming`, instead of `org.apache.spark.sql.sources.v2.streaming.reader/writer`. So that the there won't be package reader/writer inside package streaming, which is quite confusing. Before change: ``` v2 ├── reader ├── streaming │   ├── reader │   └── writer └── writer ``` After change: ``` v2 ├── reader │   └── streaming └── writer └── streaming ``` ## How was this patch tested? Unit test. Author: Wang Gengliang Closes #20435 from gengliangwang/new_pkg. --- .../spark/sql/kafka010/KafkaContinuousReader.scala | 2 +- .../apache/spark/sql/kafka010/KafkaSourceOffset.scala | 2 +- .../spark/sql/kafka010/KafkaSourceProvider.scala | 5 +++-- .../apache/spark/sql/kafka010/KafkaStreamWriter.scala | 2 +- .../{streaming => reader}/ContinuousReadSupport.java | 4 ++-- .../{streaming => reader}/MicroBatchReadSupport.java | 4 ++-- .../sources/v2/reader/SupportsReportPartitioning.java | 1 + .../{ => partitioning}/ClusteredDistribution.java | 3 ++- .../v2/reader/{ => partitioning}/Distribution.java | 3 ++- .../v2/reader/{ => partitioning}/Partitioning.java | 4 +++- .../streaming}/ContinuousDataReader.java | 2 +- .../reader => reader/streaming}/ContinuousReader.java | 2 +- .../reader => reader/streaming}/MicroBatchReader.java | 2 +- .../{streaming/reader => reader/streaming}/Offset.java | 2 +- .../reader => reader/streaming}/PartitionOffset.java | 2 +- .../spark/sql/sources/v2/writer/DataSourceWriter.java | 2 +- .../v2/{streaming => writer}/StreamWriteSupport.java | 5 ++--- .../writer => writer/streaming}/StreamWriter.java | 2 +- .../datasources/v2/DataSourcePartitioning.scala | 2 +- .../datasources/v2/DataSourceV2ScanExec.scala | 2 +- .../execution/datasources/v2/WriteToDataSourceV2.scala | 2 +- .../sql/execution/streaming/MicroBatchExecution.scala | 6 +++--- .../sql/execution/streaming/RateSourceProvider.scala | 5 ++--- .../sql/execution/streaming/RateStreamOffset.scala | 2 +- .../sql/execution/streaming/StreamingRelation.scala | 2 +- .../apache/spark/sql/execution/streaming/console.scala | 4 ++-- .../continuous/ContinuousDataSourceRDDIter.scala | 10 +++------- .../streaming/continuous/ContinuousExecution.scala | 5 +++-- .../continuous/ContinuousRateStreamSource.scala | 2 +- .../streaming/continuous/EpochCoordinator.scala | 4 ++-- .../execution/streaming/sources/ConsoleWriter.scala | 2 +- .../execution/streaming/sources/MicroBatchWriter.scala | 2 +- .../streaming/sources/RateStreamSourceV2.scala | 3 +-- .../sql/execution/streaming/sources/memoryV2.scala | 3 +-- .../apache/spark/sql/streaming/DataStreamReader.scala | 2 +- .../apache/spark/sql/streaming/DataStreamWriter.scala | 2 +- .../spark/sql/streaming/StreamingQueryManager.scala | 2 +- .../sql/sources/v2/JavaPartitionAwareDataSource.java | 3 +++ .../sql/execution/streaming/RateSourceV2Suite.scala | 2 +- .../spark/sql/sources/v2/DataSourceV2Suite.scala | 1 + .../streaming/sources/StreamingDataSourceV2Suite.scala | 8 ++++---- 41 files changed, 64 insertions(+), 61 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming => reader}/ContinuousReadSupport.java (94%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming => reader}/MicroBatchReadSupport.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{ => partitioning}/ClusteredDistribution.java (92%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{ => partitioning}/Distribution.java (93%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{ => partitioning}/Partitioning.java (90%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming/reader => reader/streaming}/ContinuousDataReader.java (96%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming/reader => reader/streaming}/ContinuousReader.java (98%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming/reader => reader/streaming}/MicroBatchReader.java (98%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming/reader => reader/streaming}/Offset.java (97%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming/reader => reader/streaming}/PartitionOffset.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming => writer}/StreamWriteSupport.java (93%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming/writer => writer/streaming}/StreamWriter.java (98%) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 8c733426b256f..41c443bc12120 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRo import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index c82154cfbad7f..8d41c0da2b133 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} -import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2, PartitionOffset} /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 85e96b6783327..694ca76e24964 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -31,8 +31,9 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSessio import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index a24efdefa4464..9307bfc001c03 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -22,8 +22,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.types.StructType /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java similarity index 94% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java index f79424e036a52..0c1d5d1a9577a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming; +package org.apache.spark.sql.sources.v2.reader; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader; +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader; import org.apache.spark.sql.types.StructType; /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java index 22660e42ad850..5e8f0c0dafdcf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming; +package org.apache.spark.sql.sources.v2.reader; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.streaming.reader.MicroBatchReader; +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader; import org.apache.spark.sql.types.StructType; /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index a2383a9d7d680..5405a916951b8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources.v2.reader; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; /** * A mix in interface for {@link DataSourceReader}. Data source readers can implement this diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java similarity index 92% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java index 27905e325df87..2d0ee50212b56 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java @@ -15,9 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.DataReader; /** * A concrete implementation of {@link Distribution}. Represents a distribution where records that diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java similarity index 93% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java index b37562167d9ef..f6b111fdf220d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java @@ -15,9 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.DataReader; /** * An interface to represent data distribution requirement, which specifies how the records should diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java similarity index 90% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java index 5e334d13a1215..309d9e5de0a0f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java @@ -15,9 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; +import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning; /** * An interface to represent the output data partitioning for a data source, which is returned by diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java similarity index 96% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java index 3f13a4dbf5793..47d26440841fd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.DataReader; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java similarity index 98% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java index 6e5177ee83a62..d1d1e7ffd1dd4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java similarity index 98% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java index fcec446d892f5..67ebde30d61a9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java similarity index 97% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java index abba3e7188b13..e41c0351edc82 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java index 4688b85f49f5f..383e73db6762b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import java.io.Serializable; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index d89d27d0e5b1b..7096aec0d22c2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -28,7 +28,7 @@ /** * A data source writer that is returned by * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceOptions)}/ - * {@link org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport#createStreamWriter( + * {@link StreamWriteSupport#createStreamWriter( * String, StructType, OutputMode, DataSourceOptions)}. * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java similarity index 93% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java index 7c5f304425093..1c0e2e12f8d51 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java @@ -15,14 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming; +package org.apache.spark.sql.sources.v2.writer; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter; -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java similarity index 98% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java index 915ee6c4fb390..4913341bd505d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.writer; +package org.apache.spark.sql.sources.v2.writer.streaming; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala index 943d0100aca56..017a6737161a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression} import org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.sources.v2.reader.{ClusteredDistribution, Partitioning} +import org.apache.spark.sql.sources.v2.reader.partitioning.{ClusteredDistribution, Partitioning} /** * An adapter from public data source partitioning to catalyst internal `Partitioning`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index ee085820b0775..df469af2c262a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types.StructType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index c544adbf32cdf..6592bd72fa338 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 93572f7a63132..d9aa8573ba930 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -30,9 +30,9 @@ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.{MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow +import org.apache.spark.sql.sources.v2.reader.MicroBatchReadSupport +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.writer.{StreamWriteSupport, SupportsWriteInternalRow} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index 5e3fee633f591..ce5e63f5bde85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -30,11 +30,10 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader -import org.apache.spark.sql.execution.streaming.sources.RateStreamMicroBatchReader import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader} +import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types._ import org.apache.spark.util.{ManualClock, SystemClock} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala index 261d69bbd9843..02fed50485b94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala @@ -23,7 +23,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.sources.v2 case class RateStreamOffset(partitionToValueAndRunTimeMs: Map[Int, ValueRunTimeMsPair]) - extends v2.streaming.reader.Offset { + extends v2.reader.streaming.Offset { implicit val defaultFormats: DefaultFormats = DefaultFormats override val json = Serialization.write(partitionToValueAndRunTimeMs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index a0ee683a895d8..845c8d2c14e43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.sources.v2.DataSourceV2 -import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 3f5bb489d6528..db600866067bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -21,8 +21,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} -import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index 8a7a38b22caca..cf02c0dda25d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -18,23 +18,19 @@ package org.apache.spark.sql.execution.streaming.continuous import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit} -import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader} -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, PartitionOffset} -import org.apache.spark.sql.streaming.ProcessingTime -import org.apache.spark.util.{SystemClock, ThreadUtils} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset} +import org.apache.spark.util.ThreadUtils class ContinuousDataSourceRDD( sc: SparkContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 9402d7c1dcefd..08c81419a9d34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -32,8 +32,9 @@ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index ff028ebc4236a..0eaaa4889ba9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamO import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType case class RateStreamPartitionOffset( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 84d262116cb46..cc6808065c0cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -23,9 +23,9 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, PartitionOffset} -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.util.RpcUtils private[continuous] sealed trait EpochCoordinatorMessage extends Serializable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index c57bdc4a28905..d276403190b3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -22,8 +22,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.types.StructType /** Common methods used to create writes for the the console sink */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala index d7ce9a7b84479..56f7ff25cbed0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter /** * A [[DataSourceWriter]] used to hook V2 stream writers into a microbatch plan. It implements diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index 43949e6180aaa..1315885da8a6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -31,8 +31,7 @@ import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeM import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.MicroBatchReadSupport -import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} import org.apache.spark.util.{ManualClock, SystemClock} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 58767261dc684..3411edbc53412 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -30,9 +30,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.Sink import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} -import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index f1b3f93c4e1fc..116ac3da07b75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 3b5b30d77945c..9aac360fd4bbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index fdd709cdb1f38..ddb1edc433d5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger} import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index 99cca0f6dd626..32fad59b97ff6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -27,6 +27,9 @@ import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.sources.v2.reader.partitioning.ClusteredDistribution; +import org.apache.spark.sql.sources.v2.reader.partitioning.Distribution; +import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; import org.apache.spark.sql.types.StructType; public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index b060aeeef811d..3158995ec62f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index ee50e8a92270b..2f49b07018aaf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.partitioning.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 3127d664d32dc..cb873ab688e96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -26,10 +26,10 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory -import org.apache.spark.sql.sources.v2.streaming._ -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.sources.v2.reader.{ContinuousReadSupport, DataReaderFactory, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils From b2e7677f4d3d8f47f5f148680af39d38f2b558f0 Mon Sep 17 00:00:00 2001 From: Atallah Hezbor Date: Wed, 31 Jan 2018 20:45:55 -0800 Subject: [PATCH 258/774] [SPARK-21396][SQL] Fixes MatchError when UDTs are passed through Hive Thriftserver Signed-off-by: Atallah Hezbor ## What changes were proposed in this pull request? This PR proposes modifying the match statement that gets the columns of a row in HiveThriftServer. There was previously no case for `UserDefinedType`, so querying a table that contained them would throw a match error. The changes catch that case and return the string representation. ## How was this patch tested? While I would have liked to add a unit test, I couldn't easily incorporate UDTs into the ``HiveThriftServer2Suites`` pipeline. With some guidance I would be happy to push a commit with tests. Instead I did a manual test by loading a `DataFrame` with Point UDT in a spark shell with a HiveThriftServer. Then in beeline, connecting to the server and querying that table. Here is the result before the change ``` 0: jdbc:hive2://localhost:10000> select * from chicago; Error: scala.MatchError: org.apache.spark.sql.PointUDT2d980dc3 (of class org.apache.spark.sql.PointUDT) (state=,code=0) ``` And after the change: ``` 0: jdbc:hive2://localhost:10000> select * from chicago; +---------------------------------------+--------------+------------------------+---------------------+--+ | __fid__ | case_number | dtg | geom | +---------------------------------------+--------------+------------------------+---------------------+--+ | 109602f9-54f8-414b-8c6f-42b1a337643e | 2 | 2016-01-01 19:00:00.0 | POINT (-77 38) | | 709602f9-fcff-4429-8027-55649b6fd7ed | 1 | 2015-12-31 19:00:00.0 | POINT (-76.5 38.5) | | 009602f9-fcb5-45b1-a867-eb8ba10cab40 | 3 | 2016-01-02 19:00:00.0 | POINT (-78 39) | +---------------------------------------+--------------+------------------------+---------------------+--+ ``` Author: Atallah Hezbor Closes #20385 from atallahhezbor/udts_over_hive. --- .../thriftserver/SparkExecuteStatementOperation.scala | 2 +- .../main/scala/org/apache/spark/sql/hive/HiveUtils.scala | 1 + .../scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala | 8 +++++++- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 664bc20601eaa..3cfc81b8a9579 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -102,7 +102,7 @@ private[hive] class SparkExecuteStatementOperation( to += from.getAs[Timestamp](ordinal) case BinaryType => to += from.getAs[Array[Byte]](ordinal) - case _: ArrayType | _: StructType | _: MapType => + case _: ArrayType | _: StructType | _: MapType | _: UserDefinedType[_] => val hiveString = HiveUtils.toHiveString((from.get(ordinal), dataTypes(ordinal))) to += hiveString } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index c7717d70c996f..d9627eb9790eb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -460,6 +460,7 @@ private[spark] object HiveUtils extends Logging { case (decimal: java.math.BigDecimal, DecimalType()) => // Hive strips trailing zeros so use its toString HiveDecimal.create(decimal).toString + case (other, _ : UserDefinedType[_]) => other.toString case (other, tpe) if primitiveTypes contains tpe => other.toString } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala index 8697d47e89e89..f2b75e4b23f02 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.QueryTest import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SQLTestUtils} import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader} class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { @@ -62,4 +62,10 @@ class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton Thread.currentThread().setContextClassLoader(contextClassLoader) } } + + test("toHiveString correctly handles UDTs") { + val point = new ExamplePoint(50.0, 50.0) + val tpe = new ExamplePointUDT() + assert(HiveUtils.toHiveString((point, tpe)) === "(50.0, 50.0)") + } } From cc41245fa3f954f961541bf4b4275c28473042b8 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 1 Feb 2018 12:56:07 +0800 Subject: [PATCH 259/774] [SPARK-23188][SQL] Make vectorized columar reader batch size configurable ## What changes were proposed in this pull request? This PR include the following changes: - Make the capacity of `VectorizedParquetRecordReader` configurable; - Make the capacity of `OrcColumnarBatchReader` configurable; - Update the error message when required capacity in writable columnar vector cannot be fulfilled. ## How was this patch tested? N/A Author: Xingbo Jiang Closes #20361 from jiangxb1987/vectorCapacity. --- .../apache/spark/sql/internal/SQLConf.scala | 16 ++++++++++++++ .../orc/OrcColumnarBatchReader.java | 22 ++++++++++--------- .../VectorizedParquetRecordReader.java | 20 ++++++++--------- .../vectorized/WritableColumnVector.java | 7 ++++-- .../datasources/orc/OrcFileFormat.scala | 3 ++- .../parquet/ParquetFileFormat.scala | 3 ++- .../parquet/ParquetEncodingSuite.scala | 12 +++++++--- .../datasources/parquet/ParquetIOSuite.scala | 21 +++++++++++++----- .../parquet/ParquetReadBenchmark.scala | 11 +++++++--- 9 files changed, 78 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7394a0d7cf983..90654e67457e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -375,6 +375,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_VECTORIZED_READER_BATCH_SIZE = buildConf("spark.sql.parquet.columnarReaderBatchSize") + .doc("The number of rows to include in a parquet vectorized reader batch. The number should " + + "be carefully chosen to minimize overhead and avoid OOMs in reading data.") + .intConf + .createWithDefault(4096) + val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") .doc("Sets the compression codec used when writing ORC files. If either `compression` or " + "`orc.compress` is specified in the table-specific options/properties, the precedence " + @@ -398,6 +404,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ORC_VECTORIZED_READER_BATCH_SIZE = buildConf("spark.sql.orc.columnarReaderBatchSize") + .doc("The number of rows to include in a orc vectorized reader batch. The number should " + + "be carefully chosen to minimize overhead and avoid OOMs in reading data.") + .intConf + .createWithDefault(4096) + val ORC_COPY_BATCH_TO_SPARK = buildConf("spark.sql.orc.copyBatchToSpark") .doc("Whether or not to copy the ORC columnar batch to Spark columnar batch in the " + "vectorized ORC reader.") @@ -1250,10 +1262,14 @@ class SQLConf extends Serializable with Logging { def orcVectorizedReaderEnabled: Boolean = getConf(ORC_VECTORIZED_READER_ENABLED) + def orcVectorizedReaderBatchSize: Int = getConf(ORC_VECTORIZED_READER_BATCH_SIZE) + def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED) + def parquetVectorizedReaderBatchSize: Int = getConf(PARQUET_VECTORIZED_READER_BATCH_SIZE) + def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index 5e7cad470e1d1..dcebdc39f0aa2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -49,8 +49,9 @@ * After creating, `initialize` and `initBatch` should be called sequentially. */ public class OrcColumnarBatchReader extends RecordReader { - // TODO: make this configurable. - private static final int CAPACITY = 4 * 1024; + + // The capacity of vectorized batch. + private int capacity; // Vectorized ORC Row Batch private VectorizedRowBatch batch; @@ -81,9 +82,10 @@ public class OrcColumnarBatchReader extends RecordReader { // Whether or not to copy the ORC columnar batch to Spark columnar batch. private final boolean copyToSpark; - public OrcColumnarBatchReader(boolean useOffHeap, boolean copyToSpark) { + public OrcColumnarBatchReader(boolean useOffHeap, boolean copyToSpark, int capacity) { MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; this.copyToSpark = copyToSpark; + this.capacity = capacity; } @@ -148,7 +150,7 @@ public void initBatch( StructField[] requiredFields, StructType partitionSchema, InternalRow partitionValues) { - batch = orcSchema.createRowBatch(CAPACITY); + batch = orcSchema.createRowBatch(capacity); assert(!batch.selectedInUse); // `selectedInUse` should be initialized with `false`. this.requiredFields = requiredFields; @@ -162,15 +164,15 @@ public void initBatch( if (copyToSpark) { if (MEMORY_MODE == MemoryMode.OFF_HEAP) { - columnVectors = OffHeapColumnVector.allocateColumns(CAPACITY, resultSchema); + columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); } else { - columnVectors = OnHeapColumnVector.allocateColumns(CAPACITY, resultSchema); + columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); } // Initialize the missing columns once. for (int i = 0; i < requiredFields.length; i++) { if (requestedColIds[i] == -1) { - columnVectors[i].putNulls(0, CAPACITY); + columnVectors[i].putNulls(0, capacity); columnVectors[i].setIsConstant(); } } @@ -193,8 +195,8 @@ public void initBatch( int colId = requestedColIds[i]; // Initialize the missing columns once. if (colId == -1) { - OnHeapColumnVector missingCol = new OnHeapColumnVector(CAPACITY, dt); - missingCol.putNulls(0, CAPACITY); + OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); + missingCol.putNulls(0, capacity); missingCol.setIsConstant(); orcVectorWrappers[i] = missingCol; } else { @@ -206,7 +208,7 @@ public void initBatch( int partitionIdx = requiredFields.length; for (int i = 0; i < partitionValues.numFields(); i++) { DataType dt = partitionSchema.fields()[i].dataType(); - OnHeapColumnVector partitionCol = new OnHeapColumnVector(CAPACITY, dt); + OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt); ColumnVectorUtils.populate(partitionCol, partitionValues, i); partitionCol.setIsConstant(); orcVectorWrappers[partitionIdx + i] = partitionCol; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index bb1b23611a7d7..5934a23db8af1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -50,8 +50,9 @@ * TODO: make this always return ColumnarBatches. */ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBase { - // TODO: make this configurable. - private static final int CAPACITY = 4 * 1024; + + // The capacity of vectorized batch. + private int capacity; /** * Batch of rows that we assemble and the current index we've returned. Every time this @@ -115,13 +116,10 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa */ private final MemoryMode MEMORY_MODE; - public VectorizedParquetRecordReader(TimeZone convertTz, boolean useOffHeap) { + public VectorizedParquetRecordReader(TimeZone convertTz, boolean useOffHeap, int capacity) { this.convertTz = convertTz; MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; - } - - public VectorizedParquetRecordReader(boolean useOffHeap) { - this(null, useOffHeap); + this.capacity = capacity; } /** @@ -199,9 +197,9 @@ private void initBatch( } if (memMode == MemoryMode.OFF_HEAP) { - columnVectors = OffHeapColumnVector.allocateColumns(CAPACITY, batchSchema); + columnVectors = OffHeapColumnVector.allocateColumns(capacity, batchSchema); } else { - columnVectors = OnHeapColumnVector.allocateColumns(CAPACITY, batchSchema); + columnVectors = OnHeapColumnVector.allocateColumns(capacity, batchSchema); } columnarBatch = new ColumnarBatch(columnVectors); if (partitionColumns != null) { @@ -215,7 +213,7 @@ private void initBatch( // Initialize missing columns with nulls. for (int i = 0; i < missingColumns.length; i++) { if (missingColumns[i]) { - columnVectors[i].putNulls(0, CAPACITY); + columnVectors[i].putNulls(0, capacity); columnVectors[i].setIsConstant(); } } @@ -257,7 +255,7 @@ public boolean nextBatch() throws IOException { if (rowsReturned >= totalRowCount) return false; checkEndOfRowGroup(); - int num = (int) Math.min((long) CAPACITY, totalCountLoadedSoFar - rowsReturned); + int num = (int) Math.min((long) capacity, totalCountLoadedSoFar - rowsReturned); for (int i = 0; i < columnReaders.length; ++i) { if (columnReaders[i] == null) continue; columnReaders[i].readBatch(num, columnVectors[i]); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index c2e595455549c..9d447cdc79063 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -98,8 +98,11 @@ public void reserve(int requiredCapacity) { private void throwUnsupportedException(int requiredCapacity, Throwable cause) { String message = "Cannot reserve additional contiguous bytes in the vectorized reader " + "(requested = " + requiredCapacity + " bytes). As a workaround, you can disable the " + - "vectorized reader by setting " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + - " to false."; + "vectorized reader, or increase the vectorized reader batch size. For parquet file " + + "format, refer to " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + " and " + + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() + "; for orc file format, refer to " + + SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + " and " + + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() + "."; throw new RuntimeException(message, cause); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 2dd314d165348..dbf3bc6f0ee6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -151,6 +151,7 @@ class OrcFileFormat val sqlConf = sparkSession.sessionState.conf val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled val enableVectorizedReader = supportBatch(sparkSession, resultSchema) + val capacity = sqlConf.orcVectorizedReaderBatchSize val copyToSpark = sparkSession.sessionState.conf.getConf(SQLConf.ORC_COPY_BATCH_TO_SPARK) val broadcastedConf = @@ -186,7 +187,7 @@ class OrcFileFormat val taskContext = Option(TaskContext.get()) if (enableVectorizedReader) { val batchReader = new OrcColumnarBatchReader( - enableOffHeapColumnVector && taskContext.isDefined, copyToSpark) + enableOffHeapColumnVector && taskContext.isDefined, copyToSpark, capacity) batchReader.initialize(fileSplit, taskAttemptContext) batchReader.initBatch( reader.getSchema, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index f53a97ba45a26..ba69f9a26c968 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -350,6 +350,7 @@ class ParquetFileFormat sparkSession.sessionState.conf.parquetRecordFilterEnabled val timestampConversion: Boolean = sparkSession.sessionState.conf.isParquetINT96TimestampConversion + val capacity = sqlConf.parquetVectorizedReaderBatchSize // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) @@ -396,7 +397,7 @@ class ParquetFileFormat val taskContext = Option(TaskContext.get()) val parquetReader = if (enableVectorizedReader) { val vectorizedReader = new VectorizedParquetRecordReader( - convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined) + convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined, capacity) vectorizedReader.initialize(split, hadoopAttemptContext) logDebug(s"Appending $partitionSchema ${file.partitionValues}") vectorizedReader.initBatch(partitionSchema, file.partitionValues) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala index edb1290ee2eb0..db73bfa149aa0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -40,7 +40,9 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex List.fill(n)(ROW).toDF.repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) reader.initialize(file.asInstanceOf[String], null) val batch = reader.resultBatch() assert(reader.nextBatch()) @@ -65,7 +67,9 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex data.repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) reader.initialize(file.asInstanceOf[String], null) val batch = reader.resultBatch() assert(reader.nextBatch()) @@ -94,7 +98,9 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex data.toDF("f").coalesce(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).asScala.head - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) reader.initialize(file, null /* set columns to null to project all columns */) val column = reader.resultBatch().column(0) assert(reader.nextBatch()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index f3ece5b15e26a..3af80930ec807 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -653,7 +653,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { spark.createDataFrame(data).repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0); { - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) try { reader.initialize(file, null) val result = mutable.ArrayBuffer.empty[(Int, String)] @@ -670,7 +672,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Project just one column { - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) try { reader.initialize(file, ("_2" :: Nil).asJava) val result = mutable.ArrayBuffer.empty[(String)] @@ -686,7 +690,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Project columns in opposite order { - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) try { reader.initialize(file, ("_2" :: "_1" :: Nil).asJava) val result = mutable.ArrayBuffer.empty[(String, Int)] @@ -703,7 +709,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Empty projection { - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) try { reader.initialize(file, List[String]().asJava) var result = 0 @@ -742,8 +750,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { dataTypes.zip(constantValues).foreach { case (dt, v) => val schema = StructType(StructField("pcol", dt) :: Nil) - val vectorizedReader = - new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val vectorizedReader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) val partitionValues = new GenericInternalRow(Array(v)) val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index 86a3c71a3c4f6..e43336d947364 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -76,6 +76,7 @@ object ParquetReadBenchmark { withTempPath { dir => withTempTable("t1", "tempTable") { val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled + val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize spark.range(values).createOrReplaceTempView("t1") spark.sql("select cast(id as INT) as id from t1") .write.parquet(dir.getCanonicalPath) @@ -96,7 +97,8 @@ object ParquetReadBenchmark { parquetReaderBenchmark.addCase("ParquetReader Vectorized") { num => var sum = 0L files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader(enableOffHeapColumnVector) + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) try { reader.initialize(p, ("id" :: Nil).asJava) val batch = reader.resultBatch() @@ -119,7 +121,8 @@ object ParquetReadBenchmark { parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num => var sum = 0L files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader(enableOffHeapColumnVector) + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) try { reader.initialize(p, ("id" :: Nil).asJava) val batch = reader.resultBatch() @@ -262,6 +265,7 @@ object ParquetReadBenchmark { withTempPath { dir => withTempTable("t1", "tempTable") { val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled + val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize spark.range(values).createOrReplaceTempView("t1") spark.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " + s"IF(rand(2) < $fractionOfNulls, NULL, cast(id as STRING)) as c2 from t1") @@ -279,7 +283,8 @@ object ParquetReadBenchmark { benchmark.addCase("PR Vectorized") { num => var sum = 0 files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader(enableOffHeapColumnVector) + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) try { reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) val batch = reader.resultBatch() From b6b50efc854f298d5b3e11c05dca995a85bec962 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Wed, 31 Jan 2018 20:59:19 -0800 Subject: [PATCH 260/774] [SQL][MINOR] Inline SpecifiedWindowFrame.defaultWindowFrame(). ## What changes were proposed in this pull request? SpecifiedWindowFrame.defaultWindowFrame(hasOrderSpecification, acceptWindowFrame) was designed to handle the cases when some Window functions don't support setting a window frame (e.g. rank). However this param is never used. We may inline the whole of this function to simplify the code. ## How was this patch tested? Existing tests. Author: Xingbo Jiang Closes #20463 from jiangxb1987/defaultWindowFrame. --- .../sql/catalyst/analysis/Analyzer.scala | 6 +++++- .../expressions/windowExpressions.scala | 21 ------------------- .../catalyst/ExpressionSQLBuilderSuite.scala | 5 +---- 3 files changed, 6 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 251099f750cf6..7848f88bda1c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2038,7 +2038,11 @@ class Analyzer( WindowExpression(wf, s.copy(frameSpecification = wf.frame)) case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) if e.resolved => - val frame = SpecifiedWindowFrame.defaultWindowFrame(o.nonEmpty, acceptWindowFrame = true) + val frame = if (o.nonEmpty) { + SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) + } else { + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) + } we.copy(windowSpec = s.copy(frameSpecification = frame)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index dd13d9a3bba51..78895f1c2f6f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -265,27 +265,6 @@ case class SpecifiedWindowFrame( } } -object SpecifiedWindowFrame { - /** - * @param hasOrderSpecification If the window spec has order by expressions. - * @param acceptWindowFrame If the window function accepts user-specified frame. - * @return the default window frame. - */ - def defaultWindowFrame( - hasOrderSpecification: Boolean, - acceptWindowFrame: Boolean): SpecifiedWindowFrame = { - if (hasOrderSpecification && acceptWindowFrame) { - // If order spec is defined and the window function supports user specified window frames, - // the default frame is RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW. - SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) - } else { - // Otherwise, the default frame is - // ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING. - SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) - } - } -} - case class UnresolvedWindowExpression( child: Expression, windowSpec: WindowSpecReference) extends UnaryExpression with Unevaluable { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala index d9cf1f361c1d6..61f9179042fe4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala @@ -108,10 +108,7 @@ class ExpressionSQLBuilderSuite extends QueryTest with TestHiveSingleton { } test("window specification") { - val frame = SpecifiedWindowFrame.defaultWindowFrame( - hasOrderSpecification = true, - acceptWindowFrame = true - ) + val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) checkSQL( WindowSpecDefinition('a.int :: Nil, Nil, frame), From 4b7cd479a28b274f5a0802c9b017b3eb15002c21 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 1 Feb 2018 13:58:13 +0800 Subject: [PATCH 261/774] Revert "[SPARK-23200] Reset Kubernetes-specific config on Checkpoint restore" This reverts commit d1721816d26bedee3c72eeb75db49da500568376. The patch is not fully tested and out-of-date. So revert it. --- .../org/apache/spark/streaming/Checkpoint.scala | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index ed2a896033749..aed67a5027433 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -53,21 +53,6 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.driver.host", "spark.driver.bindAddress", "spark.driver.port", - "spark.kubernetes.driver.pod.name", - "spark.kubernetes.executor.podNamePrefix", - "spark.kubernetes.initcontainer.executor.configmapname", - "spark.kubernetes.initcontainer.executor.configmapkey", - "spark.kubernetes.initcontainer.downloadJarsResourceIdentifier", - "spark.kubernetes.initcontainer.downloadJarsSecretLocation", - "spark.kubernetes.initcontainer.downloadFilesResourceIdentifier", - "spark.kubernetes.initcontainer.downloadFilesSecretLocation", - "spark.kubernetes.initcontainer.remoteJars", - "spark.kubernetes.initcontainer.remoteFiles", - "spark.kubernetes.mountdependencies.jarsDownloadDir", - "spark.kubernetes.mountdependencies.filesDownloadDir", - "spark.kubernetes.initcontainer.executor.stagingServerSecret.name", - "spark.kubernetes.initcontainer.executor.stagingServerSecret.mountDir", - "spark.kubernetes.executor.limit.cores", "spark.master", "spark.yarn.jars", "spark.yarn.keytab", @@ -81,7 +66,6 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) .remove("spark.driver.host") .remove("spark.driver.bindAddress") - .remove("spark.kubernetes.driver.pod.name") .remove("spark.driver.port") val newReloadConf = new SparkConf(loadDefaults = true) propertiesToReload.foreach { prop => From 07cee33736aabf9e9a4a89344eda2b8ea29b27ea Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 31 Jan 2018 22:26:27 -0800 Subject: [PATCH 262/774] [SPARK-22274][PYTHON][SQL][FOLLOWUP] Use `assertRaisesRegexp` instead of `assertRaisesRegex`. ## What changes were proposed in this pull request? This is a follow-up pr of #19872 which uses `assertRaisesRegex` but it doesn't exist in Python 2, so some tests fail when running tests in Python 2 environment. Unfortunately, we missed it because currently Python 2 environment of the pr builder doesn't have proper versions of pandas or pyarrow, so the tests were skipped. This pr modifies to use `assertRaisesRegexp` instead of `assertRaisesRegex`. ## How was this patch tested? Tested manually in my local environment. Author: Takuya UESHIN Closes #20467 from ueshin/issues/SPARK-22274/fup1. --- python/pyspark/sql/tests.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index dc26b96334c7a..b27363023ae77 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4530,19 +4530,19 @@ def test_unsupported_types(self): from pyspark.sql.functions import pandas_udf, PandasUDFType with QuietTest(self.sc): - with self.assertRaisesRegex(NotImplementedError, 'not supported'): + with self.assertRaisesRegexp(NotImplementedError, 'not supported'): @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return [v.mean(), v.std()] with QuietTest(self.sc): - with self.assertRaisesRegex(NotImplementedError, 'not supported'): + with self.assertRaisesRegexp(NotImplementedError, 'not supported'): @pandas_udf('mean double, std double', PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return v.mean(), v.std() with QuietTest(self.sc): - with self.assertRaisesRegex(NotImplementedError, 'not supported'): + with self.assertRaisesRegexp(NotImplementedError, 'not supported'): @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return {v.mean(): v.std()} From e15da5b14c8d845028365a609c0c66731d024ee7 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 1 Feb 2018 11:25:01 +0200 Subject: [PATCH 263/774] [SPARK-23107][ML] ML 2.3 QA: New Scala APIs, docs. ## What changes were proposed in this pull request? Audit new APIs and docs in 2.3.0. ## How was this patch tested? No test. Author: Yanbo Liang Closes #20459 from yanboliang/SPARK-23107. --- mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala | 2 +- .../scala/org/apache/spark/ml/regression/LinearRegression.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 1155ea5fdd85b..22e7b8bbf1ff5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -74,7 +74,7 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol with * @group param */ @Since("2.3.0") - final override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle invalid data (unseen or NULL values) in features and label column of string " + "type. Options are 'skip' (filter out rows with invalid data), error (throw an error), " + "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index a5873d03b4161..6d3fe7a6c748c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -645,7 +645,7 @@ class LinearRegressionModel private[ml] ( extends RegressionModel[Vector, LinearRegressionModel] with LinearRegressionParams with MLWritable { - def this(uid: String, coefficients: Vector, intercept: Double) = + private[ml] def this(uid: String, coefficients: Vector, intercept: Double) = this(uid, coefficients, intercept, 1.0) private var trainingSummary: Option[LinearRegressionTrainingSummary] = None From 8bb70b068ea782e799e45238fcb093a6acb0fc9f Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 1 Feb 2018 21:25:02 +0900 Subject: [PATCH 264/774] [SPARK-23280][SQL][FOLLOWUP] Fix Java style check issues. ## What changes were proposed in this pull request? This is a follow-up of #20450 which broke lint-java checks. This pr fixes the lint-java issues. ``` [ERROR] src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java:[20,8] (imports) UnusedImports: Unused import - org.apache.spark.sql.catalyst.util.MapData. [ERROR] src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java:[21,8] (imports) UnusedImports: Unused import - org.apache.spark.sql.catalyst.util.MapData. [ERROR] src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java:[22,8] (imports) UnusedImports: Unused import - org.apache.spark.sql.catalyst.util.MapData. ``` ## How was this patch tested? Checked manually in my local environment. Author: Takuya UESHIN Closes #20468 from ueshin/issues/SPARK-23280/fup1. --- .../main/java/org/apache/spark/sql/vectorized/ColumnVector.java | 1 - .../main/java/org/apache/spark/sql/vectorized/ColumnarArray.java | 1 - .../main/java/org/apache/spark/sql/vectorized/ColumnarRow.java | 1 - 3 files changed, 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index 05271ec1f46ab..530d4d23d4eaf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.vectorized; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.types.CalendarInterval; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 7c7a1c806a2b7..72a192d089b9f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -18,7 +18,6 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.util.ArrayData; -import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 0c9e92ed11fbd..b400f7f93c1fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -19,7 +19,6 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; From 89e8d556b93d1bf1b28fe153fd284f154045b0ee Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 1 Feb 2018 21:28:53 +0900 Subject: [PATCH 265/774] [SPARK-23280][SQL][FOLLOWUP] Enable `MutableColumnarRow.getMap()`. ## What changes were proposed in this pull request? This is a followup pr of #20450. We should've enabled `MutableColumnarRow.getMap()` as well. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #20471 from ueshin/issues/SPARK-23280/fup2. --- .../spark/sql/execution/vectorized/MutableColumnarRow.java | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 66668f3753604..307c19032dee5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -21,10 +21,10 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; import org.apache.spark.sql.vectorized.ColumnarArray; import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.sql.vectorized.ColumnarRow; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.types.CalendarInterval; @@ -162,8 +162,9 @@ public ColumnarArray getArray(int ordinal) { } @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); + public ColumnarMap getMap(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getMap(rowId); } @Override From ffbca84519011a747e0552632e88f5e4956e493d Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 1 Feb 2018 20:39:15 +0800 Subject: [PATCH 266/774] [SPARK-23202][SQL] Add new API in DataSourceWriter: onDataWriterCommit ## What changes were proposed in this pull request? The current DataSourceWriter API makes it hard to implement `onTaskCommit(taskCommit: TaskCommitMessage)` in `FileCommitProtocol`. In general, on receiving commit message, driver can start processing messages(e.g. persist messages into files) before all the messages are collected. The proposal to add a new API: `add(WriterCommitMessage message)`: Handles a commit message on receiving from a successful data writer. This should make the whole API of DataSourceWriter compatible with `FileCommitProtocol`, and more flexible. There was another radical attempt in #20386. This one should be more reasonable. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #20454 from gengliangwang/write_api. --- .../sources/v2/writer/DataSourceWriter.java | 14 +++++++++++-- .../datasources/v2/WriteToDataSourceV2.scala | 5 ++++- .../sql/sources/v2/DataSourceV2Suite.scala | 21 ++++++++++++++++++- .../sources/v2/SimpleWritableDataSource.scala | 21 +++++++++++++++++++ 4 files changed, 57 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 7096aec0d22c2..52324b3792b8a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -62,6 +62,14 @@ public interface DataSourceWriter { */ DataWriterFactory createWriterFactory(); + /** + * Handles a commit message on receiving from a successful data writer. + * + * If this method fails (by throwing an exception), this writing job is considered to to have been + * failed, and {@link #abort(WriterCommitMessage[])} would be called. + */ + default void onDataWriterCommit(WriterCommitMessage message) {} + /** * Commits this writing job with a list of commit messages. The commit messages are collected from * successful data writers and are produced by {@link DataWriter#commit()}. @@ -78,8 +86,10 @@ public interface DataSourceWriter { void commit(WriterCommitMessage[] messages); /** - * Aborts this writing job because some data writers are failed and keep failing when retry, or - * the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} fails. + * Aborts this writing job because some data writers are failed and keep failing when retry, + * or the Spark job fails with some unknown reasons, + * or {@link #onDataWriterCommit(WriterCommitMessage)} fails, + * or {@link #commit(WriterCommitMessage[])} fails. * * If this method fails (by throwing an exception), the underlying data source may require manual * cleanup. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 6592bd72fa338..eefbcf4c0e087 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -80,7 +80,10 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e rdd, runTask, rdd.partitions.indices, - (index, message: WriterCommitMessage) => messages(index) = message + (index, message: WriterCommitMessage) => { + messages(index) = message + writer.onDataWriterCommit(message) + } ) if (!writer.isInstanceOf[StreamWriter]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 2f49b07018aaf..1c3ba7826f7de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -21,7 +21,7 @@ import java.util.{ArrayList, List => JList} import test.org.apache.spark.sql.sources.v2._ -import org.apache.spark.SparkException +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -198,6 +198,25 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } } + + test("simple counter in writer with onDataWriterCommit") { + Seq(classOf[SimpleWritableDataSource]).foreach { cls => + withTempPath { file => + val path = file.getCanonicalPath + assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) + + val numPartition = 6 + spark.range(0, 10, 1, numPartition).select('id, -'id).write.format(cls.getName) + .option("path", path).save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(10).select('id, -'id)) + + assert(SimpleCounter.getCounter == numPartition, + "method onDataWriterCommit should be called as many as the number of partitions") + } + } + } } class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index a131b16953e3b..36dd2a350a055 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -66,9 +66,14 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter { override def createWriterFactory(): DataWriterFactory[Row] = { + SimpleCounter.resetCounter new SimpleCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) } + override def onDataWriterCommit(message: WriterCommitMessage): Unit = { + SimpleCounter.increaseCounter + } + override def commit(messages: Array[WriterCommitMessage]): Unit = { val finalPath = new Path(path) val jobPath = new Path(new Path(finalPath, "_temporary"), jobId) @@ -183,6 +188,22 @@ class SimpleCSVDataReaderFactory(path: String, conf: SerializableConfiguration) } } +private[v2] object SimpleCounter { + private var count: Int = 0 + + def increaseCounter: Unit = { + count += 1 + } + + def getCounter: Int = { + count + } + + def resetCounter: Unit = { + count = 0 + } +} + class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[Row] { From ec63e2d0743a4f75e1cce21d0fe2b54407a86a4a Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 1 Feb 2018 21:00:47 +0800 Subject: [PATCH 267/774] [SPARK-23289][CORE] OneForOneBlockFetcher.DownloadCallback.onData should write the buffer fully ## What changes were proposed in this pull request? `channel.write(buf)` may not write the whole buffer since the underlying channel is a FileChannel, we should retry until the whole buffer is written. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #20461 from zsxwing/SPARK-23289. --- .../apache/spark/network/shuffle/OneForOneBlockFetcher.java | 4 +++- core/src/test/scala/org/apache/spark/FileSuite.scala | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 9cac7d00cc6b6..0bc571874f07c 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -171,7 +171,9 @@ private class DownloadCallback implements StreamCallback { @Override public void onData(String streamId, ByteBuffer buf) throws IOException { - channel.write(buf); + while (buf.hasRemaining()) { + channel.write(buf); + } } @Override diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index e9539dc73f6fa..55a9122cf9026 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -244,7 +244,10 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { for (i <- 0 until testOutputCopies) { // Shift values by i so that they're different in the output val alteredOutput = testOutput.map(b => (b + i).toByte) - channel.write(ByteBuffer.wrap(alteredOutput)) + val buffer = ByteBuffer.wrap(alteredOutput) + while (buffer.hasRemaining) { + channel.write(buffer) + } } channel.close() file.close() From f051f834036e63d5e480d86440ce39924f979e82 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 1 Feb 2018 10:36:31 -0800 Subject: [PATCH 268/774] [SPARK-13983][SQL] Fix HiveThriftServer2 can not get "--hiveconf" and ''--hivevar" variables since 2.0 ## What changes were proposed in this pull request? `--hiveconf` and `--hivevar` variables no longer work since Spark 2.0. The `spark-sql` client has fixed by [SPARK-15730](https://issues.apache.org/jira/browse/SPARK-15730) and [SPARK-18086](https://issues.apache.org/jira/browse/SPARK-18086). but `beeline`/[`Spark SQL HiveThriftServer2`](https://github.com/apache/spark/blob/v2.1.1/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala) is still broken. This pull request fix it. This pull request works for both `JDBC client` and `beeline`. ## How was this patch tested? unit tests for `JDBC client` manual tests for `beeline`: ``` git checkout origin/pr/17886 dev/make-distribution.sh --mvn mvn --tgz -Phive -Phive-thriftserver -Phadoop-2.6 -DskipTests tar -zxf spark-2.3.0-SNAPSHOT-bin-2.6.5.tgz && cd spark-2.3.0-SNAPSHOT-bin-2.6.5 sbin/start-thriftserver.sh ``` ``` cat < test.sql select '\${a}', '\${b}'; EOF beeline -u jdbc:hive2://localhost:10000 --hiveconf a=avalue --hivevar b=bvalue -f test.sql ``` Author: Yuming Wang Closes #17886 from wangyum/SPARK-13983-dev. --- .../service/cli/session/HiveSessionImpl.java | 74 ++++++++++++++++++- .../server/SparkSQLOperationManager.scala | 12 +++ .../HiveThriftServer2Suites.scala | 23 +++++- 3 files changed, 105 insertions(+), 4 deletions(-) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java index 108074cce3d6d..fc818bc69c761 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java @@ -44,7 +44,7 @@ import org.apache.hadoop.hive.ql.history.HiveHistory; import org.apache.hadoop.hive.ql.metadata.Hive; import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.apache.hadoop.hive.ql.processors.SetProcessor; +import org.apache.hadoop.hive.ql.parse.VariableSubstitution; import org.apache.hadoop.hive.ql.session.SessionState; import org.apache.hadoop.hive.shims.ShimLoader; import org.apache.hive.common.util.HiveVersionInfo; @@ -71,6 +71,12 @@ import org.apache.hive.service.cli.thrift.TProtocolVersion; import org.apache.hive.service.server.ThreadWithGarbageCleanup; +import static org.apache.hadoop.hive.conf.SystemVariables.ENV_PREFIX; +import static org.apache.hadoop.hive.conf.SystemVariables.HIVECONF_PREFIX; +import static org.apache.hadoop.hive.conf.SystemVariables.HIVEVAR_PREFIX; +import static org.apache.hadoop.hive.conf.SystemVariables.METACONF_PREFIX; +import static org.apache.hadoop.hive.conf.SystemVariables.SYSTEM_PREFIX; + /** * HiveSession * @@ -209,7 +215,7 @@ private void configureSession(Map sessionConfMap) throws HiveSQL String key = entry.getKey(); if (key.startsWith("set:")) { try { - SetProcessor.setVariable(key.substring(4), entry.getValue()); + setVariable(key.substring(4), entry.getValue()); } catch (Exception e) { throw new HiveSQLException(e); } @@ -221,6 +227,70 @@ private void configureSession(Map sessionConfMap) throws HiveSQL } } + // Copy from org.apache.hadoop.hive.ql.processors.SetProcessor, only change: + // setConf(varname, propName, varvalue, true) when varname.startsWith(HIVECONF_PREFIX) + public static int setVariable(String varname, String varvalue) throws Exception { + SessionState ss = SessionState.get(); + if (varvalue.contains("\n")){ + ss.err.println("Warning: Value had a \\n character in it."); + } + varname = varname.trim(); + if (varname.startsWith(ENV_PREFIX)){ + ss.err.println("env:* variables can not be set."); + return 1; + } else if (varname.startsWith(SYSTEM_PREFIX)){ + String propName = varname.substring(SYSTEM_PREFIX.length()); + System.getProperties().setProperty(propName, + new VariableSubstitution().substitute(ss.getConf(),varvalue)); + } else if (varname.startsWith(HIVECONF_PREFIX)){ + String propName = varname.substring(HIVECONF_PREFIX.length()); + setConf(varname, propName, varvalue, true); + } else if (varname.startsWith(HIVEVAR_PREFIX)) { + String propName = varname.substring(HIVEVAR_PREFIX.length()); + ss.getHiveVariables().put(propName, + new VariableSubstitution().substitute(ss.getConf(),varvalue)); + } else if (varname.startsWith(METACONF_PREFIX)) { + String propName = varname.substring(METACONF_PREFIX.length()); + Hive hive = Hive.get(ss.getConf()); + hive.setMetaConf(propName, new VariableSubstitution().substitute(ss.getConf(), varvalue)); + } else { + setConf(varname, varname, varvalue, true); + } + return 0; + } + + // returns non-null string for validation fail + private static void setConf(String varname, String key, String varvalue, boolean register) + throws IllegalArgumentException { + HiveConf conf = SessionState.get().getConf(); + String value = new VariableSubstitution().substitute(conf, varvalue); + if (conf.getBoolVar(HiveConf.ConfVars.HIVECONFVALIDATION)) { + HiveConf.ConfVars confVars = HiveConf.getConfVars(key); + if (confVars != null) { + if (!confVars.isType(value)) { + StringBuilder message = new StringBuilder(); + message.append("'SET ").append(varname).append('=').append(varvalue); + message.append("' FAILED because ").append(key).append(" expects "); + message.append(confVars.typeString()).append(" type value."); + throw new IllegalArgumentException(message.toString()); + } + String fail = confVars.validate(value); + if (fail != null) { + StringBuilder message = new StringBuilder(); + message.append("'SET ").append(varname).append('=').append(varvalue); + message.append("' FAILED in validation : ").append(fail).append('.'); + throw new IllegalArgumentException(message.toString()); + } + } else if (key.startsWith("hive.")) { + throw new IllegalArgumentException("hive configuration " + key + " does not exists."); + } + } + conf.verifyAndSet(key, value); + if (register) { + SessionState.get().getOverriddenConfigurations().put(key, value); + } + } + @Override public void setOperationLogSessionDir(File operationLogRootDir) { if (!operationLogRootDir.exists()) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index a0e5012633f5e..bf7c01f60fb5c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -28,6 +28,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation} +import org.apache.spark.sql.internal.SQLConf /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. @@ -50,6 +51,9 @@ private[thriftserver] class SparkSQLOperationManager() require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" + s" initialized or had already closed.") val conf = sqlContext.sessionState.conf + val hiveSessionState = parentSession.getSessionState + setConfMap(conf, hiveSessionState.getOverriddenConfigurations) + setConfMap(conf, hiveSessionState.getHiveVariables) val runInBackground = async && conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC) val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay, runInBackground)(sqlContext, sessionToActivePool) @@ -58,4 +62,12 @@ private[thriftserver] class SparkSQLOperationManager() s"runInBackground=$runInBackground") operation } + + def setConfMap(conf: SQLConf, confMap: java.util.Map[String, String]): Unit = { + val iterator = confMap.entrySet().iterator() + while (iterator.hasNext) { + val kv = iterator.next() + conf.setConfString(kv.getKey, kv.getValue) + } + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 7289da71a3365..496f8c82a6c61 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -135,6 +135,22 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } } + test("Support beeline --hiveconf and --hivevar") { + withJdbcStatement() { statement => + executeTest(hiveConfList) + executeTest(hiveVarList) + def executeTest(hiveList: String): Unit = { + hiveList.split(";").foreach{ m => + val kv = m.split("=") + // select "${a}"; ---> avalue + val resultSet = statement.executeQuery("select \"${" + kv(0) + "}\"") + resultSet.next() + assert(resultSet.getString(1) === kv(1)) + } + } + } + } + test("JDBC query execution") { withJdbcStatement("test") { statement => val queries = Seq( @@ -740,10 +756,11 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { s"""jdbc:hive2://localhost:$serverPort/ |default? |hive.server2.transport.mode=http; - |hive.server2.thrift.http.path=cliservice + |hive.server2.thrift.http.path=cliservice; + |${hiveConfList}#${hiveVarList} """.stripMargin.split("\n").mkString.trim } else { - s"jdbc:hive2://localhost:$serverPort/" + s"jdbc:hive2://localhost:$serverPort/?${hiveConfList}#${hiveVarList}" } def withMultipleConnectionJdbcStatement(tableNames: String*)(fs: (Statement => Unit)*) { @@ -779,6 +796,8 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl private var listeningPort: Int = _ protected def serverPort: Int = listeningPort + protected val hiveConfList = "a=avalue;b=bvalue" + protected val hiveVarList = "c=cvalue;d=dvalue" protected def user = System.getProperty("user.name") protected var warehousePath: File = _ From 73da3b6968630d9e2cafc742ccb6d4eb54957df4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 1 Feb 2018 10:48:34 -0800 Subject: [PATCH 269/774] [SPARK-23293][SQL] fix data source v2 self join ## What changes were proposed in this pull request? `DataSourceV2Relation` should extend `MultiInstanceRelation`, to take care of self-join. ## How was this patch tested? a new test Author: Wenchen Fan Closes #20466 from cloud-fan/dsv2-selfjoin. --- .../execution/datasources/v2/DataSourceV2Relation.scala | 8 +++++++- .../apache/spark/sql/sources/v2/DataSourceV2Suite.scala | 6 ++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 3d4c64981373d..eebfa29f91b99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.execution.datasources.v2 +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( fullOutput: Seq[AttributeReference], - reader: DataSourceReader) extends LeafNode with DataSourceReaderHolder { + reader: DataSourceReader) + extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] @@ -33,6 +35,10 @@ case class DataSourceV2Relation( case _ => Statistics(sizeInBytes = conf.defaultSizeInBytes) } + + override def newInstance(): DataSourceV2Relation = { + copy(fullOutput = fullOutput.map(_.newInstance())) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 1c3ba7826f7de..23147fffe8a08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -217,6 +217,12 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-23293: data source v2 self join") { + val df = spark.read.format(classOf[SimpleDataSourceV2].getName).load() + val df2 = df.select(($"i" + 1).as("k"), $"j") + checkAnswer(df.join(df2, "j"), (0 until 10).map(i => Row(-i, i, i + 1))) + } } class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { From 4bcfdefb9f6d5ba88335953683a1dabbee83e9ea Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 1 Feb 2018 14:56:40 -0800 Subject: [PATCH 270/774] [INFRA] Close stale PRs. Closes #20334 Closes #20262 From 032c11b83f0d276bf8085992229b8c598f02798a Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Thu, 1 Feb 2018 15:26:59 -0800 Subject: [PATCH 271/774] [SPARK-23296][YARN] Include stacktrace in YARN-app diagnostic ## What changes were proposed in this pull request? Include stacktrace in the diagnostics message upon abnormal unregister from RM ## How was this patch tested? Tested with a failing job, and confirmed a stacktrace in the client output and YARN webUI. Author: Gera Shegalov Closes #20470 from gerashegalov/gera/stacktrace-diagnostics. --- .../scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 4d5e3bb043671..2f88feb0f1fdf 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -30,6 +30,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration @@ -718,7 +719,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends logError("User class threw exception: " + cause, cause) finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_EXCEPTION_USER_CLASS, - "User class threw exception: " + cause) + "User class threw exception: " + StringUtils.stringifyException(cause)) } sparkContextPromise.tryFailure(e.getCause()) } finally { From 90848d507457d30abb36e3ba07618dfc87c34cd6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 2 Feb 2018 10:18:32 +0800 Subject: [PATCH 272/774] [SPARK-23284][SQL] Document the behavior of several ColumnVector's get APIs when accessing null slot ## What changes were proposed in this pull request? For some ColumnVector get APIs such as getDecimal, getBinary, getStruct, getArray, getInterval, getUTF8String, we should clearly document their behaviors when accessing null slot. They should return null in this case. Then we can remove null checks from the places using above APIs. For the APIs of primitive values like getInt, getInts, etc., this also documents their behaviors when accessing null slots. Their returning values are undefined and can be anything. ## How was this patch tested? Added tests into `ColumnarBatchSuite`. Author: Liang-Chi Hsieh Closes #20455 from viirya/SPARK-23272-followup. --- .../datasources/orc/OrcColumnVector.java | 3 + .../vectorized/MutableColumnarRow.java | 7 -- .../vectorized/WritableColumnVector.java | 5 ++ .../sql/vectorized/ArrowColumnVector.java | 4 + .../spark/sql/vectorized/ColumnVector.java | 63 ++++++++++------ .../spark/sql/vectorized/ColumnarRow.java | 7 -- .../vectorized/ColumnarBatchSuite.scala | 74 ++++++++++++++++++- 7 files changed, 124 insertions(+), 39 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index c8add4c9f486c..12f4d658b1868 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -154,12 +154,14 @@ public double getDouble(int rowId) { @Override public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; BigDecimal data = decimalData.vector[getRowIndex(rowId)].getHiveDecimal().bigDecimalValue(); return Decimal.apply(data, precision, scale); } @Override public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; int index = getRowIndex(rowId); BytesColumnVector col = bytesData; return UTF8String.fromBytes(col.vector[index], col.start[index], col.length[index]); @@ -167,6 +169,7 @@ public UTF8String getUTF8String(int rowId) { @Override public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; int index = getRowIndex(rowId); byte[] binary = new byte[bytesData.length[index]]; System.arraycopy(bytesData.vector[index], bytesData.start[index], binary, 0, binary.length); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 307c19032dee5..4e4242fe8d9b9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -127,43 +127,36 @@ public boolean anyNull() { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getUTF8String(rowId); } @Override public byte[] getBinary(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getBinary(rowId); } @Override public CalendarInterval getInterval(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getInterval(rowId); } @Override public ColumnarRow getStruct(int ordinal, int numFields) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getStruct(rowId); } @Override public ColumnarArray getArray(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getArray(rowId); } @Override public ColumnarMap getMap(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getMap(rowId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 9d447cdc79063..5275e4a91eac0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -341,6 +341,7 @@ public final int putByteArray(int rowId, byte[] value) { @Override public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; if (precision <= Decimal.MAX_INT_DIGITS()) { return Decimal.createUnsafe(getInt(rowId), precision, scale); } else if (precision <= Decimal.MAX_LONG_DIGITS()) { @@ -367,6 +368,7 @@ public void putDecimal(int rowId, Decimal value, int precision) { @Override public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; if (dictionary == null) { return arrayData().getBytesAsUTF8String(getArrayOffset(rowId), getArrayLength(rowId)); } else { @@ -384,6 +386,7 @@ public UTF8String getUTF8String(int rowId) { @Override public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; if (dictionary == null) { return arrayData().getBytes(getArrayOffset(rowId), getArrayLength(rowId)); } else { @@ -613,6 +616,7 @@ public final int appendStruct(boolean isNull) { // array offsets and lengths in the current column vector. @Override public final ColumnarArray getArray(int rowId) { + if (isNullAt(rowId)) return null; return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId)); } @@ -620,6 +624,7 @@ public final ColumnarArray getArray(int rowId) { // second child column vector, and puts the offsets and lengths in the current column vector. @Override public final ColumnarMap getMap(int rowId) { + if (isNullAt(rowId)) return null; return new ColumnarMap(getChild(0), getChild(1), getArrayOffset(rowId), getArrayLength(rowId)); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index f3ece538c3b80..f8e37e995a17f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -101,21 +101,25 @@ public double getDouble(int rowId) { @Override public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; return accessor.getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; return accessor.getUTF8String(rowId); } @Override public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; return accessor.getBinary(rowId); } @Override public ColumnarArray getArray(int rowId) { + if (isNullAt(rowId)) return null; return accessor.getArray(rowId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index 530d4d23d4eaf..ad99b450a4809 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -80,12 +80,14 @@ public abstract class ColumnVector implements AutoCloseable { public abstract boolean isNullAt(int rowId); /** - * Returns the boolean type value for rowId. + * Returns the boolean type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract boolean getBoolean(int rowId); /** - * Gets boolean type values from [rowId, rowId + count) + * Gets boolean type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public boolean[] getBooleans(int rowId, int count) { boolean[] res = new boolean[count]; @@ -96,12 +98,14 @@ public boolean[] getBooleans(int rowId, int count) { } /** - * Returns the byte type value for rowId. + * Returns the byte type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract byte getByte(int rowId); /** - * Gets byte type values from [rowId, rowId + count) + * Gets byte type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public byte[] getBytes(int rowId, int count) { byte[] res = new byte[count]; @@ -112,12 +116,14 @@ public byte[] getBytes(int rowId, int count) { } /** - * Returns the short type value for rowId. + * Returns the short type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract short getShort(int rowId); /** - * Gets short type values from [rowId, rowId + count) + * Gets short type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public short[] getShorts(int rowId, int count) { short[] res = new short[count]; @@ -128,12 +134,14 @@ public short[] getShorts(int rowId, int count) { } /** - * Returns the int type value for rowId. + * Returns the int type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract int getInt(int rowId); /** - * Gets int type values from [rowId, rowId + count) + * Gets int type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public int[] getInts(int rowId, int count) { int[] res = new int[count]; @@ -144,12 +152,14 @@ public int[] getInts(int rowId, int count) { } /** - * Returns the long type value for rowId. + * Returns the long type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract long getLong(int rowId); /** - * Gets long type values from [rowId, rowId + count) + * Gets long type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public long[] getLongs(int rowId, int count) { long[] res = new long[count]; @@ -160,12 +170,14 @@ public long[] getLongs(int rowId, int count) { } /** - * Returns the float type value for rowId. + * Returns the float type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract float getFloat(int rowId); /** - * Gets float type values from [rowId, rowId + count) + * Gets float type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public float[] getFloats(int rowId, int count) { float[] res = new float[count]; @@ -176,12 +188,14 @@ public float[] getFloats(int rowId, int count) { } /** - * Returns the double type value for rowId. + * Returns the double type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract double getDouble(int rowId); /** - * Gets double type values from [rowId, rowId + count) + * Gets double type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public double[] getDoubles(int rowId, int count) { double[] res = new double[count]; @@ -192,7 +206,7 @@ public double[] getDoubles(int rowId, int count) { } /** - * Returns the struct type value for rowId. + * Returns the struct type value for rowId. If the slot for rowId is null, it should return null. * * To support struct type, implementations must implement {@link #getChild(int)} and make this * vector a tree structure. The number of child vectors must be same as the number of fields of @@ -205,7 +219,7 @@ public final ColumnarRow getStruct(int rowId) { } /** - * Returns the array type value for rowId. + * Returns the array type value for rowId. If the slot for rowId is null, it should return null. * * To support array type, implementations must construct an {@link ColumnarArray} and return it in * this method. {@link ColumnarArray} requires a {@link ColumnVector} that stores the data of all @@ -218,13 +232,13 @@ public final ColumnarRow getStruct(int rowId) { public abstract ColumnarArray getArray(int rowId); /** - * Returns the map type value for rowId. + * Returns the map type value for rowId. If the slot for rowId is null, it should return null. * * In Spark, map type value is basically a key data array and a value data array. A key from the * key array with a index and a value from the value array with the same index contribute to * an entry of this map type value. * - * To support map type, implementations must construct an {@link ColumnarMap} and return it in + * To support map type, implementations must construct a {@link ColumnarMap} and return it in * this method. {@link ColumnarMap} requires a {@link ColumnVector} that stores the data of all * the keys of all the maps in this vector, and another {@link ColumnVector} that stores the data * of all the values of all the maps in this vector, and a pair of offset and length which @@ -233,24 +247,25 @@ public final ColumnarRow getStruct(int rowId) { public abstract ColumnarMap getMap(int ordinal); /** - * Returns the decimal type value for rowId. + * Returns the decimal type value for rowId. If the slot for rowId is null, it should return null. */ public abstract Decimal getDecimal(int rowId, int precision, int scale); /** - * Returns the string type value for rowId. Note that the returned UTF8String may point to the - * data of this column vector, please copy it if you want to keep it after this column vector is - * freed. + * Returns the string type value for rowId. If the slot for rowId is null, it should return null. + * Note that the returned UTF8String may point to the data of this column vector, please copy it + * if you want to keep it after this column vector is freed. */ public abstract UTF8String getUTF8String(int rowId); /** - * Returns the binary type value for rowId. + * Returns the binary type value for rowId. If the slot for rowId is null, it should return null. */ public abstract byte[] getBinary(int rowId); /** - * Returns the calendar interval type value for rowId. + * Returns the calendar interval type value for rowId. If the slot for rowId is null, it should + * return null. * * In Spark, calendar interval type value is basically an integer value representing the number of * months in this interval, and a long value representing the number of microseconds in this diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index b400f7f93c1fe..f2f2279590023 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -119,43 +119,36 @@ public boolean anyNull() { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getUTF8String(rowId); } @Override public byte[] getBinary(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getBinary(rowId); } @Override public CalendarInterval getInterval(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getInterval(rowId); } @Override public ColumnarRow getStruct(int ordinal, int numFields) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getStruct(rowId); } @Override public ColumnarArray getArray(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getArray(rowId); } @Override public ColumnarMap getMap(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getMap(rowId); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 8fe2985836f2e..772f687526008 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -572,7 +572,7 @@ class ColumnarBatchSuite extends SparkFunSuite { } } - testVector("String APIs", 6, StringType) { + testVector("String APIs", 7, StringType) { column => val reference = mutable.ArrayBuffer.empty[String] @@ -619,6 +619,10 @@ class ColumnarBatchSuite extends SparkFunSuite { idx += 1 assert(column.arrayData().elementsAppended == 17 + (s + s).length) + column.putNull(idx) + assert(column.getUTF8String(idx) == null) + idx += 1 + reference.zipWithIndex.foreach { v => val errMsg = "VectorType=" + column.getClass.getSimpleName assert(v._1.length == column.getArrayLength(v._2), errMsg) @@ -647,6 +651,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference += new CalendarInterval(0, 2000) column.putNull(2) + assert(column.getInterval(2) == null) reference += null months.putInt(3, 20) @@ -683,6 +688,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.getArray(0).numElements == 1) assert(column.getArray(1).numElements == 2) assert(column.isNullAt(2)) + assert(column.getArray(2) == null) assert(column.getArray(3).numElements == 0) assert(column.getArray(4).numElements == 3) @@ -785,6 +791,7 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putArray(0, 0, 1) column.putArray(1, 1, 2) column.putNull(2) + assert(column.getMap(2) == null) column.putArray(3, 3, 0) column.putArray(4, 3, 3) @@ -821,6 +828,7 @@ class ColumnarBatchSuite extends SparkFunSuite { c2.putDouble(0, 3.45) column.putNull(1) + assert(column.getStruct(1) == null) c1.putInt(2, 456) c2.putDouble(2, 5.67) @@ -1261,4 +1269,68 @@ class ColumnarBatchSuite extends SparkFunSuite { batch.close() allocator.close() } + + testVector("Decimal API", 4, DecimalType.IntDecimal) { + column => + + val reference = mutable.ArrayBuffer.empty[Decimal] + + var idx = 0 + column.putDecimal(idx, new Decimal().set(10), 10) + reference += new Decimal().set(10) + idx += 1 + + column.putDecimal(idx, new Decimal().set(20), 10) + reference += new Decimal().set(20) + idx += 1 + + column.putNull(idx) + assert(column.getDecimal(idx, 10, 0) == null) + reference += null + idx += 1 + + column.putDecimal(idx, new Decimal().set(30), 10) + reference += new Decimal().set(30) + + reference.zipWithIndex.foreach { case (v, i) => + val errMsg = "VectorType=" + column.getClass.getSimpleName + assert(v == column.getDecimal(i, 10, 0), errMsg) + if (v == null) assert(column.isNullAt(i), errMsg) + } + + column.close() + } + + testVector("Binary APIs", 4, BinaryType) { + column => + + val reference = mutable.ArrayBuffer.empty[String] + var idx = 0 + column.putByteArray(idx, "Hello".getBytes(StandardCharsets.UTF_8)) + reference += "Hello" + idx += 1 + + column.putByteArray(idx, "World".getBytes(StandardCharsets.UTF_8)) + reference += "World" + idx += 1 + + column.putNull(idx) + reference += null + idx += 1 + + column.putByteArray(idx, "abc".getBytes(StandardCharsets.UTF_8)) + reference += "abc" + + reference.zipWithIndex.foreach { case (v, i) => + val errMsg = "VectorType=" + column.getClass.getSimpleName + if (v != null) { + assert(v == new String(column.getBinary(i)), errMsg) + } else { + assert(column.isNullAt(i), errMsg) + assert(column.getBinary(i) == null, errMsg) + } + } + + column.close() + } } From 969eda4a02faa7ca6cf3aff5cd10e6d51026b845 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 2 Feb 2018 11:43:22 +0800 Subject: [PATCH 273/774] [SPARK-23020][CORE] Fix another race in the in-process launcher test. First the bad news: there's an unfixable race in the launcher code. (By unfixable I mean it would take a lot more effort than this change to fix it.) The good news is that it should only affect super short lived applications, such as the one run by the flaky test, so it's possible to work around it in our test. The fix also uncovered an issue with the recently added "closeAndWait()" method; closing the connection would still possibly cause data loss, so this change waits a while for the connection to finish itself, and closes the socket if that times out. The existing connection timeout is reused so that if desired it's possible to control how long to wait. As part of that I also restored the old behavior that disconnect() would force a disconnection from the child app; the "wait for data to arrive" approach is only taken when disposing of the handle. I tested this by inserting a bunch of sleeps in the test and the socket handling code in the launcher library; with those I was able to reproduce the error from the jenkins jobs. With the changes, even with all the sleeps still in place, all tests pass. Author: Marcelo Vanzin Closes #20462 from vanzin/SPARK-23020. --- .../spark/launcher/SparkLauncherSuite.java | 40 ++++++++++++++--- .../spark/launcher/AbstractAppHandle.java | 45 ++++++++++++------- .../spark/launcher/ChildProcAppHandle.java | 2 +- .../spark/launcher/InProcessAppHandle.java | 2 +- .../apache/spark/launcher/LauncherServer.java | 30 ++++++++----- .../spark/launcher/LauncherServerSuite.java | 2 +- 6 files changed, 87 insertions(+), 34 deletions(-) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 1543f4fdb0162..2225591a4ff75 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -157,12 +157,24 @@ private void inProcessLauncherTestImpl() throws Exception { SparkAppHandle handle = null; try { - handle = new InProcessLauncher() - .setMaster("local") - .setAppResource(SparkLauncher.NO_RESOURCE) - .setMainClass(InProcessTestApp.class.getName()) - .addAppArgs("hello") - .startApplication(listener); + synchronized (InProcessTestApp.LOCK) { + handle = new InProcessLauncher() + .setMaster("local") + .setAppResource(SparkLauncher.NO_RESOURCE) + .setMainClass(InProcessTestApp.class.getName()) + .addAppArgs("hello") + .startApplication(listener); + + // SPARK-23020: see doc for InProcessTestApp.LOCK for a description of the race. Here + // we wait until we know that the connection between the app and the launcher has been + // established before allowing the app to finish. + final SparkAppHandle _handle = handle; + eventually(Duration.ofSeconds(5), Duration.ofMillis(10), () -> { + assertNotEquals(SparkAppHandle.State.UNKNOWN, _handle.getState()); + }); + + InProcessTestApp.LOCK.wait(5000); + } waitFor(handle); assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); @@ -193,10 +205,26 @@ public static void main(String[] args) throws Exception { public static class InProcessTestApp { + /** + * SPARK-23020: there's a race caused by a child app finishing too quickly. This would cause + * the InProcessAppHandle to dispose of itself even before the child connection was properly + * established, so no state changes would be detected for the application and its final + * state would be LOST. + * + * It's not really possible to fix that race safely in the handle code itself without changing + * the way in-process apps talk to the launcher library, so we work around that in the test by + * synchronizing on this object. + */ + public static final Object LOCK = new Object(); + public static void main(String[] args) throws Exception { assertNotEquals(0, args.length); assertEquals(args[0], "hello"); new SparkContext().stop(); + + synchronized (LOCK) { + LOCK.notifyAll(); + } } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java index 84a25a5254151..9cbebdaeb33d3 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java @@ -18,22 +18,22 @@ package org.apache.spark.launcher; import java.io.IOException; -import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Level; import java.util.logging.Logger; abstract class AbstractAppHandle implements SparkAppHandle { - private static final Logger LOG = Logger.getLogger(ChildProcAppHandle.class.getName()); + private static final Logger LOG = Logger.getLogger(AbstractAppHandle.class.getName()); private final LauncherServer server; private LauncherServer.ServerConnection connection; private List listeners; private AtomicReference state; - private String appId; + private volatile String appId; private volatile boolean disposed; protected AbstractAppHandle(LauncherServer server) { @@ -44,7 +44,7 @@ protected AbstractAppHandle(LauncherServer server) { @Override public synchronized void addListener(Listener l) { if (listeners == null) { - listeners = new ArrayList<>(); + listeners = new CopyOnWriteArrayList<>(); } listeners.add(l); } @@ -71,16 +71,14 @@ public void stop() { @Override public synchronized void disconnect() { - if (!isDisposed()) { - if (connection != null) { - try { - connection.closeAndWait(); - } catch (IOException ioe) { - // no-op. - } + if (connection != null && connection.isOpen()) { + try { + connection.close(); + } catch (IOException ioe) { + // no-op. } - dispose(); } + dispose(); } void setConnection(LauncherServer.ServerConnection connection) { @@ -97,10 +95,25 @@ boolean isDisposed() { /** * Mark the handle as disposed, and set it as LOST in case the current state is not final. + * + * This method should be called only when there's a reasonable expectation that the communication + * with the child application is not needed anymore, either because the code managing the handle + * has said so, or because the child application is finished. */ synchronized void dispose() { if (!isDisposed()) { + // First wait for all data from the connection to be read. Then unregister the handle. + // Otherwise, unregistering might cause the server to be stopped and all child connections + // to be closed. + if (connection != null) { + try { + connection.waitForClose(); + } catch (IOException ioe) { + // no-op. + } + } server.unregister(this); + // Set state to LOST if not yet final. setState(State.LOST, false); this.disposed = true; @@ -127,11 +140,13 @@ void setState(State s, boolean force) { current = state.get(); } - LOG.log(Level.WARNING, "Backend requested transition from final state {0} to {1}.", - new Object[] { current, s }); + if (s != State.LOST) { + LOG.log(Level.WARNING, "Backend requested transition from final state {0} to {1}.", + new Object[] { current, s }); + } } - synchronized void setAppId(String appId) { + void setAppId(String appId) { this.appId = appId; fireEvent(true); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index 5e3c95676ecbe..5609f8492f4f4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -112,7 +112,7 @@ void monitorChild() { } } - disconnect(); + dispose(); } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index b8030e0063a37..4b740d3fad20e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -66,7 +66,7 @@ synchronized void start(String appName, Method main, String[] args) { setState(State.FAILED); } - disconnect(); + dispose(); }); app.setName(String.format(THREAD_NAME_FMT, THREAD_IDS.incrementAndGet(), appName)); diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index f4ecd52fdeab8..607879fd02ea9 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -238,6 +238,7 @@ public void run() { }; ServerConnection clientConnection = new ServerConnection(client, timeout); Thread clientThread = factory.newThread(clientConnection); + clientConnection.setConnectionThread(clientThread); synchronized (clients) { clients.add(clientConnection); } @@ -290,17 +291,15 @@ class ServerConnection extends LauncherConnection { private TimerTask timeout; private volatile Thread connectionThread; - volatile AbstractAppHandle handle; + private volatile AbstractAppHandle handle; ServerConnection(Socket socket, TimerTask timeout) throws IOException { super(socket); this.timeout = timeout; } - @Override - public void run() { - this.connectionThread = Thread.currentThread(); - super.run(); + void setConnectionThread(Thread t) { + this.connectionThread = t; } @Override @@ -361,19 +360,30 @@ public void close() throws IOException { } /** - * Close the connection and wait for any buffered data to be processed before returning. + * Wait for the remote side to close the connection so that any pending data is processed. * This ensures any changes reported by the child application take effect. + * + * This method allows a short period for the above to happen (same amount of time as the + * connection timeout, which is configurable). This should be fine for well-behaved + * applications, where they close the connection arond the same time the app handle detects the + * app has finished. + * + * In case the connection is not closed within the grace period, this method forcefully closes + * it and any subsequent data that may arrive will be ignored. */ - public void closeAndWait() throws IOException { - close(); - + public void waitForClose() throws IOException { Thread connThread = this.connectionThread; if (Thread.currentThread() != connThread) { try { - connThread.join(); + connThread.join(getConnectionTimeout()); } catch (InterruptedException ie) { // Ignore. } + + if (connThread.isAlive()) { + LOG.log(Level.WARNING, "Timed out waiting for child connection to close."); + close(); + } } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 024efac33c391..d16337a319be3 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -94,8 +94,8 @@ public void infoChanged(SparkAppHandle handle) { Message stopMsg = client.inbound.poll(30, TimeUnit.SECONDS); assertTrue(stopMsg instanceof Stop); } finally { - handle.kill(); close(client); + handle.kill(); client.clientThread.join(); } } From b3a04283f490020c13b6750de021af734c449c3a Mon Sep 17 00:00:00 2001 From: Zhan Zhang Date: Fri, 2 Feb 2018 12:21:06 +0800 Subject: [PATCH 274/774] [SPARK-23306] Fix the oom caused by contention ## What changes were proposed in this pull request? here is race condition in TaskMemoryManger, which may cause OOM. The memory released may be taken by another task because there is a gap between releaseMemory and acquireMemory, e.g., UnifiedMemoryManager, causing the OOM. if the current is the only one that can perform spill. It can happen to BytesToBytesMap, as it only spill required bytes. Loop on current consumer if it still has memory to release. ## How was this patch tested? The race contention is hard to reproduce, but the current logic seems causing the issue. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Zhan Zhang Closes #20480 from zhzhan/oom. --- .../org/apache/spark/memory/TaskMemoryManager.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 632d718062212..d07faf1da1248 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -172,10 +172,7 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { currentEntry = sortedConsumers.lastEntry(); } List cList = currentEntry.getValue(); - MemoryConsumer c = cList.remove(cList.size() - 1); - if (cList.isEmpty()) { - sortedConsumers.remove(currentEntry.getKey()); - } + MemoryConsumer c = cList.get(cList.size() - 1); try { long released = c.spill(required - got, consumer); if (released > 0) { @@ -185,6 +182,11 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { if (got >= required) { break; } + } else { + cList.remove(cList.size() - 1); + if (cList.isEmpty()) { + sortedConsumers.remove(currentEntry.getKey()); + } } } catch (ClosedByInterruptException e) { // This called by user to kill a task (e.g: speculative task). From 19c7c7ebdef6c1c7a02ebac9af6a24f521b52c37 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 1 Feb 2018 20:44:46 -0800 Subject: [PATCH 275/774] [SPARK-23301][SQL] data source column pruning should work for arbitrary expressions ## What changes were proposed in this pull request? This PR fixes a mistake in the `PushDownOperatorsToDataSource` rule, the column pruning logic is incorrect about `Project`. ## How was this patch tested? a new test case for column pruning with arbitrary expressions, and improve the existing tests to make sure the `PushDownOperatorsToDataSource` really works. Author: Wenchen Fan Closes #20476 from cloud-fan/push-down. --- .../v2/PushDownOperatorsToDataSource.scala | 53 ++++---- .../sources/v2/JavaAdvancedDataSourceV2.java | 29 ++++- .../sql/sources/v2/DataSourceV2Suite.scala | 113 ++++++++++++++++-- 3 files changed, 155 insertions(+), 40 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index df034adf1e7d6..566a48394f02e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, Expression, NamedExpression, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, AttributeSet, Expression, NamedExpression, PredicateHelper} import org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule @@ -81,35 +81,34 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel // TODO: add more push down rules. - // TODO: nested fields pruning - def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: Seq[Attribute]): Unit = { - plan match { - case Project(projectList, child) => - val required = projectList.filter(requiredByParent.contains).flatMap(_.references) - pushDownRequiredColumns(child, required) - - case Filter(condition, child) => - val required = requiredByParent ++ condition.references - pushDownRequiredColumns(child, required) - - case DataSourceV2Relation(fullOutput, reader) => reader match { - case r: SupportsPushDownRequiredColumns => - // Match original case of attributes. - val attrMap = AttributeMap(fullOutput.zip(fullOutput)) - val requiredColumns = requiredByParent.map(attrMap) - r.pruneColumns(requiredColumns.toStructType) - case _ => - } + pushDownRequiredColumns(filterPushed, filterPushed.outputSet) + // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. + RemoveRedundantProject(filterPushed) + } + + // TODO: nested fields pruning + private def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: AttributeSet): Unit = { + plan match { + case Project(projectList, child) => + val required = projectList.flatMap(_.references) + pushDownRequiredColumns(child, AttributeSet(required)) + + case Filter(condition, child) => + val required = requiredByParent ++ condition.references + pushDownRequiredColumns(child, required) - // TODO: there may be more operators can be used to calculate required columns, we can add - // more and more in the future. - case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.output)) + case relation: DataSourceV2Relation => relation.reader match { + case reader: SupportsPushDownRequiredColumns => + val requiredColumns = relation.output.filter(requiredByParent.contains) + reader.pruneColumns(requiredColumns.toStructType) + + case _ => } - } - pushDownRequiredColumns(filterPushed, filterPushed.output) - // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. - RemoveRedundantProject(filterPushed) + // TODO: there may be more operators that can be used to calculate the required columns. We + // can add more and more in the future. + case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.outputSet)) + } } /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index d421f7d19563f..172e5d5eebcbe 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -32,11 +32,12 @@ public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, + public class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, SupportsPushDownFilters { - private StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); - private Filter[] filters = new Filter[0]; + // Exposed for testing. + public StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); + public Filter[] filters = new Filter[0]; @Override public StructType readSchema() { @@ -50,8 +51,26 @@ public void pruneColumns(StructType requiredSchema) { @Override public Filter[] pushFilters(Filter[] filters) { - this.filters = filters; - return new Filter[0]; + Filter[] supported = Arrays.stream(filters).filter(f -> { + if (f instanceof GreaterThan) { + GreaterThan gt = (GreaterThan) f; + return gt.attribute().equals("i") && gt.value() instanceof Integer; + } else { + return false; + } + }).toArray(Filter[]::new); + + Filter[] unsupported = Arrays.stream(filters).filter(f -> { + if (f instanceof GreaterThan) { + GreaterThan gt = (GreaterThan) f; + return !gt.attribute().equals("i") || !(gt.value() instanceof Integer); + } else { + return true; + } + }).toArray(Filter[]::new); + + this.filters = supported; + return unsupported; } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 23147fffe8a08..eccd45442a3b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -21,11 +21,13 @@ import java.util.{ArrayList, List => JList} import test.org.apache.spark.sql.sources.v2._ -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.SparkException +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.functions._ import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.partitioning.{ClusteredDistribution, Distribution, Partitioning} @@ -48,14 +50,72 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("advanced implementation") { + def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] + }.head + } + + def getJavaReader(query: DataFrame): JavaAdvancedDataSourceV2#Reader = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[JavaAdvancedDataSourceV2#Reader] + }.head + } + Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() checkAnswer(df, (0 until 10).map(i => Row(i, -i))) - checkAnswer(df.select('j), (0 until 10).map(i => Row(-i))) - checkAnswer(df.filter('i > 3), (4 until 10).map(i => Row(i, -i))) - checkAnswer(df.select('j).filter('i > 6), (7 until 10).map(i => Row(-i))) - checkAnswer(df.select('i).filter('i > 10), Nil) + + val q1 = df.select('j) + checkAnswer(q1, (0 until 10).map(i => Row(-i))) + if (cls == classOf[AdvancedDataSourceV2]) { + val reader = getReader(q1) + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) + } else { + val reader = getJavaReader(q1) + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) + } + + val q2 = df.filter('i > 3) + checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) + if (cls == classOf[AdvancedDataSourceV2]) { + val reader = getReader(q2) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i", "j")) + } else { + val reader = getJavaReader(q2) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i", "j")) + } + + val q3 = df.select('i).filter('i > 6) + checkAnswer(q3, (7 until 10).map(i => Row(i))) + if (cls == classOf[AdvancedDataSourceV2]) { + val reader = getReader(q3) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i")) + } else { + val reader = getJavaReader(q3) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i")) + } + + val q4 = df.select('j).filter('j < -10) + checkAnswer(q4, Nil) + if (cls == classOf[AdvancedDataSourceV2]) { + val reader = getReader(q4) + // 'j < 10 is not supported by the testing data source. + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) + } else { + val reader = getJavaReader(q4) + // 'j < 10 is not supported by the testing data source. + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) + } } } } @@ -223,6 +283,39 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val df2 = df.select(($"i" + 1).as("k"), $"j") checkAnswer(df.join(df2, "j"), (0 until 10).map(i => Row(-i, i, i + 1))) } + + test("SPARK-23301: column pruning with arbitrary expressions") { + def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] + }.head + } + + val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() + + val q1 = df.select('i + 1) + checkAnswer(q1, (1 until 11).map(i => Row(i))) + val reader1 = getReader(q1) + assert(reader1.requiredSchema.fieldNames === Seq("i")) + + val q2 = df.select(lit(1)) + checkAnswer(q2, (0 until 10).map(i => Row(1))) + val reader2 = getReader(q2) + assert(reader2.requiredSchema.isEmpty) + + // 'j === 1 can't be pushed down, but we should still be able do column pruning + val q3 = df.filter('j === -1).select('j * 2) + checkAnswer(q3, Row(-2)) + val reader3 = getReader(q3) + assert(reader3.filters.isEmpty) + assert(reader3.requiredSchema.fieldNames === Seq("j")) + + // column pruning should work with other operators. + val q4 = df.sort('i).limit(1).select('i + 1) + checkAnswer(q4, Row(1)) + val reader4 = getReader(q4) + assert(reader4.requiredSchema.fieldNames === Seq("i")) + } } class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { @@ -270,8 +363,12 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { } override def pushFilters(filters: Array[Filter]): Array[Filter] = { - this.filters = filters - Array.empty + val (supported, unsupported) = filters.partition { + case GreaterThan("i", _: Int) => true + case _ => false + } + this.filters = supported + unsupported } override def pushedFilters(): Array[Filter] = filters From b9503fcbb3f4a3ce263164d1f11a8e99b9ca5710 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 2 Feb 2018 22:43:28 +0800 Subject: [PATCH 276/774] [SPARK-23312][SQL] add a config to turn off vectorized cache reader ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-23309 reported a performance regression about cached table in Spark 2.3. While the investigating is still going on, this PR adds a conf to turn off the vectorized cache reader, to unblock the 2.3 release. ## How was this patch tested? a new test Author: Wenchen Fan Closes #20483 from cloud-fan/cache. --- .../org/apache/spark/sql/internal/SQLConf.scala | 8 ++++++++ .../columnar/InMemoryTableScanExec.scala | 2 +- .../org/apache/spark/sql/CachedTableSuite.scala | 15 +++++++++++++-- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 90654e67457e0..1e2501ee7757d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -141,6 +141,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val CACHE_VECTORIZED_READER_ENABLED = + buildConf("spark.sql.inMemoryColumnarStorage.enableVectorizedReader") + .doc("Enables vectorized reader for columnar caching.") + .booleanConf + .createWithDefault(true) + val COLUMN_VECTOR_OFFHEAP_ENABLED = buildConf("spark.sql.columnVector.offheap.enabled") .internal() @@ -1272,6 +1278,8 @@ class SQLConf extends Serializable with Logging { def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) + def cacheVectorizedReaderEnabled: Boolean = getConf(CACHE_VECTORIZED_READER_ENABLED) + def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) def targetPostShuffleInputSize: Long = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index c167f1e7dc621..e972f8b30d87c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -54,7 +54,7 @@ case class InMemoryTableScanExec( override val supportsBatch: Boolean = { // In the initial implementation, for ease of review // support only primitive data types and # of fields is less than wholeStageMaxNumFields - relation.schema.fields.forall(f => f.dataType match { + conf.cacheVectorizedReaderEnabled && relation.schema.fields.forall(f => f.dataType match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true case _ => false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 72fe0f42801f1..9f27fa09127af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -21,8 +21,6 @@ import scala.collection.mutable.HashSet import scala.concurrent.duration._ import scala.language.postfixOps -import org.scalatest.concurrent.Eventually._ - import org.apache.spark.CleanerListener import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression @@ -30,6 +28,7 @@ import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.{AccumulatorContext, Utils} @@ -782,4 +781,16 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext assert(getNumInMemoryRelations(cachedDs2) == 1) } } + + test("SPARK-23312: vectorized cache reader can be disabled") { + Seq(true, false).foreach { vectorized => + withSQLConf(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { + val df = spark.range(10).cache() + df.queryExecution.executedPlan.foreach { + case i: InMemoryTableScanExec => assert(i.supportsBatch == vectorized) + case _ => + } + } + } + } } From dd52681bf542386711609cb037a55b3d264eddef Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 2 Feb 2018 09:10:50 -0600 Subject: [PATCH 277/774] [SPARK-23253][CORE][SHUFFLE] Only write shuffle temporary index file when there is not an existing one ## What changes were proposed in this pull request? Shuffle Index temporay file is used for atomic creating shuffle index file, it is not needed when the index file already exists after another attempts of same task had it done. ## How was this patch tested? exitsting ut cc squito Author: Kent Yao Closes #20422 from yaooqinn/SPARK-23253. --- .../shuffle/IndexShuffleBlockResolver.scala | 27 ++++----- .../sort/IndexShuffleBlockResolverSuite.scala | 59 ++++++++++++++----- 2 files changed, 56 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 266ee42e39cca..c5f3f6e2b42b6 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -141,19 +141,6 @@ private[spark] class IndexShuffleBlockResolver( val indexFile = getIndexFile(shuffleId, mapId) val indexTmp = Utils.tempFileWith(indexFile) try { - val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) - Utils.tryWithSafeFinally { - // We take in lengths of each block, need to convert it to offsets. - var offset = 0L - out.writeLong(offset) - for (length <- lengths) { - offset += length - out.writeLong(offset) - } - } { - out.close() - } - val dataFile = getDataFile(shuffleId, mapId) // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure // the following check and rename are atomic. @@ -166,10 +153,22 @@ private[spark] class IndexShuffleBlockResolver( if (dataTmp != null && dataTmp.exists()) { dataTmp.delete() } - indexTmp.delete() } else { // This is the first successful attempt in writing the map outputs for this task, // so override any existing index and data files with the ones we wrote. + val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) + Utils.tryWithSafeFinally { + // We take in lengths of each block, need to convert it to offsets. + var offset = 0L + out.writeLong(offset) + for (length <- lengths) { + offset += length + out.writeLong(offset) + } + } { + out.close() + } + if (indexFile.exists()) { indexFile.delete() } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index d21ce73f4021e..4ce379b76b551 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle.sort -import java.io.{File, FileInputStream, FileOutputStream} +import java.io.{DataInputStream, File, FileInputStream, FileOutputStream} import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS @@ -64,6 +64,9 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } test("commit shuffle files multiple times") { + val shuffleId = 1 + val mapId = 2 + val idxName = s"shuffle_${shuffleId}_${mapId}_0.index" val resolver = new IndexShuffleBlockResolver(conf, blockManager) val lengths = Array[Long](10, 0, 20) val dataTmp = File.createTempFile("shuffle", null, tempDir) @@ -73,9 +76,13 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out.close() } - resolver.writeIndexFileAndCommit(1, 2, lengths, dataTmp) + resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) - val dataFile = resolver.getDataFile(1, 2) + val indexFile = new File(tempDir.getAbsolutePath, idxName) + val dataFile = resolver.getDataFile(shuffleId, mapId) + + assert(indexFile.exists()) + assert(indexFile.length() === (lengths.length + 1) * 8) assert(dataFile.exists()) assert(dataFile.length() === 30) assert(!dataTmp.exists()) @@ -89,7 +96,9 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out2.close() } - resolver.writeIndexFileAndCommit(1, 2, lengths2, dataTmp2) + resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths2, dataTmp2) + + assert(indexFile.length() === (lengths.length + 1) * 8) assert(lengths2.toSeq === lengths.toSeq) assert(dataFile.exists()) assert(dataFile.length() === 30) @@ -97,18 +106,27 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa // The dataFile should be the previous one val firstByte = new Array[Byte](1) - val in = new FileInputStream(dataFile) + val dataIn = new FileInputStream(dataFile) Utils.tryWithSafeFinally { - in.read(firstByte) + dataIn.read(firstByte) } { - in.close() + dataIn.close() } assert(firstByte(0) === 0) + // The index file should not change + val indexIn = new DataInputStream(new FileInputStream(indexFile)) + Utils.tryWithSafeFinally { + indexIn.readLong() // the first offset is always 0 + assert(indexIn.readLong() === 10, "The index file should not change") + } { + indexIn.close() + } + // remove data file dataFile.delete() - val lengths3 = Array[Long](10, 10, 15) + val lengths3 = Array[Long](7, 10, 15, 3) val dataTmp3 = File.createTempFile("shuffle", null, tempDir) val out3 = new FileOutputStream(dataTmp3) Utils.tryWithSafeFinally { @@ -117,20 +135,29 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out3.close() } - resolver.writeIndexFileAndCommit(1, 2, lengths3, dataTmp3) + resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths3, dataTmp3) + assert(indexFile.length() === (lengths3.length + 1) * 8) assert(lengths3.toSeq != lengths.toSeq) assert(dataFile.exists()) assert(dataFile.length() === 35) - assert(!dataTmp2.exists()) + assert(!dataTmp3.exists()) - // The dataFile should be the previous one - val firstByte2 = new Array[Byte](1) - val in2 = new FileInputStream(dataFile) + // The dataFile should be the new one, since we deleted the dataFile from the first attempt + val dataIn2 = new FileInputStream(dataFile) + Utils.tryWithSafeFinally { + dataIn2.read(firstByte) + } { + dataIn2.close() + } + assert(firstByte(0) === 2) + + // The index file should be updated, since we deleted the dataFile from the first attempt + val indexIn2 = new DataInputStream(new FileInputStream(indexFile)) Utils.tryWithSafeFinally { - in2.read(firstByte2) + indexIn2.readLong() // the first offset is always 0 + assert(indexIn2.readLong() === 7, "The index file should be updated") } { - in2.close() + indexIn2.close() } - assert(firstByte2(0) === 2) } } From eefec93d193d43d5b71b8f8a4b1060286da971dd Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 2 Feb 2018 10:17:51 -0600 Subject: [PATCH 278/774] [SPARK-23295][BUILD][MINOR] Exclude Waring message when generating versions in make-distribution.sh ## What changes were proposed in this pull request? When we specified a wrong profile to make a spark distribution, such as `-Phadoop1000`, we will get an odd package named like `spark-[WARNING] The requested profile "hadoop1000" could not be activated because it does not exist.-bin-hadoop-2.7.tgz`, which actually should be `"spark-$VERSION-bin-$NAME.tgz"` ## How was this patch tested? ### before ``` build/mvn help:evaluate -Dexpression=scala.binary.version -Phadoop1000 2>/dev/null | grep -v "INFO" | tail -n 1 [WARNING] The requested profile "hadoop1000" could not be activated because it does not exist. ``` ``` build/mvn help:evaluate -Dexpression=project.version -Phadoop1000 2>/dev/null | grep -v "INFO" | tail -n 1 [WARNING] The requested profile "hadoop1000" could not be activated because it does not exist. ``` ### after ``` build/mvn help:evaluate -Dexpression=project.version -Phadoop1000 2>/dev/null | grep -v "INFO" | grep -v "WARNING" | tail -n 1 2.4.0-SNAPSHOT ``` ``` build/mvn help:evaluate -Dexpression=scala.binary.version -Dscala.binary.version=2.11.1 2>/dev/null | grep -v "INFO" | grep -v "WARNING" | tail -n 1 2.11.1 ``` cloud-fan srowen Author: Kent Yao Closes #20469 from yaooqinn/dist-minor. --- dev/make-distribution.sh | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 7245163ea2a51..8b02446b2f15f 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -117,15 +117,21 @@ if [ ! "$(command -v "$MVN")" ] ; then exit -1; fi -VERSION=$("$MVN" help:evaluate -Dexpression=project.version $@ 2>/dev/null | grep -v "INFO" | tail -n 1) +VERSION=$("$MVN" help:evaluate -Dexpression=project.version $@ 2>/dev/null\ + | grep -v "INFO"\ + | grep -v "WARNING"\ + | tail -n 1) SCALA_VERSION=$("$MVN" help:evaluate -Dexpression=scala.binary.version $@ 2>/dev/null\ | grep -v "INFO"\ + | grep -v "WARNING"\ | tail -n 1) SPARK_HADOOP_VERSION=$("$MVN" help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ | grep -v "INFO"\ + | grep -v "WARNING"\ | tail -n 1) SPARK_HIVE=$("$MVN" help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\ | grep -v "INFO"\ + | grep -v "WARNING"\ | fgrep --count "hive";\ # Reset exit status to 0, otherwise the script stops here if the last grep finds nothing\ # because we use "set -o pipefail" From eaf35de2471fac4337dd2920026836d52b1ec847 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 2 Feb 2018 17:37:51 -0800 Subject: [PATCH 279/774] [SPARK-23064][SS][DOCS] Stream-stream joins Documentation - follow up ## What changes were proposed in this pull request? Further clarification of caveats in using stream-stream outer joins. ## How was this patch tested? N/A Author: Tathagata Das Closes #20494 from tdas/SPARK-23064-2. --- docs/structured-streaming-programming-guide.md | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 62589a62ac4c4..48d6d0b542cc0 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1346,10 +1346,20 @@ joined <- join( -However, note that the outer NULL results will be generated with a delay (depends on the specified -watermark delay and the time range condition) because the engine has to wait for that long to ensure + +There are a few points to note regarding outer joins. + +- *The outer NULL results will be generated with a delay that depends on the specified watermark +delay and the time range condition.* This is because the engine has to wait for that long to ensure there were no matches and there will be no more matches in future. +- In the current implementation in the micro-batch engine, watermarks are advanced at the end of a +micro-batch, and the next micro-batch uses the updated watermark to clean up state and output +outer results. Since we trigger a micro-batch only when there is new data to be processed, the +generation of the outer result may get delayed if there no new data being received in the stream. +*In short, if any of the two input streams being joined does not receive data for a while, the +outer (both cases, left or right) output may get delayed.* + ##### Support matrix for joins in streaming queries
{executor.map(_.isBlacklisted).getOrElse(false)}for applicationfor stagefalse
From 3ff83ad43a704cc3354ef9783e711c065e2a1a22 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 2 Feb 2018 20:36:27 -0800 Subject: [PATCH 280/774] [SQL] Minor doc update: Add an example in DataFrameReader.schema ## What changes were proposed in this pull request? This patch adds a small example to the schema string definition of schema function. It isn't obvious how to use it, so an example would be useful. ## How was this patch tested? N/A - doc only. Author: Reynold Xin Closes #20491 from rxin/schema-doc. --- .../src/main/scala/org/apache/spark/sql/DataFrameReader.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 46b5f54a33f74..fcaf8d618c168 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -74,6 +74,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * infer the input schema automatically from data. By specifying the schema here, the underlying * data source can skip the schema inference step, and thus speed up data loading. * + * {{{ + * spark.read.schema("a INT, b STRING, c DOUBLE").csv("test.csv") + * }}} + * * @since 2.3.0 */ def schema(schemaString: String): DataFrameReader = { From fe73cb4b439169f16cc24cd851a11fd398ce7edf Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 2 Feb 2018 20:49:08 -0800 Subject: [PATCH 281/774] [SPARK-23317][SQL] rename ContinuousReader.setOffset to setStartOffset ## What changes were proposed in this pull request? In the document of `ContinuousReader.setOffset`, we say this method is used to specify the start offset. We also have a `ContinuousReader.getStartOffset` to get the value back. I think it makes more sense to rename `ContinuousReader.setOffset` to `setStartOffset`. ## How was this patch tested? N/A Author: Wenchen Fan Closes #20486 from cloud-fan/rename. --- .../org/apache/spark/sql/kafka010/KafkaContinuousReader.scala | 2 +- .../sql/sources/v2/reader/streaming/ContinuousReader.java | 4 ++-- .../execution/streaming/continuous/ContinuousExecution.scala | 2 +- .../streaming/continuous/ContinuousRateStreamSource.scala | 2 +- .../spark/sql/execution/streaming/RateSourceV2Suite.scala | 2 +- .../sql/streaming/sources/StreamingDataSourceV2Suite.scala | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 41c443bc12120..b049a054cb40e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -71,7 +71,7 @@ class KafkaContinuousReader( override def readSchema: StructType = KafkaOffsetReader.kafkaSchema private var offset: Offset = _ - override def setOffset(start: ju.Optional[Offset]): Unit = { + override def setStartOffset(start: ju.Optional[Offset]): Unit = { offset = start.orElse { val offsets = initialOffsets match { case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java index d1d1e7ffd1dd4..7fe7f00ac2fa8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java @@ -51,12 +51,12 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceReader * start from the first record after the provided offset, or from an implementation-defined * inferred starting point if no offset is provided. */ - void setOffset(Optional start); + void setStartOffset(Optional start); /** * Return the specified or inferred start offset for this reader. * - * @throws IllegalStateException if setOffset has not been called + * @throws IllegalStateException if setStartOffset has not been called */ Offset getStartOffset(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 08c81419a9d34..ed22b9100497a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -181,7 +181,7 @@ class ContinuousExecution( val loggedOffset = offsets.offsets(0) val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) - reader.setOffset(java.util.Optional.ofNullable(realOffset.orNull)) + reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) new StreamingDataSourceV2Relation(newOutput, reader) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 0eaaa4889ba9e..b63d8d3e20650 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -61,7 +61,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) private var offset: Offset = _ - override def setOffset(offset: java.util.Optional[Offset]): Unit = { + override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index 3158995ec62f1..0d68d9c3138aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -160,7 +160,7 @@ class RateSourceV2Suite extends StreamTest { test("continuous data") { val reader = new RateStreamContinuousReader( new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - reader.setOffset(Optional.empty()) + reader.setStartOffset(Optional.empty()) val tasks = reader.createDataReaderFactories() assert(tasks.size == 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index cb873ab688e96..51f44fa6285e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -43,7 +43,7 @@ case class FakeReader() extends MicroBatchReader with ContinuousReader { def readSchema(): StructType = StructType(Seq()) def stop(): Unit = {} def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) - def setOffset(start: Optional[Offset]): Unit = {} + def setStartOffset(start: Optional[Offset]): Unit = {} def createDataReaderFactories(): java.util.ArrayList[DataReaderFactory[Row]] = { throw new IllegalStateException("fake source - cannot actually read") From 63b49fa2e599080c2ba7d5189f9dde20a2e01fb4 Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Sat, 3 Feb 2018 00:02:03 -0800 Subject: [PATCH 282/774] [SPARK-23311][SQL][TEST] add FilterFunction test case for test CombineTypedFilters ## What changes were proposed in this pull request? In the current test case for CombineTypedFilters, we lack the test of FilterFunction, so let's add it. In addition, in TypedFilterOptimizationSuite's existing test cases, Let's extract a common LocalRelation. ## How was this patch tested? add new test cases. Author: caoxuewen Closes #20482 from heary-cao/TypedFilterOptimizationSuite. --- .../spark/sql/catalyst/dsl/package.scala | 3 + .../TypedFilterOptimizationSuite.scala | 95 ++++++++++++++++--- 2 files changed, 84 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 59cb26d5e6c36..efb2eba655e15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions +import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ @@ -301,6 +302,8 @@ package object dsl { def filter[T : Encoder](func: T => Boolean): LogicalPlan = TypedFilter(func, logicalPlan) + def filter[T : Encoder](func: FilterFunction[T]): LogicalPlan = TypedFilter(func, logicalPlan) + def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan) def deserialize[T : Encoder]: LogicalPlan = CatalystSerde.deserialize[T](logicalPlan) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala index 56f096f3ecf8c..5fc99a3a57c0f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.reflect.runtime.universe.TypeTag +import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -38,18 +39,19 @@ class TypedFilterOptimizationSuite extends PlanTest { implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]() + val testRelation = LocalRelation('_1.int, '_2.int) + test("filter after serialize with the same object type") { - val input = LocalRelation('_1.int, '_2.int) val f = (i: (Int, Int)) => i._1 > 0 - val query = input + val query = testRelation .deserialize[(Int, Int)] .serialize[(Int, Int)] .filter(f).analyze val optimized = Optimize.execute(query) - val expected = input + val expected = testRelation .deserialize[(Int, Int)] .where(callFunction(f, BooleanType, 'obj)) .serialize[(Int, Int)].analyze @@ -58,10 +60,9 @@ class TypedFilterOptimizationSuite extends PlanTest { } test("filter after serialize with different object types") { - val input = LocalRelation('_1.int, '_2.int) val f = (i: OtherTuple) => i._1 > 0 - val query = input + val query = testRelation .deserialize[(Int, Int)] .serialize[(Int, Int)] .filter(f).analyze @@ -70,17 +71,16 @@ class TypedFilterOptimizationSuite extends PlanTest { } test("filter before deserialize with the same object type") { - val input = LocalRelation('_1.int, '_2.int) val f = (i: (Int, Int)) => i._1 > 0 - val query = input + val query = testRelation .filter(f) .deserialize[(Int, Int)] .serialize[(Int, Int)].analyze val optimized = Optimize.execute(query) - val expected = input + val expected = testRelation .deserialize[(Int, Int)] .where(callFunction(f, BooleanType, 'obj)) .serialize[(Int, Int)].analyze @@ -89,10 +89,9 @@ class TypedFilterOptimizationSuite extends PlanTest { } test("filter before deserialize with different object types") { - val input = LocalRelation('_1.int, '_2.int) val f = (i: OtherTuple) => i._1 > 0 - val query = input + val query = testRelation .filter(f) .deserialize[(Int, Int)] .serialize[(Int, Int)].analyze @@ -101,21 +100,89 @@ class TypedFilterOptimizationSuite extends PlanTest { } test("back to back filter with the same object type") { - val input = LocalRelation('_1.int, '_2.int) val f1 = (i: (Int, Int)) => i._1 > 0 val f2 = (i: (Int, Int)) => i._2 > 0 - val query = input.filter(f1).filter(f2).analyze + val query = testRelation.filter(f1).filter(f2).analyze val optimized = Optimize.execute(query) assert(optimized.collect { case t: TypedFilter => t }.length == 1) } test("back to back filter with different object types") { - val input = LocalRelation('_1.int, '_2.int) val f1 = (i: (Int, Int)) => i._1 > 0 val f2 = (i: OtherTuple) => i._2 > 0 - val query = input.filter(f1).filter(f2).analyze + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 2) + } + + test("back to back FilterFunction with the same object type") { + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._1 > 0 + } + val f2 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._2 > 0 + } + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 1) + } + + test("back to back FilterFunction with different object types") { + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._1 > 0 + } + val f2 = new FilterFunction[OtherTuple] { + override def call(value: OtherTuple): Boolean = value._2 > 0 + } + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 2) + } + + test("FilterFunction and filter with the same object type") { + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._1 > 0 + } + val f2 = (i: (Int, Int)) => i._2 > 0 + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 1) + } + + test("FilterFunction and filter with different object types") { + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._1 > 0 + } + val f2 = (i: OtherTuple) => i._2 > 0 + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 2) + } + + test("filter and FilterFunction with the same object type") { + val f2 = (i: (Int, Int)) => i._1 > 0 + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._2 > 0 + } + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 1) + } + + test("filter and FilterFunction with different object types") { + val f2 = (i: (Int, Int)) => i._1 > 0 + val f1 = new FilterFunction[OtherTuple] { + override def call(value: OtherTuple): Boolean = value._2 > 0 + } + + val query = testRelation.filter(f1).filter(f2).analyze val optimized = Optimize.execute(query) assert(optimized.collect { case t: TypedFilter => t }.length == 2) } From 522e0b1866a0298669c83de5a47ba380dc0b7c84 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 3 Feb 2018 00:04:00 -0800 Subject: [PATCH 283/774] [SPARK-23305][SQL][TEST] Test `spark.sql.files.ignoreMissingFiles` for all file-based data sources ## What changes were proposed in this pull request? Like Parquet, all file-based data source handles `spark.sql.files.ignoreMissingFiles` correctly. We had better have a test coverage for feature parity and in order to prevent future accidental regression for all data sources. ## How was this patch tested? Pass Jenkins with a newly added test case. Author: Dongjoon Hyun Closes #20479 from dongjoon-hyun/SPARK-23305. --- .../spark/sql/FileBasedDataSourceSuite.scala | 37 +++++++++++++++++++ .../parquet/ParquetQuerySuite.scala | 33 ----------------- 2 files changed, 37 insertions(+), 33 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index c272c99ae45a8..640d6b1583663 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { @@ -92,4 +96,37 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { } } } + + allFileBasedDataSources.foreach { format => + testQuietly(s"Enabling/disabling ignoreMissingFiles using $format") { + def testIgnoreMissingFiles(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + Seq("0").toDF("a").write.format(format).save(new Path(basePath, "first").toString) + Seq("1").toDF("a").write.format(format).save(new Path(basePath, "second").toString) + val thirdPath = new Path(basePath, "third") + Seq("2").toDF("a").write.format(format).save(thirdPath.toString) + val df = spark.read.format(format).load( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + + val fs = thirdPath.getFileSystem(spark.sparkContext.hadoopConfiguration) + assert(fs.delete(thirdPath, true)) + checkAnswer(df, Seq(Row("0"), Row("1"))) + } + } + + withSQLConf(SQLConf.IGNORE_MISSING_FILES.key -> "true") { + testIgnoreMissingFiles() + } + + withSQLConf(SQLConf.IGNORE_MISSING_FILES.key -> "false") { + val exception = intercept[SparkException] { + testIgnoreMissingFiles() + } + assert(exception.getMessage().contains("does not exist")) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 6ad88ed997ce7..55b0f729be8ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -355,39 +355,6 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } - testQuietly("Enabling/disabling ignoreMissingFiles") { - def testIgnoreMissingFiles(): Unit = { - withTempDir { dir => - val basePath = dir.getCanonicalPath - spark.range(1).toDF("a").write.parquet(new Path(basePath, "first").toString) - spark.range(1, 2).toDF("a").write.parquet(new Path(basePath, "second").toString) - val thirdPath = new Path(basePath, "third") - spark.range(2, 3).toDF("a").write.parquet(thirdPath.toString) - val df = spark.read.parquet( - new Path(basePath, "first").toString, - new Path(basePath, "second").toString, - new Path(basePath, "third").toString) - - val fs = thirdPath.getFileSystem(spark.sparkContext.hadoopConfiguration) - fs.delete(thirdPath, true) - checkAnswer( - df, - Seq(Row(0), Row(1))) - } - } - - withSQLConf(SQLConf.IGNORE_MISSING_FILES.key -> "true") { - testIgnoreMissingFiles() - } - - withSQLConf(SQLConf.IGNORE_MISSING_FILES.key -> "false") { - val exception = intercept[SparkException] { - testIgnoreMissingFiles() - } - assert(exception.getMessage().contains("does not exist")) - } - } - /** * this is part of test 'Enabling/disabling ignoreCorruptFiles' but run in a loop * to increase the chance of failure From 4aaa7d40bf495317e740b6d6f9c2a55dfd03521b Mon Sep 17 00:00:00 2001 From: Shashwat Anand Date: Sat, 3 Feb 2018 10:31:04 -0800 Subject: [PATCH 284/774] [MINOR][DOC] Use raw triple double quotes around docstrings where there are occurrences of backslashes. From [PEP 257](https://www.python.org/dev/peps/pep-0257/): > For consistency, always use """triple double quotes""" around docstrings. Use r"""raw triple double quotes""" if you use any backslashes in your docstrings. For Unicode docstrings, use u"""Unicode triple-quoted strings""". For example, this is what help (kafka_wordcount) shows: ``` DESCRIPTION Counts words in UTF8 encoded, ' ' delimited text received from the network every second. Usage: kafka_wordcount.py To run this on your local machine, you need to setup Kafka and create a producer first, see http://kafka.apache.org/documentation.html#quickstart and then run the example `$ bin/spark-submit --jars external/kafka-assembly/target/scala-*/spark-streaming-kafka-assembly-*.jar examples/src/main/python/streaming/kafka_wordcount.py localhost:2181 test` ``` This is what it shows, after the fix: ``` DESCRIPTION Counts words in UTF8 encoded, '\n' delimited text received from the network every second. Usage: kafka_wordcount.py To run this on your local machine, you need to setup Kafka and create a producer first, see http://kafka.apache.org/documentation.html#quickstart and then run the example `$ bin/spark-submit --jars \ external/kafka-assembly/target/scala-*/spark-streaming-kafka-assembly-*.jar \ examples/src/main/python/streaming/kafka_wordcount.py \ localhost:2181 test` ``` The thing worth noticing is no linebreak here in the help. ## What changes were proposed in this pull request? Change triple double quotes to raw triple double quotes when there are occurrences of backslashes in docstrings. ## How was this patch tested? Manually as this is a doc fix. Author: Shashwat Anand Closes #20497 from ashashwat/docstring-fixes. --- .../main/python/sql/streaming/structured_network_wordcount.py | 2 +- .../sql/streaming/structured_network_wordcount_windowed.py | 2 +- examples/src/main/python/streaming/direct_kafka_wordcount.py | 2 +- examples/src/main/python/streaming/flume_wordcount.py | 2 +- examples/src/main/python/streaming/kafka_wordcount.py | 2 +- examples/src/main/python/streaming/network_wordcount.py | 2 +- .../src/main/python/streaming/network_wordjoinsentiments.py | 2 +- examples/src/main/python/streaming/sql_network_wordcount.py | 2 +- .../src/main/python/streaming/stateful_network_wordcount.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount.py b/examples/src/main/python/sql/streaming/structured_network_wordcount.py index afde2550587ca..c3284c1d01017 100644 --- a/examples/src/main/python/sql/streaming/structured_network_wordcount.py +++ b/examples/src/main/python/sql/streaming/structured_network_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network. Usage: structured_network_wordcount.py and describe the TCP server that Structured Streaming diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py index 02a7d3363d780..db672551504b5 100644 --- a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py +++ b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network over a sliding window of configurable duration. Each line from the network is tagged with a timestamp that is used to determine the windows into which it falls. diff --git a/examples/src/main/python/streaming/direct_kafka_wordcount.py b/examples/src/main/python/streaming/direct_kafka_wordcount.py index 7097f7f4502bd..425df309011a0 100644 --- a/examples/src/main/python/streaming/direct_kafka_wordcount.py +++ b/examples/src/main/python/streaming/direct_kafka_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text directly received from Kafka in every 2 seconds. Usage: direct_kafka_wordcount.py diff --git a/examples/src/main/python/streaming/flume_wordcount.py b/examples/src/main/python/streaming/flume_wordcount.py index d75bc6daac138..5d6e6dc36d6f9 100644 --- a/examples/src/main/python/streaming/flume_wordcount.py +++ b/examples/src/main/python/streaming/flume_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network every second. Usage: flume_wordcount.py diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py index 8d697f620f467..704f6602e2297 100644 --- a/examples/src/main/python/streaming/kafka_wordcount.py +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network every second. Usage: kafka_wordcount.py diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py index 2b48bcfd55db0..9010fafb425e6 100644 --- a/examples/src/main/python/streaming/network_wordcount.py +++ b/examples/src/main/python/streaming/network_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network every second. Usage: network_wordcount.py and describe the TCP server that Spark Streaming would connect to receive data. diff --git a/examples/src/main/python/streaming/network_wordjoinsentiments.py b/examples/src/main/python/streaming/network_wordjoinsentiments.py index b309d9fad33f5..d51a380a5d5f9 100644 --- a/examples/src/main/python/streaming/network_wordjoinsentiments.py +++ b/examples/src/main/python/streaming/network_wordjoinsentiments.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Shows the most positive words in UTF8 encoded, '\n' delimited text directly received the network every 5 seconds. The streaming data is joined with a static RDD of the AFINN word list (http://neuro.imm.dtu.dk/wiki/AFINN) diff --git a/examples/src/main/python/streaming/sql_network_wordcount.py b/examples/src/main/python/streaming/sql_network_wordcount.py index 398ac8d2d8f5e..7f12281c0e3fe 100644 --- a/examples/src/main/python/streaming/sql_network_wordcount.py +++ b/examples/src/main/python/streaming/sql_network_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Use DataFrames and SQL to count words in UTF8 encoded, '\n' delimited text received from the network every second. diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py index f8bbc659c2ea7..d7bb61e729f18 100644 --- a/examples/src/main/python/streaming/stateful_network_wordcount.py +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network every second. From 551dff2bccb65e9b3f77b986f167aec90d9a6016 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 3 Feb 2018 10:40:21 -0800 Subject: [PATCH 285/774] [SPARK-21658][SQL][PYSPARK] Revert "[] Add default None for value in na.replace in PySpark" This reverts commit 0fcde87aadc9a92e138f11583119465ca4b5c518. See the discussion in [SPARK-21658](https://issues.apache.org/jira/browse/SPARK-21658), [SPARK-19454](https://issues.apache.org/jira/browse/SPARK-19454) and https://github.com/apache/spark/pull/16793 Author: hyukjinkwon Closes #20496 from HyukjinKwon/revert-SPARK-21658. --- python/pyspark/sql/dataframe.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1496cba91b90e..2e55407b5397b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1577,16 +1577,6 @@ def replace(self, to_replace, value=None, subset=None): |null| null|null| +----+------+----+ - >>> df4.na.replace('Alice').show() - +----+------+----+ - | age|height|name| - +----+------+----+ - | 10| 80|null| - | 5| null| Bob| - |null| null| Tom| - |null| null|null| - +----+------+----+ - >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() +----+------+----+ | age|height|name| @@ -2055,7 +2045,7 @@ def fill(self, value, subset=None): fill.__doc__ = DataFrame.fillna.__doc__ - def replace(self, to_replace, value=None, subset=None): + def replace(self, to_replace, value, subset=None): return self.df.replace(to_replace, value, subset) replace.__doc__ = DataFrame.replace.__doc__ From 715047b02df0ac9ec16ab2a73481ab7f36ffc6ca Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 4 Feb 2018 17:53:31 +0900 Subject: [PATCH 286/774] [SPARK-23256][ML][PYTHON] Add columnSchema method to PySpark image reader ## What changes were proposed in this pull request? This PR proposes to add `columnSchema` in Python side too. ```python >>> from pyspark.ml.image import ImageSchema >>> ImageSchema.columnSchema.simpleString() 'struct' ``` ## How was this patch tested? Manually tested and unittest was added in `python/pyspark/ml/tests.py`. Author: hyukjinkwon Closes #20475 from HyukjinKwon/SPARK-23256. --- python/pyspark/ml/image.py | 20 +++++++++++++++++++- python/pyspark/ml/tests.py | 1 + 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 2d86c7f03860c..45c936645f2a8 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -40,6 +40,7 @@ class _ImageSchema(object): def __init__(self): self._imageSchema = None self._ocvTypes = None + self._columnSchema = None self._imageFields = None self._undefinedImageType = None @@ -49,7 +50,7 @@ def imageSchema(self): Returns the image schema. :return: a :class:`StructType` with a single column of images - named "image" (nullable). + named "image" (nullable) and having the same type returned by :meth:`columnSchema`. .. versionadded:: 2.3.0 """ @@ -75,6 +76,23 @@ def ocvTypes(self): self._ocvTypes = dict(ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes()) return self._ocvTypes + @property + def columnSchema(self): + """ + Returns the schema for the image column. + + :return: a :class:`StructType` for image column, + ``struct``. + + .. versionadded:: 2.4.0 + """ + + if self._columnSchema is None: + ctx = SparkContext._active_spark_context + jschema = ctx._jvm.org.apache.spark.ml.image.ImageSchema.columnSchema() + self._columnSchema = _parse_datatype_json_string(jschema.json()) + return self._columnSchema + @property def imageFields(self): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 1af2b91da900d..75d04785a0710 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1852,6 +1852,7 @@ def test_read_images(self): self.assertEqual(len(array), first_row[1]) self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row) self.assertEqual(df.schema, ImageSchema.imageSchema) + self.assertEqual(df.schema["image"].dataType, ImageSchema.columnSchema) expected = {'CV_8UC3': 16, 'Undefined': -1, 'CV_8U': 0, 'CV_8UC1': 0, 'CV_8UC4': 24} self.assertEqual(ImageSchema.ocvTypes, expected) expected = ['origin', 'height', 'width', 'nChannels', 'mode', 'data'] From 6fb3fd15365d43733aefdb396db205d7ccf57f75 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 4 Feb 2018 09:15:48 -0800 Subject: [PATCH 287/774] [SPARK-22036][SQL][FOLLOWUP] Fix decimalArithmeticOperations.sql ## What changes were proposed in this pull request? Fix decimalArithmeticOperations.sql test ## How was this patch tested? N/A Author: Yuming Wang Author: wangyum Author: Yuming Wang Closes #20498 from wangyum/SPARK-22036. --- .../native/decimalArithmeticOperations.sql | 6 +- .../decimalArithmeticOperations.sql.out | 140 ++++++++++-------- 2 files changed, 80 insertions(+), 66 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql index c6d8a49d4b93a..9be7fcdadfea8 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql @@ -48,8 +48,9 @@ select 12345678901234567890.0 * 12345678901234567890.0; select 1e35 / 0.1; -- arithmetic operations causing a precision loss are truncated +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345; select 123456789123456789.1234567890 * 1.123456789123456789; -select 0.001 / 9876543210987654321098765432109876543.2 +select 12345678912345.123456789123 / 0.000000012345678; -- return NULL instead of rounding, according to old Spark versions' behavior set spark.sql.decimalOperations.allowPrecisionLoss=false; @@ -74,7 +75,8 @@ select 12345678901234567890.0 * 12345678901234567890.0; select 1e35 / 0.1; -- arithmetic operations causing a precision loss return NULL +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345; select 123456789123456789.1234567890 * 1.123456789123456789; -select 0.001 / 9876543210987654321098765432109876543.2 +select 12345678912345.123456789123 / 0.000000012345678; drop table decimals_test; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out index 4d70fe19d539f..6bfdb84548d4d 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 32 +-- Number of queries: 36 -- !query 0 @@ -146,146 +146,158 @@ NULL -- !query 17 -select 123456789123456789.1234567890 * 1.123456789123456789 +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 -- !query 17 schema -struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> +struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,6)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,6))):decimal(38,6)> -- !query 17 output -138698367904130467.654320988515622621 +10012345678912345678912345678911.246907 -- !query 18 -select 0.001 / 9876543210987654321098765432109876543.2 - -set spark.sql.decimalOperations.allowPrecisionLoss=false +select 123456789123456789.1234567890 * 1.123456789123456789 -- !query 18 schema -struct<> +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> -- !query 18 output -org.apache.spark.sql.catalyst.parser.ParseException - -mismatched input 'spark' expecting (line 3, pos 4) - -== SQL == -select 0.001 / 9876543210987654321098765432109876543.2 - -set spark.sql.decimalOperations.allowPrecisionLoss=false -----^^^ +138698367904130467.654320988515622621 -- !query 19 -select id, a+b, a-b, a*b, a/b from decimals_test order by id +select 12345678912345.123456789123 / 0.000000012345678 -- !query 19 schema -struct +struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,9)> -- !query 19 output -1 1099 -899 99900 0.1001 -2 24690.246 0 152402061.885129 1 -3 1234.2234567891011 -1233.9765432108989 152.358023 0.0001 -4 123456789123456790.12345678912345679 123456789123456787.87654321087654321 138698367904130467.515623 109890109097814272.043109 +1000000073899961059796.725866332 -- !query 20 -select id, a*10, b/10 from decimals_test order by id +set spark.sql.decimalOperations.allowPrecisionLoss=false -- !query 20 schema -struct +struct -- !query 20 output -1 1000 99.9 -2 123451.23 1234.5123 -3 1.234567891011 123.41 -4 1234567891234567890 0.112345678912345679 +spark.sql.decimalOperations.allowPrecisionLoss false -- !query 21 -select 10.3 * 3.0 +select id, a+b, a-b, a*b, a/b from decimals_test order by id -- !query 21 schema -struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> +struct -- !query 21 output -30.9 +1 1099 -899 NULL 0.1001001001001001 +2 24690.246 0 NULL 1 +3 1234.2234567891011 -1233.9765432108989 NULL 0.000100037913541123 +4 123456789123456790.123456789123456789 123456789123456787.876543210876543211 NULL 109890109097814272.043109406191131436 -- !query 22 -select 10.3000 * 3.0 +select id, a*10, b/10 from decimals_test order by id -- !query 22 schema -struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> +struct -- !query 22 output -30.9 +1 1000 99.9 +2 123451.23 1234.5123 +3 1.234567891011 123.41 +4 1234567891234567890 0.1123456789123456789 -- !query 23 -select 10.30000 * 30.0 +select 10.3 * 3.0 -- !query 23 schema -struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> +struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> -- !query 23 output -309 +30.9 -- !query 24 -select 10.300000000000000000 * 3.000000000000000000 +select 10.3000 * 3.0 -- !query 24 schema -struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,34)> +struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> -- !query 24 output 30.9 -- !query 25 -select 10.300000000000000000 * 3.0000000000000000000 +select 10.30000 * 30.0 -- !query 25 schema -struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,34)> +struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> -- !query 25 output -30.9 +309 -- !query 26 -select (5e36 + 0.1) + 5e36 +select 10.300000000000000000 * 3.000000000000000000 -- !query 26 schema -struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,36)> -- !query 26 output -NULL +30.9 -- !query 27 -select (-4e36 - 0.1) - 7e36 +select 10.300000000000000000 * 3.0000000000000000000 -- !query 27 schema -struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,37)> -- !query 27 output NULL -- !query 28 -select 12345678901234567890.0 * 12345678901234567890.0 +select (5e36 + 0.1) + 5e36 -- !query 28 schema -struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> -- !query 28 output NULL -- !query 29 -select 1e35 / 0.1 +select (-4e36 - 0.1) - 7e36 -- !query 29 schema -struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> -- !query 29 output NULL -- !query 30 -select 123456789123456789.1234567890 * 1.123456789123456789 +select 12345678901234567890.0 * 12345678901234567890.0 -- !query 30 schema -struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> -- !query 30 output -138698367904130467.654320988515622621 +NULL -- !query 31 -select 0.001 / 9876543210987654321098765432109876543.2 - -drop table decimals_test +select 1e35 / 0.1 -- !query 31 schema -struct<> +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)> -- !query 31 output -org.apache.spark.sql.catalyst.parser.ParseException +NULL -mismatched input 'table' expecting (line 3, pos 5) -== SQL == -select 0.001 / 9876543210987654321098765432109876543.2 +-- !query 32 +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 +-- !query 32 schema +struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,7)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,7))):decimal(38,7)> +-- !query 32 output +NULL + + +-- !query 33 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 33 schema +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)> +-- !query 33 output +NULL + +-- !query 34 +select 12345678912345.123456789123 / 0.000000012345678 +-- !query 34 schema +struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,18)> +-- !query 34 output +NULL + + +-- !query 35 drop table decimals_test ------^^^ +-- !query 35 schema +struct<> +-- !query 35 output + From a6bf3db20773ba65cbc4f2775db7bd215e78829a Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 5 Feb 2018 18:41:49 +0800 Subject: [PATCH 288/774] [SPARK-23307][WEBUI] Sort jobs/stages/tasks/queries with the completed timestamp before cleaning up them ## What changes were proposed in this pull request? Sort jobs/stages/tasks/queries with the completed timestamp before cleaning up them to make the behavior consistent with 2.2. ## How was this patch tested? - Jenkins. - Manually ran the following codes and checked the UI for jobs/stages/tasks/queries. ``` spark.ui.retainedJobs 10 spark.ui.retainedStages 10 spark.sql.ui.retainedExecutions 10 spark.ui.retainedTasks 10 ``` ``` new Thread() { override def run() { spark.range(1, 2).foreach { i => Thread.sleep(10000) } } }.start() Thread.sleep(5000) for (_ <- 1 to 20) { new Thread() { override def run() { spark.range(1, 2).foreach { i => } } }.start() } Thread.sleep(15000) spark.range(1, 2).foreach { i => } sc.makeRDD(1 to 100, 100).foreach { i => } ``` Author: Shixiong Zhu Closes #20481 from zsxwing/SPARK-23307. --- .../spark/status/AppStatusListener.scala | 13 +-- .../org/apache/spark/status/storeTypes.scala | 7 ++ .../spark/status/AppStatusListenerSuite.scala | 90 +++++++++++++++++++ .../execution/ui/SQLAppStatusListener.scala | 4 +- .../sql/execution/ui/SQLAppStatusStore.scala | 9 +- .../ui/SQLAppStatusListenerSuite.scala | 45 ++++++++++ 6 files changed, 158 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 3e34bdc0c7b63..ab01cddfca5b0 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -875,8 +875,8 @@ private[spark] class AppStatusListener( return } - val toDelete = KVUtils.viewToSeq(kvstore.view(classOf[JobDataWrapper]), - countToDelete.toInt) { j => + val view = kvstore.view(classOf[JobDataWrapper]).index("completionTime").first(0L) + val toDelete = KVUtils.viewToSeq(view, countToDelete.toInt) { j => j.info.status != JobExecutionStatus.RUNNING && j.info.status != JobExecutionStatus.UNKNOWN } toDelete.foreach { j => kvstore.delete(j.getClass(), j.info.jobId) } @@ -888,8 +888,8 @@ private[spark] class AppStatusListener( return } - val stages = KVUtils.viewToSeq(kvstore.view(classOf[StageDataWrapper]), - countToDelete.toInt) { s => + val view = kvstore.view(classOf[StageDataWrapper]).index("completionTime").first(0L) + val stages = KVUtils.viewToSeq(view, countToDelete.toInt) { s => s.info.status != v1.StageStatus.ACTIVE && s.info.status != v1.StageStatus.PENDING } @@ -945,8 +945,9 @@ private[spark] class AppStatusListener( val countToDelete = calculateNumberToRemove(stage.savedTasks.get(), maxTasksPerStage).toInt if (countToDelete > 0) { val stageKey = Array(stage.info.stageId, stage.info.attemptNumber) - val view = kvstore.view(classOf[TaskDataWrapper]).index("stage").first(stageKey) - .last(stageKey) + val view = kvstore.view(classOf[TaskDataWrapper]) + .index(TaskIndexNames.COMPLETION_TIME) + .parent(stageKey) // Try to delete finished tasks only. val toDelete = KVUtils.viewToSeq(view, countToDelete) { t => diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index c9cb996a55fcc..412644d3657b5 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -73,6 +73,8 @@ private[spark] class JobDataWrapper( @JsonIgnore @KVIndex private def id: Int = info.jobId + @JsonIgnore @KVIndex("completionTime") + private def completionTime: Long = info.completionTime.map(_.getTime).getOrElse(-1L) } private[spark] class StageDataWrapper( @@ -90,6 +92,8 @@ private[spark] class StageDataWrapper( @JsonIgnore @KVIndex("active") private def active: Boolean = info.status == StageStatus.ACTIVE + @JsonIgnore @KVIndex("completionTime") + private def completionTime: Long = info.completionTime.map(_.getTime).getOrElse(-1L) } /** @@ -134,6 +138,7 @@ private[spark] object TaskIndexNames { final val STAGE = "stage" final val STATUS = "sta" final val TASK_INDEX = "idx" + final val COMPLETION_TIME = "ct" } /** @@ -337,6 +342,8 @@ private[spark] class TaskDataWrapper( @JsonIgnore @KVIndex(value = TaskIndexNames.ERROR, parent = TaskIndexNames.STAGE) private def error: String = if (errorMessage.isDefined) errorMessage.get else "" + @JsonIgnore @KVIndex(value = TaskIndexNames.COMPLETION_TIME, parent = TaskIndexNames.STAGE) + private def completionTime: Long = launchTime + duration } private[spark] class RDDStorageInfoWrapper(val info: RDDStorageInfo) { diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 042bba7f226fd..b74d6ee2ec836 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -1010,6 +1010,96 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } + test("eviction should respect job completion time") { + val testConf = conf.clone().set(MAX_RETAINED_JOBS, 2) + val listener = new AppStatusListener(store, testConf, true) + + // Start job 1 and job 2 + time += 1 + listener.onJobStart(SparkListenerJobStart(1, time, Nil, null)) + time += 1 + listener.onJobStart(SparkListenerJobStart(2, time, Nil, null)) + + // Stop job 2 before job 1 + time += 1 + listener.onJobEnd(SparkListenerJobEnd(2, time, JobSucceeded)) + time += 1 + listener.onJobEnd(SparkListenerJobEnd(1, time, JobSucceeded)) + + // Start job 3 and job 2 should be evicted. + time += 1 + listener.onJobStart(SparkListenerJobStart(3, time, Nil, null)) + assert(store.count(classOf[JobDataWrapper]) === 2) + intercept[NoSuchElementException] { + store.read(classOf[JobDataWrapper], 2) + } + } + + test("eviction should respect stage completion time") { + val testConf = conf.clone().set(MAX_RETAINED_STAGES, 2) + val listener = new AppStatusListener(store, testConf, true) + + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") + val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2") + val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3") + + // Start stage 1 and stage 2 + time += 1 + stage1.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties())) + time += 1 + stage2.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage2, new Properties())) + + // Stop stage 2 before stage 1 + time += 1 + stage2.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage2)) + time += 1 + stage1.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage1)) + + // Start stage 3 and stage 2 should be evicted. + stage3.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage3, new Properties())) + assert(store.count(classOf[StageDataWrapper]) === 2) + intercept[NoSuchElementException] { + store.read(classOf[StageDataWrapper], Array(2, 0)) + } + } + + test("eviction should respect task completion time") { + val testConf = conf.clone().set(MAX_RETAINED_TASKS_PER_STAGE, 2) + val listener = new AppStatusListener(store, testConf, true) + + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") + stage1.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties())) + + // Start task 1 and task 2 + val tasks = createTasks(3, Array("1")) + tasks.take(2).foreach { task => + listener.onTaskStart(SparkListenerTaskStart(stage1.stageId, stage1.attemptNumber, task)) + } + + // Stop task 2 before task 1 + time += 1 + tasks(1).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd( + SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(1), null)) + time += 1 + tasks(0).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd( + SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(0), null)) + + // Start task 3 and task 2 should be evicted. + listener.onTaskStart(SparkListenerTaskStart(stage1.stageId, stage1.attemptNumber, tasks(2))) + assert(store.count(classOf[TaskDataWrapper]) === 2) + intercept[NoSuchElementException] { + store.read(classOf[TaskDataWrapper], tasks(1).id) + } + } + test("driver logs") { val listener = new AppStatusListener(store, conf, true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 73a105266e1c1..53fb9a0cc21cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -332,8 +332,8 @@ class SQLAppStatusListener( return } - val toDelete = KVUtils.viewToSeq(kvstore.view(classOf[SQLExecutionUIData]), - countToDelete.toInt) { e => e.completionTime.isDefined } + val view = kvstore.view(classOf[SQLExecutionUIData]).index("completionTime").first(0L) + val toDelete = KVUtils.viewToSeq(view, countToDelete.toInt)(_.completionTime.isDefined) toDelete.foreach { e => kvstore.delete(e.getClass(), e.executionId) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala index 910f2e52fdbb3..9a76584717f42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala @@ -23,11 +23,12 @@ import java.util.Date import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import com.fasterxml.jackson.annotation.JsonIgnore import com.fasterxml.jackson.databind.annotation.JsonDeserialize import org.apache.spark.JobExecutionStatus import org.apache.spark.status.KVUtils.KVIndexParam -import org.apache.spark.util.kvstore.KVStore +import org.apache.spark.util.kvstore.{KVIndex, KVStore} /** * Provides a view of a KVStore with methods that make it easy to query SQL-specific state. There's @@ -90,7 +91,11 @@ class SQLExecutionUIData( * from the SQL listener instance. */ @JsonDeserialize(keyAs = classOf[JLong]) - val metricValues: Map[Long, String]) + val metricValues: Map[Long, String]) { + + @JsonIgnore @KVIndex("completionTime") + private def completionTimeIndex: Long = completionTime.map(_.getTime).getOrElse(-1L) +} class SparkPlanGraphWrapper( @KVIndexParam val executionId: Long, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 7d84f45d36bee..85face3994fd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, SparkPlanInfo, SQLExecution} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.internal.StaticSQLConf.UI_RETAINED_EXECUTIONS import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.status.ElementTrackingStore import org.apache.spark.status.config._ @@ -510,6 +511,50 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with } } + test("eviction should respect execution completion time") { + val conf = sparkContext.conf.clone().set(UI_RETAINED_EXECUTIONS.key, "2") + val store = new ElementTrackingStore(new InMemoryStore, conf) + val listener = new SQLAppStatusListener(conf, store, live = true) + val statusStore = new SQLAppStatusStore(store, Some(listener)) + + var time = 0 + val df = createTestDataFrame + // Start execution 1 and execution 2 + time += 1 + listener.onOtherEvent(SparkListenerSQLExecutionStart( + 1, + "test", + "test", + df.queryExecution.toString, + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + time)) + time += 1 + listener.onOtherEvent(SparkListenerSQLExecutionStart( + 2, + "test", + "test", + df.queryExecution.toString, + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + time)) + + // Stop execution 2 before execution 1 + time += 1 + listener.onOtherEvent(SparkListenerSQLExecutionEnd(2, time)) + time += 1 + listener.onOtherEvent(SparkListenerSQLExecutionEnd(1, time)) + + // Start execution 3 and execution 2 should be evicted. + time += 1 + listener.onOtherEvent(SparkListenerSQLExecutionStart( + 3, + "test", + "test", + df.queryExecution.toString, + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + time)) + assert(statusStore.executionsCount === 2) + assert(statusStore.execution(2) === None) + } } From 03b7e120dd7ff7848c936c7a23644da5bd7219ab Mon Sep 17 00:00:00 2001 From: Sital Kedia Date: Mon, 5 Feb 2018 10:19:18 -0800 Subject: [PATCH 289/774] [SPARK-23310][CORE] Turn off read ahead input stream for unshafe shuffle reader To fix regression for TPC-DS queries Author: Sital Kedia Closes #20492 from sitalkedia/turn_off_async_inputstream. --- .../util/collection/unsafe/sort/UnsafeSorterSpillReader.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index e2f48e5508af6..71e7c7a95ebdb 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -76,8 +76,10 @@ public UnsafeSorterSpillReader( SparkEnv.get() == null ? 0.5 : SparkEnv.get().conf().getDouble("spark.unsafe.sorter.spill.read.ahead.fraction", 0.5); + // SPARK-23310: Disable read-ahead input stream, because it is causing lock contention and perf regression for + // TPC-DS queries. final boolean readAheadEnabled = SparkEnv.get() != null && - SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", true); + SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", false); final InputStream bs = new NioBufferedFileInputStream(file, (int) bufferSizeBytes); From c2766b07b4b9ed976931966a79c65043e81cf694 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Mon, 5 Feb 2018 14:17:11 -0800 Subject: [PATCH 290/774] [SPARK-23330][WEBUI] Spark UI SQL executions page throws NPE ## What changes were proposed in this pull request? Spark SQL executions page throws the following error and the page crashes: ``` HTTP ERROR 500 Problem accessing /SQL/. Reason: Server Error Caused by: java.lang.NullPointerException at scala.collection.immutable.StringOps$.length$extension(StringOps.scala:47) at scala.collection.immutable.StringOps.length(StringOps.scala:47) at scala.collection.IndexedSeqOptimized$class.isEmpty(IndexedSeqOptimized.scala:27) at scala.collection.immutable.StringOps.isEmpty(StringOps.scala:29) at scala.collection.TraversableOnce$class.nonEmpty(TraversableOnce.scala:111) at scala.collection.immutable.StringOps.nonEmpty(StringOps.scala:29) at org.apache.spark.sql.execution.ui.ExecutionTable.descriptionCell(AllExecutionsPage.scala:182) at org.apache.spark.sql.execution.ui.ExecutionTable.row(AllExecutionsPage.scala:155) at org.apache.spark.sql.execution.ui.ExecutionTable$$anonfun$8.apply(AllExecutionsPage.scala:204) at org.apache.spark.sql.execution.ui.ExecutionTable$$anonfun$8.apply(AllExecutionsPage.scala:204) at org.apache.spark.ui.UIUtils$$anonfun$listingTable$2.apply(UIUtils.scala:339) at org.apache.spark.ui.UIUtils$$anonfun$listingTable$2.apply(UIUtils.scala:339) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48) at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) at scala.collection.AbstractTraversable.map(Traversable.scala:104) at org.apache.spark.ui.UIUtils$.listingTable(UIUtils.scala:339) at org.apache.spark.sql.execution.ui.ExecutionTable.toNodeSeq(AllExecutionsPage.scala:203) at org.apache.spark.sql.execution.ui.AllExecutionsPage.render(AllExecutionsPage.scala:67) at org.apache.spark.ui.WebUI$$anonfun$2.apply(WebUI.scala:82) at org.apache.spark.ui.WebUI$$anonfun$2.apply(WebUI.scala:82) at org.apache.spark.ui.JettyUtils$$anon$3.doGet(JettyUtils.scala:90) at javax.servlet.http.HttpServlet.service(HttpServlet.java:687) at javax.servlet.http.HttpServlet.service(HttpServlet.java:790) at org.eclipse.jetty.servlet.ServletHolder.handle(ServletHolder.java:848) at org.eclipse.jetty.servlet.ServletHandler.doHandle(ServletHandler.java:584) at org.eclipse.jetty.server.handler.ContextHandler.doHandle(ContextHandler.java:1180) at org.eclipse.jetty.servlet.ServletHandler.doScope(ServletHandler.java:512) at org.eclipse.jetty.server.handler.ContextHandler.doScope(ContextHandler.java:1112) at org.eclipse.jetty.server.handler.ScopedHandler.handle(ScopedHandler.java:141) at org.eclipse.jetty.server.handler.ContextHandlerCollection.handle(ContextHandlerCollection.java:213) at org.eclipse.jetty.server.handler.HandlerWrapper.handle(HandlerWrapper.java:134) at org.eclipse.jetty.server.Server.handle(Server.java:534) at org.eclipse.jetty.server.HttpChannel.handle(HttpChannel.java:320) at org.eclipse.jetty.server.HttpConnection.onFillable(HttpConnection.java:251) at org.eclipse.jetty.io.AbstractConnection$ReadCallback.succeeded(AbstractConnection.java:283) at org.eclipse.jetty.io.FillInterest.fillable(FillInterest.java:108) at org.eclipse.jetty.io.SelectChannelEndPoint$2.run(SelectChannelEndPoint.java:93) at org.eclipse.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:303) at org.eclipse.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148) at org.eclipse.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136) at org.eclipse.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671) at org.eclipse.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589) at java.lang.Thread.run(Thread.java:748) ``` One of the possible reason that this page fails may be the `SparkListenerSQLExecutionStart` event get dropped before processed, so the execution description and details don't get updated. This was not a issue in 2.2 because it would ignore any job start event that arrives before the corresponding execution start event, which doesn't sound like a good decision. We shall try to handle the null values in the front page side, that is, try to give a default value when `execution.details` or `execution.description` is null. Another possible approach is not to spill the `LiveExecutionData` in `SQLAppStatusListener.update(exec: LiveExecutionData)` if `exec.details` is null. This is not ideal because this way you will not see the execution if `SparkListenerSQLExecutionStart` event is lost, because `AllExecutionsPage` only read executions from KVStore. ## How was this patch tested? After the change, the page shows the following: ![image](https://user-images.githubusercontent.com/4784782/35775480-28cc5fde-093e-11e8-8ccc-f58c2ef4a514.png) Author: Xingbo Jiang Closes #20502 from jiangxb1987/executionPage. --- .../apache/spark/sql/execution/ui/AllExecutionsPage.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index 7019d98e1619f..e751ce39cd5d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -179,7 +179,7 @@ private[ui] abstract class ExecutionTable( } private def descriptionCell(execution: SQLExecutionUIData): Seq[Node] = { - val details = if (execution.details.nonEmpty) { + val details = if (execution.details != null && execution.details.nonEmpty) { +details ++ @@ -190,8 +190,10 @@ private[ui] abstract class ExecutionTable( Nil } - val desc = { + val desc = if (execution.description != null && execution.description.nonEmpty) { {execution.description} + } else { + {execution.executionId} }
{desc} {details}
From f3f1e14bb73dfdd2927d95b12d7d61d22de8a0ac Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 6 Feb 2018 14:42:42 +0800 Subject: [PATCH 291/774] [SPARK-23326][WEBUI] schedulerDelay should return 0 when the task is running ## What changes were proposed in this pull request? When a task is still running, metrics like executorRunTime are not available. Then `schedulerDelay` will be almost the same as `duration` and that's confusing. This PR makes `schedulerDelay` return 0 when the task is running which is the same behavior as 2.2. ## How was this patch tested? `AppStatusUtilsSuite.schedulerDelay` Author: Shixiong Zhu Closes #20493 from zsxwing/SPARK-23326. --- .../apache/spark/status/AppStatusUtils.scala | 11 ++- .../spark/status/AppStatusUtilsSuite.scala | 89 +++++++++++++++++++ 2 files changed, 98 insertions(+), 2 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala b/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala index 341bd4e0cd016..87f434daf4870 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala @@ -17,16 +17,23 @@ package org.apache.spark.status -import org.apache.spark.status.api.v1.{TaskData, TaskMetrics} +import org.apache.spark.status.api.v1.TaskData private[spark] object AppStatusUtils { + private val TASK_FINISHED_STATES = Set("FAILED", "KILLED", "SUCCESS") + + private def isTaskFinished(task: TaskData): Boolean = { + TASK_FINISHED_STATES.contains(task.status) + } + def schedulerDelay(task: TaskData): Long = { - if (task.taskMetrics.isDefined && task.duration.isDefined) { + if (isTaskFinished(task) && task.taskMetrics.isDefined && task.duration.isDefined) { val m = task.taskMetrics.get schedulerDelay(task.launchTime.getTime(), fetchStart(task), task.duration.get, m.executorDeserializeTime, m.resultSerializationTime, m.executorRunTime) } else { + // The task is still running and the metrics like executorRunTime are not available. 0L } } diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala new file mode 100644 index 0000000000000..9e74e86ad54b9 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status + +import java.util.Date + +import org.apache.spark.SparkFunSuite +import org.apache.spark.status.api.v1.{TaskData, TaskMetrics} + +class AppStatusUtilsSuite extends SparkFunSuite { + + test("schedulerDelay") { + val runningTask = new TaskData( + taskId = 0, + index = 0, + attempt = 0, + launchTime = new Date(1L), + resultFetchStart = None, + duration = Some(100L), + executorId = "1", + host = "localhost", + status = "RUNNING", + taskLocality = "PROCESS_LOCAL", + speculative = false, + accumulatorUpdates = Nil, + errorMessage = None, + taskMetrics = Some(new TaskMetrics( + executorDeserializeTime = 0L, + executorDeserializeCpuTime = 0L, + executorRunTime = 0L, + executorCpuTime = 0L, + resultSize = 0L, + jvmGcTime = 0L, + resultSerializationTime = 0L, + memoryBytesSpilled = 0L, + diskBytesSpilled = 0L, + peakExecutionMemory = 0L, + inputMetrics = null, + outputMetrics = null, + shuffleReadMetrics = null, + shuffleWriteMetrics = null))) + assert(AppStatusUtils.schedulerDelay(runningTask) === 0L) + + val finishedTask = new TaskData( + taskId = 0, + index = 0, + attempt = 0, + launchTime = new Date(1L), + resultFetchStart = None, + duration = Some(100L), + executorId = "1", + host = "localhost", + status = "SUCCESS", + taskLocality = "PROCESS_LOCAL", + speculative = false, + accumulatorUpdates = Nil, + errorMessage = None, + taskMetrics = Some(new TaskMetrics( + executorDeserializeTime = 5L, + executorDeserializeCpuTime = 3L, + executorRunTime = 90L, + executorCpuTime = 10L, + resultSize = 100L, + jvmGcTime = 10L, + resultSerializationTime = 2L, + memoryBytesSpilled = 0L, + diskBytesSpilled = 0L, + peakExecutionMemory = 100L, + inputMetrics = null, + outputMetrics = null, + shuffleReadMetrics = null, + shuffleWriteMetrics = null))) + assert(AppStatusUtils.schedulerDelay(finishedTask) === 3L) + } +} From a24c03138a6935a442b983c8a4c721b26df3f9e2 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 6 Feb 2018 14:52:25 +0800 Subject: [PATCH 292/774] [SPARK-23290][SQL][PYTHON] Use datetime.date for date type when converting Spark DataFrame to Pandas DataFrame. ## What changes were proposed in this pull request? In #18664, there was a change in how `DateType` is being returned to users ([line 1968 in dataframe.py](https://github.com/apache/spark/pull/18664/files#diff-6fc344560230bf0ef711bb9b5573f1faR1968)). This can cause client code which works in Spark 2.2 to fail. See [SPARK-23290](https://issues.apache.org/jira/browse/SPARK-23290?focusedCommentId=16350917&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-16350917) for an example. This pr modifies to use `datetime.date` for date type as Spark 2.2 does. ## How was this patch tested? Tests modified to fit the new behavior and existing tests. Author: Takuya UESHIN Closes #20506 from ueshin/issues/SPARK-23290. --- python/pyspark/serializers.py | 9 ++++-- python/pyspark/sql/dataframe.py | 7 ++-- python/pyspark/sql/tests.py | 57 ++++++++++++++++++++++++--------- python/pyspark/sql/types.py | 15 +++++++++ 4 files changed, 66 insertions(+), 22 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 88d6a191babca..e870325d202ca 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -267,12 +267,15 @@ def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ - from pyspark.sql.types import _check_dataframe_localize_timestamps + from pyspark.sql.types import from_arrow_schema, _check_dataframe_convert_date, \ + _check_dataframe_localize_timestamps import pyarrow as pa reader = pa.open_stream(stream) + schema = from_arrow_schema(reader.schema) for batch in reader: - # NOTE: changed from pa.Columns.to_pandas, timezone issue in conversion fixed in 0.7.1 - pdf = _check_dataframe_localize_timestamps(batch.to_pandas(), self._timezone) + pdf = batch.to_pandas() + pdf = _check_dataframe_convert_date(pdf, schema) + pdf = _check_dataframe_localize_timestamps(pdf, self._timezone) yield [c for _, c in pdf.iteritems()] def __repr__(self): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 2e55407b5397b..59a417015b949 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1923,7 +1923,8 @@ def toPandas(self): if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: - from pyspark.sql.types import _check_dataframe_localize_timestamps + from pyspark.sql.types import _check_dataframe_convert_date, \ + _check_dataframe_localize_timestamps from pyspark.sql.utils import require_minimum_pyarrow_version import pyarrow require_minimum_pyarrow_version() @@ -1931,6 +1932,7 @@ def toPandas(self): if tables: table = pyarrow.concat_tables(tables) pdf = table.to_pandas() + pdf = _check_dataframe_convert_date(pdf, self.schema) return _check_dataframe_localize_timestamps(pdf, timezone) else: return pd.DataFrame.from_records([], columns=self.columns) @@ -2009,7 +2011,6 @@ def _to_corrected_pandas_type(dt): """ When converting Spark SQL records to Pandas DataFrame, the inferred data type may be wrong. This method gets the corrected data type for Pandas if that type may be inferred uncorrectly. - NOTE: DateType is inferred incorrectly as 'object', TimestampType is correct with datetime64[ns] """ import numpy as np if type(dt) == ByteType: @@ -2020,8 +2021,6 @@ def _to_corrected_pandas_type(dt): return np.int32 elif type(dt) == FloatType: return np.float32 - elif type(dt) == DateType: - return 'datetime64[ns]' else: return None diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b27363023ae77..545ec5aee08ff 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2816,7 +2816,7 @@ def test_to_pandas(self): self.assertEquals(types[1], np.object) self.assertEquals(types[2], np.bool) self.assertEquals(types[3], np.float32) - self.assertEquals(types[4], 'datetime64[ns]') + self.assertEquals(types[4], np.object) # datetime.date self.assertEquals(types[5], 'datetime64[ns]') @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed") @@ -3388,7 +3388,7 @@ class ArrowTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): - from datetime import datetime + from datetime import date, datetime from decimal import Decimal ReusedSQLTestCase.setUpClass() @@ -3410,11 +3410,11 @@ def setUpClass(cls): StructField("7_date_t", DateType(), True), StructField("8_timestamp_t", TimestampType(), True)]) cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"), - datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), + date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"), - datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), + date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), - datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] @classmethod def tearDownClass(cls): @@ -3461,7 +3461,9 @@ def _toPandas_arrow_toggle(self, df): def test_toPandas_arrow_toggle(self): df = self.spark.createDataFrame(self.data, schema=self.schema) pdf, pdf_arrow = self._toPandas_arrow_toggle(df) - self.assertPandasEqual(pdf_arrow, pdf) + expected = self.create_pandas_data_frame() + self.assertPandasEqual(expected, pdf) + self.assertPandasEqual(expected, pdf_arrow) def test_toPandas_respect_session_timezone(self): df = self.spark.createDataFrame(self.data, schema=self.schema) @@ -4062,18 +4064,42 @@ def test_vectorized_udf_unsupported_types(self): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): df.select(f(col('map'))).collect() - def test_vectorized_udf_null_date(self): + def test_vectorized_udf_dates(self): from pyspark.sql.functions import pandas_udf, col from datetime import date - schema = StructType().add("date", DateType()) - data = [(date(1969, 1, 1),), - (date(2012, 2, 2),), - (None,), - (date(2100, 4, 4),)] + schema = StructType().add("idx", LongType()).add("date", DateType()) + data = [(0, date(1969, 1, 1),), + (1, date(2012, 2, 2),), + (2, None,), + (3, date(2100, 4, 4),)] df = self.spark.createDataFrame(data, schema=schema) - date_f = pandas_udf(lambda t: t, returnType=DateType()) - res = df.select(date_f(col("date"))) - self.assertEquals(df.collect(), res.collect()) + + date_copy = pandas_udf(lambda t: t, returnType=DateType()) + df = df.withColumn("date_copy", date_copy(col("date"))) + + @pandas_udf(returnType=StringType()) + def check_data(idx, date, date_copy): + import pandas as pd + msgs = [] + is_equal = date.isnull() + for i in range(len(idx)): + if (is_equal[i] and data[idx[i]][1] is None) or \ + date[i] == data[idx[i]][1]: + msgs.append(None) + else: + msgs.append( + "date values are not equal (date='%s': data[%d][1]='%s')" + % (date[i], idx[i], data[idx[i]][1])) + return pd.Series(msgs) + + result = df.withColumn("check_data", + check_data(col("idx"), col("date"), col("date_copy"))).collect() + + self.assertEquals(len(data), len(result)) + for i in range(len(result)): + self.assertEquals(data[i][1], result[i][1]) # "date" col + self.assertEquals(data[i][1], result[i][2]) # "date_copy" col + self.assertIsNone(result[i][3]) # "check_data" col def test_vectorized_udf_timestamps(self): from pyspark.sql.functions import pandas_udf, col @@ -4114,6 +4140,7 @@ def check_data(idx, timestamp, timestamp_copy): self.assertEquals(len(data), len(result)) for i in range(len(result)): self.assertEquals(data[i][1], result[i][1]) # "timestamp" col + self.assertEquals(data[i][1], result[i][2]) # "timestamp_copy" col self.assertIsNone(result[i][3]) # "check_data" col def test_vectorized_udf_return_timestamp_tz(self): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 0dc5823f72a3c..093dae5a22e1f 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1694,6 +1694,21 @@ def from_arrow_schema(arrow_schema): for field in arrow_schema]) +def _check_dataframe_convert_date(pdf, schema): + """ Correct date type value to use datetime.date. + + Pandas DataFrame created from PyArrow uses datetime64[ns] for date type values, but we should + use datetime.date to match the behavior with when Arrow optimization is disabled. + + :param pdf: pandas.DataFrame + :param schema: a Spark schema of the pandas.DataFrame + """ + for field in schema: + if type(field.dataType) == DateType: + pdf[field.name] = pdf[field.name].dt.date + return pdf + + def _check_dataframe_localize_timestamps(pdf, timezone): """ Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone From 8141c3e3ddb55586906b9bc79ef515142c2b551a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 6 Feb 2018 16:08:15 +0900 Subject: [PATCH 293/774] [SPARK-23300][TESTS] Prints out if Pandas and PyArrow are installed or not in PySpark SQL tests ## What changes were proposed in this pull request? This PR proposes to log if PyArrow and Pandas are installed or not so we can check if related tests are going to be skipped or not. ## How was this patch tested? Manually tested: I don't have PyArrow installed in PyPy. ```bash $ ./run-tests --python-executables=python3 ``` ``` ... Will test against the following Python executables: ['python3'] Will test the following Python modules: ['pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming'] Will test PyArrow related features against Python executable 'python3' in 'pyspark-sql' module. Will test Pandas related features against Python executable 'python3' in 'pyspark-sql' module. Starting test(python3): pyspark.mllib.tests Starting test(python3): pyspark.sql.tests Starting test(python3): pyspark.streaming.tests Starting test(python3): pyspark.tests ``` ```bash $ ./run-tests --modules=pyspark-streaming ``` ``` ... Will test against the following Python executables: ['python2.7', 'pypy'] Will test the following Python modules: ['pyspark-streaming'] Starting test(pypy): pyspark.streaming.tests Starting test(pypy): pyspark.streaming.util Starting test(python2.7): pyspark.streaming.tests Starting test(python2.7): pyspark.streaming.util ``` ```bash $ ./run-tests ``` ``` ... Will test against the following Python executables: ['python2.7', 'pypy'] Will test the following Python modules: ['pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming'] Will test PyArrow related features against Python executable 'python2.7' in 'pyspark-sql' module. Will test Pandas related features against Python executable 'python2.7' in 'pyspark-sql' module. Will skip PyArrow related features against Python executable 'pypy' in 'pyspark-sql' module. PyArrow >= 0.8.0 is required; however, PyArrow was not found. Will test Pandas related features against Python executable 'pypy' in 'pyspark-sql' module. Starting test(pypy): pyspark.streaming.tests Starting test(pypy): pyspark.sql.tests Starting test(pypy): pyspark.tests Starting test(python2.7): pyspark.mllib.tests ``` ```bash $ ./run-tests --modules=pyspark-sql --python-executables=pypy ``` ``` ... Will test against the following Python executables: ['pypy'] Will test the following Python modules: ['pyspark-sql'] Will skip PyArrow related features against Python executable 'pypy' in 'pyspark-sql' module. PyArrow >= 0.8.0 is required; however, PyArrow was not found. Will test Pandas related features against Python executable 'pypy' in 'pyspark-sql' module. Starting test(pypy): pyspark.sql.tests Starting test(pypy): pyspark.sql.catalog Starting test(pypy): pyspark.sql.column Starting test(pypy): pyspark.sql.conf ``` After some modification to produce other cases: ```bash $ ./run-tests ``` ``` ... Will test against the following Python executables: ['python2.7', 'pypy'] Will test the following Python modules: ['pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming'] Will skip PyArrow related features against Python executable 'python2.7' in 'pyspark-sql' module. PyArrow >= 20.0.0 is required; however, PyArrow 0.8.0 was found. Will skip Pandas related features against Python executable 'python2.7' in 'pyspark-sql' module. Pandas >= 20.0.0 is required; however, Pandas 0.20.2 was found. Will skip PyArrow related features against Python executable 'pypy' in 'pyspark-sql' module. PyArrow >= 20.0.0 is required; however, PyArrow was not found. Will skip Pandas related features against Python executable 'pypy' in 'pyspark-sql' module. Pandas >= 20.0.0 is required; however, Pandas 0.22.0 was found. Starting test(pypy): pyspark.sql.tests Starting test(pypy): pyspark.streaming.tests Starting test(pypy): pyspark.tests Starting test(python2.7): pyspark.mllib.tests ``` ```bash ./run-tests-with-coverage ``` ``` ... Will test against the following Python executables: ['python2.7', 'pypy'] Will test the following Python modules: ['pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming'] Will test PyArrow related features against Python executable 'python2.7' in 'pyspark-sql' module. Will test Pandas related features against Python executable 'python2.7' in 'pyspark-sql' module. Coverage is not installed in Python executable 'pypy' but 'COVERAGE_PROCESS_START' environment variable is set, exiting. ``` Author: hyukjinkwon Closes #20473 from HyukjinKwon/SPARK-23300. --- python/run-tests.py | 73 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 68 insertions(+), 5 deletions(-) diff --git a/python/run-tests.py b/python/run-tests.py index f03284c334285..6b41b5ee22814 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -31,6 +31,7 @@ import Queue else: import queue as Queue +from distutils.version import LooseVersion # Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module @@ -38,8 +39,8 @@ from sparktestsupport import SPARK_HOME # noqa (suppress pep8 warnings) -from sparktestsupport.shellutils import which, subprocess_check_output, run_cmd # noqa -from sparktestsupport.modules import all_modules # noqa +from sparktestsupport.shellutils import which, subprocess_check_output # noqa +from sparktestsupport.modules import all_modules, pyspark_sql # noqa python_modules = dict((m.name, m) for m in all_modules if m.python_test_goals if m.name != 'root') @@ -151,6 +152,67 @@ def parse_opts(): return opts +def _check_dependencies(python_exec, modules_to_test): + if "COVERAGE_PROCESS_START" in os.environ: + # Make sure if coverage is installed. + try: + subprocess_check_output( + [python_exec, "-c", "import coverage"], + stderr=open(os.devnull, 'w')) + except: + print_red("Coverage is not installed in Python executable '%s' " + "but 'COVERAGE_PROCESS_START' environment variable is set, " + "exiting." % python_exec) + sys.exit(-1) + + # If we should test 'pyspark-sql', it checks if PyArrow and Pandas are installed and + # explicitly prints out. See SPARK-23300. + if pyspark_sql in modules_to_test: + # TODO(HyukjinKwon): Relocate and deduplicate these version specifications. + minimum_pyarrow_version = '0.8.0' + minimum_pandas_version = '0.19.2' + + try: + pyarrow_version = subprocess_check_output( + [python_exec, "-c", "import pyarrow; print(pyarrow.__version__)"], + universal_newlines=True, + stderr=open(os.devnull, 'w')).strip() + if LooseVersion(pyarrow_version) >= LooseVersion(minimum_pyarrow_version): + LOGGER.info("Will test PyArrow related features against Python executable " + "'%s' in '%s' module." % (python_exec, pyspark_sql.name)) + else: + LOGGER.warning( + "Will skip PyArrow related features against Python executable " + "'%s' in '%s' module. PyArrow >= %s is required; however, PyArrow " + "%s was found." % ( + python_exec, pyspark_sql.name, minimum_pyarrow_version, pyarrow_version)) + except: + LOGGER.warning( + "Will skip PyArrow related features against Python executable " + "'%s' in '%s' module. PyArrow >= %s is required; however, PyArrow " + "was not found." % (python_exec, pyspark_sql.name, minimum_pyarrow_version)) + + try: + pandas_version = subprocess_check_output( + [python_exec, "-c", "import pandas; print(pandas.__version__)"], + universal_newlines=True, + stderr=open(os.devnull, 'w')).strip() + if LooseVersion(pandas_version) >= LooseVersion(minimum_pandas_version): + LOGGER.info("Will test Pandas related features against Python executable " + "'%s' in '%s' module." % (python_exec, pyspark_sql.name)) + else: + LOGGER.warning( + "Will skip Pandas related features against Python executable " + "'%s' in '%s' module. Pandas >= %s is required; however, Pandas " + "%s was found." % ( + python_exec, pyspark_sql.name, minimum_pandas_version, pandas_version)) + except: + LOGGER.warning( + "Will skip Pandas related features against Python executable " + "'%s' in '%s' module. Pandas >= %s is required; however, Pandas " + "was not found." % (python_exec, pyspark_sql.name, minimum_pandas_version)) + + def main(): opts = parse_opts() if (opts.verbose): @@ -175,9 +237,10 @@ def main(): task_queue = Queue.PriorityQueue() for python_exec in python_execs: - if "COVERAGE_PROCESS_START" in os.environ: - # Make sure if coverage is installed. - run_cmd([python_exec, "-c", "import coverage"]) + # Check if the python executable has proper dependencies installed to run tests + # for given modules properly. + _check_dependencies(python_exec, modules_to_test) + python_implementation = subprocess_check_output( [python_exec, "-c", "import platform; print(platform.python_implementation())"], universal_newlines=True).strip() From 63c5bf13ce5cd3b8d7e7fb88de881ed207fde720 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 6 Feb 2018 18:30:50 +0900 Subject: [PATCH 294/774] [SPARK-23334][SQL][PYTHON] Fix pandas_udf with return type StringType() to handle str type properly in Python 2. ## What changes were proposed in this pull request? In Python 2, when `pandas_udf` tries to return string type value created in the udf with `".."`, the execution fails. E.g., ```python from pyspark.sql.functions import pandas_udf, col import pandas as pd df = spark.range(10) str_f = pandas_udf(lambda x: pd.Series(["%s" % i for i in x]), "string") df.select(str_f(col('id'))).show() ``` raises the following exception: ``` ... java.lang.AssertionError: assertion failed: Invalid schema from pandas_udf: expected StringType, got BinaryType at scala.Predef$.assert(Predef.scala:170) at org.apache.spark.sql.execution.python.ArrowEvalPythonExec$$anon$2.(ArrowEvalPythonExec.scala:93) ... ``` Seems like pyarrow ignores `type` parameter for `pa.Array.from_pandas()` and consider it as binary type when the type is string type and the string values are `str` instead of `unicode` in Python 2. This pr adds a workaround for the case. ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN Closes #20507 from ueshin/issues/SPARK-23334. --- python/pyspark/serializers.py | 4 ++++ python/pyspark/sql/tests.py | 9 +++++++++ 2 files changed, 13 insertions(+) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index e870325d202ca..91a7f093cec19 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -230,6 +230,10 @@ def create_array(s, t): s = _check_series_convert_timestamps_internal(s.fillna(0), timezone) # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2 return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False) + elif t is not None and pa.types.is_string(t) and sys.version < '3': + # TODO: need decode before converting to Arrow in Python 2 + return pa.Array.from_pandas(s.apply( + lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t) return pa.Array.from_pandas(s, mask=mask, type=t) arrs = [create_array(s, t) for s, t in series] diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 545ec5aee08ff..89b7c2182d2d1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3922,6 +3922,15 @@ def test_vectorized_udf_null_string(self): res = df.select(str_f(col('str'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_string_in_udf(self): + from pyspark.sql.functions import pandas_udf, col + import pandas as pd + df = self.spark.range(10) + str_f = pandas_udf(lambda x: pd.Series(map(str, x)), StringType()) + actual = df.select(str_f(col('id'))) + expected = df.select(col('id').cast('string')) + self.assertEquals(expected.collect(), actual.collect()) + def test_vectorized_udf_datatype_string(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10).select( From 7db9979babe52d15828967c86eb77e3fb2791579 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 6 Feb 2018 10:46:48 -0800 Subject: [PATCH 295/774] [SPARK-23310][CORE][FOLLOWUP] Fix Java style check issues. ## What changes were proposed in this pull request? This is a follow-up of #20492 which broke lint-java checks. This pr fixes the lint-java issues. ``` [ERROR] src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java:[79] (sizes) LineLength: Line is longer than 100 characters (found 114). ``` ## How was this patch tested? Checked manually in my local environment. Author: Takuya UESHIN Closes #20514 from ueshin/issues/SPARK-23310/fup1. --- .../util/collection/unsafe/sort/UnsafeSorterSpillReader.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 71e7c7a95ebdb..2c53c8d809d2e 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -76,8 +76,8 @@ public UnsafeSorterSpillReader( SparkEnv.get() == null ? 0.5 : SparkEnv.get().conf().getDouble("spark.unsafe.sorter.spill.read.ahead.fraction", 0.5); - // SPARK-23310: Disable read-ahead input stream, because it is causing lock contention and perf regression for - // TPC-DS queries. + // SPARK-23310: Disable read-ahead input stream, because it is causing lock contention and perf + // regression for TPC-DS queries. final boolean readAheadEnabled = SparkEnv.get() != null && SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", false); From ac7454cac04a1d9252b3856360eda5c3e8bcb8da Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 6 Feb 2018 12:27:37 -0800 Subject: [PATCH 296/774] [SPARK-23312][SQL][FOLLOWUP] add a config to turn off vectorized cache reader ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/20483 tried to provide a way to turn off the new columnar cache reader, to restore the behavior in 2.2. However even we turn off that config, the behavior is still different than 2.2. If the output data are rows, we still enable whole stage codegen for the scan node, which is different with 2.2, we should also fix it. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #20513 from cloud-fan/cache. --- .../spark/sql/execution/columnar/InMemoryTableScanExec.scala | 3 +++ .../src/test/scala/org/apache/spark/sql/CachedTableSuite.scala | 3 ++- .../apache/spark/sql/execution/WholeStageCodegenSuite.scala | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index e972f8b30d87c..a93e8a1ad954d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -61,6 +61,9 @@ case class InMemoryTableScanExec( }) && !WholeStageCodegenExec.isTooManyFields(conf, relation.schema) } + // TODO: revisit this. Shall we always turn off whole stage codegen if the output data are rows? + override def supportCodegen: Boolean = supportsBatch + override protected def needsUnsafeRowConversion: Boolean = false private val columnIndices = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 9f27fa09127af..669e5f2bf4e65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -787,7 +787,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext withSQLConf(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { val df = spark.range(10).cache() df.queryExecution.executedPlan.foreach { - case i: InMemoryTableScanExec => assert(i.supportsBatch == vectorized) + case i: InMemoryTableScanExec => + assert(i.supportsBatch == vectorized && i.supportCodegen == vectorized) case _ => } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 6e8d5a70d5a8f..ef16292a8e75c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -137,7 +137,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { val dsStringFilter = dsString.filter(_ == "1") val planString = dsStringFilter.queryExecution.executedPlan assert(planString.collect { - case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec)) if !i.supportsBatch => () + case i: InMemoryTableScanExec if !i.supportsBatch => () }.length == 1) assert(dsStringFilter.collect() === Array("1")) } From caf30445632de6aec810309293499199e7a20892 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 6 Feb 2018 12:30:04 -0800 Subject: [PATCH 297/774] [MINOR][TEST] Fix class name for Pandas UDF tests ## What changes were proposed in this pull request? In https://github.com/apache/spark/commit/b2ce17b4c9fea58140a57ca1846b2689b15c0d61, I mistakenly renamed `VectorizedUDFTests` to `ScalarPandasUDF`. This PR fixes the mistake. ## How was this patch tested? Existing tests. Author: Li Jin Closes #20489 from icexelloss/fix-scalar-udf-tests. --- python/pyspark/sql/tests.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 89b7c2182d2d1..53da7dd45c2f2 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3766,7 +3766,7 @@ def foo(k, v): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class ScalarPandasUDF(ReusedSQLTestCase): +class ScalarPandasUDFTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): @@ -4279,7 +4279,7 @@ def test_register_vectorized_udf_basic(self): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class GroupbyApplyPandasUDFTests(ReusedSQLTestCase): +class GroupedMapPandasUDFTests(ReusedSQLTestCase): @property def data(self): @@ -4448,7 +4448,7 @@ def test_unsupported_types(self): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class GroupbyAggPandasUDFTests(ReusedSQLTestCase): +class GroupedAggPandasUDFTests(ReusedSQLTestCase): @property def data(self): From b96a083b1c6ff0d2c588be9499b456e1adce97dc Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 6 Feb 2018 12:43:45 -0800 Subject: [PATCH 298/774] [SPARK-23315][SQL] failed to get output from canonicalized data source v2 related plans ## What changes were proposed in this pull request? `DataSourceV2Relation` keeps a `fullOutput` and resolves the real output on demand by column name lookup. i.e. ``` lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name => fullOutput.find(_.name == name).get } ``` This will be broken after we canonicalize the plan, because all attribute names become "None", see https://github.com/apache/spark/blob/v2.3.0-rc1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala#L42 To fix this, `DataSourceV2Relation` should just keep `output`, and update the `output` when doing column pruning. ## How was this patch tested? a new test case Author: Wenchen Fan Closes #20485 from cloud-fan/canonicalize. --- .../v2/DataSourceReaderHolder.scala | 12 +++----- .../datasources/v2/DataSourceV2Relation.scala | 8 ++--- .../datasources/v2/DataSourceV2ScanExec.scala | 4 +-- .../v2/PushDownOperatorsToDataSource.scala | 29 +++++++++++++------ .../sql/sources/v2/DataSourceV2Suite.scala | 20 ++++++++++++- 5 files changed, 48 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala index 6460c97abe344..81219e9771bd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import java.util.Objects -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.sources.v2.reader._ /** @@ -28,9 +28,9 @@ import org.apache.spark.sql.sources.v2.reader._ trait DataSourceReaderHolder { /** - * The full output of the data source reader, without column pruning. + * The output of the data source reader, w.r.t. column pruning. */ - def fullOutput: Seq[AttributeReference] + def output: Seq[Attribute] /** * The held data source reader. @@ -46,7 +46,7 @@ trait DataSourceReaderHolder { case s: SupportsPushDownFilters => s.pushedFilters().toSet case _ => Nil } - Seq(fullOutput, reader.getClass, reader.readSchema(), filters) + Seq(output, reader.getClass, filters) } def canEqual(other: Any): Boolean @@ -61,8 +61,4 @@ trait DataSourceReaderHolder { override def hashCode(): Int = { metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) } - - lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name => - fullOutput.find(_.name == name).get - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index eebfa29f91b99..38f6b15224788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( - fullOutput: Seq[AttributeReference], + output: Seq[AttributeReference], reader: DataSourceReader) extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder { @@ -37,7 +37,7 @@ case class DataSourceV2Relation( } override def newInstance(): DataSourceV2Relation = { - copy(fullOutput = fullOutput.map(_.newInstance())) + copy(output = output.map(_.newInstance())) } } @@ -46,8 +46,8 @@ case class DataSourceV2Relation( * to the non-streaming relation. */ class StreamingDataSourceV2Relation( - fullOutput: Seq[AttributeReference], - reader: DataSourceReader) extends DataSourceV2Relation(fullOutput, reader) { + output: Seq[AttributeReference], + reader: DataSourceReader) extends DataSourceV2Relation(output, reader) { override def isStreaming: Boolean = true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index df469af2c262a..7d9581be4db89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -35,14 +35,12 @@ import org.apache.spark.sql.types.StructType * Physical plan node for scanning data from a data source. */ case class DataSourceV2ScanExec( - fullOutput: Seq[AttributeReference], + output: Seq[AttributeReference], @transient reader: DataSourceReader) extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] - override def producedAttributes: AttributeSet = AttributeSet(fullOutput) - override def outputPartitioning: physical.Partitioning = reader match { case s: SupportsReportPartitioning => new DataSourcePartitioning( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 566a48394f02e..1ca6cbf061b4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -81,33 +81,44 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel // TODO: add more push down rules. - pushDownRequiredColumns(filterPushed, filterPushed.outputSet) + val columnPruned = pushDownRequiredColumns(filterPushed, filterPushed.outputSet) // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. - RemoveRedundantProject(filterPushed) + RemoveRedundantProject(columnPruned) } // TODO: nested fields pruning - private def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: AttributeSet): Unit = { + private def pushDownRequiredColumns( + plan: LogicalPlan, requiredByParent: AttributeSet): LogicalPlan = { plan match { - case Project(projectList, child) => + case p @ Project(projectList, child) => val required = projectList.flatMap(_.references) - pushDownRequiredColumns(child, AttributeSet(required)) + p.copy(child = pushDownRequiredColumns(child, AttributeSet(required))) - case Filter(condition, child) => + case f @ Filter(condition, child) => val required = requiredByParent ++ condition.references - pushDownRequiredColumns(child, required) + f.copy(child = pushDownRequiredColumns(child, required)) case relation: DataSourceV2Relation => relation.reader match { case reader: SupportsPushDownRequiredColumns => + // TODO: Enable the below assert after we make `DataSourceV2Relation` immutable. Fow now + // it's possible that the mutable reader being updated by someone else, and we need to + // always call `reader.pruneColumns` here to correct it. + // assert(relation.output.toStructType == reader.readSchema(), + // "Schema of data source reader does not match the relation plan.") + val requiredColumns = relation.output.filter(requiredByParent.contains) reader.pruneColumns(requiredColumns.toStructType) - case _ => + val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap + val newOutput = reader.readSchema().map(_.name).map(nameToAttr) + relation.copy(output = newOutput) + + case _ => relation } // TODO: there may be more operators that can be used to calculate the required columns. We // can add more and more in the future. - case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.outputSet)) + case _ => plan.mapChildren(c => pushDownRequiredColumns(c, c.outputSet)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index eccd45442a3b2..a1c87fb15542c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -24,7 +24,7 @@ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.functions._ @@ -316,6 +316,24 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val reader4 = getReader(q4) assert(reader4.requiredSchema.fieldNames === Seq("i")) } + + test("SPARK-23315: get output from canonicalized data source v2 related plans") { + def checkCanonicalizedOutput(df: DataFrame, numOutput: Int): Unit = { + val logical = df.queryExecution.optimizedPlan.collect { + case d: DataSourceV2Relation => d + }.head + assert(logical.canonicalized.output.length == numOutput) + + val physical = df.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d + }.head + assert(physical.canonicalized.output.length == numOutput) + } + + val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() + checkCanonicalizedOutput(df, 2) + checkCanonicalizedOutput(df.select('i), 1) + } } class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { From c36fecc3b416c38002779c3cf40b6a665ac4bf13 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 6 Feb 2018 16:46:43 -0800 Subject: [PATCH 299/774] [SPARK-23327][SQL] Update the description and tests of three external API or functions ## What changes were proposed in this pull request? Update the description and tests of three external API or functions `createFunction `, `length` and `repartitionByRange ` ## How was this patch tested? N/A Author: gatorsmile Closes #20495 from gatorsmile/updateFunc. --- R/pkg/R/functions.R | 4 +++- python/pyspark/sql/functions.py | 8 ++++--- .../sql/catalyst/catalog/SessionCatalog.scala | 7 ++++-- .../expressions/stringExpressions.scala | 23 ++++++++++--------- .../scala/org/apache/spark/sql/Dataset.scala | 2 ++ .../sql/execution/command/functions.scala | 14 +++++++---- .../org/apache/spark/sql/functions.scala | 4 +++- .../execution/command/DDLParserSuite.scala | 10 ++++---- 8 files changed, 44 insertions(+), 28 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 55365a41d774b..9f7c6317cd924 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1026,7 +1026,9 @@ setMethod("last_day", }) #' @details -#' \code{length}: Computes the length of a given string or binary column. +#' \code{length}: Computes the character length of a string data or number of bytes +#' of a binary data. The length of string data includes the trailing spaces. +#' The length of binary data includes binary zeros. #' #' @rdname column_string_functions #' @aliases length length,Column-method diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3c8fb4c4d19e7..05031f5ec87d7 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1705,10 +1705,12 @@ def unhex(col): @ignore_unicode_prefix @since(1.5) def length(col): - """Calculates the length of a string or binary expression. + """Computes the character length of string data or number of bytes of binary data. + The length of character data includes the trailing spaces. The length of binary data + includes binary zeros. - >>> spark.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect() - [Row(length=3)] + >>> spark.createDataFrame([('ABC ',)], ['a']).select(length('a').alias('length')).collect() + [Row(length=4)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.length(_to_java_column(col))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index a129896230775..4b119c75260a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -988,8 +988,11 @@ class SessionCatalog( // ------------------------------------------------------- /** - * Create a metastore function in the database specified in `funcDefinition`. + * Create a function in the database specified in `funcDefinition`. * If no such database is specified, create it in the current database. + * + * @param ignoreIfExists: When true, ignore if the function with the specified name exists + * in the specified database. */ def createFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = { val db = formatDatabaseName(funcDefinition.identifier.database.getOrElse(getCurrentDatabase)) @@ -1061,7 +1064,7 @@ class SessionCatalog( } /** - * Check if the specified function exists. + * Check if the function with the specified name exists */ def functionExists(name: FunctionIdentifier): Boolean = { val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 5cf783f1a5979..d7612e30b4c57 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1653,19 +1653,19 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run * A function that returns the char length of the given string expression or * number of bytes of the given binary expression. */ -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the character length of `expr` or number of bytes in binary data.", + usage = "_FUNC_(expr) - Returns the character length of string data or number of bytes of " + + "binary data. The length of string data includes the trailing spaces. The length of binary " + + "data includes binary zeros.", examples = """ Examples: - > SELECT _FUNC_('Spark SQL'); - 9 - > SELECT CHAR_LENGTH('Spark SQL'); - 9 - > SELECT CHARACTER_LENGTH('Spark SQL'); - 9 + > SELECT _FUNC_('Spark SQL '); + 10 + > SELECT CHAR_LENGTH('Spark SQL '); + 10 + > SELECT CHARACTER_LENGTH('Spark SQL '); + 10 """) -// scalastyle:on line.size.limit case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -1687,7 +1687,7 @@ case class Length(child: Expression) extends UnaryExpression with ImplicitCastIn * A function that returns the bit length of the given string or binary expression. */ @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the bit length of `expr` or number of bits in binary data.", + usage = "_FUNC_(expr) - Returns the bit length of string data or number of bits of binary data.", examples = """ Examples: > SELECT _FUNC_('Spark SQL'); @@ -1716,7 +1716,8 @@ case class BitLength(child: Expression) extends UnaryExpression with ImplicitCas * A function that returns the byte length of the given string or binary expression. */ @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the byte length of `expr` or number of bytes in binary data.", + usage = "_FUNC_(expr) - Returns the byte length of string data or number of bytes of binary " + + "data.", examples = """ Examples: > SELECT _FUNC_('Spark SQL'); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d47cd0aecf56a..0aee1d7be5788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2825,6 +2825,7 @@ class Dataset[T] private[sql]( * * At least one partition-by expression must be specified. * When no explicit sort order is specified, "ascending nulls first" is assumed. + * Note, the rows are not sorted in each partition of the resulting Dataset. * * @group typedrel * @since 2.3.0 @@ -2848,6 +2849,7 @@ class Dataset[T] private[sql]( * * At least one partition-by expression must be specified. * When no explicit sort order is specified, "ascending nulls first" is assumed. + * Note, the rows are not sorted in each partition of the resulting Dataset. * * @group typedrel * @since 2.3.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index 4f92ffee687aa..1f7808c2f8e80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -40,6 +40,10 @@ import org.apache.spark.sql.types.{StringType, StructField, StructType} * CREATE [OR REPLACE] FUNCTION [IF NOT EXISTS] [databaseName.]functionName * AS className [USING JAR\FILE 'uri' [, JAR|FILE 'uri']] * }}} + * + * @param ignoreIfExists: When true, ignore if the function with the specified name exists + * in the specified database. + * @param replace: When true, alter the function with the specified name */ case class CreateFunctionCommand( databaseName: Option[String], @@ -47,17 +51,17 @@ case class CreateFunctionCommand( className: String, resources: Seq[FunctionResource], isTemp: Boolean, - ifNotExists: Boolean, + ignoreIfExists: Boolean, replace: Boolean) extends RunnableCommand { - if (ifNotExists && replace) { + if (ignoreIfExists && replace) { throw new AnalysisException("CREATE FUNCTION with both IF NOT EXISTS and REPLACE" + " is not allowed.") } // Disallow to define a temporary function with `IF NOT EXISTS` - if (ifNotExists && isTemp) { + if (ignoreIfExists && isTemp) { throw new AnalysisException( "It is not allowed to define a TEMPORARY function with IF NOT EXISTS.") } @@ -79,12 +83,12 @@ case class CreateFunctionCommand( // Handles `CREATE OR REPLACE FUNCTION AS ... USING ...` if (replace && catalog.functionExists(func.identifier)) { // alter the function in the metastore - catalog.alterFunction(CatalogFunction(func.identifier, className, resources)) + catalog.alterFunction(func) } else { // For a permanent, we will store the metadata into underlying external catalog. // This function will be loaded into the FunctionRegistry when a query uses it. // We do not load it into FunctionRegistry right now. - catalog.createFunction(CatalogFunction(func.identifier, className, resources), ifNotExists) + catalog.createFunction(func, ignoreIfExists) } } Seq.empty[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0d11682d80a3c..0d54c02c3d06f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2267,7 +2267,9 @@ object functions { } /** - * Computes the length of a given string or binary column. + * Computes the character length of a given string or number of bytes of a binary string. + * The length of character strings include the trailing spaces. The length of binary strings + * includes binary zeros. * * @group string_funcs * @since 1.5.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index 2b1aea08b1223..e0ccae15f1d05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -236,7 +236,7 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { Seq( FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar1"), FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar2")), - isTemp = true, ifNotExists = false, replace = false) + isTemp = true, ignoreIfExists = false, replace = false) val expected2 = CreateFunctionCommand( Some("hello"), "world", @@ -244,7 +244,7 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { Seq( FunctionResource(FunctionResourceType.fromString("archive"), "/path/to/archive"), FunctionResource(FunctionResourceType.fromString("file"), "/path/to/file")), - isTemp = false, ifNotExists = false, replace = false) + isTemp = false, ignoreIfExists = false, replace = false) val expected3 = CreateFunctionCommand( None, "helloworld3", @@ -252,7 +252,7 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { Seq( FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar1"), FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar2")), - isTemp = true, ifNotExists = false, replace = true) + isTemp = true, ignoreIfExists = false, replace = true) val expected4 = CreateFunctionCommand( Some("hello"), "world1", @@ -260,7 +260,7 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { Seq( FunctionResource(FunctionResourceType.fromString("archive"), "/path/to/archive"), FunctionResource(FunctionResourceType.fromString("file"), "/path/to/file")), - isTemp = false, ifNotExists = false, replace = true) + isTemp = false, ignoreIfExists = false, replace = true) val expected5 = CreateFunctionCommand( Some("hello"), "world2", @@ -268,7 +268,7 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { Seq( FunctionResource(FunctionResourceType.fromString("archive"), "/path/to/archive"), FunctionResource(FunctionResourceType.fromString("file"), "/path/to/file")), - isTemp = false, ifNotExists = true, replace = false) + isTemp = false, ignoreIfExists = true, replace = false) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) comparePlans(parsed3, expected3) From 9775df67f924663598d51723a878557ddafb8cfd Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 7 Feb 2018 23:24:16 +0900 Subject: [PATCH 300/774] [SPARK-23122][PYSPARK][FOLLOWUP] Replace registerTempTable by createOrReplaceTempView ## What changes were proposed in this pull request? Replace `registerTempTable` by `createOrReplaceTempView`. ## How was this patch tested? N/A Author: gatorsmile Closes #20523 from gatorsmile/updateExamples. --- python/pyspark/sql/udf.py | 2 +- .../src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 0f759c448b8a7..08c6b9e521e82 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -356,7 +356,7 @@ def registerJavaUDAF(self, name, javaClassName): >>> spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg") >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) - >>> df.registerTempTable("df") + >>> df.createOrReplaceTempView("df") >>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] """ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java index ddbaa45a483cb..08dc129f27a0c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java @@ -46,7 +46,7 @@ public void tearDown() { @SuppressWarnings("unchecked") @Test public void udf1Test() { - spark.range(1, 10).toDF("value").registerTempTable("df"); + spark.range(1, 10).toDF("value").createOrReplaceTempView("df"); spark.udf().registerJavaUDAF("myDoubleAvg", MyDoubleAvg.class.getName()); Row result = spark.sql("SELECT myDoubleAvg(value) as my_avg from df").head(); Assert.assertEquals(105.0, result.getDouble(0), 1.0e-6); From 71cfba04aeec5ae9b85a507b13996e80f8750edc Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 7 Feb 2018 23:28:10 +0900 Subject: [PATCH 301/774] [SPARK-23319][TESTS] Explicitly specify Pandas and PyArrow versions in PySpark tests (to skip or test) ## What changes were proposed in this pull request? This PR proposes to explicitly specify Pandas and PyArrow versions in PySpark tests to skip or test. We declared the extra dependencies: https://github.com/apache/spark/blob/b8bfce51abf28c66ba1fc67b0f25fe1617c81025/python/setup.py#L204 In case of PyArrow: Currently we only check if pyarrow is installed or not without checking the version. It already fails to run tests. For example, if PyArrow 0.7.0 is installed: ``` ====================================================================== ERROR: test_vectorized_udf_wrong_return_type (pyspark.sql.tests.ScalarPandasUDF) ---------------------------------------------------------------------- Traceback (most recent call last): File "/.../spark/python/pyspark/sql/tests.py", line 4019, in test_vectorized_udf_wrong_return_type f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) File "/.../spark/python/pyspark/sql/functions.py", line 2309, in pandas_udf return _create_udf(f=f, returnType=return_type, evalType=eval_type) File "/.../spark/python/pyspark/sql/udf.py", line 47, in _create_udf require_minimum_pyarrow_version() File "/.../spark/python/pyspark/sql/utils.py", line 132, in require_minimum_pyarrow_version "however, your version was %s." % pyarrow.__version__) ImportError: pyarrow >= 0.8.0 must be installed on calling Python process; however, your version was 0.7.0. ---------------------------------------------------------------------- Ran 33 tests in 8.098s FAILED (errors=33) ``` In case of Pandas: There are few tests for old Pandas which were tested only when Pandas version was lower, and I rewrote them to be tested when both Pandas version is lower and missing. ## How was this patch tested? Manually tested by modifying the condition: ``` test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.' test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.' ``` ``` test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' ``` ``` test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.' test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.' ``` ``` test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' ``` Author: hyukjinkwon Closes #20487 from HyukjinKwon/pyarrow-pandas-skip. --- pom.xml | 4 ++ python/pyspark/sql/dataframe.py | 3 ++ python/pyspark/sql/session.py | 3 ++ python/pyspark/sql/tests.py | 87 ++++++++++++++++++--------------- python/pyspark/sql/utils.py | 30 +++++++++--- python/setup.py | 10 +++- 6 files changed, 89 insertions(+), 48 deletions(-) diff --git a/pom.xml b/pom.xml index 666d5d7169a15..d18831df1db6d 100644 --- a/pom.xml +++ b/pom.xml @@ -185,6 +185,10 @@ 2.8 1.8 1.0.0 + 0.8.0 ${java.home} diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 59a417015b949..8ec24db8717b2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1913,6 +1913,9 @@ def toPandas(self): 0 2 Alice 1 5 Bob """ + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + import pandas as pd if self.sql_ctx.getConf("spark.sql.execution.pandas.respectSessionTimeZone").lower() \ diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 1ed04298bc899..b3af9b82953f3 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -646,6 +646,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr except Exception: has_pandas = False if has_pandas and isinstance(data, pandas.DataFrame): + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + if self.conf.get("spark.sql.execution.pandas.respectSessionTimeZone").lower() \ == "true": timezone = self.conf.get("spark.sql.session.timeZone") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 53da7dd45c2f2..58359b61dc83a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -48,19 +48,26 @@ else: import unittest -_have_pandas = False -_have_old_pandas = False +_pandas_requirement_message = None try: - import pandas - try: - from pyspark.sql.utils import require_minimum_pandas_version - require_minimum_pandas_version() - _have_pandas = True - except: - _have_old_pandas = True -except: - # No Pandas, but that's okay, we'll skip those tests - pass + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() +except ImportError as e: + from pyspark.util import _exception_message + # If Pandas version requirement is not satisfied, skip related tests. + _pandas_requirement_message = _exception_message(e) + +_pyarrow_requirement_message = None +try: + from pyspark.sql.utils import require_minimum_pyarrow_version + require_minimum_pyarrow_version() +except ImportError as e: + from pyspark.util import _exception_message + # If Arrow version requirement is not satisfied, skip related tests. + _pyarrow_requirement_message = _exception_message(e) + +_have_pandas = _pandas_requirement_message is None +_have_pyarrow = _pyarrow_requirement_message is None from pyspark import SparkContext from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row @@ -75,15 +82,6 @@ from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException -_have_arrow = False -try: - import pyarrow - _have_arrow = True -except: - # No Arrow, but that's okay, we'll skip those tests - pass - - class UTCOffsetTimezone(datetime.tzinfo): """ Specifies timezone in UTC offset @@ -2794,7 +2792,6 @@ def count_bucketed_cols(names, table="pyspark_bucket"): def _to_pandas(self): from datetime import datetime, date - import numpy as np schema = StructType().add("a", IntegerType()).add("b", StringType())\ .add("c", BooleanType()).add("d", FloatType())\ .add("dt", DateType()).add("ts", TimestampType()) @@ -2807,7 +2804,7 @@ def _to_pandas(self): df = self.spark.createDataFrame(data, schema) return df.toPandas() - @unittest.skipIf(not _have_pandas, "Pandas not installed") + @unittest.skipIf(not _have_pandas, _pandas_requirement_message) def test_to_pandas(self): import numpy as np pdf = self._to_pandas() @@ -2819,13 +2816,13 @@ def test_to_pandas(self): self.assertEquals(types[4], np.object) # datetime.date self.assertEquals(types[5], 'datetime64[ns]') - @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed") - def test_to_pandas_old(self): + @unittest.skipIf(_have_pandas, "Required Pandas was found.") + def test_to_pandas_required_pandas_not_found(self): with QuietTest(self.sc): with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'): self._to_pandas() - @unittest.skipIf(not _have_pandas, "Pandas not installed") + @unittest.skipIf(not _have_pandas, _pandas_requirement_message) def test_to_pandas_avoid_astype(self): import numpy as np schema = StructType().add("a", IntegerType()).add("b", StringType())\ @@ -2843,7 +2840,7 @@ def test_create_dataframe_from_array_of_long(self): df = self.spark.createDataFrame(data) self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807])) - @unittest.skipIf(not _have_pandas, "Pandas not installed") + @unittest.skipIf(not _have_pandas, _pandas_requirement_message) def test_create_dataframe_from_pandas_with_timestamp(self): import pandas as pd from datetime import datetime @@ -2858,14 +2855,16 @@ def test_create_dataframe_from_pandas_with_timestamp(self): self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) - @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed") - def test_create_dataframe_from_old_pandas(self): - import pandas as pd - from datetime import datetime - pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], - "d": [pd.Timestamp.now().date()]}) + @unittest.skipIf(_have_pandas, "Required Pandas was found.") + def test_create_dataframe_required_pandas_not_found(self): with QuietTest(self.sc): - with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'): + with self.assertRaisesRegexp( + ImportError, + '(Pandas >= .* must be installed|No module named pandas)'): + import pandas as pd + from datetime import datetime + pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], + "d": [pd.Timestamp.now().date()]}) self.spark.createDataFrame(pdf) @@ -3383,7 +3382,9 @@ def __init__(self, **kwargs): _make_type_verifier(data_type, nullable=False)(obj) -@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) class ArrowTests(ReusedSQLTestCase): @classmethod @@ -3641,7 +3642,9 @@ def test_createDataFrame_with_int_col_names(self): self.assertEqual(pdf_col_names, df_arrow.columns) -@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) class PandasUDFTests(ReusedSQLTestCase): def test_pandas_udf_basic(self): from pyspark.rdd import PythonEvalType @@ -3765,7 +3768,9 @@ def foo(k, v): return k -@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) class ScalarPandasUDFTests(ReusedSQLTestCase): @classmethod @@ -4278,7 +4283,9 @@ def test_register_vectorized_udf_basic(self): self.assertEquals(expected.collect(), res2.collect()) -@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) class GroupedMapPandasUDFTests(ReusedSQLTestCase): @property @@ -4447,7 +4454,9 @@ def test_unsupported_types(self): df.groupby('id').apply(f).collect() -@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) class GroupedAggPandasUDFTests(ReusedSQLTestCase): @property diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 08c34c6dccc5e..578298632dd4c 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -115,18 +115,32 @@ def toJArray(gateway, jtype, arr): def require_minimum_pandas_version(): """ Raise ImportError if minimum version of Pandas is not installed """ + # TODO(HyukjinKwon): Relocate and deduplicate the version specification. + minimum_pandas_version = "0.19.2" + from distutils.version import LooseVersion - import pandas - if LooseVersion(pandas.__version__) < LooseVersion('0.19.2'): - raise ImportError("Pandas >= 0.19.2 must be installed on calling Python process; " - "however, your version was %s." % pandas.__version__) + try: + import pandas + except ImportError: + raise ImportError("Pandas >= %s must be installed; however, " + "it was not found." % minimum_pandas_version) + if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version): + raise ImportError("Pandas >= %s must be installed; however, " + "your version was %s." % (minimum_pandas_version, pandas.__version__)) def require_minimum_pyarrow_version(): """ Raise ImportError if minimum version of pyarrow is not installed """ + # TODO(HyukjinKwon): Relocate and deduplicate the version specification. + minimum_pyarrow_version = "0.8.0" + from distutils.version import LooseVersion - import pyarrow - if LooseVersion(pyarrow.__version__) < LooseVersion('0.8.0'): - raise ImportError("pyarrow >= 0.8.0 must be installed on calling Python process; " - "however, your version was %s." % pyarrow.__version__) + try: + import pyarrow + except ImportError: + raise ImportError("PyArrow >= %s must be installed; however, " + "it was not found." % minimum_pyarrow_version) + if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version): + raise ImportError("PyArrow >= %s must be installed; however, " + "your version was %s." % (minimum_pyarrow_version, pyarrow.__version__)) diff --git a/python/setup.py b/python/setup.py index 251d4526d4dd0..6a98401941d8d 100644 --- a/python/setup.py +++ b/python/setup.py @@ -100,6 +100,11 @@ def _supports_symlinks(): file=sys.stderr) exit(-1) +# If you are changing the versions here, please also change ./python/pyspark/sql/utils.py and +# ./python/run-tests.py. In case of Arrow, you should also check ./pom.xml. +_minimum_pandas_version = "0.19.2" +_minimum_pyarrow_version = "0.8.0" + try: # We copy the shell script to be under pyspark/python/pyspark so that the launcher scripts # find it where expected. The rest of the files aren't copied because they are accessed @@ -201,7 +206,10 @@ def _supports_symlinks(): extras_require={ 'ml': ['numpy>=1.7'], 'mllib': ['numpy>=1.7'], - 'sql': ['pandas>=0.19.2', 'pyarrow>=0.8.0'] + 'sql': [ + 'pandas>=%s' % _minimum_pandas_version, + 'pyarrow>=%s' % _minimum_pyarrow_version, + ] }, classifiers=[ 'Development Status :: 5 - Production/Stable', From 9841ae0313cbee1f083f131f9446808c90ed5a7b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 7 Feb 2018 09:48:49 -0800 Subject: [PATCH 302/774] [SPARK-23345][SQL] Remove open stream record even closing it fails ## What changes were proposed in this pull request? When `DebugFilesystem` closes opened stream, if any exception occurs, we still need to remove the open stream record from `DebugFilesystem`. Otherwise, it goes to report leaked filesystem connection. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #20524 from viirya/SPARK-23345. --- core/src/test/scala/org/apache/spark/DebugFilesystem.scala | 7 +++++-- .../org/apache/spark/sql/test/SharedSparkSession.scala | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala index 91355f7362900..a5bdc95790722 100644 --- a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala +++ b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala @@ -103,8 +103,11 @@ class DebugFilesystem extends LocalFileSystem { override def markSupported(): Boolean = wrapped.markSupported() override def close(): Unit = { - wrapped.close() - removeOpenStream(wrapped) + try { + wrapped.close() + } finally { + removeOpenStream(wrapped) + } } override def read(): Int = wrapped.read() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index 0b4629a51b425..e758c865b908f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -111,7 +111,7 @@ trait SharedSparkSession spark.sharedState.cacheManager.clearCache() // files can be closed from other threads, so wait a bit // normally this doesn't take more than 1s - eventually(timeout(10.seconds)) { + eventually(timeout(10.seconds), interval(2.seconds)) { DebugFilesystem.assertNoOpenStreams() } } From 30295bf5a6754d0ae43334f7bf00e7a29ed0f1af Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 7 Feb 2018 15:22:53 -0800 Subject: [PATCH 303/774] [SPARK-23092][SQL] Migrate MemoryStream to DataSourceV2 APIs ## What changes were proposed in this pull request? This PR migrates the MemoryStream to DataSourceV2 APIs. One additional change is in the reported keys in StreamingQueryProgress.durationMs. "getOffset" and "getBatch" replaced with "setOffsetRange" and "getEndOffset" as tracking these make more sense. Unit tests changed accordingly. ## How was this patch tested? Existing unit tests, few updated unit tests. Author: Tathagata Das Author: Burak Yavuz Closes #20445 from tdas/SPARK-23092. --- .../sql/execution/streaming/LongOffset.scala | 4 +- .../streaming/MicroBatchExecution.scala | 27 ++-- .../sql/execution/streaming/memory.scala | 132 +++++++++++------- .../sources/RateStreamSourceV2.scala | 2 +- .../streaming/ForeachSinkSuite.scala | 55 +++----- .../spark/sql/streaming/StreamSuite.scala | 8 +- .../spark/sql/streaming/StreamTest.scala | 2 +- .../StreamingQueryListenerSuite.scala | 5 +- .../sql/streaming/StreamingQuerySuite.scala | 70 ++++++---- 9 files changed, 171 insertions(+), 134 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala index 5f0b195fcfcb8..3ff5b86ac45d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.execution.streaming +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} + /** * A simple offset for sources that produce a single linear stream of data. */ -case class LongOffset(offset: Long) extends Offset { +case class LongOffset(offset: Long) extends OffsetV2 { override val json = offset.toString diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index d9aa8573ba930..045d2b4b9569c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -270,16 +270,17 @@ class MicroBatchExecution( } case s: MicroBatchReader => updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("getOffset") { - // Once v1 streaming source execution is gone, we can refactor this away. - // For now, we set the range here to get the source to infer the available end offset, - // get that offset, and then set the range again when we later execute. - s.setOffsetRange( - toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), - Optional.empty()) - - (s, Some(s.getEndOffset)) + reportTimeTaken("setOffsetRange") { + // Once v1 streaming source execution is gone, we can refactor this away. + // For now, we set the range here to get the source to infer the available end offset, + // get that offset, and then set the range again when we later execute. + s.setOffsetRange( + toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), + Optional.empty()) } + + val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() } + (s, Option(currentOffset)) }.toMap availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) @@ -401,10 +402,14 @@ class MicroBatchExecution( case (reader: MicroBatchReader, available) if committedOffsets.get(reader).map(_ != available).getOrElse(true) => val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json)) + val availableV2: OffsetV2 = available match { + case v1: SerializedOffset => reader.deserializeOffset(v1.json) + case v2: OffsetV2 => v2 + } reader.setOffsetRange( toJava(current), - Optional.of(available.asInstanceOf[OffsetV2])) - logDebug(s"Retrieving data from $reader: $current -> $available") + Optional.of(availableV2)) + logDebug(s"Retrieving data from $reader: $current -> $availableV2") Some(reader -> new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader)) case _ => None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 509a69dd922fb..352d4ce9fbcaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -17,21 +17,23 @@ package org.apache.spark.sql.execution.streaming +import java.{util => ju} +import java.util.Optional import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ -import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.encoders.encoderFor -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, Statistics} +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -51,9 +53,10 @@ object MemoryStream { * available. */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) - extends Source with Logging { + extends MicroBatchReader with SupportsScanUnsafeRow with Logging { protected val encoder = encoderFor[A] - protected val logicalPlan = StreamingExecutionRelation(this, sqlContext.sparkSession) + private val attributes = encoder.schema.toAttributes + protected val logicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) protected val output = logicalPlan.output /** @@ -61,11 +64,17 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) * Stored in a ListBuffer to facilitate removing committed batches. */ @GuardedBy("this") - protected val batches = new ListBuffer[Dataset[A]] + protected val batches = new ListBuffer[Array[UnsafeRow]] @GuardedBy("this") protected var currentOffset: LongOffset = new LongOffset(-1) + @GuardedBy("this") + private var startOffset = new LongOffset(-1) + + @GuardedBy("this") + private var endOffset = new LongOffset(-1) + /** * Last offset that was discarded, or -1 if no commits have occurred. Note that the value * -1 is used in calculations below and isn't just an arbitrary constant. @@ -73,8 +82,6 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) @GuardedBy("this") protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) - def schema: StructType = encoder.schema - def toDS(): Dataset[A] = { Dataset(sqlContext.sparkSession, logicalPlan) } @@ -88,72 +95,69 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } def addData(data: TraversableOnce[A]): Offset = { - val encoded = data.toVector.map(d => encoder.toRow(d).copy()) - val plan = new LocalRelation(schema.toAttributes, encoded, isStreaming = true) - val ds = Dataset[A](sqlContext.sparkSession, plan) - logDebug(s"Adding ds: $ds") + val objects = data.toSeq + val rows = objects.iterator.map(d => encoder.toRow(d).copy().asInstanceOf[UnsafeRow]).toArray + logDebug(s"Adding: $objects") this.synchronized { currentOffset = currentOffset + 1 - batches += ds + batches += rows currentOffset } } override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]" - override def getOffset: Option[Offset] = synchronized { - if (currentOffset.offset == -1) { - None - } else { - Some(currentOffset) + override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { + synchronized { + startOffset = start.orElse(LongOffset(-1)).asInstanceOf[LongOffset] + endOffset = end.orElse(currentOffset).asInstanceOf[LongOffset] } } - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) - val startOrdinal = - start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1 - val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1 - - // Internal buffer only holds the batches after lastCommittedOffset. - val newBlocks = synchronized { - val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 - val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 - assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd") - batches.slice(sliceStart, sliceEnd) - } + override def readSchema(): StructType = encoder.schema - if (newBlocks.isEmpty) { - return sqlContext.internalCreateDataFrame( - sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) - } + override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) + + override def getStartOffset: OffsetV2 = synchronized { + if (startOffset.offset == -1) null else startOffset + } - logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal)) + override def getEndOffset: OffsetV2 = synchronized { + if (endOffset.offset == -1) null else endOffset + } - newBlocks - .map(_.toDF()) - .reduceOption(_ union _) - .getOrElse { - sys.error("No data selected!") + override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + synchronized { + // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) + val startOrdinal = startOffset.offset.toInt + 1 + val endOrdinal = endOffset.offset.toInt + 1 + + // Internal buffer only holds the batches after lastCommittedOffset. + val newBlocks = synchronized { + val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 + val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 + assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd") + batches.slice(sliceStart, sliceEnd) } + + logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) + + newBlocks.map { block => + new MemoryStreamDataReaderFactory(block).asInstanceOf[DataReaderFactory[UnsafeRow]] + }.asJava + } } private def generateDebugString( - blocks: TraversableOnce[Dataset[A]], + rows: Seq[UnsafeRow], startOrdinal: Int, endOrdinal: Int): String = { - val originalUnsupportedCheck = - sqlContext.getConf("spark.sql.streaming.unsupportedOperationCheck") - try { - sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", "false") - s"MemoryBatch [$startOrdinal, $endOrdinal]: " + - s"${blocks.flatMap(_.collect()).mkString(", ")}" - } finally { - sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", originalUnsupportedCheck) - } + val fromRow = encoder.resolveAndBind().fromRow _ + s"MemoryBatch [$startOrdinal, $endOrdinal]: " + + s"${rows.map(row => fromRow(row)).mkString(", ")}" } - override def commit(end: Offset): Unit = synchronized { + override def commit(end: OffsetV2): Unit = synchronized { def check(newOffset: LongOffset): Unit = { val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt @@ -176,11 +180,33 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) def reset(): Unit = synchronized { batches.clear() + startOffset = LongOffset(-1) + endOffset = LongOffset(-1) currentOffset = new LongOffset(-1) lastOffsetCommitted = new LongOffset(-1) } } + +class MemoryStreamDataReaderFactory(records: Array[UnsafeRow]) + extends DataReaderFactory[UnsafeRow] { + override def createDataReader(): DataReader[UnsafeRow] = { + new DataReader[UnsafeRow] { + private var currentIndex = -1 + + override def next(): Boolean = { + // Return true as long as the new index is in the array. + currentIndex += 1 + currentIndex < records.length + } + + override def get(): UnsafeRow = records(currentIndex) + + override def close(): Unit = {} + } + } +} + /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index 1315885da8a6f..077a255946a6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -151,7 +151,7 @@ case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactor } class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] { - var currentIndex = -1 + private var currentIndex = -1 override def next(): Boolean = { // Return true as long as the new index is in the seq. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala index 41434e6d8b974..b249dd41a84a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala @@ -46,49 +46,34 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf .foreach(new TestForeachWriter()) .start() - // -- batch 0 --------------------------------------- - input.addData(1, 2, 3, 4) - query.processAllAvailable() + def verifyOutput(expectedVersion: Int, expectedData: Seq[Int]): Unit = { + import ForeachSinkSuite._ - var expectedEventsForPartition0 = Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Process(value = 2), - ForeachSinkSuite.Process(value = 3), - ForeachSinkSuite.Close(None) - ) - var expectedEventsForPartition1 = Seq( - ForeachSinkSuite.Open(partition = 1, version = 0), - ForeachSinkSuite.Process(value = 1), - ForeachSinkSuite.Process(value = 4), - ForeachSinkSuite.Close(None) - ) + val events = ForeachSinkSuite.allEvents() + assert(events.size === 2) // one seq of events for each of the 2 partitions - var allEvents = ForeachSinkSuite.allEvents() - assert(allEvents.size === 2) - assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1)) + // Verify both seq of events have an Open event as the first event + assert(events.map(_.head).toSet === Set(0, 1).map(p => Open(p, expectedVersion))) + + // Verify all the Process event correspond to the expected data + val allProcessEvents = events.flatMap(_.filter(_.isInstanceOf[Process[_]])) + assert(allProcessEvents.toSet === expectedData.map { data => Process(data) }.toSet) + + // Verify both seq of events have a Close event as the last event + assert(events.map(_.last).toSet === Set(Close(None), Close(None))) + } + // -- batch 0 --------------------------------------- ForeachSinkSuite.clear() + input.addData(1, 2, 3, 4) + query.processAllAvailable() + verifyOutput(expectedVersion = 0, expectedData = 1 to 4) // -- batch 1 --------------------------------------- + ForeachSinkSuite.clear() input.addData(5, 6, 7, 8) query.processAllAvailable() - - expectedEventsForPartition0 = Seq( - ForeachSinkSuite.Open(partition = 0, version = 1), - ForeachSinkSuite.Process(value = 5), - ForeachSinkSuite.Process(value = 7), - ForeachSinkSuite.Close(None) - ) - expectedEventsForPartition1 = Seq( - ForeachSinkSuite.Open(partition = 1, version = 1), - ForeachSinkSuite.Process(value = 6), - ForeachSinkSuite.Process(value = 8), - ForeachSinkSuite.Close(None) - ) - - allEvents = ForeachSinkSuite.allEvents() - assert(allEvents.size === 2) - assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1)) + verifyOutput(expectedVersion = 1, expectedData = 5 to 8) query.stop() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index c65e5d3dd75c2..d1a04833390f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -492,16 +492,16 @@ class StreamSuite extends StreamTest { val explainWithoutExtended = q.explainInternal(false) // `extended = false` only displays the physical plan. - assert("LocalRelation".r.findAllMatchIn(explainWithoutExtended).size === 0) - assert("LocalTableScan".r.findAllMatchIn(explainWithoutExtended).size === 1) + assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) + assert("DataSourceV2Scan".r.findAllMatchIn(explainWithoutExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithoutExtended.contains("StateStoreRestore")) val explainWithExtended = q.explainInternal(true) // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical // plan. - assert("LocalRelation".r.findAllMatchIn(explainWithExtended).size === 3) - assert("LocalTableScan".r.findAllMatchIn(explainWithExtended).size === 1) + assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithExtended).size === 3) + assert("DataSourceV2Scan".r.findAllMatchIn(explainWithExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithExtended.contains("StateStoreRestore")) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index d6433562fb29b..37fe595529baf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -120,7 +120,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData { override def toString: String = s"AddData to $source: ${data.mkString(",")}" - override def addData(query: Option[StreamExecution]): (Source, Offset) = { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { (source, source.addData(data)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 79d65192a14aa..b96f2bcbdd644 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.sql.{Encoder, SparkSession} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.util.JsonProtocol @@ -298,9 +299,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { try { val input = new MemoryStream[Int](0, sqlContext) { @volatile var numTriggers = 0 - override def getOffset: Option[Offset] = { + override def getEndOffset: OffsetV2 = { numTriggers += 1 - super.getOffset + super.getEndOffset } } val clock = new StreamManualClock() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 76201c63a2701..3f9aa0d1fa5be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -17,25 +17,27 @@ package org.apache.spark.sql.streaming +import java.{util => ju} +import java.util.Optional import java.util.concurrent.CountDownLatch import org.apache.commons.lang3.RandomStringUtils -import org.mockito.Mockito._ import org.scalactic.TolerantNumerics import org.scalatest.BeforeAndAfter -import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.mockito.MockitoSugar import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType -import org.apache.spark.util.ManualClock class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging with MockitoSugar { @@ -206,19 +208,29 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi /** Custom MemoryStream that waits for manual clock to reach a time */ val inputData = new MemoryStream[Int](0, sqlContext) { - // getOffset should take 50 ms the first time it is called - override def getOffset: Option[Offset] = { - val offset = super.getOffset - if (offset.nonEmpty) { - clock.waitTillTime(1050) + + private def dataAdded: Boolean = currentOffset.offset != -1 + + // setOffsetRange should take 50 ms the first time it is called after data is added + override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { + synchronized { + if (dataAdded) clock.waitTillTime(1050) + super.setOffsetRange(start, end) } - offset + } + + // getEndOffset should take 100 ms the first time it is called after data is added + override def getEndOffset(): OffsetV2 = synchronized { + if (dataAdded) clock.waitTillTime(1150) + super.getEndOffset() } // getBatch should take 100 ms the first time it is called - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - if (start.isEmpty) clock.waitTillTime(1150) - super.getBatch(start, end) + override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + synchronized { + clock.waitTillTime(1350) + super.createUnsafeRowReaderFactories() + } } } @@ -258,39 +270,44 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.status.message === "Waiting for next trigger"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - // Test status and progress while offset is being fetched + // Test status and progress when setOffsetRange is being called AddData(inputData, 1, 2), - AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on getOffset + AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on setOffsetRange AssertStreamExecThreadIsWaitingForTime(1050), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message.startsWith("Getting offsets from")), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - // Test status and progress while batch is being fetched - AdvanceManualClock(50), // time = 1050 to unblock getOffset + AdvanceManualClock(50), // time = 1050 to unblock setOffsetRange AssertClockTime(1050), - AssertStreamExecThreadIsWaitingForTime(1150), // will block on getBatch that needs 1150 + AssertStreamExecThreadIsWaitingForTime(1150), // will block on getEndOffset that needs 1150 + AssertOnQuery(_.status.isDataAvailable === false), + AssertOnQuery(_.status.isTriggerActive === true), + AssertOnQuery(_.status.message.startsWith("Getting offsets from")), + AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), + + AdvanceManualClock(100), // time = 1150 to unblock getEndOffset + AssertClockTime(1150), + AssertStreamExecThreadIsWaitingForTime(1350), // will block on createReadTasks that needs 1350 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - // Test status and progress while batch is being processed - AdvanceManualClock(100), // time = 1150 to unblock getBatch - AssertClockTime(1150), - AssertStreamExecThreadIsWaitingForTime(1500), // will block in Spark job that needs 1500 + AdvanceManualClock(200), // time = 1350 to unblock createReadTasks + AssertClockTime(1350), + AssertStreamExecThreadIsWaitingForTime(1500), // will block on map task that needs 1500 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch processing has completed - AssertOnQuery { _ => clock.getTimeMillis() === 1150 }, - AdvanceManualClock(350), // time = 1500 to unblock job + AdvanceManualClock(150), // time = 1500 to unblock map task AssertClockTime(1500), CheckAnswer(2), - AssertStreamExecThreadIsWaitingForTime(2000), + AssertStreamExecThreadIsWaitingForTime(2000), // will block until the next trigger AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === false), AssertOnQuery(_.status.message === "Waiting for next trigger"), @@ -307,10 +324,11 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.numInputRows === 2) assert(progress.processedRowsPerSecond === 4.0) - assert(progress.durationMs.get("getOffset") === 50) - assert(progress.durationMs.get("getBatch") === 100) + assert(progress.durationMs.get("setOffsetRange") === 50) + assert(progress.durationMs.get("getEndOffset") === 100) assert(progress.durationMs.get("queryPlanning") === 0) assert(progress.durationMs.get("walCommit") === 0) + assert(progress.durationMs.get("addBatch") === 350) assert(progress.durationMs.get("triggerExecution") === 500) assert(progress.sources.length === 1) From a62f30d3fa032ff75bc2b7bebbd0813e67ea5fd5 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 8 Feb 2018 12:46:10 +0900 Subject: [PATCH 304/774] [SPARK-23319][TESTS][FOLLOWUP] Fix a test for Python 3 without pandas. ## What changes were proposed in this pull request? This is a followup pr of #20487. When importing module but it doesn't exists, the error message is slightly different between Python 2 and 3. E.g., in Python 2: ``` No module named pandas ``` in Python 3: ``` No module named 'pandas' ``` So, one test to check an import error fails in Python 3 without pandas. This pr fixes it. ## How was this patch tested? Tested manually in my local environment. Author: Takuya UESHIN Closes #20538 from ueshin/issues/SPARK-23319/fup1. --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 58359b61dc83a..90ff084fed55e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2860,7 +2860,7 @@ def test_create_dataframe_required_pandas_not_found(self): with QuietTest(self.sc): with self.assertRaisesRegexp( ImportError, - '(Pandas >= .* must be installed|No module named pandas)'): + "(Pandas >= .* must be installed|No module named '?pandas'?)"): import pandas as pd from datetime import datetime pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], From 3473fda6dc77bdfd84b3de95d2082856ad4f8626 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 8 Feb 2018 12:21:18 +0800 Subject: [PATCH 305/774] Revert [SPARK-22279][SQL] Turn on spark.sql.hive.convertMetastoreOrc by default ## What changes were proposed in this pull request? This is to revert the changes made in https://github.com/apache/spark/pull/19499 , because this causes a regression. We should not ignore the table-specific compression conf when the Hive serde tables are converted to the data source tables. ## How was this patch tested? The existing tests. Author: gatorsmile Closes #20536 from gatorsmile/revert22279. --- .../src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index d9627eb9790eb..93f3f38e52aa9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -109,7 +109,7 @@ private[spark] object HiveUtils extends Logging { .doc("When set to true, the built-in ORC reader and writer are used to process " + "ORC tables created by using the HiveQL syntax, instead of Hive serde.") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val HIVE_METASTORE_SHARED_PREFIXES = buildConf("spark.sql.hive.metastore.sharedPrefixes") .doc("A comma separated list of class prefixes that should be loaded using the classloader " + From 7f5f5fb1296275a38da0adfa05125dd8ebf729ff Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 8 Feb 2018 00:08:54 -0800 Subject: [PATCH 306/774] [SPARK-23348][SQL] append data using saveAsTable should adjust the data types ## What changes were proposed in this pull request? For inserting/appending data to an existing table, Spark should adjust the data types of the input query according to the table schema, or fail fast if it's uncastable. There are several ways to insert/append data: SQL API, `DataFrameWriter.insertInto`, `DataFrameWriter.saveAsTable`. The first 2 ways create `InsertIntoTable` plan, and the last way creates `CreateTable` plan. However, we only adjust input query data types for `InsertIntoTable`, and users may hit weird errors when appending data using `saveAsTable`. See the JIRA for the error case. This PR fixes this bug by adjusting data types for `CreateTable` too. ## How was this patch tested? new test. Author: Wenchen Fan Closes #20527 from cloud-fan/saveAsTable. --- .../sql/execution/datasources/rules.scala | 72 +++++++++++-------- .../sql/execution/command/DDLSuite.scala | 28 ++++++++ 2 files changed, 69 insertions(+), 31 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 5dbcf4a915cbf..5cc21eeaeaa94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command.DDLUtils @@ -178,7 +178,8 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi c.copy( tableDesc = existingTable, - query = Some(newQuery)) + query = Some(DDLPreprocessingUtils.castAndRenameQueryOutput( + newQuery, existingTable.schema.toAttributes, conf))) // Here we normalize partition, bucket and sort column names, w.r.t. the case sensitivity // config, and do various checks: @@ -316,7 +317,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi * table. It also does data type casting and field renaming, to make sure that the columns to be * inserted have the correct data type and fields have the correct names. */ -case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { +case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { private def preprocess( insert: InsertIntoTable, tblName: String, @@ -336,6 +337,8 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] wit s"including ${staticPartCols.size} partition column(s) having constant value(s).") } + val newQuery = DDLPreprocessingUtils.castAndRenameQueryOutput( + insert.query, expectedColumns, conf) if (normalizedPartSpec.nonEmpty) { if (normalizedPartSpec.size != partColNames.length) { throw new AnalysisException( @@ -346,37 +349,11 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] wit """.stripMargin) } - castAndRenameChildOutput(insert.copy(partition = normalizedPartSpec), expectedColumns) + insert.copy(query = newQuery, partition = normalizedPartSpec) } else { // All partition columns are dynamic because the InsertIntoTable command does // not explicitly specify partitioning columns. - castAndRenameChildOutput(insert, expectedColumns) - .copy(partition = partColNames.map(_ -> None).toMap) - } - } - - private def castAndRenameChildOutput( - insert: InsertIntoTable, - expectedOutput: Seq[Attribute]): InsertIntoTable = { - val newChildOutput = expectedOutput.zip(insert.query.output).map { - case (expected, actual) => - if (expected.dataType.sameType(actual.dataType) && - expected.name == actual.name && - expected.metadata == actual.metadata) { - actual - } else { - // Renaming is needed for handling the following cases like - // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 - // 2) Target tables have column metadata - Alias(cast(actual, expected.dataType), expected.name)( - explicitMetadata = Option(expected.metadata)) - } - } - - if (newChildOutput == insert.query.output) { - insert - } else { - insert.copy(query = Project(newChildOutput, insert.query)) + insert.copy(query = newQuery, partition = partColNames.map(_ -> None).toMap) } } @@ -491,3 +468,36 @@ object PreWriteCheck extends (LogicalPlan => Unit) { } } } + +object DDLPreprocessingUtils { + + /** + * Adjusts the name and data type of the input query output columns, to match the expectation. + */ + def castAndRenameQueryOutput( + query: LogicalPlan, + expectedOutput: Seq[Attribute], + conf: SQLConf): LogicalPlan = { + val newChildOutput = expectedOutput.zip(query.output).map { + case (expected, actual) => + if (expected.dataType.sameType(actual.dataType) && + expected.name == actual.name && + expected.metadata == actual.metadata) { + actual + } else { + // Renaming is needed for handling the following cases like + // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 + // 2) Target tables have column metadata + Alias( + Cast(actual, expected.dataType, Option(conf.sessionLocalTimeZone)), + expected.name)(explicitMetadata = Option(expected.metadata)) + } + } + + if (newChildOutput == query.output) { + query + } else { + Project(newChildOutput, query) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index ee3674ba17821..f76bfd2fda2b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -37,6 +37,8 @@ import org.apache.spark.util.Utils class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with BeforeAndAfterEach { + import testImplicits._ + override def afterEach(): Unit = { try { // drop all databases, tables and functions after each test @@ -132,6 +134,32 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo checkAnswer(spark.table("t"), Row(Row("a", 1)) :: Nil) } } + + // TODO: This test is copied from HiveDDLSuite, unify it later. + test("SPARK-23348: append data to data source table with saveAsTable") { + withTable("t", "t1") { + Seq(1 -> "a").toDF("i", "j").write.saveAsTable("t") + checkAnswer(spark.table("t"), Row(1, "a")) + + sql("INSERT INTO t SELECT 2, 'b'") + checkAnswer(spark.table("t"), Row(1, "a") :: Row(2, "b") :: Nil) + + Seq(3 -> "c").toDF("i", "j").write.mode("append").saveAsTable("t") + checkAnswer(spark.table("t"), Row(1, "a") :: Row(2, "b") :: Row(3, "c") :: Nil) + + Seq("c" -> 3).toDF("i", "j").write.mode("append").saveAsTable("t") + checkAnswer(spark.table("t"), Row(1, "a") :: Row(2, "b") :: Row(3, "c") + :: Row(null, "3") :: Nil) + + Seq(4 -> "d").toDF("i", "j").write.saveAsTable("t1") + + val e = intercept[AnalysisException] { + Seq(5 -> "e").toDF("i", "j").write.mode("append").format("json").saveAsTable("t1") + } + assert(e.message.contains("The format of the existing table default.t1 is " + + "`ParquetFileFormat`. It doesn't match the specified format `JsonFileFormat`.")) + } + } } abstract class DDLSuite extends QueryTest with SQLTestUtils { From a75f927173632eee1316879447cb62c8cf30ae37 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 8 Feb 2018 19:20:11 +0800 Subject: [PATCH 307/774] [SPARK-23268][SQL][FOLLOWUP] Reorganize packages in data source V2 ## What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/20435. While reorganizing the packages for streaming data source v2, the top level stream read/write support interfaces should not be in the reader/writer package, but should be in the `sources.v2` package, to follow the `ReadSupport`, `WriteSupport`, etc. ## How was this patch tested? N/A Author: Wenchen Fan Closes #20509 from cloud-fan/followup. --- .../org/apache/spark/sql/kafka010/KafkaSourceProvider.scala | 4 +--- .../sql/sources/v2/{reader => }/ContinuousReadSupport.java | 4 +--- .../sql/sources/v2/{reader => }/MicroBatchReadSupport.java | 4 +--- .../sql/sources/v2/{writer => }/StreamWriteSupport.java | 5 ++--- .../apache/spark/sql/sources/v2/writer/DataSourceWriter.java | 1 + .../spark/sql/execution/streaming/MicroBatchExecution.scala | 5 ++--- .../spark/sql/execution/streaming/RateSourceProvider.scala | 1 - .../spark/sql/execution/streaming/StreamingRelation.scala | 3 +-- .../org/apache/spark/sql/execution/streaming/console.scala | 3 +-- .../execution/streaming/continuous/ContinuousExecution.scala | 4 +--- .../sql/execution/streaming/sources/RateStreamSourceV2.scala | 2 +- .../spark/sql/execution/streaming/sources/memoryV2.scala | 2 +- .../org/apache/spark/sql/streaming/DataStreamReader.scala | 3 +-- .../org/apache/spark/sql/streaming/DataStreamWriter.scala | 2 +- .../apache/spark/sql/streaming/StreamingQueryManager.scala | 2 +- .../spark/sql/execution/streaming/RateSourceV2Suite.scala | 2 +- .../sql/streaming/sources/StreamingDataSourceV2Suite.scala | 5 ++--- 17 files changed, 19 insertions(+), 33 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{reader => }/ContinuousReadSupport.java (92%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{reader => }/MicroBatchReadSupport.java (93%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{writer => }/StreamWriteSupport.java (93%) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 694ca76e24964..d4fa0359c12d6 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -30,9 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport -import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java similarity index 92% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java index 0c1d5d1a9577a..7df5a451ae5f3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java @@ -15,13 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader; import org.apache.spark.sql.types.StructType; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java similarity index 93% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java index 5e8f0c0dafdcf..209ffa7a0b9fa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java @@ -15,13 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader; import org.apache.spark.sql.types.StructType; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java similarity index 93% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java index 1c0e2e12f8d51..a77b01497269e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java @@ -15,12 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.writer; +package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 52324b3792b8a..e3f682bf96a66 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -21,6 +21,7 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.StreamWriteSupport; import org.apache.spark.sql.sources.v2.WriteSupport; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 045d2b4b9569c..812533313332e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -29,10 +29,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.MicroBatchReadSupport +import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.sources.v2.writer.{StreamWriteSupport, SupportsWriteInternalRow} +import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index ce5e63f5bde85..649fbbfa184ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types._ import org.apache.spark.util.{ManualClock, SystemClock} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 845c8d2c14e43..7146190645b37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -25,8 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.sources.v2.DataSourceV2 -import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2} object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index db600866067bc..cfba1001c6de0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} -import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index ed22b9100497a..c3294d64b10cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -31,10 +31,8 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index 077a255946a6b..4e2459bb05bd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 3411edbc53412..f960208155e3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.Sink -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 116ac3da07b75..f23851655350a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -28,8 +28,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 9aac360fd4bbc..2fc903168cfa0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.StreamWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index ddb1edc433d5a..7cefd03e43bc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger} import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.StreamWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index 0d68d9c3138aa..983ba1668f58f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -26,8 +26,8 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 51f44fa6285e4..af4618bed5456 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -25,10 +25,9 @@ import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, Streami import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{ContinuousReadSupport, DataReaderFactory, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} import org.apache.spark.sql.types.StructType From 76e019d9bdcdca176c79c1cd71ddbf496333bf93 Mon Sep 17 00:00:00 2001 From: liuxian Date: Thu, 8 Feb 2018 23:41:30 +0800 Subject: [PATCH 308/774] [SPARK-21860][CORE] Improve memory reuse for heap memory in `HeapMemoryAllocator` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In `HeapMemoryAllocator`, when allocating memory from pool, and the key of pool is memory size. Actually some size of memory ,such as 1025bytes,1026bytes,......1032bytes, we can think they are the same,because we allocate memory in multiples of 8 bytes. In this case, we can improve memory reuse. ## How was this patch tested? Existing tests and added unit tests Author: liuxian Closes #19077 from 10110346/headmemoptimize. --- .../unsafe/memory/HeapMemoryAllocator.java | 18 +++++++++------ .../spark/unsafe/PlatformUtilSuite.java | 22 +++++++++++++++++++ 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index a9603c1aba051..2733760dd19ef 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -46,9 +46,12 @@ private boolean shouldPool(long size) { @Override public MemoryBlock allocate(long size) throws OutOfMemoryError { - if (shouldPool(size)) { + int numWords = (int) ((size + 7) / 8); + long alignedSize = numWords * 8L; + assert (alignedSize >= size); + if (shouldPool(alignedSize)) { synchronized (this) { - final LinkedList> pool = bufferPoolsBySize.get(size); + final LinkedList> pool = bufferPoolsBySize.get(alignedSize); if (pool != null) { while (!pool.isEmpty()) { final WeakReference arrayReference = pool.pop(); @@ -62,11 +65,11 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { return memory; } } - bufferPoolsBySize.remove(size); + bufferPoolsBySize.remove(alignedSize); } } } - long[] array = new long[(int) ((size + 7) / 8)]; + long[] array = new long[numWords]; MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); @@ -98,12 +101,13 @@ public void free(MemoryBlock memory) { long[] array = (long[]) memory.obj; memory.setObjAndOffset(null, 0); - if (shouldPool(size)) { + long alignedSize = ((size + 7) / 8) * 8; + if (shouldPool(alignedSize)) { synchronized (this) { - LinkedList> pool = bufferPoolsBySize.get(size); + LinkedList> pool = bufferPoolsBySize.get(alignedSize); if (pool == null) { pool = new LinkedList<>(); - bufferPoolsBySize.put(size, pool); + bufferPoolsBySize.put(alignedSize, pool); } pool.add(new WeakReference<>(array)); } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 62854837b05ed..71c53d35dcab8 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.unsafe; +import org.apache.spark.unsafe.memory.HeapMemoryAllocator; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -134,4 +135,25 @@ public void memoryDebugFillEnabledInTest() { MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); MemoryAllocator.UNSAFE.free(offheap); } + + @Test + public void heapMemoryReuse() { + MemoryAllocator heapMem = new HeapMemoryAllocator(); + // The size is less than `HeapMemoryAllocator.POOLING_THRESHOLD_BYTES`,allocate new memory every time. + MemoryBlock onheap1 = heapMem.allocate(513); + Object obj1 = onheap1.getBaseObject(); + heapMem.free(onheap1); + MemoryBlock onheap2 = heapMem.allocate(514); + Assert.assertNotEquals(obj1, onheap2.getBaseObject()); + + // The size is greater than `HeapMemoryAllocator.POOLING_THRESHOLD_BYTES`, + // reuse the previous memory which has released. + MemoryBlock onheap3 = heapMem.allocate(1024 * 1024 + 1); + Assert.assertEquals(onheap3.size(), 1024 * 1024 + 1); + Object obj3 = onheap3.getBaseObject(); + heapMem.free(onheap3); + MemoryBlock onheap4 = heapMem.allocate(1024 * 1024 + 7); + Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7); + Assert.assertEquals(obj3, onheap4.getBaseObject()); + } } From 4df84c3f818aa536515729b442601e08c253ed35 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 8 Feb 2018 12:52:08 -0600 Subject: [PATCH 309/774] [SPARK-23336][BUILD] Upgrade snappy-java to 1.1.7.1 ## What changes were proposed in this pull request? This PR upgrade snappy-java from 1.1.2.6 to 1.1.7.1. 1.1.7.1 release notes: - Improved performance for big-endian architecture - The other performance improvement in [snappy-1.1.5](https://github.com/google/snappy/releases/tag/1.1.5) 1.1.4 release notes: - Fix a 1% performance regression when snappy is used in PIE executables. - Improve compression performance by 5%. - Improve decompression performance by 20%. More details: https://github.com/xerial/snappy-java/blob/master/Milestone.md ## How was this patch tested? manual tests Author: Yuming Wang Closes #20510 from wangyum/SPARK-23336. --- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- pom.xml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 48e54568e6fc6..99031384aa22e 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -182,7 +182,7 @@ slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snakeyaml-1.15.jar snappy-0.2.jar -snappy-java-1.1.2.6.jar +snappy-java-1.1.7.1.jar spire-macros_2.11-0.13.0.jar spire_2.11-0.13.0.jar stax-api-1.0-2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 1807a77900e52..cf8d2789b7ee9 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -183,7 +183,7 @@ slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snakeyaml-1.15.jar snappy-0.2.jar -snappy-java-1.1.2.6.jar +snappy-java-1.1.7.1.jar spire-macros_2.11-0.13.0.jar spire_2.11-0.13.0.jar stax-api-1.0-2.jar diff --git a/pom.xml b/pom.xml index d18831df1db6d..de949b94d676c 100644 --- a/pom.xml +++ b/pom.xml @@ -160,7 +160,7 @@ 1.9.13 2.6.7 2.6.7.1 - 1.1.2.6 + 1.1.7.1 1.1.2 1.2.0-incubating 1.10 From 8cbcc33876c773722163b2259644037bbb259bd1 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 9 Feb 2018 12:54:57 +0800 Subject: [PATCH 310/774] [SPARK-23186][SQL] Initialize DriverManager first before loading JDBC Drivers ## What changes were proposed in this pull request? Since some JDBC Drivers have class initialization code to call `DriverManager`, we need to initialize `DriverManager` first in order to avoid potential executor-side **deadlock** situations like the following (or [STORM-2527](https://issues.apache.org/jira/browse/STORM-2527)). ``` Thread 9587: (state = BLOCKED) - sun.reflect.NativeConstructorAccessorImpl.newInstance0(java.lang.reflect.Constructor, java.lang.Object[]) bci=0 (Compiled frame; information may be imprecise) - sun.reflect.NativeConstructorAccessorImpl.newInstance(java.lang.Object[]) bci=85, line=62 (Compiled frame) - sun.reflect.DelegatingConstructorAccessorImpl.newInstance(java.lang.Object[]) bci=5, line=45 (Compiled frame) - java.lang.reflect.Constructor.newInstance(java.lang.Object[]) bci=79, line=423 (Compiled frame) - java.lang.Class.newInstance() bci=138, line=442 (Compiled frame) - java.util.ServiceLoader$LazyIterator.nextService() bci=119, line=380 (Interpreted frame) - java.util.ServiceLoader$LazyIterator.next() bci=11, line=404 (Interpreted frame) - java.util.ServiceLoader$1.next() bci=37, line=480 (Interpreted frame) - java.sql.DriverManager$2.run() bci=21, line=603 (Interpreted frame) - java.sql.DriverManager$2.run() bci=1, line=583 (Interpreted frame) - java.security.AccessController.doPrivileged(java.security.PrivilegedAction) bci=0 (Compiled frame) - java.sql.DriverManager.loadInitialDrivers() bci=27, line=583 (Interpreted frame) - java.sql.DriverManager.() bci=32, line=101 (Interpreted frame) - org.apache.phoenix.mapreduce.util.ConnectionUtil.getConnection(java.lang.String, java.lang.Integer, java.lang.String, java.util.Properties) bci=12, line=98 (Interpreted frame) - org.apache.phoenix.mapreduce.util.ConnectionUtil.getInputConnection(org.apache.hadoop.conf.Configuration, java.util.Properties) bci=22, line=57 (Interpreted frame) - org.apache.phoenix.mapreduce.PhoenixInputFormat.getQueryPlan(org.apache.hadoop.mapreduce.JobContext, org.apache.hadoop.conf.Configuration) bci=61, line=116 (Interpreted frame) - org.apache.phoenix.mapreduce.PhoenixInputFormat.createRecordReader(org.apache.hadoop.mapreduce.InputSplit, org.apache.hadoop.mapreduce.TaskAttemptContext) bci=10, line=71 (Interpreted frame) - org.apache.spark.rdd.NewHadoopRDD$$anon$1.(org.apache.spark.rdd.NewHadoopRDD, org.apache.spark.Partition, org.apache.spark.TaskContext) bci=233, line=156 (Interpreted frame) Thread 9170: (state = BLOCKED) - org.apache.phoenix.jdbc.PhoenixDriver.() bci=35, line=125 (Interpreted frame) - sun.reflect.NativeConstructorAccessorImpl.newInstance0(java.lang.reflect.Constructor, java.lang.Object[]) bci=0 (Compiled frame) - sun.reflect.NativeConstructorAccessorImpl.newInstance(java.lang.Object[]) bci=85, line=62 (Compiled frame) - sun.reflect.DelegatingConstructorAccessorImpl.newInstance(java.lang.Object[]) bci=5, line=45 (Compiled frame) - java.lang.reflect.Constructor.newInstance(java.lang.Object[]) bci=79, line=423 (Compiled frame) - java.lang.Class.newInstance() bci=138, line=442 (Compiled frame) - org.apache.spark.sql.execution.datasources.jdbc.DriverRegistry$.register(java.lang.String) bci=89, line=46 (Interpreted frame) - org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$$anonfun$createConnectionFactory$2.apply() bci=7, line=53 (Interpreted frame) - org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$$anonfun$createConnectionFactory$2.apply() bci=1, line=52 (Interpreted frame) - org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD$$anon$1.(org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD, org.apache.spark.Partition, org.apache.spark.TaskContext) bci=81, line=347 (Interpreted frame) - org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD.compute(org.apache.spark.Partition, org.apache.spark.TaskContext) bci=7, line=339 (Interpreted frame) ``` ## How was this patch tested? N/A Author: Dongjoon Hyun Closes #20359 from dongjoon-hyun/SPARK-23186. --- .../sql/execution/datasources/jdbc/DriverRegistry.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala index 7a6c0f9fed2f9..1723596de1db2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala @@ -32,6 +32,13 @@ import org.apache.spark.util.Utils */ object DriverRegistry extends Logging { + /** + * Load DriverManager first to avoid any race condition between + * DriverManager static initialization block and specific driver class's + * static initialization block. e.g. PhoenixDriver + */ + DriverManager.getDrivers + private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty def register(className: String): Unit = { From 4b4ee2601079f12f8f410a38d2081793cbdedc14 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 9 Feb 2018 14:21:10 +0800 Subject: [PATCH 311/774] [SPARK-23328][PYTHON] Disallow default value None in na.replace/replace when 'to_replace' is not a dictionary ## What changes were proposed in this pull request? This PR proposes to disallow default value None when 'to_replace' is not a dictionary. It seems weird we set the default value of `value` to `None` and we ended up allowing the case as below: ```python >>> df.show() ``` ``` +----+------+-----+ | age|height| name| +----+------+-----+ | 10| 80|Alice| ... ``` ```python >>> df.na.replace('Alice').show() ``` ``` +----+------+----+ | age|height|name| +----+------+----+ | 10| 80|null| ... ``` **After** This PR targets to disallow the case above: ```python >>> df.na.replace('Alice').show() ``` ``` ... TypeError: value is required when to_replace is not a dictionary. ``` while we still allow when `to_replace` is a dictionary: ```python >>> df.na.replace({'Alice': None}).show() ``` ``` +----+------+----+ | age|height|name| +----+------+----+ | 10| 80|null| ... ``` ## How was this patch tested? Manually tested, tests were added in `python/pyspark/sql/tests.py` and doctests were fixed. Author: hyukjinkwon Closes #20499 from HyukjinKwon/SPARK-19454-followup. --- docs/sql-programming-guide.md | 1 + python/pyspark/__init__.py | 1 + python/pyspark/_globals.py | 70 +++++++++++++++++++++++++++++++++ python/pyspark/sql/dataframe.py | 26 +++++++++--- python/pyspark/sql/tests.py | 11 +++--- 5 files changed, 99 insertions(+), 10 deletions(-) create mode 100644 python/pyspark/_globals.py diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a0e221b39cc34..eab4030ee25d2 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1929,6 +1929,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. All the arithmetic operations are affected by the change, ie. addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`) and positive module (`pmod`). - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. + - In PySpark, `df.replace` does not allow to omit `value` when `to_replace` is not a dictionary. Previously, `value` could be omitted in the other cases and had `None` by default, which is counterintuitive and error prone. ## Upgrading From Spark SQL 2.1 to 2.2 diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 4d142c91629cc..58218918693ca 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -54,6 +54,7 @@ from pyspark.taskcontext import TaskContext from pyspark.profiler import Profiler, BasicProfiler from pyspark.version import __version__ +from pyspark._globals import _NoValue def since(version): diff --git a/python/pyspark/_globals.py b/python/pyspark/_globals.py new file mode 100644 index 0000000000000..8e6099db09963 --- /dev/null +++ b/python/pyspark/_globals.py @@ -0,0 +1,70 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Module defining global singleton classes. + +This module raises a RuntimeError if an attempt to reload it is made. In that +way the identities of the classes defined here are fixed and will remain so +even if pyspark itself is reloaded. In particular, a function like the following +will still work correctly after pyspark is reloaded: + + def foo(arg=pyspark._NoValue): + if arg is pyspark._NoValue: + ... + +See gh-7844 for a discussion of the reload problem that motivated this module. + +Note that this approach is taken after from NumPy. +""" + +__ALL__ = ['_NoValue'] + + +# Disallow reloading this module so as to preserve the identities of the +# classes defined here. +if '_is_loaded' in globals(): + raise RuntimeError('Reloading pyspark._globals is not allowed') +_is_loaded = True + + +class _NoValueType(object): + """Special keyword value. + + The instance of this class may be used as the default value assigned to a + deprecated keyword in order to check if it has been given a user defined + value. + + This class was copied from NumPy. + """ + __instance = None + + def __new__(cls): + # ensure that only one instance exists + if not cls.__instance: + cls.__instance = super(_NoValueType, cls).__new__(cls) + return cls.__instance + + # needed for python 2 to preserve identity through a pickle + def __reduce__(self): + return (self.__class__, ()) + + def __repr__(self): + return "" + + +_NoValue = _NoValueType() diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8ec24db8717b2..faee870a2d2e2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -27,7 +27,7 @@ import warnings -from pyspark import copy_func, since +from pyspark import copy_func, since, _NoValue from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \ UTF8Deserializer @@ -1532,7 +1532,7 @@ def fillna(self, value, subset=None): return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) @since(1.4) - def replace(self, to_replace, value=None, subset=None): + def replace(self, to_replace, value=_NoValue, subset=None): """Returns a new :class:`DataFrame` replacing a value with another value. :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are aliases of each other. @@ -1545,8 +1545,8 @@ def replace(self, to_replace, value=None, subset=None): :param to_replace: bool, int, long, float, string, list or dict. Value to be replaced. - If the value is a dict, then `value` is ignored and `to_replace` must be a - mapping between a value and a replacement. + If the value is a dict, then `value` is ignored or can be omitted, and `to_replace` + must be a mapping between a value and a replacement. :param value: bool, int, long, float, string, list or None. The replacement value must be a bool, int, long, float, string or None. If `value` is a list, `value` should be of the same length and type as `to_replace`. @@ -1577,6 +1577,16 @@ def replace(self, to_replace, value=None, subset=None): |null| null|null| +----+------+----+ + >>> df4.na.replace({'Alice': None}).show() + +----+------+----+ + | age|height|name| + +----+------+----+ + | 10| 80|null| + | 5| null| Bob| + |null| null| Tom| + |null| null|null| + +----+------+----+ + >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() +----+------+----+ | age|height|name| @@ -1587,6 +1597,12 @@ def replace(self, to_replace, value=None, subset=None): |null| null|null| +----+------+----+ """ + if value is _NoValue: + if isinstance(to_replace, dict): + value = None + else: + raise TypeError("value argument is required when to_replace is not a dictionary.") + # Helper functions def all_of(types): """Given a type or tuple of types and a sequence of xs @@ -2047,7 +2063,7 @@ def fill(self, value, subset=None): fill.__doc__ = DataFrame.fillna.__doc__ - def replace(self, to_replace, value, subset=None): + def replace(self, to_replace, value=_NoValue, subset=None): return self.df.replace(to_replace, value, subset) replace.__doc__ = DataFrame.replace.__doc__ diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 90ff084fed55e..6ace16955000d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2243,11 +2243,6 @@ def test_replace(self): .replace(False, True).first()) self.assertTupleEqual(row, (True, True)) - # replace list while value is not given (default to None) - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first() - self.assertTupleEqual(row, (None, 10, 80.0)) - # replace string with None and then drop None rows row = self.spark.createDataFrame( [(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).dropna() @@ -2283,6 +2278,12 @@ def test_replace(self): self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first() + with self.assertRaisesRegexp( + TypeError, + 'value argument is required when to_replace is not a dictionary.'): + self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first() + def test_capture_analysis_exception(self): self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc")) self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) From f77270b8811bbd8956d0c08fa556265d2c5ee20e Mon Sep 17 00:00:00 2001 From: liuxian Date: Fri, 9 Feb 2018 08:45:06 -0600 Subject: [PATCH 312/774] [SPARK-23358][CORE] When the number of partitions is greater than 2^28, it will result in an error result ## What changes were proposed in this pull request? In the `checkIndexAndDataFile`,the `blocks` is the ` Int` type, when it is greater than 2^28, `blocks*8` will overflow, and this will result in an error result. In fact, `blocks` is actually the number of partitions. ## How was this patch tested? Manual test Author: liuxian Closes #20544 from 10110346/overflow. --- .../org/apache/spark/shuffle/IndexShuffleBlockResolver.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index c5f3f6e2b42b6..d88b25cc7e258 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -84,7 +84,7 @@ private[spark] class IndexShuffleBlockResolver( */ private def checkIndexAndDataFile(index: File, data: File, blocks: Int): Array[Long] = { // the index file should have `block + 1` longs as offset. - if (index.length() != (blocks + 1) * 8) { + if (index.length() != (blocks + 1) * 8L) { return null } val lengths = new Array[Long](blocks) From 0fc26313f8071cdcb4ccd67bb1d6942983199d36 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 9 Feb 2018 08:46:27 -0600 Subject: [PATCH 313/774] [SPARK-21860][CORE][FOLLOWUP] fix java style error ## What changes were proposed in this pull request? #19077 introduced a Java style error (too long line). Quick fix. ## How was this patch tested? running `./dev/lint-java` Author: Marco Gaido Closes #20558 from mgaido91/SPARK-21860. --- .../test/java/org/apache/spark/unsafe/PlatformUtilSuite.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 71c53d35dcab8..3ad9ac7b4de9c 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -139,7 +139,8 @@ public void memoryDebugFillEnabledInTest() { @Test public void heapMemoryReuse() { MemoryAllocator heapMem = new HeapMemoryAllocator(); - // The size is less than `HeapMemoryAllocator.POOLING_THRESHOLD_BYTES`,allocate new memory every time. + // The size is less than `HeapMemoryAllocator.POOLING_THRESHOLD_BYTES`, + // allocate new memory every time. MemoryBlock onheap1 = heapMem.allocate(513); Object obj1 = onheap1.getBaseObject(); heapMem.free(onheap1); From 7f10cf83f311526737fc96d5bb8281d12e41932f Mon Sep 17 00:00:00 2001 From: Rob Vesse Date: Fri, 9 Feb 2018 11:21:20 -0800 Subject: [PATCH 314/774] [SPARK-16501][MESOS] Allow providing Mesos principal & secret via files This commit modifies the Mesos submission client to allow the principal and secret to be provided indirectly via files. The path to these files can be specified either via Spark configuration or via environment variable. Assuming these files are appropriately protected by FS/OS permissions this means we don't ever leak the actual values in process info like ps Environment variable specification is useful because it allows you to interpolate the location of this file when using per-user Mesos credentials. For some background as to why we have taken this approach I will briefly describe our set up. On our systems we provide each authorised user account with their own Mesos credentials to provide certain security and audit guarantees to our customers. These credentials are managed by a central Secret management service. In our `spark-env.sh` we determine the appropriate secret and principal files to use depending on the user who is invoking Spark hence the need to inject these via environment variables as well as by configuration properties. So we set these environment variables appropriately and our Spark read in the contents of those files to authenticate itself with Mesos. This is functionality we have been using it in production across multiple customer sites for some time. This has been in the field for around 18 months with no reported issues. These changes have been sufficient to meet our customer security and audit requirements. We have been building and deploying custom builds of Apache Spark with various minor tweaks like this which we are now looking to contribute back into the community in order that we can rely upon stock Apache Spark builds and stop maintaining our own internal fork. Author: Rob Vesse Closes #20167 from rvesse/SPARK-16501. --- docs/running-on-mesos.md | 40 ++++- .../cluster/mesos/MesosSchedulerUtils.scala | 55 ++++-- .../mesos/MesosSchedulerUtilsSuite.scala | 161 +++++++++++++++++- 3 files changed, 238 insertions(+), 18 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 2bb5ecf1b8509..8e58892e2689f 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -82,6 +82,27 @@ a Spark driver program configured to connect to Mesos. Alternatively, you can also install Spark in the same location in all the Mesos slaves, and configure `spark.mesos.executor.home` (defaults to SPARK_HOME) to point to that location. +## Authenticating to Mesos + +When Mesos Framework authentication is enabled it is necessary to provide a principal and secret by which to authenticate Spark to Mesos. Each Spark job will register with Mesos as a separate framework. + +Depending on your deployment environment you may wish to create a single set of framework credentials that are shared across all users or create framework credentials for each user. Creating and managing framework credentials should be done following the Mesos [Authentication documentation](http://mesos.apache.org/documentation/latest/authentication/). + +Framework credentials may be specified in a variety of ways depending on your deployment environment and security requirements. The most simple way is to specify the `spark.mesos.principal` and `spark.mesos.secret` values directly in your Spark configuration. Alternatively you may specify these values indirectly by instead specifying `spark.mesos.principal.file` and `spark.mesos.secret.file`, these settings point to files containing the principal and secret. These files must be plaintext files in UTF-8 encoding. Combined with appropriate file ownership and mode/ACLs this provides a more secure way to specify these credentials. + +Additionally if you prefer to use environment variables you can specify all of the above via environment variables instead, the environment variable names are simply the configuration settings uppercased with `.` replaced with `_` e.g. `SPARK_MESOS_PRINCIPAL`. + +### Credential Specification Preference Order + +Please note that if you specify multiple ways to obtain the credentials then the following preference order applies. Spark will use the first valid value found and any subsequent values are ignored: + +- `spark.mesos.principal` configuration setting +- `SPARK_MESOS_PRINCIPAL` environment variable +- `spark.mesos.principal.file` configuration setting +- `SPARK_MESOS_PRINCIPAL_FILE` environment variable + +An equivalent order applies for the secret. Essentially we prefer the configuration to be specified directly rather than indirectly by files, and we prefer that configuration settings are used over environment variables. + ## Uploading Spark Package When Mesos runs a task on a Mesos slave for the first time, that slave must have a Spark binary @@ -427,7 +448,14 @@ See the [configuration page](configuration.html) for information on Spark config
+ + + + + @@ -435,7 +463,15 @@ See the [configuration page](configuration.html) for information on Spark config + + + + + diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index e75450369ad85..ecbcc960fc5a0 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler.cluster.mesos +import java.io.File +import java.nio.charset.StandardCharsets import java.util.{List => JList} import java.util.concurrent.CountDownLatch @@ -25,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import com.google.common.base.Splitter +import com.google.common.io.Files import org.apache.mesos.{MesosSchedulerDriver, Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.Protos.FrameworkInfo.Capability @@ -71,26 +74,15 @@ trait MesosSchedulerUtils extends Logging { failoverTimeout: Option[Double] = None, frameworkId: Option[String] = None): SchedulerDriver = { val fwInfoBuilder = FrameworkInfo.newBuilder().setUser(sparkUser).setName(appName) - val credBuilder = Credential.newBuilder() + fwInfoBuilder.setHostname(Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse( + conf.get(DRIVER_HOST_ADDRESS))) webuiUrl.foreach { url => fwInfoBuilder.setWebuiUrl(url) } checkpoint.foreach { checkpoint => fwInfoBuilder.setCheckpoint(checkpoint) } failoverTimeout.foreach { timeout => fwInfoBuilder.setFailoverTimeout(timeout) } frameworkId.foreach { id => fwInfoBuilder.setId(FrameworkID.newBuilder().setValue(id).build()) } - fwInfoBuilder.setHostname(Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse( - conf.get(DRIVER_HOST_ADDRESS))) - conf.getOption("spark.mesos.principal").foreach { principal => - fwInfoBuilder.setPrincipal(principal) - credBuilder.setPrincipal(principal) - } - conf.getOption("spark.mesos.secret").foreach { secret => - credBuilder.setSecret(secret) - } - if (credBuilder.hasSecret && !fwInfoBuilder.hasPrincipal) { - throw new SparkException( - "spark.mesos.principal must be configured when spark.mesos.secret is set") - } + conf.getOption("spark.mesos.role").foreach { role => fwInfoBuilder.setRole(role) } @@ -98,6 +90,7 @@ trait MesosSchedulerUtils extends Logging { if (maxGpus > 0) { fwInfoBuilder.addCapabilities(Capability.newBuilder().setType(Capability.Type.GPU_RESOURCES)) } + val credBuilder = buildCredentials(conf, fwInfoBuilder) if (credBuilder.hasPrincipal) { new MesosSchedulerDriver( scheduler, fwInfoBuilder.build(), masterUrl, credBuilder.build()) @@ -106,6 +99,40 @@ trait MesosSchedulerUtils extends Logging { } } + def buildCredentials( + conf: SparkConf, + fwInfoBuilder: Protos.FrameworkInfo.Builder): Protos.Credential.Builder = { + val credBuilder = Credential.newBuilder() + conf.getOption("spark.mesos.principal") + .orElse(Option(conf.getenv("SPARK_MESOS_PRINCIPAL"))) + .orElse( + conf.getOption("spark.mesos.principal.file") + .orElse(Option(conf.getenv("SPARK_MESOS_PRINCIPAL_FILE"))) + .map { principalFile => + Files.toString(new File(principalFile), StandardCharsets.UTF_8) + } + ).foreach { principal => + fwInfoBuilder.setPrincipal(principal) + credBuilder.setPrincipal(principal) + } + conf.getOption("spark.mesos.secret") + .orElse(Option(conf.getenv("SPARK_MESOS_SECRET"))) + .orElse( + conf.getOption("spark.mesos.secret.file") + .orElse(Option(conf.getenv("SPARK_MESOS_SECRET_FILE"))) + .map { secretFile => + Files.toString(new File(secretFile), StandardCharsets.UTF_8) + } + ).foreach { secret => + credBuilder.setSecret(secret) + } + if (credBuilder.hasSecret && !fwInfoBuilder.hasPrincipal) { + throw new SparkException( + "spark.mesos.principal must be configured when spark.mesos.secret is set") + } + credBuilder + } + /** * Starts the MesosSchedulerDriver and stores the current running driver to this new instance. * This driver is expected to not be running. diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala index 7df738958f85c..8d90e1a8591ad 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala @@ -17,16 +17,20 @@ package org.apache.spark.scheduler.cluster.mesos +import java.io.{File, FileNotFoundException} + import scala.collection.JavaConverters._ import scala.language.reflectiveCalls -import org.apache.mesos.Protos.{Resource, Value} +import com.google.common.io.Files +import org.apache.mesos.Protos.{FrameworkInfo, Resource, Value} import org.mockito.Mockito._ import org.scalatest._ import org.scalatest.mockito.MockitoSugar -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.internal.config._ +import org.apache.spark.util.SparkConfWithEnv class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoSugar { @@ -237,4 +241,157 @@ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoS val portsToUse = getRangesFromResources(resourcesToBeUsed).map{r => r._1} portsToUse.isEmpty shouldBe true } + + test("Principal specified via spark.mesos.principal") { + val conf = new SparkConf() + conf.set("spark.mesos.principal", "test-principal") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Principal specified via spark.mesos.principal.file") { + val pFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt"); + pFile.deleteOnExit() + Files.write("test-principal".getBytes("UTF-8"), pFile); + val conf = new SparkConf() + conf.set("spark.mesos.principal.file", pFile.getAbsolutePath()) + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Principal specified via spark.mesos.principal.file that does not exist") { + val conf = new SparkConf() + conf.set("spark.mesos.principal.file", "/tmp/does-not-exist") + + intercept[FileNotFoundException] { + utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + } + } + + test("Principal specified via SPARK_MESOS_PRINCIPAL") { + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL" -> "test-principal")) + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Principal specified via SPARK_MESOS_PRINCIPAL_FILE") { + val pFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt"); + pFile.deleteOnExit() + Files.write("test-principal".getBytes("UTF-8"), pFile); + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL_FILE" -> pFile.getAbsolutePath())) + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Principal specified via SPARK_MESOS_PRINCIPAL_FILE that does not exist") { + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL_FILE" -> "/tmp/does-not-exist")) + + intercept[FileNotFoundException] { + utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + } + } + + test("Secret specified via spark.mesos.secret") { + val conf = new SparkConf() + conf.set("spark.mesos.principal", "test-principal") + conf.set("spark.mesos.secret", "my-secret") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } + + test("Principal specified via spark.mesos.secret.file") { + val sFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt"); + sFile.deleteOnExit() + Files.write("my-secret".getBytes("UTF-8"), sFile); + val conf = new SparkConf() + conf.set("spark.mesos.principal", "test-principal") + conf.set("spark.mesos.secret.file", sFile.getAbsolutePath()) + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } + + test("Principal specified via spark.mesos.secret.file that does not exist") { + val conf = new SparkConf() + conf.set("spark.mesos.principal", "test-principal") + conf.set("spark.mesos.secret.file", "/tmp/does-not-exist") + + intercept[FileNotFoundException] { + utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + } + } + + test("Principal specified via SPARK_MESOS_SECRET") { + val env = Map("SPARK_MESOS_SECRET" -> "my-secret") + val conf = new SparkConfWithEnv(env) + conf.set("spark.mesos.principal", "test-principal") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } + + test("Principal specified via SPARK_MESOS_SECRET_FILE") { + val sFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt"); + sFile.deleteOnExit() + Files.write("my-secret".getBytes("UTF-8"), sFile); + + val sFilePath = sFile.getAbsolutePath() + val env = Map("SPARK_MESOS_SECRET_FILE" -> sFilePath) + val conf = new SparkConfWithEnv(env) + conf.set("spark.mesos.principal", "test-principal") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } + + test("Secret specified with no principal") { + val conf = new SparkConf() + conf.set("spark.mesos.secret", "my-secret") + + intercept[SparkException] { + utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + } + } + + test("Principal specification preference") { + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL" -> "other-principal")) + conf.set("spark.mesos.principal", "test-principal") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Secret specification preference") { + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_SECRET" -> "other-secret")) + conf.set("spark.mesos.principal", "test-principal") + conf.set("spark.mesos.secret", "my-secret") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } } From 557938e2839afce26a10a849a2a4be8fc4580427 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Fri, 9 Feb 2018 18:18:30 -0600 Subject: [PATCH 315/774] [MINOR][HIVE] Typo fixes ## What changes were proposed in this pull request? Typo fixes (with expanding a Hive property) ## How was this patch tested? local build. Awaiting Jenkins Author: Jacek Laskowski Closes #20550 from jaceklaskowski/hiveutils-typos. --- .../src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 93f3f38e52aa9..c448c5a9821be 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -304,7 +304,7 @@ private[spark] object HiveUtils extends Logging { throw new IllegalArgumentException( "Builtin jars can only be used when hive execution version == hive metastore version. " + s"Execution: $builtinHiveVersion != Metastore: $hiveMetastoreVersion. " + - "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + + s"Specify a valid path to the correct hive jars using ${HIVE_METASTORE_JARS.key} " + s"or change ${HIVE_METASTORE_VERSION.key} to $builtinHiveVersion.") } @@ -324,7 +324,7 @@ private[spark] object HiveUtils extends Logging { if (jars.length == 0) { throw new IllegalArgumentException( "Unable to locate hive jars to connect to metastore. " + - "Please set spark.sql.hive.metastore.jars.") + s"Please set ${HIVE_METASTORE_JARS.key}.") } logInfo( From 6d7c38330e68c7beb10f54eee8b4f607ee3c4136 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Fri, 9 Feb 2018 16:21:47 -0800 Subject: [PATCH 316/774] [SPARK-23275][SQL] fix the thread leaking in hive/tests ## What changes were proposed in this pull request? This is a follow up of https://github.com/apache/spark/pull/20441. The two lines actually can trigger the hive metastore bug: https://issues.apache.org/jira/browse/HIVE-16844 The two configs are not in the default `ObjectStore` properties, so any run hive commands after these two lines will set the `propsChanged` flag in the `ObjectStore.setConf` and then cause thread leaks. I don't think the two lines are very useful. They can be removed safely. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu Closes #20562 from liufengdb/fix-omm. --- .../main/scala/org/apache/spark/sql/hive/test/TestHive.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 59708e7a0f2ff..19028939f3673 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -530,8 +530,6 @@ private[hive] class TestHiveSparkSession( // For some reason, RESET does not reset the following variables... // https://issues.apache.org/jira/browse/HIVE-9004 metadataHive.runSqlHive("set hive.table.parameters.default=") - metadataHive.runSqlHive("set datanucleus.cache.collections=true") - metadataHive.runSqlHive("set datanucleus.cache.collections.lazy=true") // Lots of tests fail if we do not change the partition whitelist from the default. metadataHive.runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") From 97a224a855c4410b2dfb9c0bcc6aae583bd28e92 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sun, 11 Feb 2018 01:08:02 +0900 Subject: [PATCH 317/774] [SPARK-23360][SQL][PYTHON] Get local timezone from environment via pytz, or dateutil. ## What changes were proposed in this pull request? Currently we use `tzlocal()` to get Python local timezone, but it sometimes causes unexpected behavior. I changed the way to get Python local timezone to use pytz if the timezone is specified in environment variable, or timezone file via dateutil . ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN Closes #20559 from ueshin/issues/SPARK-23360/master. --- python/pyspark/sql/tests.py | 28 ++++++++++++++++++++++++++++ python/pyspark/sql/types.py | 23 +++++++++++++++++++---- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6ace16955000d..1087c3fafdd16 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2868,6 +2868,34 @@ def test_create_dataframe_required_pandas_not_found(self): "d": [pd.Timestamp.now().date()]}) self.spark.createDataFrame(pdf) + # Regression test for SPARK-23360 + @unittest.skipIf(not _have_pandas, _pandas_requirement_message) + def test_create_dateframe_from_pandas_with_dst(self): + import pandas as pd + from datetime import datetime + + pdf = pd.DataFrame({'time': [datetime(2015, 10, 31, 22, 30)]}) + + df = self.spark.createDataFrame(pdf) + self.assertPandasEqual(pdf, df.toPandas()) + + orig_env_tz = os.environ.get('TZ', None) + orig_session_tz = self.spark.conf.get('spark.sql.session.timeZone') + try: + tz = 'America/Los_Angeles' + os.environ['TZ'] = tz + time.tzset() + self.spark.conf.set('spark.sql.session.timeZone', tz) + + df = self.spark.createDataFrame(pdf) + self.assertPandasEqual(pdf, df.toPandas()) + finally: + del os.environ['TZ'] + if orig_env_tz is not None: + os.environ['TZ'] = orig_env_tz + time.tzset() + self.spark.conf.set('spark.sql.session.timeZone', orig_session_tz) + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 093dae5a22e1f..2599dc5fdc599 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1709,6 +1709,21 @@ def _check_dataframe_convert_date(pdf, schema): return pdf +def _get_local_timezone(): + """ Get local timezone using pytz with environment variable, or dateutil. + + If there is a 'TZ' environment variable, pass it to pandas to use pytz and use it as timezone + string, otherwise use the special word 'dateutil/:' which means that pandas uses dateutil and + it reads system configuration to know the system local timezone. + + See also: + - https://github.com/pandas-dev/pandas/blob/0.19.x/pandas/tslib.pyx#L1753 + - https://github.com/dateutil/dateutil/blob/2.6.1/dateutil/tz/tz.py#L1338 + """ + import os + return os.environ.get('TZ', 'dateutil/:') + + def _check_dataframe_localize_timestamps(pdf, timezone): """ Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone @@ -1721,7 +1736,7 @@ def _check_dataframe_localize_timestamps(pdf, timezone): require_minimum_pandas_version() from pandas.api.types import is_datetime64tz_dtype - tz = timezone or 'tzlocal()' + tz = timezone or _get_local_timezone() for column, series in pdf.iteritems(): # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64tz_dtype(series.dtype): @@ -1744,7 +1759,7 @@ def _check_series_convert_timestamps_internal(s, timezone): from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64_dtype(s.dtype): - tz = timezone or 'tzlocal()' + tz = timezone or _get_local_timezone() return s.dt.tz_localize(tz).dt.tz_convert('UTC') elif is_datetime64tz_dtype(s.dtype): return s.dt.tz_convert('UTC') @@ -1766,8 +1781,8 @@ def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone): import pandas as pd from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype - from_tz = from_timezone or 'tzlocal()' - to_tz = to_timezone or 'tzlocal()' + from_tz = from_timezone or _get_local_timezone() + to_tz = to_timezone or _get_local_timezone() # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64tz_dtype(s.dtype): return s.dt.tz_convert(to_tz).dt.tz_localize(None) From 0783876c81f212e1422a1b7786c26e3ac8e84f9f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 10 Feb 2018 10:46:45 -0600 Subject: [PATCH 318/774] [SPARK-23344][PYTHON][ML] Add distanceMeasure param to KMeans ## What changes were proposed in this pull request? SPARK-22119 introduced a new parameter for KMeans, ie. `distanceMeasure`. The PR adds it also to the Python interface. ## How was this patch tested? added UTs Author: Marco Gaido Closes #20520 from mgaido91/SPARK-23344. --- python/pyspark/ml/clustering.py | 32 +++++++++++++++++++++++++++----- python/pyspark/ml/tests.py | 18 ++++++++++++++++++ 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 66fb00508522e..6448b76a0da88 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -403,17 +403,23 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol typeConverter=TypeConverters.toString) initSteps = Param(Params._dummy(), "initSteps", "The number of steps for k-means|| " + "initialization mode. Must be > 0.", typeConverter=TypeConverters.toInt) + distanceMeasure = Param(Params._dummy(), "distanceMeasure", "The distance measure. " + + "Supported options: 'euclidean' and 'cosine'.", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", k=2, - initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None): + initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, + distanceMeasure="euclidean"): """ __init__(self, featuresCol="features", predictionCol="prediction", k=2, \ - initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None) + initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, \ + distanceMeasure="euclidean") """ super(KMeans, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid) - self._setDefault(k=2, initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20) + self._setDefault(k=2, initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, + distanceMeasure="euclidean") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -423,10 +429,12 @@ def _create_model(self, java_model): @keyword_only @since("1.5.0") def setParams(self, featuresCol="features", predictionCol="prediction", k=2, - initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None): + initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, + distanceMeasure="euclidean"): """ setParams(self, featuresCol="features", predictionCol="prediction", k=2, \ - initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None) + initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, \ + distanceMeasure="euclidean") Sets params for KMeans. """ @@ -475,6 +483,20 @@ def getInitSteps(self): """ return self.getOrDefault(self.initSteps) + @since("2.4.0") + def setDistanceMeasure(self, value): + """ + Sets the value of :py:attr:`distanceMeasure`. + """ + return self._set(distanceMeasure=value) + + @since("2.4.0") + def getDistanceMeasure(self): + """ + Gets the value of `distanceMeasure` + """ + return self.getOrDefault(self.distanceMeasure) + class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 75d04785a0710..6d6737241e06e 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -418,6 +418,9 @@ def test_kmeans_param(self): self.assertEqual(algo.getK(), 10) algo.setInitSteps(10) self.assertEqual(algo.getInitSteps(), 10) + self.assertEqual(algo.getDistanceMeasure(), "euclidean") + algo.setDistanceMeasure("cosine") + self.assertEqual(algo.getDistanceMeasure(), "cosine") def test_hasseed(self): noSeedSpecd = TestParams() @@ -1620,6 +1623,21 @@ def test_kmeans_summary(self): self.assertEqual(s.k, 2) +class KMeansTests(SparkSessionTestCase): + + def test_kmeans_cosine_distance(self): + data = [(Vectors.dense([1.0, 1.0]),), (Vectors.dense([10.0, 10.0]),), + (Vectors.dense([1.0, 0.5]),), (Vectors.dense([10.0, 4.4]),), + (Vectors.dense([-1.0, 1.0]),), (Vectors.dense([-100.0, 90.0]),)] + df = self.spark.createDataFrame(data, ["features"]) + kmeans = KMeans(k=3, seed=1, distanceMeasure="cosine") + model = kmeans.fit(df) + result = model.transform(df).collect() + self.assertTrue(result[0].prediction == result[1].prediction) + self.assertTrue(result[2].prediction == result[3].prediction) + self.assertTrue(result[4].prediction == result[5].prediction) + + class OneVsRestTests(SparkSessionTestCase): def test_copy(self): From a34fce19bc0ee5a7e36c6ecba75d2aeb70fdcbc7 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Sun, 11 Feb 2018 17:31:35 +0900 Subject: [PATCH 319/774] [SPARK-23314][PYTHON] Add ambiguous=False when localizing tz-naive timestamps in Arrow codepath to deal with dst ## What changes were proposed in this pull request? When tz_localize a tz-naive timetamp, pandas will throw exception if the timestamp is during daylight saving time period, e.g., `2015-11-01 01:30:00`. This PR fixes this issue by setting `ambiguous=False` when calling tz_localize, which is the same default behavior of pytz. ## How was this patch tested? Add `test_timestamp_dst` Author: Li Jin Closes #20537 from icexelloss/SPARK-23314. --- python/pyspark/sql/tests.py | 39 +++++++++++++++++++++++++++++++++++++ python/pyspark/sql/types.py | 37 ++++++++++++++++++++++++++++++++--- 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1087c3fafdd16..4bc59fd99fca5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3670,6 +3670,21 @@ def test_createDataFrame_with_int_col_names(self): self.assertEqual(pdf_col_names, df.columns) self.assertEqual(pdf_col_names, df_arrow.columns) + # Regression test for SPARK-23314 + def test_timestamp_dst(self): + import pandas as pd + # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am + dt = [datetime.datetime(2015, 11, 1, 0, 30), + datetime.datetime(2015, 11, 1, 1, 30), + datetime.datetime(2015, 11, 1, 2, 30)] + pdf = pd.DataFrame({'time': dt}) + + df_from_python = self.spark.createDataFrame(dt, 'timestamp').toDF('time') + df_from_pandas = self.spark.createDataFrame(pdf) + + self.assertPandasEqual(pdf, df_from_python.toPandas()) + self.assertPandasEqual(pdf, df_from_pandas.toPandas()) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, @@ -4311,6 +4326,18 @@ def test_register_vectorized_udf_basic(self): self.assertEquals(expected.collect(), res1.collect()) self.assertEquals(expected.collect(), res2.collect()) + # Regression test for SPARK-23314 + def test_timestamp_dst(self): + from pyspark.sql.functions import pandas_udf + # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am + dt = [datetime.datetime(2015, 11, 1, 0, 30), + datetime.datetime(2015, 11, 1, 1, 30), + datetime.datetime(2015, 11, 1, 2, 30)] + df = self.spark.createDataFrame(dt, 'timestamp').toDF('time') + foo_udf = pandas_udf(lambda x: x, 'timestamp') + result = df.withColumn('time', foo_udf(df.time)) + self.assertEquals(df.collect(), result.collect()) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, @@ -4482,6 +4509,18 @@ def test_unsupported_types(self): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): df.groupby('id').apply(f).collect() + # Regression test for SPARK-23314 + def test_timestamp_dst(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am + dt = [datetime.datetime(2015, 11, 1, 0, 30), + datetime.datetime(2015, 11, 1, 1, 30), + datetime.datetime(2015, 11, 1, 2, 30)] + df = self.spark.createDataFrame(dt, 'timestamp').toDF('time') + foo_udf = pandas_udf(lambda pdf: pdf, 'time timestamp', PandasUDFType.GROUPED_MAP) + result = df.groupby('time').apply(foo_udf).sort('time') + self.assertPandasEqual(df.toPandas(), result.toPandas()) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 2599dc5fdc599..f7141b4549e4e 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1759,8 +1759,38 @@ def _check_series_convert_timestamps_internal(s, timezone): from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64_dtype(s.dtype): + # When tz_localize a tz-naive timestamp, the result is ambiguous if the tz-naive + # timestamp is during the hour when the clock is adjusted backward during due to + # daylight saving time (dst). + # E.g., for America/New_York, the clock is adjusted backward on 2015-11-01 2:00 to + # 2015-11-01 1:00 from dst-time to standard time, and therefore, when tz_localize + # a tz-naive timestamp 2015-11-01 1:30 with America/New_York timezone, it can be either + # dst time (2015-01-01 1:30-0400) or standard time (2015-11-01 1:30-0500). + # + # Here we explicit choose to use standard time. This matches the default behavior of + # pytz. + # + # Here are some code to help understand this behavior: + # >>> import datetime + # >>> import pandas as pd + # >>> import pytz + # >>> + # >>> t = datetime.datetime(2015, 11, 1, 1, 30) + # >>> ts = pd.Series([t]) + # >>> tz = pytz.timezone('America/New_York') + # >>> + # >>> ts.dt.tz_localize(tz, ambiguous=True) + # 0 2015-11-01 01:30:00-04:00 + # dtype: datetime64[ns, America/New_York] + # >>> + # >>> ts.dt.tz_localize(tz, ambiguous=False) + # 0 2015-11-01 01:30:00-05:00 + # dtype: datetime64[ns, America/New_York] + # >>> + # >>> str(tz.localize(t)) + # '2015-11-01 01:30:00-05:00' tz = timezone or _get_local_timezone() - return s.dt.tz_localize(tz).dt.tz_convert('UTC') + return s.dt.tz_localize(tz, ambiguous=False).dt.tz_convert('UTC') elif is_datetime64tz_dtype(s.dtype): return s.dt.tz_convert('UTC') else: @@ -1788,8 +1818,9 @@ def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone): return s.dt.tz_convert(to_tz).dt.tz_localize(None) elif is_datetime64_dtype(s.dtype) and from_tz != to_tz: # `s.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT. - return s.apply(lambda ts: ts.tz_localize(from_tz).tz_convert(to_tz).tz_localize(None) - if ts is not pd.NaT else pd.NaT) + return s.apply( + lambda ts: ts.tz_localize(from_tz, ambiguous=False).tz_convert(to_tz).tz_localize(None) + if ts is not pd.NaT else pd.NaT) else: return s From 8acb51f08b448628b65e90af3b268994f9550e45 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 11 Feb 2018 18:55:38 +0900 Subject: [PATCH 320/774] [SPARK-23084][PYTHON] Add unboundedPreceding(), unboundedFollowing() and currentRow() to PySpark ## What changes were proposed in this pull request? Added unboundedPreceding(), unboundedFollowing() and currentRow() to PySpark, also updated the rangeBetween API ## How was this patch tested? did unit test on my local. Please let me know if I need to add unit test in tests.py Author: Huaxin Gao Closes #20400 from huaxingao/spark_23084. --- python/pyspark/sql/functions.py | 30 ++++++++++++++ python/pyspark/sql/window.py | 70 ++++++++++++++++++++++++--------- 2 files changed, 82 insertions(+), 18 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 05031f5ec87d7..9bb9c323a5a60 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -809,6 +809,36 @@ def ntile(n): return Column(sc._jvm.functions.ntile(int(n))) +@since(2.4) +def unboundedPreceding(): + """ + Window function: returns the special frame boundary that represents the first row + in the window partition. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.unboundedPreceding()) + + +@since(2.4) +def unboundedFollowing(): + """ + Window function: returns the special frame boundary that represents the last row + in the window partition. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.unboundedFollowing()) + + +@since(2.4) +def currentRow(): + """ + Window function: returns the special frame boundary that represents the current row + in the window partition. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.currentRow()) + + # ---------------------- Date/Timestamp functions ------------------------------ @since(1.5) diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index 7ce27f9b102c0..bb841a9b9ff7c 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -16,9 +16,11 @@ # import sys +if sys.version >= '3': + long = int from pyspark import since, SparkContext -from pyspark.sql.column import _to_seq, _to_java_column +from pyspark.sql.column import Column, _to_seq, _to_java_column __all__ = ["Window", "WindowSpec"] @@ -120,20 +122,45 @@ def rangeBetween(start, end): and "5" means the five off after the current row. We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, - and ``Window.currentRow`` to specify special boundary values, rather than using integral - values directly. + ``Window.currentRow``, ``pyspark.sql.functions.unboundedPreceding``, + ``pyspark.sql.functions.unboundedFollowing`` and ``pyspark.sql.functions.currentRow`` + to specify special boundary values, rather than using integral values directly. :param start: boundary start, inclusive. - The frame is unbounded if this is ``Window.unboundedPreceding``, or + The frame is unbounded if this is ``Window.unboundedPreceding``, + a column returned by ``pyspark.sql.functions.unboundedPreceding``, or any value less than or equal to max(-sys.maxsize, -9223372036854775808). :param end: boundary end, inclusive. - The frame is unbounded if this is ``Window.unboundedFollowing``, or + The frame is unbounded if this is ``Window.unboundedFollowing``, + a column returned by ``pyspark.sql.functions.unboundedFollowing``, or any value greater than or equal to min(sys.maxsize, 9223372036854775807). + + >>> from pyspark.sql import functions as F, SparkSession, Window + >>> spark = SparkSession.builder.getOrCreate() + >>> df = spark.createDataFrame( + ... [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")], ["id", "category"]) + >>> window = Window.orderBy("id").partitionBy("category").rangeBetween( + ... F.currentRow(), F.lit(1)) + >>> df.withColumn("sum", F.sum("id").over(window)).show() + +---+--------+---+ + | id|category|sum| + +---+--------+---+ + | 1| b| 3| + | 2| b| 5| + | 3| b| 3| + | 1| a| 4| + | 1| a| 4| + | 2| a| 2| + +---+--------+---+ """ - if start <= Window._PRECEDING_THRESHOLD: - start = Window.unboundedPreceding - if end >= Window._FOLLOWING_THRESHOLD: - end = Window.unboundedFollowing + if isinstance(start, (int, long)) and isinstance(end, (int, long)): + if start <= Window._PRECEDING_THRESHOLD: + start = Window.unboundedPreceding + if end >= Window._FOLLOWING_THRESHOLD: + end = Window.unboundedFollowing + elif isinstance(start, Column) and isinstance(end, Column): + start = start._jc + end = end._jc sc = SparkContext._active_spark_context jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rangeBetween(start, end) return WindowSpec(jspec) @@ -208,27 +235,34 @@ def rangeBetween(self, start, end): and "5" means the five off after the current row. We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, - and ``Window.currentRow`` to specify special boundary values, rather than using integral - values directly. + ``Window.currentRow``, ``pyspark.sql.functions.unboundedPreceding``, + ``pyspark.sql.functions.unboundedFollowing`` and ``pyspark.sql.functions.currentRow`` + to specify special boundary values, rather than using integral values directly. :param start: boundary start, inclusive. - The frame is unbounded if this is ``Window.unboundedPreceding``, or + The frame is unbounded if this is ``Window.unboundedPreceding``, + a column returned by ``pyspark.sql.functions.unboundedPreceding``, or any value less than or equal to max(-sys.maxsize, -9223372036854775808). :param end: boundary end, inclusive. - The frame is unbounded if this is ``Window.unboundedFollowing``, or + The frame is unbounded if this is ``Window.unboundedFollowing``, + a column returned by ``pyspark.sql.functions.unboundedFollowing``, or any value greater than or equal to min(sys.maxsize, 9223372036854775807). """ - if start <= Window._PRECEDING_THRESHOLD: - start = Window.unboundedPreceding - if end >= Window._FOLLOWING_THRESHOLD: - end = Window.unboundedFollowing + if isinstance(start, (int, long)) and isinstance(end, (int, long)): + if start <= Window._PRECEDING_THRESHOLD: + start = Window.unboundedPreceding + if end >= Window._FOLLOWING_THRESHOLD: + end = Window.unboundedFollowing + elif isinstance(start, Column) and isinstance(end, Column): + start = start._jc + end = end._jc return WindowSpec(self._jspec.rangeBetween(start, end)) def _test(): import doctest SparkContext('local[4]', 'PythonTest') - (failure_count, test_count) = doctest.testmod() + (failure_count, test_count) = doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE) if failure_count: exit(-1) From eacb62fbbed317fd0e972102838af231385d54d8 Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Sun, 11 Feb 2018 19:23:15 +0900 Subject: [PATCH 321/774] [SPARK-22624][PYSPARK] Expose range partitioning shuffle introduced by spark-22614 ## What changes were proposed in this pull request? Expose range partitioning shuffle introduced by spark-22614 ## How was this patch tested? Unit test in dataframe.py Please review http://spark.apache.org/contributing.html before opening a pull request. Author: xubo245 <601450868@qq.com> Closes #20456 from xubo245/SPARK22624_PysparkRangePartition. --- python/pyspark/sql/dataframe.py | 45 +++++++++++++++++++++++++++++++++ python/pyspark/sql/tests.py | 28 ++++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index faee870a2d2e2..5cc8b63cdfadf 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -667,6 +667,51 @@ def repartition(self, numPartitions, *cols): else: raise TypeError("numPartitions should be an int or Column") + @since("2.4.0") + def repartitionByRange(self, numPartitions, *cols): + """ + Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The + resulting DataFrame is range partitioned. + + ``numPartitions`` can be an int to specify the target number of partitions or a Column. + If it is a Column, it will be used as the first partitioning column. If not specified, + the default number of partitions is used. + + At least one partition-by expression must be specified. + When no explicit sort order is specified, "ascending nulls first" is assumed. + + >>> df.repartitionByRange(2, "age").rdd.getNumPartitions() + 2 + >>> df.show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 5| Bob| + +---+-----+ + >>> df.repartitionByRange(1, "age").rdd.getNumPartitions() + 1 + >>> data = df.repartitionByRange("age") + >>> df.show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 5| Bob| + +---+-----+ + """ + if isinstance(numPartitions, int): + if len(cols) == 0: + return ValueError("At least one partition-by expression must be specified.") + else: + return DataFrame( + self._jdf.repartitionByRange(numPartitions, self._jcols(*cols)), self.sql_ctx) + elif isinstance(numPartitions, (basestring, Column)): + cols = (numPartitions,) + cols + return DataFrame(self._jdf.repartitionByRange(self._jcols(*cols)), self.sql_ctx) + else: + raise TypeError("numPartitions should be an int, string or Column") + @since(1.3) def distinct(self): """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4bc59fd99fca5..fe89bd0685027 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2148,6 +2148,34 @@ def test_expr(self): result = df.select(functions.expr("length(a)")).collect()[0].asDict() self.assertEqual(13, result["length(a)"]) + def test_repartitionByRange_dataframe(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True)]) + + df1 = self.spark.createDataFrame( + [(u'Bob', 27, 66.0), (u'Alice', 10, 10.0), (u'Bob', 10, 66.0)], schema) + df2 = self.spark.createDataFrame( + [(u'Alice', 10, 10.0), (u'Bob', 10, 66.0), (u'Bob', 27, 66.0)], schema) + + # test repartitionByRange(numPartitions, *cols) + df3 = df1.repartitionByRange(2, "name", "age") + self.assertEqual(df3.rdd.getNumPartitions(), 2) + self.assertEqual(df3.rdd.first(), df2.rdd.first()) + self.assertEqual(df3.rdd.take(3), df2.rdd.take(3)) + + # test repartitionByRange(numPartitions, *cols) + df4 = df1.repartitionByRange(3, "name", "age") + self.assertEqual(df4.rdd.getNumPartitions(), 3) + self.assertEqual(df4.rdd.first(), df2.rdd.first()) + self.assertEqual(df4.rdd.take(3), df2.rdd.take(3)) + + # test repartitionByRange(*cols) + df5 = df1.repartitionByRange("name", "age") + self.assertEqual(df5.rdd.first(), df2.rdd.first()) + self.assertEqual(df5.rdd.take(3), df2.rdd.take(3)) + def test_replace(self): schema = StructType([ StructField("name", StringType(), True), From 4bbd7443ebb005f81ed6bc39849940ac8db3b3cc Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 12 Feb 2018 00:03:49 +0800 Subject: [PATCH 322/774] [SPARK-23376][SQL] creating UnsafeKVExternalSorter with BytesToBytesMap may fail ## What changes were proposed in this pull request? This is a long-standing bug in `UnsafeKVExternalSorter` and was reported in the dev list multiple times. When creating `UnsafeKVExternalSorter` with `BytesToBytesMap`, we need to create a `UnsafeInMemorySorter` to sort the data in `BytesToBytesMap`. The data format of the sorter and the map is same, so no data movement is required. However, both the sorter and the map need a point array for some bookkeeping work. There is an optimization in `UnsafeKVExternalSorter`: reuse the point array between the sorter and the map, to avoid an extra memory allocation. This sounds like a reasonable optimization, the length of the `BytesToBytesMap` point array is at least 4 times larger than the number of keys(to avoid hash collision, the hash table size should be at least 2 times larger than the number of keys, and each key occupies 2 slots). `UnsafeInMemorySorter` needs the pointer array size to be 4 times of the number of entries, so we are safe to reuse the point array. However, the number of keys of the map doesn't equal to the number of entries in the map, because `BytesToBytesMap` supports duplicated keys. This breaks the assumption of the above optimization and we may run out of space when inserting data into the sorter, and hit error ``` java.lang.IllegalStateException: There is no space for new record at org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.insertRecord(UnsafeInMemorySorter.java:239) at org.apache.spark.sql.execution.UnsafeKVExternalSorter.(UnsafeKVExternalSorter.java:149) ... ``` This PR fixes this bug by creating a new point array if the existing one is not big enough. ## How was this patch tested? a new test Author: Wenchen Fan Closes #20561 from cloud-fan/bug. --- .../sql/execution/UnsafeKVExternalSorter.java | 31 +++++++++++---- .../UnsafeKVExternalSorterSuite.scala | 39 +++++++++++++++++++ 2 files changed, 62 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index b0b5383a081a0..9eb03430a7db2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -34,6 +34,7 @@ import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.KVIterator; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.unsafe.sort.*; @@ -98,19 +99,33 @@ public UnsafeKVExternalSorter( numElementsForSpillThreshold, canUseRadixSort); } else { - // The array will be used to do in-place sort, which require half of the space to be empty. - // Note: each record in the map takes two entries in the array, one is record pointer, - // another is the key prefix. - assert(map.numKeys() * 2 <= map.getArray().size() / 2); - // During spilling, the array in map will not be used, so we can borrow that and use it - // as the underlying array for in-memory sorter (it's always large enough). - // Since we will not grow the array, it's fine to pass `null` as consumer. + // During spilling, the pointer array in `BytesToBytesMap` will not be used, so we can borrow + // that and use it as the pointer array for `UnsafeInMemorySorter`. + LongArray pointerArray = map.getArray(); + // `BytesToBytesMap`'s pointer array is only guaranteed to hold all the distinct keys, but + // `UnsafeInMemorySorter`'s pointer array need to hold all the entries. Since + // `BytesToBytesMap` can have duplicated keys, here we need a check to make sure the pointer + // array can hold all the entries in `BytesToBytesMap`. + // The pointer array will be used to do in-place sort, which requires half of the space to be + // empty. Note: each record in the map takes two entries in the pointer array, one is record + // pointer, another is key prefix. So the required size of pointer array is `numRecords * 4`. + // TODO: It's possible to change UnsafeInMemorySorter to have multiple entries with same key, + // so that we can always reuse the pointer array. + if (map.numValues() > pointerArray.size() / 4) { + // Here we ask the map to allocate memory, so that the memory manager won't ask the map + // to spill, if the memory is not enough. + pointerArray = map.allocateArray(map.numValues() * 4L); + } + + // Since the pointer array(either reuse the one in the map, or create a new one) is guaranteed + // to be large enough, it's fine to pass `null` as consumer because we won't allocate more + // memory. final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( null, taskMemoryManager, comparatorSupplier.get(), prefixComparator, - map.getArray(), + pointerArray, canUseRadixSort); // We cannot use the destructive iterator here because we are reusing the existing memory diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 6af9f8b77f8d3..bf588d3bb7841 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.map.BytesToBytesMap /** * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data. @@ -205,4 +206,42 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { spill = true ) } + + test("SPARK-23376: Create UnsafeKVExternalSorter with BytesToByteMap having duplicated keys") { + val memoryManager = new TestMemoryManager(new SparkConf()) + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + val map = new BytesToBytesMap(taskMemoryManager, 64, taskMemoryManager.pageSizeBytes()) + + // Key/value are a unsafe rows with a single int column + val schema = new StructType().add("i", IntegerType) + val key = new UnsafeRow(1) + key.pointTo(new Array[Byte](32), 32) + key.setInt(0, 1) + val value = new UnsafeRow(1) + value.pointTo(new Array[Byte](32), 32) + value.setInt(0, 2) + + for (_ <- 1 to 65) { + val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) + loc.append( + key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) + } + + // Make sure we can successfully create a UnsafeKVExternalSorter with a `BytesToBytesMap` + // which has duplicated keys and the number of entries exceeds its capacity. + try { + TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, null, null)) + new UnsafeKVExternalSorter( + schema, + schema, + sparkContext.env.blockManager, + sparkContext.env.serializerManager, + taskMemoryManager.pageSizeBytes(), + Int.MaxValue, + map) + } finally { + TaskContext.unset() + } + } } From c0c902aedcf9ed24e482d873d766a7df63b964cb Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 11 Feb 2018 20:15:30 -0600 Subject: [PATCH 323/774] [SPARK-22119][FOLLOWUP][ML] Use spherical KMeans with cosine distance ## What changes were proposed in this pull request? In #19340 some comments considered needed to use spherical KMeans when cosine distance measure is specified, as Matlab does; instead of the implementation based on the behavior of other tools/libraries like Rapidminer, nltk and ELKI, ie. the centroids are computed as the mean of all the points in the clusters. The PR introduce the approach used in spherical KMeans. This behavior has the nice feature to minimize the within-cluster cosine distance. ## How was this patch tested? existing/improved UTs Author: Marco Gaido Closes #20518 from mgaido91/SPARK-22119_followup. --- .../spark/mllib/clustering/KMeans.scala | 54 ++++++++++++++++--- .../spark/ml/clustering/KMeansSuite.scala | 15 +++++- 2 files changed, 62 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 607145cb59fba..3c4ba0bc60c7f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -310,8 +310,7 @@ class KMeans private ( points.foreach { point => val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point) costAccum.add(cost) - val sum = sums(bestCenter) - axpy(1.0, point.vector, sum) + distanceMeasureInstance.updateClusterSum(point, sums(bestCenter)) counts(bestCenter) += 1 } @@ -319,10 +318,9 @@ class KMeans private ( }.reduceByKey { case ((sum1, count1), (sum2, count2)) => axpy(1.0, sum2, sum1) (sum1, count1 + count2) - }.mapValues { case (sum, count) => - scal(1.0 / count, sum) - new VectorWithNorm(sum) - }.collectAsMap() + }.collectAsMap().mapValues { case (sum, count) => + distanceMeasureInstance.centroid(sum, count) + } bcCenters.destroy(blocking = false) @@ -657,6 +655,26 @@ private[spark] abstract class DistanceMeasure extends Serializable { v1: VectorWithNorm, v2: VectorWithNorm): Double + /** + * Updates the value of `sum` adding the `point` vector. + * @param point a `VectorWithNorm` to be added to `sum` of a cluster + * @param sum the `sum` for a cluster to be updated + */ + def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { + axpy(1.0, point.vector, sum) + } + + /** + * Returns a centroid for a cluster given its `sum` vector and its `count` of points. + * + * @param sum the `sum` for a cluster + * @param count the number of points in the cluster + * @return the centroid of the cluster + */ + def centroid(sum: Vector, count: Long): VectorWithNorm = { + scal(1.0 / count, sum) + new VectorWithNorm(sum) + } } @Since("2.4.0") @@ -743,6 +761,30 @@ private[spark] class CosineDistanceMeasure extends DistanceMeasure { * @return the cosine distance between the two input vectors */ override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { + assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.") 1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm } + + /** + * Updates the value of `sum` adding the `point` vector. + * @param point a `VectorWithNorm` to be added to `sum` of a cluster + * @param sum the `sum` for a cluster to be updated + */ + override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { + axpy(1.0 / point.norm, point.vector, sum) + } + + /** + * Returns a centroid for a cluster given its `sum` vector and its `count` of points. + * + * @param sum the `sum` for a cluster + * @param count the number of points in the cluster + * @return the centroid of the cluster + */ + override def centroid(sum: Vector, count: Long): VectorWithNorm = { + scal(1.0 / count, sum) + val norm = Vectors.norm(sum, 2) + scal(1.0 / norm, sum) + new VectorWithNorm(sum, 1) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index e4506f23feb31..32830b39407ad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.clustering import scala.util.Random -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} @@ -179,6 +179,19 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(predictionsMap(Vectors.dense(-1.0, 1.0)) == predictionsMap(Vectors.dense(-100.0, 90.0))) + model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0) + } + + test("KMeans with cosine distance is not supported for 0-length vectors") { + val model = new KMeans().setDistanceMeasure(DistanceMeasure.COSINE).setK(2) + val df = spark.createDataFrame(spark.sparkContext.parallelize(Array( + Vectors.dense(0.0, 0.0), + Vectors.dense(10.0, 10.0), + Vectors.dense(1.0, 0.5) + )).map(v => TestRow(v))) + val e = intercept[SparkException](model.fit(df)) + assert(e.getCause.isInstanceOf[AssertionError]) + assert(e.getCause.getMessage.contains("Cosine distance is not defined")) } test("read/write") { From 6efd5d117e98074d1b16a5c991fbd38df9aa196e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 11 Feb 2018 23:46:23 -0800 Subject: [PATCH 324/774] [SPARK-23390][SQL] Flaky Test Suite: FileBasedDataSourceSuite in Spark 2.3/hadoop 2.7 ## What changes were proposed in this pull request? This test only fails with sbt on Hadoop 2.7, I can't reproduce it locally, but here is my speculation by looking at the code: 1. FileSystem.delete doesn't delete the directory entirely, somehow we can still open the file as a 0-length empty file.(just speculation) 2. ORC intentionally allow empty files, and the reader fails during reading without closing the file stream. This PR improves the test to make sure all files are deleted and can't be opened. ## How was this patch tested? N/A Author: Wenchen Fan Closes #20584 from cloud-fan/flaky-test. --- .../spark/sql/FileBasedDataSourceSuite.scala | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 640d6b1583663..2e332362ea644 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.io.FileNotFoundException + import org.apache.hadoop.fs.Path import org.apache.spark.SparkException @@ -102,17 +104,27 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { def testIgnoreMissingFiles(): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath + Seq("0").toDF("a").write.format(format).save(new Path(basePath, "first").toString) Seq("1").toDF("a").write.format(format).save(new Path(basePath, "second").toString) + val thirdPath = new Path(basePath, "third") + val fs = thirdPath.getFileSystem(spark.sparkContext.hadoopConfiguration) Seq("2").toDF("a").write.format(format).save(thirdPath.toString) + val files = fs.listStatus(thirdPath).filter(_.isFile).map(_.getPath) + val df = spark.read.format(format).load( new Path(basePath, "first").toString, new Path(basePath, "second").toString, new Path(basePath, "third").toString) - val fs = thirdPath.getFileSystem(spark.sparkContext.hadoopConfiguration) + // Make sure all data files are deleted and can't be opened. + files.foreach(f => fs.delete(f, false)) assert(fs.delete(thirdPath, true)) + for (f <- files) { + intercept[FileNotFoundException](fs.open(f)) + } + checkAnswer(df, Seq(Row("0"), Row("1"))) } } From c338c8cf8253c037ecd4f39bbd58ed5a86581b37 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 12 Feb 2018 20:49:36 +0900 Subject: [PATCH 325/774] [SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs ## What changes were proposed in this pull request? This PR targets to explicitly specify supported types in Pandas UDFs. The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things. 1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see: ```python from pyspark.sql.functions import pandas_udf pudf = pandas_udf(lambda x: x, "binary") df = spark.createDataFrame([[bytearray(1)]]) df.select(pudf("_1")).show() ``` ``` ... TypeError: Unsupported type in conversion to Arrow: BinaryType ``` We can document this behaviour for its guide. 2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case. ```python from pyspark.sql.functions import pandas_udf, PandasUDFType foo = pandas_udf(lambda v: v.mean(), 'array', PandasUDFType.GROUPED_AGG) df = spark.range(100).selectExpr("id", "array(id) as value") df.groupBy("id").agg(foo("value")).show() ``` ``` ... NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG ``` 3. Since we can check the return type ahead, we can fail fast before actual execution. ```python # we can fail fast at this stage because we know the schema ahead pandas_udf(lambda x: x, BinaryType()) ``` ## How was this patch tested? Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added. Author: hyukjinkwon Closes #20531 from HyukjinKwon/pudf-cleanup. --- docs/sql-programming-guide.md | 4 +- python/pyspark/sql/tests.py | 130 +++++++++++------- python/pyspark/sql/types.py | 4 + python/pyspark/sql/udf.py | 36 +++-- python/pyspark/worker.py | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 2 +- 6 files changed, 111 insertions(+), 67 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index eab4030ee25d2..6174a93b68492 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1676,7 +1676,7 @@ Using the above optimizations with Arrow will produce the same results as when A enabled. Note that even with Arrow, `toPandas()` results in the collection of all records in the DataFrame to the driver program and should be done on a small subset of the data. Not all Spark data types are currently supported and an error can be raised if a column has an unsupported type, -see [Supported Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`, +see [Supported SQL Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`, Spark will fall back to create the DataFrame without Arrow. ## Pandas UDFs (a.k.a. Vectorized UDFs) @@ -1734,7 +1734,7 @@ For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/p ### Supported SQL Types -Currently, all Spark SQL data types are supported by Arrow-based conversion except `MapType`, +Currently, all Spark SQL data types are supported by Arrow-based conversion except `BinaryType`, `MapType`, `ArrayType` of `TimestampType`, and nested `StructType`. ### Setting Arrow Batch Size diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fe89bd0685027..2af218a691026 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3790,10 +3790,10 @@ def foo(x): self.assertEqual(foo.returnType, schema) self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - @pandas_udf(returnType='v double', functionType=PandasUDFType.SCALAR) + @pandas_udf(returnType='double', functionType=PandasUDFType.SCALAR) def foo(x): return x - self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.returnType, DoubleType()) self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP) @@ -3830,7 +3830,7 @@ def zero_with_type(): @pandas_udf(returnType=PandasUDFType.GROUPED_MAP) def foo(df): return df - with self.assertRaisesRegexp(ValueError, 'Invalid returnType'): + with self.assertRaisesRegexp(TypeError, 'Invalid returnType'): @pandas_udf(returnType='double', functionType=PandasUDFType.GROUPED_MAP) def foo(df): return df @@ -3879,7 +3879,7 @@ def random_udf(v): return random_udf def test_vectorized_udf_basic(self): - from pyspark.sql.functions import pandas_udf, col + from pyspark.sql.functions import pandas_udf, col, array df = self.spark.range(10).select( col('id').cast('string').alias('str'), col('id').cast('int').alias('int'), @@ -3887,7 +3887,8 @@ def test_vectorized_udf_basic(self): col('id').cast('float').alias('float'), col('id').cast('double').alias('double'), col('id').cast('decimal').alias('decimal'), - col('id').cast('boolean').alias('bool')) + col('id').cast('boolean').alias('bool'), + array(col('id')).alias('array_long')) f = lambda x: x str_f = pandas_udf(f, StringType()) int_f = pandas_udf(f, IntegerType()) @@ -3896,10 +3897,11 @@ def test_vectorized_udf_basic(self): double_f = pandas_udf(f, DoubleType()) decimal_f = pandas_udf(f, DecimalType()) bool_f = pandas_udf(f, BooleanType()) + array_long_f = pandas_udf(f, ArrayType(LongType())) res = df.select(str_f(col('str')), int_f(col('int')), long_f(col('long')), float_f(col('float')), double_f(col('double')), decimal_f('decimal'), - bool_f(col('bool'))) + bool_f(col('bool')), array_long_f('array_long')) self.assertEquals(df.collect(), res.collect()) def test_register_nondeterministic_vectorized_udf_basic(self): @@ -4104,10 +4106,11 @@ def test_vectorized_udf_chained(self): def test_vectorized_udf_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) - f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'): - df.select(f(col('id'))).collect() + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*MapType'): + pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) def test_vectorized_udf_return_scalar(self): from pyspark.sql.functions import pandas_udf, col @@ -4142,13 +4145,18 @@ def test_vectorized_udf_varargs(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_unsupported_types(self): - from pyspark.sql.functions import pandas_udf, col - schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) - df = self.spark.createDataFrame([(None,)], schema=schema) - f = pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) + from pyspark.sql.functions import pandas_udf with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported data type'): - df.select(f(col('map'))).collect() + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*MapType'): + pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*BinaryType'): + pandas_udf(lambda x: x, BinaryType()) def test_vectorized_udf_dates(self): from pyspark.sql.functions import pandas_udf, col @@ -4379,15 +4387,16 @@ def data(self): .withColumn("vs", array([lit(i) for i in range(20, 30)])) \ .withColumn("v", explode(col('vs'))).drop('vs') - def test_simple(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - df = self.data + def test_supported_types(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col + df = self.data.withColumn("arr", array(col("id"))) foo_udf = pandas_udf( lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), StructType( [StructField('id', LongType()), StructField('v', IntegerType()), + StructField('arr', ArrayType(LongType())), StructField('v1', DoubleType()), StructField('v2', LongType())]), PandasUDFType.GROUPED_MAP @@ -4490,17 +4499,15 @@ def test_datatype_string(self): def test_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, PandasUDFType - df = self.data - - foo = pandas_udf( - lambda pdf: pdf, - 'id long, v map', - PandasUDFType.GROUPED_MAP - ) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'): - df.groupby('id').apply(foo).sort('id').toPandas() + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*grouped map Pandas UDF.*MapType'): + pandas_udf( + lambda pdf: pdf, + 'id long, v map', + PandasUDFType.GROUPED_MAP) def test_wrong_args(self): from pyspark.sql.functions import udf, pandas_udf, sum, PandasUDFType @@ -4519,23 +4526,30 @@ def test_wrong_args(self): df.groupby('id').apply( pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())]))) with self.assertRaisesRegexp(ValueError, 'Invalid udf'): - df.groupby('id').apply( - pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]))) + df.groupby('id').apply(pandas_udf(lambda x, y: x, DoubleType())) with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUPED_MAP'): df.groupby('id').apply( - pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]), - PandasUDFType.SCALAR)) + pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR)) def test_unsupported_types(self): - from pyspark.sql.functions import pandas_udf, col, PandasUDFType + from pyspark.sql.functions import pandas_udf, PandasUDFType schema = StructType( [StructField("id", LongType(), True), StructField("map", MapType(StringType(), IntegerType()), True)]) - df = self.spark.createDataFrame([(1, None,)], schema=schema) - f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUPED_MAP) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported data type'): - df.groupby('id').apply(f).collect() + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*grouped map Pandas UDF.*MapType'): + pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP) + + schema = StructType( + [StructField("id", LongType(), True), + StructField("arr_ts", ArrayType(TimestampType()), True)]) + with QuietTest(self.sc): + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*grouped map Pandas UDF.*ArrayType.*TimestampType'): + pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP) # Regression test for SPARK-23314 def test_timestamp_dst(self): @@ -4614,23 +4628,32 @@ def weighted_mean(v, w): return weighted_mean def test_manual(self): + from pyspark.sql.functions import pandas_udf, array + df = self.data sum_udf = self.pandas_agg_sum_udf mean_udf = self.pandas_agg_mean_udf - - result1 = df.groupby('id').agg(sum_udf(df.v), mean_udf(df.v)).sort('id') + mean_arr_udf = pandas_udf( + self.pandas_agg_mean_udf.func, + ArrayType(self.pandas_agg_mean_udf.returnType), + self.pandas_agg_mean_udf.evalType) + + result1 = df.groupby('id').agg( + sum_udf(df.v), + mean_udf(df.v), + mean_arr_udf(array(df.v))).sort('id') expected1 = self.spark.createDataFrame( - [[0, 245.0, 24.5], - [1, 255.0, 25.5], - [2, 265.0, 26.5], - [3, 275.0, 27.5], - [4, 285.0, 28.5], - [5, 295.0, 29.5], - [6, 305.0, 30.5], - [7, 315.0, 31.5], - [8, 325.0, 32.5], - [9, 335.0, 33.5]], - ['id', 'sum(v)', 'avg(v)']) + [[0, 245.0, 24.5, [24.5]], + [1, 255.0, 25.5, [25.5]], + [2, 265.0, 26.5, [26.5]], + [3, 275.0, 27.5, [27.5]], + [4, 285.0, 28.5, [28.5]], + [5, 295.0, 29.5, [29.5]], + [6, 305.0, 30.5, [30.5]], + [7, 315.0, 31.5, [31.5]], + [8, 325.0, 32.5, [32.5]], + [9, 335.0, 33.5, [33.5]]], + ['id', 'sum(v)', 'avg(v)', 'avg(array(v))']) self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) @@ -4667,14 +4690,15 @@ def test_basic(self): self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) def test_unsupported_types(self): - from pyspark.sql.types import ArrayType, DoubleType, MapType + from pyspark.sql.types import DoubleType, MapType from pyspark.sql.functions import pandas_udf, PandasUDFType with QuietTest(self.sc): with self.assertRaisesRegexp(NotImplementedError, 'not supported'): - @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUPED_AGG) - def mean_and_std_udf(v): - return [v.mean(), v.std()] + pandas_udf( + lambda x: x, + ArrayType(ArrayType(TimestampType())), + PandasUDFType.GROUPED_AGG) with QuietTest(self.sc): with self.assertRaisesRegexp(NotImplementedError, 'not supported'): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index f7141b4549e4e..e25941cd37595 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1638,6 +1638,8 @@ def to_arrow_type(dt): # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read arrow_type = pa.timestamp('us', tz='UTC') elif type(dt) == ArrayType: + if type(dt.elementType) == TimestampType: + raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) arrow_type = pa.list_(to_arrow_type(dt.elementType)) else: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) @@ -1680,6 +1682,8 @@ def from_arrow_type(at): elif types.is_timestamp(at): spark_type = TimestampType() elif types.is_list(at): + if types.is_timestamp(at.value_type): + raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) spark_type = ArrayType(from_arrow_type(at.value_type)) else: raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 08c6b9e521e82..e5b35fc60e167 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -23,7 +23,7 @@ from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \ - _parse_datatype_string + _parse_datatype_string, to_arrow_type, to_arrow_schema __all__ = ["UDFRegistration"] @@ -112,15 +112,31 @@ def returnType(self): else: self._returnType_placeholder = _parse_datatype_string(self._returnType) - if self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \ - and not isinstance(self._returnType_placeholder, StructType): - raise ValueError("Invalid returnType: returnType must be a StructType for " - "pandas_udf with function type GROUPED_MAP") - elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF \ - and isinstance(self._returnType_placeholder, (StructType, ArrayType, MapType)): - raise NotImplementedError( - "ArrayType, StructType and MapType are not supported with " - "PandasUDFType.GROUPED_AGG") + if self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF: + try: + to_arrow_type(self._returnType_placeholder) + except TypeError: + raise NotImplementedError( + "Invalid returnType with scalar Pandas UDFs: %s is " + "not supported" % str(self._returnType_placeholder)) + elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + if isinstance(self._returnType_placeholder, StructType): + try: + to_arrow_schema(self._returnType_placeholder) + except TypeError: + raise NotImplementedError( + "Invalid returnType with grouped map Pandas UDFs: " + "%s is not supported" % str(self._returnType_placeholder)) + else: + raise TypeError("Invalid returnType for grouped map Pandas " + "UDFs: returnType must be a StructType.") + elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: + try: + to_arrow_type(self._returnType_placeholder) + except TypeError: + raise NotImplementedError( + "Invalid returnType with grouped aggregate Pandas UDFs: " + "%s is not supported" % str(self._returnType_placeholder)) return self._returnType_placeholder diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 121b3dd1aeec9..89a3a92bc66d6 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -116,7 +116,7 @@ def wrap_grouped_agg_pandas_udf(f, return_type): def wrapped(*series): import pandas as pd result = f(*series) - return pd.Series(result) + return pd.Series([result]) return lambda *a: (wrapped(*a), arrow_return_type) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1e2501ee7757d..7835dbaa58439 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1064,7 +1064,7 @@ object SQLConf { "for use with pyspark.sql.DataFrame.toPandas, and " + "pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame. " + "The following data types are unsupported: " + - "MapType, ArrayType of TimestampType, and nested StructType.") + "BinaryType, MapType, ArrayType of TimestampType, and nested StructType.") .booleanConf .createWithDefault(false) From caeb108e25e5bfb7cffcf09ef9abbb1abcfa355d Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Mon, 12 Feb 2018 22:05:27 +0800 Subject: [PATCH 326/774] [MINOR][TEST] spark.testing` No effect on the SparkFunSuite unit test ## What changes were proposed in this pull request? Currently, we use SBT and MAVN to spark unit test, are affected by the parameters of `spark.testing`. However, when using the IDE test tool, `spark.testing` support is not very good, sometimes need to be manually added to the beforeEach. example: HiveSparkSubmitSuite RPackageUtilsSuite SparkSubmitSuite. The PR unified `spark.testing` parameter extraction to SparkFunSuite, support IDE test tool, and the test code is more compact. ## How was this patch tested? the existed test cases. Author: caoxuewen Closes #20582 from heary-cao/sparktesting. --- core/src/test/scala/org/apache/spark/SparkFunSuite.scala | 1 + .../test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala | 1 - .../test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala | 1 - .../spark/network/netty/NettyBlockTransferServiceSuite.scala | 1 + .../scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 1 - 5 files changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 3af9d82393bc4..31289026b0027 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -59,6 +59,7 @@ abstract class SparkFunSuite protected val enableAutoThreadAudit = true protected override def beforeAll(): Unit = { + System.setProperty("spark.testing", "true") if (enableAutoThreadAudit) { doThreadPreAudit() } diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala index 32dd3ecc2f027..ef947eb074647 100644 --- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala @@ -66,7 +66,6 @@ class RPackageUtilsSuite override def beforeEach(): Unit = { super.beforeEach() - System.setProperty("spark.testing", "true") lineBuffer.clear() } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 27dd435332348..803a38d77fb82 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -107,7 +107,6 @@ class SparkSubmitSuite override def beforeEach() { super.beforeEach() - System.setProperty("spark.testing", "true") } // scalastyle:off println diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index f7bc3725d7278..78423ee68a0ec 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -80,6 +80,7 @@ class NettyBlockTransferServiceSuite private def verifyServicePort(expectedPort: Int, actualPort: Int): Unit = { actualPort should be >= expectedPort // avoid testing equality in case of simultaneous tests + // if `spark.testing` is true, // the default value for `spark.port.maxRetries` is 100 under test actualPort should be <= (expectedPort + 100) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 10204f4694663..2d31781132edc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -50,7 +50,6 @@ class HiveSparkSubmitSuite override def beforeEach() { super.beforeEach() - System.setProperty("spark.testing", "true") } test("temporary Hive UDF: define a UDF and use it") { From 0e2c266de7189473177f45aa68ea6a45c7e47ec3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 12 Feb 2018 22:07:59 +0800 Subject: [PATCH 327/774] [SPARK-22977][SQL] fix web UI SQL tab for CTAS ## What changes were proposed in this pull request? This is a regression in Spark 2.3. In Spark 2.2, we have a fragile UI support for SQL data writing commands. We only track the input query plan of `FileFormatWriter` and display its metrics. This is not ideal because we don't know who triggered the writing(can be table insertion, CTAS, etc.), but it's still useful to see the metrics of the input query. In Spark 2.3, we introduced a new mechanism: `DataWritigCommand`, to fix the UI issue entirely. Now these writing commands have real children, and we don't need to hack into the `FileFormatWriter` for the UI. This also helps with `explain`, now `explain` can show the physical plan of the input query, while in 2.2 the physical writing plan is simply `ExecutedCommandExec` and it has no child. However there is a regression in CTAS. CTAS commands don't extend `DataWritigCommand`, and we don't have the UI hack in `FileFormatWriter` anymore, so the UI for CTAS is just an empty node. See https://issues.apache.org/jira/browse/SPARK-22977 for more information about this UI issue. To fix it, we should apply the `DataWritigCommand` mechanism to CTAS commands. TODO: In the future, we should refactor this part and create some physical layer code pieces for data writing, and reuse them in different writing commands. We should have different logical nodes for different operators, even some of them share some same logic, e.g. CTAS, CREATE TABLE, INSERT TABLE. Internally we can share the same physical logic. ## How was this patch tested? manually tested. For data source table 1 For hive table 2 Author: Wenchen Fan Closes #20521 from cloud-fan/UI. --- .../command/createDataSourceTables.scala | 21 +++---- .../execution/datasources/DataSource.scala | 44 +++++++++++++-- .../datasources/DataSourceStrategy.scala | 2 +- .../spark/sql/hive/HiveStrategies.scala | 2 +- .../CreateHiveTableAsSelectCommand.scala | 55 ++++++++++--------- .../sql/hive/execution/HiveExplainSuite.scala | 26 --------- 6 files changed, 80 insertions(+), 70 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 306f43dc4214a..e9747769dfcfc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -21,7 +21,9 @@ import java.net.URI import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.StructType @@ -136,12 +138,11 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo case class CreateDataSourceTableAsSelectCommand( table: CatalogTable, mode: SaveMode, - query: LogicalPlan) - extends RunnableCommand { - - override protected def innerChildren: Seq[LogicalPlan] = Seq(query) + query: LogicalPlan, + outputColumns: Seq[Attribute]) + extends DataWritingCommand { - override def run(sparkSession: SparkSession): Seq[Row] = { + override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { assert(table.tableType != CatalogTableType.VIEW) assert(table.provider.isDefined) @@ -163,7 +164,7 @@ case class CreateDataSourceTableAsSelectCommand( } saveDataIntoTable( - sparkSession, table, table.storage.locationUri, query, SaveMode.Append, tableExists = true) + sparkSession, table, table.storage.locationUri, child, SaveMode.Append, tableExists = true) } else { assert(table.schema.isEmpty) @@ -173,7 +174,7 @@ case class CreateDataSourceTableAsSelectCommand( table.storage.locationUri } val result = saveDataIntoTable( - sparkSession, table, tableLocation, query, SaveMode.Overwrite, tableExists = false) + sparkSession, table, tableLocation, child, SaveMode.Overwrite, tableExists = false) val newTable = table.copy( storage = table.storage.copy(locationUri = tableLocation), // We will use the schema of resolved.relation as the schema of the table (instead of @@ -198,10 +199,10 @@ case class CreateDataSourceTableAsSelectCommand( session: SparkSession, table: CatalogTable, tableLocation: Option[URI], - data: LogicalPlan, + physicalPlan: SparkPlan, mode: SaveMode, tableExists: Boolean): BaseRelation = { - // Create the relation based on the input logical plan: `data`. + // Create the relation based on the input logical plan: `query`. val pathOption = tableLocation.map("path" -> CatalogUtils.URIToString(_)) val dataSource = DataSource( session, @@ -212,7 +213,7 @@ case class CreateDataSourceTableAsSelectCommand( catalogTable = if (tableExists) Some(table) else None) try { - dataSource.writeAndRead(mode, query) + dataSource.writeAndRead(mode, query, outputColumns, physicalPlan) } catch { case ex: AnalysisException => logError(s"Failed to write to table ${table.identifier.unquotedString}", ex) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 25e1210504273..6e1b5727e3fd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -31,8 +31,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogUtils} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider import org.apache.spark.sql.execution.datasources.json.JsonFileFormat @@ -435,10 +437,11 @@ case class DataSource( } /** - * Writes the given [[LogicalPlan]] out in this [[FileFormat]]. + * Creates a command node to write the given [[LogicalPlan]] out to the given [[FileFormat]]. + * The returned command is unresolved and need to be analyzed. */ private def planForWritingFileFormat( - format: FileFormat, mode: SaveMode, data: LogicalPlan): LogicalPlan = { + format: FileFormat, mode: SaveMode, data: LogicalPlan): InsertIntoHadoopFsRelationCommand = { // Don't glob path for the write path. The contracts here are: // 1. Only one output path can be specified on the write path; // 2. Output path must be a legal HDFS style file system path; @@ -482,9 +485,24 @@ case class DataSource( /** * Writes the given [[LogicalPlan]] out to this [[DataSource]] and returns a [[BaseRelation]] for * the following reading. + * + * @param mode The save mode for this writing. + * @param data The input query plan that produces the data to be written. Note that this plan + * is analyzed and optimized. + * @param outputColumns The original output columns of the input query plan. The optimizer may not + * preserve the output column's names' case, so we need this parameter + * instead of `data.output`. + * @param physicalPlan The physical plan of the input query plan. We should run the writing + * command with this physical plan instead of creating a new physical plan, + * so that the metrics can be correctly linked to the given physical plan and + * shown in the web UI. */ - def writeAndRead(mode: SaveMode, data: LogicalPlan): BaseRelation = { - if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { + def writeAndRead( + mode: SaveMode, + data: LogicalPlan, + outputColumns: Seq[Attribute], + physicalPlan: SparkPlan): BaseRelation = { + if (outputColumns.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { throw new AnalysisException("Cannot save interval data type into external storage.") } @@ -493,9 +511,23 @@ case class DataSource( dataSource.createRelation( sparkSession.sqlContext, mode, caseInsensitiveOptions, Dataset.ofRows(sparkSession, data)) case format: FileFormat => - sparkSession.sessionState.executePlan(planForWritingFileFormat(format, mode, data)).toRdd + val cmd = planForWritingFileFormat(format, mode, data) + val resolvedPartCols = cmd.partitionColumns.map { col => + // The partition columns created in `planForWritingFileFormat` should always be + // `UnresolvedAttribute` with a single name part. + assert(col.isInstanceOf[UnresolvedAttribute]) + val unresolved = col.asInstanceOf[UnresolvedAttribute] + assert(unresolved.nameParts.length == 1) + val name = unresolved.nameParts.head + outputColumns.find(a => equality(a.name, name)).getOrElse { + throw new AnalysisException( + s"Unable to resolve $name given [${data.output.map(_.name).mkString(", ")}]") + } + } + val resolved = cmd.copy(partitionColumns = resolvedPartCols, outputColumns = outputColumns) + resolved.run(sparkSession, physicalPlan) // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring - copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation() + copy(userSpecifiedSchema = Some(outputColumns.toStructType.asNullable)).resolveRelation() case _ => sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index d94c5bbccdd84..3f41612c08065 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -139,7 +139,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast case CreateTable(tableDesc, mode, Some(query)) if query.resolved && DDLUtils.isDatasourceTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc.copy(schema = query.schema)) - CreateDataSourceTableAsSelectCommand(tableDesc, mode, query) + CreateDataSourceTableAsSelectCommand(tableDesc, mode, query, query.output) case InsertIntoTable(l @ LogicalRelation(_: InsertableRelation, _, _, _), parts, query, overwrite, false) if parts.isEmpty => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index ab857b9055720..8df05cbb20361 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -157,7 +157,7 @@ object HiveAnalysis extends Rule[LogicalPlan] { case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isHiveTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc) - CreateHiveTableAsSelectCommand(tableDesc, query, mode) + CreateHiveTableAsSelectCommand(tableDesc, query, query.output, mode) case InsertIntoDir(isLocal, storage, provider, child, overwrite) if DDLUtils.isHiveTable(provider) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 65e8b4e3c725c..1e801fe1845c4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.hive.execution import scala.util.control.NonFatal import org.apache.spark.sql.{AnalysisException, Row, SaveMode, SparkSession} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} -import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.command.DataWritingCommand /** @@ -36,15 +37,15 @@ import org.apache.spark.sql.execution.command.RunnableCommand case class CreateHiveTableAsSelectCommand( tableDesc: CatalogTable, query: LogicalPlan, + outputColumns: Seq[Attribute], mode: SaveMode) - extends RunnableCommand { + extends DataWritingCommand { private val tableIdentifier = tableDesc.identifier - override def innerChildren: Seq[LogicalPlan] = Seq(query) - - override def run(sparkSession: SparkSession): Seq[Row] = { - if (sparkSession.sessionState.catalog.tableExists(tableIdentifier)) { + override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog + if (catalog.tableExists(tableIdentifier)) { assert(mode != SaveMode.Overwrite, s"Expect the table $tableIdentifier has been dropped when the save mode is Overwrite") @@ -56,34 +57,36 @@ case class CreateHiveTableAsSelectCommand( return Seq.empty } - sparkSession.sessionState.executePlan( - InsertIntoTable( - UnresolvedRelation(tableIdentifier), - Map(), - query, - overwrite = false, - ifPartitionNotExists = false)).toRdd + InsertIntoHiveTable( + tableDesc, + Map.empty, + query, + overwrite = false, + ifPartitionNotExists = false, + outputColumns = outputColumns).run(sparkSession, child) } else { // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data // processing. assert(tableDesc.schema.isEmpty) - sparkSession.sessionState.catalog.createTable( - tableDesc.copy(schema = query.schema), ignoreIfExists = false) + catalog.createTable(tableDesc.copy(schema = query.schema), ignoreIfExists = false) try { - sparkSession.sessionState.executePlan( - InsertIntoTable( - UnresolvedRelation(tableIdentifier), - Map(), - query, - overwrite = true, - ifPartitionNotExists = false)).toRdd + // Read back the metadata of the table which was created just now. + val createdTableMeta = catalog.getTableMetadata(tableDesc.identifier) + // For CTAS, there is no static partition values to insert. + val partition = createdTableMeta.partitionColumnNames.map(_ -> None).toMap + InsertIntoHiveTable( + createdTableMeta, + partition, + query, + overwrite = true, + ifPartitionNotExists = false, + outputColumns = outputColumns).run(sparkSession, child) } catch { case NonFatal(e) => // drop the created table. - sparkSession.sessionState.catalog.dropTable(tableIdentifier, ignoreIfNotExists = true, - purge = false) + catalog.dropTable(tableIdentifier, ignoreIfNotExists = true, purge = false) throw e } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index f84d188075b72..5d56f89c2271c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -128,32 +128,6 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto "src") } - test("SPARK-17409: The EXPLAIN output of CTAS only shows the analyzed plan") { - withTempView("jt") { - val ds = (1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""").toDS() - spark.read.json(ds).createOrReplaceTempView("jt") - val outputs = sql( - s""" - |EXPLAIN EXTENDED - |CREATE TABLE t1 - |AS - |SELECT * FROM jt - """.stripMargin).collect().map(_.mkString).mkString - - val shouldContain = - "== Parsed Logical Plan ==" :: "== Analyzed Logical Plan ==" :: "Subquery" :: - "== Optimized Logical Plan ==" :: "== Physical Plan ==" :: - "CreateHiveTableAsSelect" :: "InsertIntoHiveTable" :: "jt" :: Nil - for (key <- shouldContain) { - assert(outputs.contains(key), s"$key doesn't exist in result") - } - - val physicalIndex = outputs.indexOf("== Physical Plan ==") - assert(outputs.substring(physicalIndex).contains("Subquery"), - "Physical Plan should contain SubqueryAlias since the query should not be optimized") - } - } - test("explain output of physical plan should contain proper codegen stage ID") { checkKeywordsExist(sql( """ From 4a4dd4f36f65410ef5c87f7b61a960373f044e61 Mon Sep 17 00:00:00 2001 From: liuxian Date: Mon, 12 Feb 2018 08:49:45 -0600 Subject: [PATCH 328/774] [SPARK-23391][CORE] It may lead to overflow for some integer multiplication ## What changes were proposed in this pull request? In the `getBlockData`,`blockId.reduceId` is the `Int` type, when it is greater than 2^28, `blockId.reduceId*8` will overflow In the `decompress0`, `len` and `unitSize` are Int type, so `len * unitSize` may lead to overflow ## How was this patch tested? N/A Author: liuxian Closes #20581 from 10110346/overflow2. --- .../org/apache/spark/shuffle/IndexShuffleBlockResolver.scala | 4 ++-- .../execution/columnar/compression/compressionSchemes.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index d88b25cc7e258..d3f1c7ec1bbee 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -202,13 +202,13 @@ private[spark] class IndexShuffleBlockResolver( // class of issue from re-occurring in the future which is why they are left here even though // SPARK-22982 is fixed. val channel = Files.newByteChannel(indexFile.toPath) - channel.position(blockId.reduceId * 8) + channel.position(blockId.reduceId * 8L) val in = new DataInputStream(Channels.newInputStream(channel)) try { val offset = in.readLong() val nextOffset = in.readLong() val actualPosition = channel.position() - val expectedPosition = blockId.reduceId * 8 + 16 + val expectedPosition = blockId.reduceId * 8L + 16 if (actualPosition != expectedPosition) { throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " + s"expected $expectedPosition but actual position was $actualPosition.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala index 79dcf3a6105ce..00a1d54b41709 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala @@ -116,7 +116,7 @@ private[columnar] case object PassThrough extends CompressionScheme { while (pos < capacity) { if (pos != nextNullIndex) { val len = nextNullIndex - pos - assert(len * unitSize < Int.MaxValue) + assert(len * unitSize.toLong < Int.MaxValue) putFunction(columnVector, pos, bufferPos, len) bufferPos += len * unitSize pos += len From 5bb11411aec18b8d623e54caba5397d7cb8e89f0 Mon Sep 17 00:00:00 2001 From: James Thompson Date: Mon, 12 Feb 2018 11:34:56 -0800 Subject: [PATCH 329/774] [SPARK-23388][SQL] Support for Parquet Binary DecimalType in VectorizedColumnReader ## What changes were proposed in this pull request? Re-add support for parquet binary DecimalType in VectorizedColumnReader ## How was this patch tested? Existing test suite Author: James Thompson Closes #20580 from jamesthomp/jt/add-back-binary-decimal. --- .../execution/datasources/parquet/VectorizedColumnReader.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index c120863152a96..47dd625f4b154 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -444,7 +444,8 @@ private void readBinaryBatch(int rowId, int num, WritableColumnVector column) { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; - if (column.dataType() == DataTypes.StringType || column.dataType() == DataTypes.BinaryType) { + if (column.dataType() == DataTypes.StringType || column.dataType() == DataTypes.BinaryType + || DecimalType.isByteArrayDecimalType(column.dataType())) { defColumn.readBinarys(num, column, rowId, maxDefLevel, data); } else if (column.dataType() == DataTypes.TimestampType) { if (!shouldConvertTimestamps()) { From 0c66fe4f22f8af4932893134bb0fd56f00fabeae Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 12 Feb 2018 12:20:29 -0800 Subject: [PATCH 330/774] [SPARK-22002][SQL][FOLLOWUP][TEST] Add a test to check if the original schema doesn't have metadata. ## What changes were proposed in this pull request? This is a follow-up pr of #19231 which modified the behavior to remove metadata from JDBC table schema. This pr adds a test to check if the schema doesn't have metadata. ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN Closes #20585 from ueshin/issues/SPARK-22002/fup1. --- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index cb2df0ac54f4c..5238adce4a699 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1168,4 +1168,26 @@ class JDBCSuite extends SparkFunSuite val df3 = sql("SELECT * FROM test_sessionInitStatement") assert(df3.collect() === Array(Row(21519, 1234))) } + + test("jdbc data source shouldn't have unnecessary metadata in its schema") { + val schema = StructType(Seq( + StructField("NAME", StringType, true), StructField("THEID", IntegerType, true))) + + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("DbTaBle", "TEST.PEOPLE") + .load() + assert(df.schema === schema) + + withTempView("people_view") { + sql( + s""" + |CREATE TEMPORARY VIEW people_view + |USING org.apache.spark.sql.jdbc + |OPTIONS (uRl '$url', DbTaBlE 'TEST.PEOPLE', User 'testUser', PassWord 'testPass') + """.stripMargin.replaceAll("\n", " ")) + + assert(sql("select * from people_view").schema === schema) + } + } } From fba01b9a65e5d9438d35da0bd807c179ba741911 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Mon, 12 Feb 2018 14:58:31 -0800 Subject: [PATCH 331/774] [SPARK-23378][SQL] move setCurrentDatabase from HiveExternalCatalog to HiveClientImpl ## What changes were proposed in this pull request? This removes the special case that `alterPartitions` call from `HiveExternalCatalog` can reset the current database in the hive client as a side effect. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu Closes #20564 from liufengdb/move. --- .../spark/sql/hive/HiveExternalCatalog.scala | 5 ---- .../sql/hive/client/HiveClientImpl.scala | 26 ++++++++++++++----- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 3b8a8ca301c27..1ee1d57b8ebe1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -1107,11 +1107,6 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - // Note: Before altering table partitions in Hive, you *must* set the current database - // to the one that contains the table of interest. Otherwise you will end up with the - // most helpful error message ever: "Unable to alter partition. alter is not possible." - // See HIVE-2742 for more detail. - client.setCurrentDatabase(db) client.alterPartitions(db, table, withStatsProps) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 6c0f4144992ae..c223f51b1be75 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -291,14 +291,18 @@ private[hive] class HiveClientImpl( state.err = stream } - override def setCurrentDatabase(databaseName: String): Unit = withHiveState { - if (databaseExists(databaseName)) { - state.setCurrentDatabase(databaseName) + private def setCurrentDatabaseRaw(db: String): Unit = { + if (databaseExists(db)) { + state.setCurrentDatabase(db) } else { - throw new NoSuchDatabaseException(databaseName) + throw new NoSuchDatabaseException(db) } } + override def setCurrentDatabase(databaseName: String): Unit = withHiveState { + setCurrentDatabaseRaw(databaseName) + } + override def createDatabase( database: CatalogDatabase, ignoreIfExists: Boolean): Unit = withHiveState { @@ -598,8 +602,18 @@ private[hive] class HiveClientImpl( db: String, table: String, newParts: Seq[CatalogTablePartition]): Unit = withHiveState { - val hiveTable = toHiveTable(getTable(db, table), Some(userName)) - shim.alterPartitions(client, table, newParts.map { p => toHivePartition(p, hiveTable) }.asJava) + // Note: Before altering table partitions in Hive, you *must* set the current database + // to the one that contains the table of interest. Otherwise you will end up with the + // most helpful error message ever: "Unable to alter partition. alter is not possible." + // See HIVE-2742 for more detail. + val original = state.getCurrentDatabase + try { + setCurrentDatabaseRaw(db) + val hiveTable = toHiveTable(getTable(db, table), Some(userName)) + shim.alterPartitions(client, table, newParts.map { toHivePartition(_, hiveTable) }.asJava) + } finally { + state.setCurrentDatabase(original) + } } /** From 6cb59708c70c03696c772fbb5d158eed57fe67d4 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 12 Feb 2018 15:26:37 -0800 Subject: [PATCH 332/774] [SPARK-23313][DOC] Add a migration guide for ORC ## What changes were proposed in this pull request? This PR adds a migration guide documentation for ORC. ![orc-guide](https://user-images.githubusercontent.com/9700541/36123859-ec165cae-1002-11e8-90b7-7313be7a81a5.png) ## How was this patch tested? N/A. Author: Dongjoon Hyun Closes #20484 from dongjoon-hyun/SPARK-23313. --- docs/sql-programming-guide.md | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 6174a93b68492..0f9f01e18682f 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1776,6 +1776,35 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.2 to 2.3 + - Since Spark 2.3, Spark supports a vectorized ORC reader with a new ORC file format for ORC files. To do that, the following configurations are newly added or change their default values. The vectorized reader is used for the native ORC tables (e.g., the ones created using the clause `USING ORC`) when `spark.sql.orc.impl` is set to `native` and `spark.sql.orc.enableVectorizedReader` is set to `true`. For the Hive ORC serde table (e.g., the ones created using the clause `USING HIVE OPTIONS (fileFormat 'ORC')`), the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is set to `true`. + + - New configurations + +
spark.mesos.principal (none) - Set the principal with which Spark framework will use to authenticate with Mesos. + Set the principal with which Spark framework will use to authenticate with Mesos. You can also specify this via the environment variable `SPARK_MESOS_PRINCIPAL`. +
spark.mesos.principal.file(none) + Set the file containing the principal with which Spark framework will use to authenticate with Mesos. Allows specifying the principal indirectly in more security conscious deployments. The file must be readable by the user launching the job and be UTF-8 encoded plaintext. You can also specify this via the environment variable `SPARK_MESOS_PRINCIPAL_FILE`.
(none) Set the secret with which Spark framework will use to authenticate with Mesos. Used, for example, when - authenticating with the registry. + authenticating with the registry. You can also specify this via the environment variable `SPARK_MESOS_SECRET`. +
spark.mesos.secret.file(none) + Set the file containing the secret with which Spark framework will use to authenticate with Mesos. Used, for example, when + authenticating with the registry. Allows for specifying the secret indirectly in more security conscious deployments. The file must be readable by the user launching the job and be UTF-8 encoded plaintext. You can also specify this via the environment variable `SPARK_MESOS_SECRET_FILE`.
+ + + + + + + + + + + +
Property NameDefaultMeaning
spark.sql.orc.implnativeThe name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4.1. `hive` means the ORC library in Hive 1.2.1 which is used prior to Spark 2.3.
spark.sql.orc.enableVectorizedReadertrueEnables vectorized orc decoding in native implementation. If false, a new non-vectorized ORC reader is used in native implementation. For hive implementation, this is ignored.
+ + - Changed configurations + + + + + + + + +
Property NameDefaultMeaning
spark.sql.orc.filterPushdowntrueEnables filter pushdown for ORC files. It is false by default prior to Spark 2.3.
+ - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. - The `percentile_approx` function previously accepted numeric type input and output double type results. Now it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles. - Since Spark 2.3, the Join/Filter's deterministic predicates that are after the first non-deterministic predicates are also pushed down/through the child operators, if possible. In prior Spark versions, these filters are not eligible for predicate pushdown. From 4104b68e958cd13975567a96541dac7cccd8195c Mon Sep 17 00:00:00 2001 From: sychen Date: Mon, 12 Feb 2018 16:00:47 -0800 Subject: [PATCH 333/774] [SPARK-23230][SQL] When hive.default.fileformat is other kinds of file types, create textfile table cause a serde error When hive.default.fileformat is other kinds of file types, create textfile table cause a serde error. We should take the default type of textfile and sequencefile both as org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe. ``` set hive.default.fileformat=orc; create table tbl( i string ) stored as textfile; desc formatted tbl; Serde Library org.apache.hadoop.hive.ql.io.orc.OrcSerde InputFormat org.apache.hadoop.mapred.TextInputFormat OutputFormat org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat ``` Author: sychen Closes #20406 from cxzl25/default_serde. --- .../apache/spark/sql/internal/HiveSerDe.scala | 6 ++++-- .../sql/hive/execution/HiveSerDeSuite.scala | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala index dac463641cfab..eca612f06f9bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala @@ -31,7 +31,8 @@ object HiveSerDe { "sequencefile" -> HiveSerDe( inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), - outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")), + outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat"), + serde = Option("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")), "rcfile" -> HiveSerDe( @@ -54,7 +55,8 @@ object HiveSerDe { "textfile" -> HiveSerDe( inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), + outputFormat = Option("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat"), + serde = Option("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")), "avro" -> HiveSerDe( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index 1c9f00141ae1d..d7752e987cb4b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -100,6 +100,25 @@ class HiveSerDeSuite extends HiveComparisonTest with PlanTest with BeforeAndAfte assert(output == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) assert(serde == Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) } + + withSQLConf("hive.default.fileformat" -> "orc") { + val (desc, exists) = extractTableDesc( + "CREATE TABLE IF NOT EXISTS fileformat_test (id int) STORED AS textfile") + assert(exists) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(desc.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + } + + withSQLConf("hive.default.fileformat" -> "orc") { + val (desc, exists) = extractTableDesc( + "CREATE TABLE IF NOT EXISTS fileformat_test (id int) STORED AS sequencefile") + assert(exists) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.SequenceFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.mapred.SequenceFileOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + } } test("create hive serde table with new syntax - basic") { From c1bcef876c1415e39e624cfbca9c9bdeae24cbb9 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Tue, 13 Feb 2018 11:40:34 +0800 Subject: [PATCH 334/774] [SPARK-23323][SQL] Support commit coordinator for DataSourceV2 writes ## What changes were proposed in this pull request? DataSourceV2 batch writes should use the output commit coordinator if it is required by the data source. This adds a new method, `DataWriterFactory#useCommitCoordinator`, that determines whether the coordinator will be used. If the write factory returns true, `WriteToDataSourceV2` will use the coordinator for batch writes. ## How was this patch tested? This relies on existing write tests, which now use the commit coordinator. Author: Ryan Blue Closes #20490 from rdblue/SPARK-23323-add-commit-coordinator. --- .../sources/v2/writer/DataSourceWriter.java | 19 +++++++-- .../datasources/v2/WriteToDataSourceV2.scala | 41 +++++++++++++++---- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index e3f682bf96a66..0a0fd8db58035 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -63,6 +63,16 @@ public interface DataSourceWriter { */ DataWriterFactory createWriterFactory(); + /** + * Returns whether Spark should use the commit coordinator to ensure that at most one attempt for + * each task commits. + * + * @return true if commit coordinator should be used, false otherwise. + */ + default boolean useCommitCoordinator() { + return true; + } + /** * Handles a commit message on receiving from a successful data writer. * @@ -79,10 +89,11 @@ default void onDataWriterCommit(WriterCommitMessage message) {} * failed, and {@link #abort(WriterCommitMessage[])} would be called. The state of the destination * is undefined and @{@link #abort(WriterCommitMessage[])} may not be able to deal with it. * - * Note that, one partition may have multiple committed data writers because of speculative tasks. - * Spark will pick the first successful one and get its commit message. Implementations should be - * aware of this and handle it correctly, e.g., have a coordinator to make sure only one data - * writer can commit, or have a way to clean up the data of already-committed writers. + * Note that speculative execution may cause multiple tasks to run for a partition. By default, + * Spark uses the commit coordinator to allow at most one attempt to commit. Implementations can + * disable this behavior by overriding {@link #useCommitCoordinator()}. If disabled, multiple + * attempts may have committed successfully and one successful commit message per task will be + * passed to this commit method. The remaining commit messages are ignored by Spark. */ void commit(WriterCommitMessage[] messages); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index eefbcf4c0e087..535e7962d7439 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row @@ -53,6 +54,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) } + val useCommitCoordinator = writer.useCommitCoordinator val rdd = query.execute() val messages = new Array[WriterCommitMessage](rdd.partitions.length) @@ -73,7 +75,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e DataWritingSparkTask.runContinuous(writeTask, context, iter) case _ => (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.run(writeTask, context, iter) + DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator) } sparkContext.runJob( @@ -116,21 +118,44 @@ object DataWritingSparkTask extends Logging { def run( writeTask: DataWriterFactory[InternalRow], context: TaskContext, - iter: Iterator[InternalRow]): WriterCommitMessage = { - val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber()) + iter: Iterator[InternalRow], + useCommitCoordinator: Boolean): WriterCommitMessage = { + val stageId = context.stageId() + val partId = context.partitionId() + val attemptId = context.attemptNumber() + val dataWriter = writeTask.createDataWriter(partId, attemptId) // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { iter.foreach(dataWriter.write) - logInfo(s"Writer for partition ${context.partitionId()} is committing.") - val msg = dataWriter.commit() - logInfo(s"Writer for partition ${context.partitionId()} committed.") + + val msg = if (useCommitCoordinator) { + val coordinator = SparkEnv.get.outputCommitCoordinator + val commitAuthorized = coordinator.canCommit(context.stageId(), partId, attemptId) + if (commitAuthorized) { + logInfo(s"Writer for stage $stageId, task $partId.$attemptId is authorized to commit.") + dataWriter.commit() + } else { + val message = s"Stage $stageId, task $partId.$attemptId: driver did not authorize commit" + logInfo(message) + // throwing CommitDeniedException will trigger the catch block for abort + throw new CommitDeniedException(message, stageId, partId, attemptId) + } + + } else { + logInfo(s"Writer for partition ${context.partitionId()} is committing.") + dataWriter.commit() + } + + logInfo(s"Writer for stage $stageId, task $partId.$attemptId committed.") + msg + })(catchBlock = { // If there is an error, abort this writer - logError(s"Writer for partition ${context.partitionId()} is aborting.") + logError(s"Writer for stage $stageId, task $partId.$attemptId is aborting.") dataWriter.abort() - logError(s"Writer for partition ${context.partitionId()} aborted.") + logError(s"Writer for stage $stageId, task $partId.$attemptId aborted.") }) } From ed4e78bd606e7defc2cd01a5c2e9b47954baa424 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Mon, 12 Feb 2018 20:57:26 -0800 Subject: [PATCH 335/774] [SPARK-23379][SQL] skip when setting the same current database in HiveClientImpl ## What changes were proposed in this pull request? If the target database name is as same as the current database, we should be able to skip one metastore access. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu Closes #20565 from liufengdb/remove-redundant. --- .../apache/spark/sql/hive/client/HiveClientImpl.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index c223f51b1be75..146fa54a1bce4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -292,10 +292,12 @@ private[hive] class HiveClientImpl( } private def setCurrentDatabaseRaw(db: String): Unit = { - if (databaseExists(db)) { - state.setCurrentDatabase(db) - } else { - throw new NoSuchDatabaseException(db) + if (state.getCurrentDatabase != db) { + if (databaseExists(db)) { + state.setCurrentDatabase(db) + } else { + throw new NoSuchDatabaseException(db) + } } } From f17b936f0ddb7d46d1349bd42f9a64c84c06e48d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 12 Feb 2018 21:12:22 -0800 Subject: [PATCH 336/774] [SPARK-23303][SQL] improve the explain result for data source v2 relations ## What changes were proposed in this pull request? The current explain result for data source v2 relation is unreadable: ``` == Parsed Logical Plan == 'Filter ('i > 6) +- AnalysisBarrier +- Project [j#1] +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Analyzed Logical Plan == j: int Project [j#1] +- Filter (i#0 > 6) +- Project [j#1, i#0] +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Optimized Logical Plan == Project [j#1] +- Filter isnotnull(i#0) +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Physical Plan == *(1) Project [j#1] +- *(1) Filter isnotnull(i#0) +- *(1) DataSourceV2Scan [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 ``` after this PR ``` == Parsed Logical Plan == 'Project [unresolvedalias('j, None)] +- AnalysisBarrier +- Relation AdvancedDataSourceV2[i#0, j#1] == Analyzed Logical Plan == j: int Project [j#1] +- Relation AdvancedDataSourceV2[i#0, j#1] == Optimized Logical Plan == Relation AdvancedDataSourceV2[j#1] == Physical Plan == *(1) Scan AdvancedDataSourceV2[j#1] ``` ------- ``` == Analyzed Logical Plan == i: int, j: int Filter (i#88 > 3) +- Relation JavaAdvancedDataSourceV2[i#88, j#89] == Optimized Logical Plan == Filter isnotnull(i#88) +- Relation JavaAdvancedDataSourceV2[i#88, j#89] (PushedFilter: [GreaterThan(i,3)]) == Physical Plan == *(1) Filter isnotnull(i#88) +- *(1) Scan JavaAdvancedDataSourceV2[i#88, j#89] (PushedFilter: [GreaterThan(i,3)]) ``` an example for streaming query ``` == Parsed Logical Plan == Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject cast(value#25 as string).toString, obj#4: java.lang.String +- Streaming Relation FakeDataSourceV2$[value#25] == Analyzed Logical Plan == value: string, count(1): bigint Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject cast(value#25 as string).toString, obj#4: java.lang.String +- Streaming Relation FakeDataSourceV2$[value#25] == Optimized Logical Plan == Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject value#25.toString, obj#4: java.lang.String +- Streaming Relation FakeDataSourceV2$[value#25] == Physical Plan == *(4) HashAggregate(keys=[value#6], functions=[count(1)], output=[value#6, count(1)#11L]) +- StateStoreSave [value#6], state info [ checkpoint = *********(redacted)/cloud/dev/spark/target/tmp/temporary-549f264b-2531-4fcb-a52f-433c77347c12/state, runId = f84d9da9-2f8c-45c1-9ea1-70791be684de, opId = 0, ver = 0, numPartitions = 5], Complete, 0 +- *(3) HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#16L]) +- StateStoreRestore [value#6], state info [ checkpoint = *********(redacted)/cloud/dev/spark/target/tmp/temporary-549f264b-2531-4fcb-a52f-433c77347c12/state, runId = f84d9da9-2f8c-45c1-9ea1-70791be684de, opId = 0, ver = 0, numPartitions = 5] +- *(2) HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#16L]) +- Exchange hashpartitioning(value#6, 5) +- *(1) HashAggregate(keys=[value#6], functions=[partial_count(1)], output=[value#6, count#16L]) +- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- *(1) MapElements , obj#5: java.lang.String +- *(1) DeserializeToObject value#25.toString, obj#4: java.lang.String +- *(1) Scan FakeDataSourceV2$[value#25] ``` ## How was this patch tested? N/A Author: Wenchen Fan Closes #20477 from cloud-fan/explain. --- .../kafka010/KafkaContinuousSourceSuite.scala | 18 +--- .../sql/kafka010/KafkaContinuousTest.scala | 3 +- .../spark/sql/kafka010/KafkaSourceSuite.scala | 3 +- .../apache/spark/sql/DataFrameReader.scala | 8 +- .../v2/DataSourceReaderHolder.scala | 64 ------------- .../v2/DataSourceV2QueryPlan.scala | 96 +++++++++++++++++++ .../datasources/v2/DataSourceV2Relation.scala | 26 +++-- .../datasources/v2/DataSourceV2ScanExec.scala | 6 +- .../datasources/v2/DataSourceV2Strategy.scala | 4 +- .../v2/PushDownOperatorsToDataSource.scala | 4 +- .../streaming/MicroBatchExecution.scala | 22 +++-- .../continuous/ContinuousExecution.scala | 9 +- .../spark/sql/streaming/StreamSuite.scala | 8 +- .../spark/sql/streaming/StreamTest.scala | 2 +- .../continuous/ContinuousSuite.scala | 11 +-- 15 files changed, 157 insertions(+), 127 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index a7083fa4e3417..72ee0c551ec3d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -17,20 +17,9 @@ package org.apache.spark.sql.kafka010 -import java.util.Properties -import java.util.concurrent.atomic.AtomicInteger - -import org.scalatest.time.SpanSugar._ -import scala.collection.mutable -import scala.util.Random - -import org.apache.spark.SparkContext -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} +import org.apache.spark.sql.Dataset import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.streaming.{StreamTest, Trigger} -import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} +import org.apache.spark.sql.streaming.Trigger // Run tests in KafkaSourceSuiteBase in continuous execution mode. class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest @@ -71,7 +60,8 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => + r.reader.asInstanceOf[KafkaContinuousReader] }.exists { r => // Ensure the new topic is present and the old topic is gone. r.knownPartitions.exists(_.topic == topic2) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index 5a1a14f7a307a..d34458ac81014 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -47,7 +47,8 @@ trait KafkaContinuousTest extends KafkaSourceTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => + r.reader.asInstanceOf[KafkaContinuousReader] }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 02c87643568bd..cb09cce75ff6f 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -117,7 +117,8 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => + r.reader.asInstanceOf[KafkaContinuousReader] } }) if (sources.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index fcaf8d618c168..984b6510f2dbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -189,11 +189,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance() + val ds = cls.newInstance().asInstanceOf[DataSourceV2] val options = new DataSourceOptions((extraOptions ++ - DataSourceV2Utils.extractSessionConfigs( - ds = ds.asInstanceOf[DataSourceV2], - conf = sparkSession.sessionState.conf)).asJava) + DataSourceV2Utils.extractSessionConfigs(ds, sparkSession.sessionState.conf)).asJava) // Streaming also uses the data source V2 API. So it may be that the data source implements // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading @@ -221,7 +219,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { if (reader == null) { loadV1Source(paths: _*) } else { - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + Dataset.ofRows(sparkSession, DataSourceV2Relation(ds, reader)) } } else { loadV1Source(paths: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala deleted file mode 100644 index 81219e9771bd8..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import java.util.Objects - -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.sources.v2.reader._ - -/** - * A base class for data source reader holder with customized equals/hashCode methods. - */ -trait DataSourceReaderHolder { - - /** - * The output of the data source reader, w.r.t. column pruning. - */ - def output: Seq[Attribute] - - /** - * The held data source reader. - */ - def reader: DataSourceReader - - /** - * The metadata of this data source reader that can be used for equality test. - */ - private def metadata: Seq[Any] = { - val filters: Any = reader match { - case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet - case s: SupportsPushDownFilters => s.pushedFilters().toSet - case _ => Nil - } - Seq(output, reader.getClass, filters) - } - - def canEqual(other: Any): Boolean - - override def equals(other: Any): Boolean = other match { - case other: DataSourceReaderHolder => - canEqual(other) && metadata.length == other.metadata.length && - metadata.zip(other.metadata).forall { case (l, r) => l == r } - case _ => false - } - - override def hashCode(): Int = { - metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala new file mode 100644 index 0000000000000..1e0d088f3a57c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.Objects + +import org.apache.commons.lang3.StringUtils + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.DataSourceV2 +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.util.Utils + +/** + * A base class for data source v2 related query plan(both logical and physical). It defines the + * equals/hashCode methods, and provides a string representation of the query plan, according to + * some common information. + */ +trait DataSourceV2QueryPlan { + + /** + * The output of the data source reader, w.r.t. column pruning. + */ + def output: Seq[Attribute] + + /** + * The instance of this data source implementation. Note that we only consider its class in + * equals/hashCode, not the instance itself. + */ + def source: DataSourceV2 + + /** + * The created data source reader. Here we use it to get the filters that has been pushed down + * so far, itself doesn't take part in the equals/hashCode. + */ + def reader: DataSourceReader + + private lazy val filters = reader match { + case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet + case s: SupportsPushDownFilters => s.pushedFilters().toSet + case _ => Set.empty + } + + /** + * The metadata of this data source query plan that can be used for equality check. + */ + private def metadata: Seq[Any] = Seq(output, source.getClass, filters) + + def canEqual(other: Any): Boolean + + override def equals(other: Any): Boolean = other match { + case other: DataSourceV2QueryPlan => canEqual(other) && metadata == other.metadata + case _ => false + } + + override def hashCode(): Int = { + metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) + } + + def metadataString: String = { + val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] + if (filters.nonEmpty) entries += "PushedFilter" -> filters.mkString("[", ", ", "]") + + val outputStr = Utils.truncatedString(output, "[", ", ", "]") + + val entriesStr = if (entries.nonEmpty) { + Utils.truncatedString(entries.map { + case (key, value) => key + ": " + StringUtils.abbreviate(redact(value), 100) + }, " (", ", ", ")") + } else { + "" + } + + s"${source.getClass.getSimpleName}$outputStr$entriesStr" + } + + private def redact(text: String): String = { + Utils.redact(SQLConf.get.stringRedationPattern, text) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 38f6b15224788..cd97e0cab6b5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -20,15 +20,23 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} +import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( output: Seq[AttributeReference], - reader: DataSourceReader) - extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder { + source: DataSourceV2, + reader: DataSourceReader, + override val isStreaming: Boolean) + extends LeafNode with MultiInstanceRelation with DataSourceV2QueryPlan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] + override def simpleString: String = { + val streamingHeader = if (isStreaming) "Streaming " else "" + s"${streamingHeader}Relation $metadataString" + } + override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) @@ -41,18 +49,8 @@ case class DataSourceV2Relation( } } -/** - * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical - * to the non-streaming relation. - */ -class StreamingDataSourceV2Relation( - output: Seq[AttributeReference], - reader: DataSourceReader) extends DataSourceV2Relation(output, reader) { - override def isStreaming: Boolean = true -} - object DataSourceV2Relation { - def apply(reader: DataSourceReader): DataSourceV2Relation = { - new DataSourceV2Relation(reader.readSchema().toAttributes, reader) + def apply(source: DataSourceV2, reader: DataSourceReader): DataSourceV2Relation = { + new DataSourceV2Relation(reader.readSchema().toAttributes, source, reader, isStreaming = false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 7d9581be4db89..c99d535efcf81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types.StructType @@ -36,11 +37,14 @@ import org.apache.spark.sql.types.StructType */ case class DataSourceV2ScanExec( output: Seq[AttributeReference], + @transient source: DataSourceV2, @transient reader: DataSourceReader) - extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { + extends LeafExecNode with DataSourceV2QueryPlan with ColumnarBatchScan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] + override def simpleString: String = s"Scan $metadataString" + override def outputPartitioning: physical.Partitioning = reader match { case s: SupportsReportPartitioning => new DataSourcePartitioning( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index df5b524485f54..fb61e6f32b1f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case DataSourceV2Relation(output, reader) => - DataSourceV2ScanExec(output, reader) :: Nil + case r: DataSourceV2Relation => + DataSourceV2ScanExec(r.output, r.source, r.reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 1ca6cbf061b4e..4cfdd50e8f46b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -39,11 +39,11 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel // TODO: Ideally column pruning should be implemented via a plan property that is propagated // top-down, then we can simplify the logic here and only collect target operators. val filterPushed = plan transformUp { - case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => + case FilterAndProject(fields, condition, r: DataSourceV2Relation) => val (candidates, nonDeterministic) = splitConjunctivePredicates(condition).partition(_.deterministic) - val stayUpFilters: Seq[Expression] = reader match { + val stayUpFilters: Seq[Expression] = r.reader match { case r: SupportsPushDownCatalystFilters => r.pushCatalystFilters(candidates.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 812533313332e..84564b6639ac9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -27,9 +27,9 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} @@ -52,6 +52,8 @@ class MicroBatchExecution( @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty + private val readerToDataSourceMap = MutableMap.empty[MicroBatchReader, DataSourceV2] + private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) case OneTimeTrigger => OneTimeExecutor() @@ -90,6 +92,7 @@ class MicroBatchExecution( metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 + readerToDataSourceMap(reader) = source StreamingExecutionRelation(reader, output)(sparkSession) }) case s @ StreamingRelationV2(_, sourceName, _, output, v1Relation) => @@ -405,12 +408,15 @@ class MicroBatchExecution( case v1: SerializedOffset => reader.deserializeOffset(v1.json) case v2: OffsetV2 => v2 } - reader.setOffsetRange( - toJava(current), - Optional.of(availableV2)) + reader.setOffsetRange(toJava(current), Optional.of(availableV2)) logDebug(s"Retrieving data from $reader: $current -> $availableV2") - Some(reader -> - new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader)) + Some(reader -> new DataSourceV2Relation( + reader.readSchema().toAttributes, + // Provide a fake value here just in case something went wrong, e.g. the reader gives + // a wrong `equals` implementation. + readerToDataSourceMap.getOrElse(reader, FakeDataSourceV2), + reader, + isStreaming = true)) case _ => None } } @@ -500,3 +506,5 @@ class MicroBatchExecution( Optional.ofNullable(scalaOption.orNull) } } + +object FakeDataSourceV2 extends DataSourceV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index c3294d64b10cd..f87d57d0b3209 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} @@ -167,7 +167,7 @@ class ContinuousExecution( var insertedSourceId = 0 val withNewSources = logicalPlan transform { - case ContinuousExecutionRelation(_, _, output) => + case ContinuousExecutionRelation(ds, _, output) => val reader = continuousSources(insertedSourceId) insertedSourceId += 1 val newOutput = reader.readSchema().toAttributes @@ -180,7 +180,7 @@ class ContinuousExecution( val loggedOffset = offsets.offsets(0) val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) - new StreamingDataSourceV2Relation(newOutput, reader) + new DataSourceV2Relation(newOutput, ds, reader, isStreaming = true) } // Rewire the plan to use the new attributes that were returned by the source. @@ -201,7 +201,8 @@ class ContinuousExecution( val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { - case DataSourceV2Relation(_, r: ContinuousReader) => r + case r: DataSourceV2Relation if r.reader.isInstanceOf[ContinuousReader] => + r.reader.asInstanceOf[ContinuousReader] }.head reportTimeTaken("queryPlanning") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index d1a04833390f5..70eb9f0ac66d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -492,16 +492,16 @@ class StreamSuite extends StreamTest { val explainWithoutExtended = q.explainInternal(false) // `extended = false` only displays the physical plan. - assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) - assert("DataSourceV2Scan".r.findAllMatchIn(explainWithoutExtended).size === 1) + assert("Streaming Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) + assert("Scan FakeDataSourceV2".r.findAllMatchIn(explainWithoutExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithoutExtended.contains("StateStoreRestore")) val explainWithExtended = q.explainInternal(true) // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical // plan. - assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithExtended).size === 3) - assert("DataSourceV2Scan".r.findAllMatchIn(explainWithExtended).size === 1) + assert("Streaming Relation".r.findAllMatchIn(explainWithExtended).size === 3) + assert("Scan FakeDataSourceV2".r.findAllMatchIn(explainWithExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithExtended.contains("StateStoreRestore")) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 37fe595529baf..254394685857b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -605,7 +605,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be plan .collect { case StreamingExecutionRelation(s, _) => s - case DataSourceV2Relation(_, r) => r + case d: DataSourceV2Relation => d.reader } .zipWithIndex .find(_._1 == source) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 4b4ed82dc6520..9ee9aaf87f87c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.streaming.continuous -import java.util.UUID - -import org.apache.spark.{SparkContext, SparkEnv, SparkException} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.test.TestSparkSession @@ -43,7 +40,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, r: RateStreamContinuousReader) => r + case DataSourceV2ScanExec(_, _, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 From 407f67249639709c40c46917700ed6dd736daa7d Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 13 Feb 2018 15:05:13 +0900 Subject: [PATCH 337/774] [SPARK-20090][FOLLOW-UP] Revert the deprecation of `names` in PySpark ## What changes were proposed in this pull request? Deprecating the field `name` in PySpark is not expected. This PR is to revert the change. ## How was this patch tested? N/A Author: gatorsmile Closes #20595 from gatorsmile/removeDeprecate. --- python/pyspark/sql/types.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index e25941cd37595..cd857402db8f7 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -455,9 +455,6 @@ class StructType(DataType): Iterating a :class:`StructType` will iterate its :class:`StructField`\\s. A contained :class:`StructField` can be accessed by name or position. - .. note:: `names` attribute is deprecated in 2.3. Use `fieldNames` method instead - to get a list of field names. - >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct1["f1"] StructField(f1,StringType,true) From 9dae715168a8e72e318ab231c34a1069bfa342a6 Mon Sep 17 00:00:00 2001 From: Arseniy Tashoyan Date: Tue, 13 Feb 2018 06:20:34 -0600 Subject: [PATCH 338/774] [SPARK-23318][ML] FP-growth: WARN FPGrowth: Input data is not cached ## What changes were proposed in this pull request? Cache the RDD of items in ml.FPGrowth before passing it to mllib.FPGrowth. Cache only when the user did not cache the input dataset of transactions. This fixes the warning about uncached data emerging from mllib.FPGrowth. ## How was this patch tested? Manually: 1. Run ml.FPGrowthExample - warning is there 2. Apply the fix 3. Run ml.FPGrowthExample again - no warning anymore Author: Arseniy Tashoyan Closes #20578 from tashoyan/SPARK-23318. --- .../scala/org/apache/spark/ml/fpm/FPGrowth.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index aa7871d6ff29d..3d041fc80eb7f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -32,6 +32,7 @@ import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.storage.StorageLevel /** * Common params for FPGrowth and FPGrowthModel @@ -158,18 +159,30 @@ class FPGrowth @Since("2.2.0") ( } private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = { + val handlePersistence = dataset.storageLevel == StorageLevel.NONE + val data = dataset.select($(itemsCol)) - val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray) + val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[Any](0).toArray) val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport)) if (isSet(numPartitions)) { mllibFP.setNumPartitions($(numPartitions)) } + + if (handlePersistence) { + items.persist(StorageLevel.MEMORY_AND_DISK) + } + val parentModel = mllibFP.run(items) val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq)) val schema = StructType(Seq( StructField("items", dataset.schema($(itemsCol)).dataType, nullable = false), StructField("freq", LongType, nullable = false))) val frequentItems = dataset.sparkSession.createDataFrame(rows, schema) + + if (handlePersistence) { + items.unpersist() + } + copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) } From 300c40f50ab4258d697f06a814d1491dc875c847 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Tue, 13 Feb 2018 06:23:10 -0600 Subject: [PATCH 339/774] [SPARK-23384][WEB-UI] When it has no incomplete(completed) applications found, the last updated time is not formatted and client local time zone is not show in history server web ui. ## What changes were proposed in this pull request? When it has no incomplete(completed) applications found, the last updated time is not formatted and client local time zone is not show in history server web ui. It is a bug. fix before: ![1](https://user-images.githubusercontent.com/26266482/36070635-264d7cf0-0f3a-11e8-8426-14135ffedb16.png) fix after: ![2](https://user-images.githubusercontent.com/26266482/36070651-8ec3800e-0f3a-11e8-991c-6122cc9539fe.png) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #20573 from guoxiaolongzte/SPARK-23384. --- .../scala/org/apache/spark/deploy/history/HistoryPage.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 5d62a7d8bebb4..6fc12d721e6f1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -37,7 +37,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val lastUpdatedTime = parent.getLastUpdatedTime() val providerConfig = parent.getProviderConfig() val content = - + ++ +
    @@ -65,7 +66,6 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") if (allAppsSize > 0) { ++
    ++ - ++ ++ } else if (requestedIncomplete) { From 116c581d2658571d38f8b9b27a516ef517170589 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Tue, 13 Feb 2018 06:54:15 -0800 Subject: [PATCH 340/774] [SPARK-20659][CORE] Removing sc.getExecutorStorageStatus and making StorageStatus private MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In this PR StorageStatus is made to private and simplified a bit moreover SparkContext.getExecutorStorageStatus method is removed. The reason of keeping StorageStatus is that it is usage from SparkContext.getRDDStorageInfo. Instead of the method SparkContext.getExecutorStorageStatus executor infos are extended with additional memory metrics such as usedOnHeapStorageMemory, usedOffHeapStorageMemory, totalOnHeapStorageMemory, totalOffHeapStorageMemory. ## How was this patch tested? By running existing unit tests. Author: “attilapiros” Author: Attila Zsolt Piros <2017933+attilapiros@users.noreply.github.com> Closes #20546 from attilapiros/SPARK-20659. --- .../org/apache/spark/SparkExecutorInfo.java | 4 + .../scala/org/apache/spark/SparkContext.scala | 19 +- .../org/apache/spark/SparkStatusTracker.scala | 9 +- .../org/apache/spark/StatusAPIImpl.scala | 6 +- .../apache/spark/storage/StorageUtils.scala | 119 +--------- .../org/apache/spark/DistributedSuite.scala | 7 +- .../StandaloneDynamicAllocationSuite.scala | 2 +- .../apache/spark/storage/StorageSuite.scala | 219 ------------------ project/MimaExcludes.scala | 14 ++ .../spark/repl/SingletonReplSuite.scala | 6 +- 10 files changed, 44 insertions(+), 361 deletions(-) diff --git a/core/src/main/java/org/apache/spark/SparkExecutorInfo.java b/core/src/main/java/org/apache/spark/SparkExecutorInfo.java index dc3e826475987..2b93385adf103 100644 --- a/core/src/main/java/org/apache/spark/SparkExecutorInfo.java +++ b/core/src/main/java/org/apache/spark/SparkExecutorInfo.java @@ -30,4 +30,8 @@ public interface SparkExecutorInfo extends Serializable { int port(); long cacheSize(); int numRunningTasks(); + long usedOnHeapStorageMemory(); + long usedOffHeapStorageMemory(); + long totalOnHeapStorageMemory(); + long totalOffHeapStorageMemory(); } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3828d4f703247..c4f74c4f1f9c2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1715,7 +1715,13 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] def getRDDStorageInfo(filter: RDD[_] => Boolean): Array[RDDInfo] = { assertNotStopped() val rddInfos = persistentRdds.values.filter(filter).map(RDDInfo.fromRdd).toArray - StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) + rddInfos.foreach { rddInfo => + val rddId = rddInfo.id + val rddStorageInfo = statusStore.asOption(statusStore.rdd(rddId)) + rddInfo.numCachedPartitions = rddStorageInfo.map(_.numCachedPartitions).getOrElse(0) + rddInfo.memSize = rddStorageInfo.map(_.memoryUsed).getOrElse(0L) + rddInfo.diskSize = rddStorageInfo.map(_.diskUsed).getOrElse(0L) + } rddInfos.filter(_.isCached) } @@ -1726,17 +1732,6 @@ class SparkContext(config: SparkConf) extends Logging { */ def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap - /** - * :: DeveloperApi :: - * Return information about blocks stored in all of the slaves - */ - @DeveloperApi - @deprecated("This method may change or be removed in a future release.", "2.2.0") - def getExecutorStorageStatus: Array[StorageStatus] = { - assertNotStopped() - env.blockManager.master.getStorageStatus - } - /** * :: DeveloperApi :: * Return pools for fair scheduler diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala index 70865cb58c571..815237eba0174 100644 --- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -97,7 +97,8 @@ class SparkStatusTracker private[spark] (sc: SparkContext, store: AppStatusStore } /** - * Returns information of all known executors, including host, port, cacheSize, numRunningTasks. + * Returns information of all known executors, including host, port, cacheSize, numRunningTasks + * and memory metrics. */ def getExecutorInfos: Array[SparkExecutorInfo] = { store.executorList(true).map { exec => @@ -113,7 +114,11 @@ class SparkStatusTracker private[spark] (sc: SparkContext, store: AppStatusStore host, port, cachedMem, - exec.activeTasks) + exec.activeTasks, + exec.memoryMetrics.map(_.usedOffHeapStorageMemory).getOrElse(0L), + exec.memoryMetrics.map(_.usedOnHeapStorageMemory).getOrElse(0L), + exec.memoryMetrics.map(_.totalOffHeapStorageMemory).getOrElse(0L), + exec.memoryMetrics.map(_.totalOnHeapStorageMemory).getOrElse(0L)) }.toArray } } diff --git a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala index c1f24a6377788..6a888c1e9e772 100644 --- a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala +++ b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala @@ -38,5 +38,9 @@ private class SparkExecutorInfoImpl( val host: String, val port: Int, val cacheSize: Long, - val numRunningTasks: Int) + val numRunningTasks: Int, + val usedOnHeapStorageMemory: Long, + val usedOffHeapStorageMemory: Long, + val totalOnHeapStorageMemory: Long, + val totalOffHeapStorageMemory: Long) extends SparkExecutorInfo diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index e9694fdbca2de..adc406bb1c441 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -24,19 +24,15 @@ import scala.collection.mutable import sun.nio.ch.DirectBuffer -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging /** - * :: DeveloperApi :: * Storage information for each BlockManager. * * This class assumes BlockId and BlockStatus are immutable, such that the consumers of this * class cannot mutate the source of the information. Accesses are not thread-safe. */ -@DeveloperApi -@deprecated("This class may be removed or made private in a future release.", "2.2.0") -class StorageStatus( +private[spark] class StorageStatus( val blockManagerId: BlockManagerId, val maxMemory: Long, val maxOnHeapMem: Option[Long], @@ -44,9 +40,6 @@ class StorageStatus( /** * Internal representation of the blocks stored in this block manager. - * - * We store RDD blocks and non-RDD blocks separately to allow quick retrievals of RDD blocks. - * These collections should only be mutated through the add/update/removeBlock methods. */ private val _rddBlocks = new mutable.HashMap[Int, mutable.Map[BlockId, BlockStatus]] private val _nonRddBlocks = new mutable.HashMap[BlockId, BlockStatus] @@ -87,9 +80,6 @@ class StorageStatus( */ def rddBlocks: Map[BlockId, BlockStatus] = _rddBlocks.flatMap { case (_, blocks) => blocks } - /** Return the blocks that belong to the given RDD stored in this block manager. */ - def rddBlocksById(rddId: Int): Map[BlockId, BlockStatus] = _rddBlocks.getOrElse(rddId, Map.empty) - /** Add the given block to this storage status. If it already exists, overwrite it. */ private[spark] def addBlock(blockId: BlockId, blockStatus: BlockStatus): Unit = { updateStorageInfo(blockId, blockStatus) @@ -101,46 +91,6 @@ class StorageStatus( } } - /** Update the given block in this storage status. If it doesn't already exist, add it. */ - private[spark] def updateBlock(blockId: BlockId, blockStatus: BlockStatus): Unit = { - addBlock(blockId, blockStatus) - } - - /** Remove the given block from this storage status. */ - private[spark] def removeBlock(blockId: BlockId): Option[BlockStatus] = { - updateStorageInfo(blockId, BlockStatus.empty) - blockId match { - case RDDBlockId(rddId, _) => - // Actually remove the block, if it exists - if (_rddBlocks.contains(rddId)) { - val removed = _rddBlocks(rddId).remove(blockId) - // If the given RDD has no more blocks left, remove the RDD - if (_rddBlocks(rddId).isEmpty) { - _rddBlocks.remove(rddId) - } - removed - } else { - None - } - case _ => - _nonRddBlocks.remove(blockId) - } - } - - /** - * Return whether the given block is stored in this block manager in O(1) time. - * - * @note This is much faster than `this.blocks.contains`, which is O(blocks) time. - */ - def containsBlock(blockId: BlockId): Boolean = { - blockId match { - case RDDBlockId(rddId, _) => - _rddBlocks.get(rddId).exists(_.contains(blockId)) - case _ => - _nonRddBlocks.contains(blockId) - } - } - /** * Return the given block stored in this block manager in O(1) time. * @@ -155,37 +105,12 @@ class StorageStatus( } } - /** - * Return the number of blocks stored in this block manager in O(RDDs) time. - * - * @note This is much faster than `this.blocks.size`, which is O(blocks) time. - */ - def numBlocks: Int = _nonRddBlocks.size + numRddBlocks - - /** - * Return the number of RDD blocks stored in this block manager in O(RDDs) time. - * - * @note This is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time. - */ - def numRddBlocks: Int = _rddBlocks.values.map(_.size).sum - - /** - * Return the number of blocks that belong to the given RDD in O(1) time. - * - * @note This is much faster than `this.rddBlocksById(rddId).size`, which is - * O(blocks in this RDD) time. - */ - def numRddBlocksById(rddId: Int): Int = _rddBlocks.get(rddId).map(_.size).getOrElse(0) - /** Return the max memory can be used by this block manager. */ def maxMem: Long = maxMemory /** Return the memory remaining in this block manager. */ def memRemaining: Long = maxMem - memUsed - /** Return the memory used by caching RDDs */ - def cacheSize: Long = onHeapCacheSize.getOrElse(0L) + offHeapCacheSize.getOrElse(0L) - /** Return the memory used by this block manager. */ def memUsed: Long = onHeapMemUsed.getOrElse(0L) + offHeapMemUsed.getOrElse(0L) @@ -220,15 +145,9 @@ class StorageStatus( /** Return the disk space used by this block manager. */ def diskUsed: Long = _nonRddStorageInfo.diskUsage + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum - /** Return the memory used by the given RDD in this block manager in O(1) time. */ - def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_.memoryUsage).getOrElse(0L) - /** Return the disk space used by the given RDD in this block manager in O(1) time. */ def diskUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_.diskUsage).getOrElse(0L) - /** Return the storage level, if any, used by the given RDD in this block manager. */ - def rddStorageLevel(rddId: Int): Option[StorageLevel] = _rddStorageInfo.get(rddId).map(_.level) - /** * Update the relevant storage info, taking into account any existing status for this block. */ @@ -295,40 +214,4 @@ private[spark] object StorageUtils extends Logging { cleaner.clean() } } - - /** - * Update the given list of RDDInfo with the given list of storage statuses. - * This method overwrites the old values stored in the RDDInfo's. - */ - def updateRddInfo(rddInfos: Seq[RDDInfo], statuses: Seq[StorageStatus]): Unit = { - rddInfos.foreach { rddInfo => - val rddId = rddInfo.id - // Assume all blocks belonging to the same RDD have the same storage level - val storageLevel = statuses - .flatMap(_.rddStorageLevel(rddId)).headOption.getOrElse(StorageLevel.NONE) - val numCachedPartitions = statuses.map(_.numRddBlocksById(rddId)).sum - val memSize = statuses.map(_.memUsedByRdd(rddId)).sum - val diskSize = statuses.map(_.diskUsedByRdd(rddId)).sum - - rddInfo.storageLevel = storageLevel - rddInfo.numCachedPartitions = numCachedPartitions - rddInfo.memSize = memSize - rddInfo.diskSize = diskSize - } - } - - /** - * Return a mapping from block ID to its locations for each block that belongs to the given RDD. - */ - def getRddBlockLocations(rddId: Int, statuses: Seq[StorageStatus]): Map[BlockId, Seq[String]] = { - val blockLocations = new mutable.HashMap[BlockId, mutable.ListBuffer[String]] - statuses.foreach { status => - status.rddBlocksById(rddId).foreach { case (bid, _) => - val location = status.blockManagerId.hostPort - blockLocations.getOrElseUpdate(bid, mutable.ListBuffer.empty) += location - } - } - blockLocations - } - } diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index e09d5f59817b9..28ea0c6f0bdba 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -160,11 +160,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val data = sc.parallelize(1 to 1000, 10) val cachedData = data.persist(storageLevel) assert(cachedData.count === 1000) - assert(sc.getExecutorStorageStatus.map(_.rddBlocksById(cachedData.id).size).sum === - storageLevel.replication * data.getNumPartitions) - assert(cachedData.count === 1000) - assert(cachedData.count === 1000) - + assert(sc.getRDDStorageInfo.filter(_.id == cachedData.id).map(_.numCachedPartitions).sum === + data.getNumPartitions) // Get all the locations of the first partition and try to fetch the partitions // from those locations. val blockIds = data.partitions.indices.map(index => RDDBlockId(data.id, index)).toArray diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index bf7480d79f8a1..c21ee7d26f8ca 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -610,7 +610,7 @@ class StandaloneDynamicAllocationSuite * we submit a request to kill them. This must be called before each kill request. */ private def syncExecutors(sc: SparkContext): Unit = { - val driverExecutors = sc.getExecutorStorageStatus + val driverExecutors = sc.env.blockManager.master.getStorageStatus .map(_.blockManagerId.executorId) .filter { _ != SparkContext.DRIVER_IDENTIFIER} val masterExecutors = getExecutorIds(sc) diff --git a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala index da198f946fd64..ca352387055f4 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala @@ -51,27 +51,6 @@ class StorageSuite extends SparkFunSuite { assert(status.diskUsed === 60L) } - test("storage status update non-RDD blocks") { - val status = storageStatus1 - status.updateBlock(TestBlockId("foo"), BlockStatus(memAndDisk, 50L, 100L)) - status.updateBlock(TestBlockId("fee"), BlockStatus(memAndDisk, 100L, 20L)) - assert(status.blocks.size === 3) - assert(status.memUsed === 160L) - assert(status.memRemaining === 840L) - assert(status.diskUsed === 140L) - } - - test("storage status remove non-RDD blocks") { - val status = storageStatus1 - status.removeBlock(TestBlockId("foo")) - status.removeBlock(TestBlockId("faa")) - assert(status.blocks.size === 1) - assert(status.blocks.contains(TestBlockId("fee"))) - assert(status.memUsed === 10L) - assert(status.memRemaining === 990L) - assert(status.diskUsed === 20L) - } - // For testing add, update, remove, get, and contains etc. for both RDD and non-RDD blocks private def storageStatus2: StorageStatus = { val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L, Some(1000L), Some(0L)) @@ -95,85 +74,6 @@ class StorageSuite extends SparkFunSuite { assert(status.rddBlocks.contains(RDDBlockId(2, 2))) assert(status.rddBlocks.contains(RDDBlockId(2, 3))) assert(status.rddBlocks.contains(RDDBlockId(2, 4))) - assert(status.rddBlocksById(0).size === 1) - assert(status.rddBlocksById(0).contains(RDDBlockId(0, 0))) - assert(status.rddBlocksById(1).size === 1) - assert(status.rddBlocksById(1).contains(RDDBlockId(1, 1))) - assert(status.rddBlocksById(2).size === 3) - assert(status.rddBlocksById(2).contains(RDDBlockId(2, 2))) - assert(status.rddBlocksById(2).contains(RDDBlockId(2, 3))) - assert(status.rddBlocksById(2).contains(RDDBlockId(2, 4))) - assert(status.memUsedByRdd(0) === 10L) - assert(status.memUsedByRdd(1) === 100L) - assert(status.memUsedByRdd(2) === 30L) - assert(status.diskUsedByRdd(0) === 20L) - assert(status.diskUsedByRdd(1) === 200L) - assert(status.diskUsedByRdd(2) === 80L) - assert(status.rddStorageLevel(0) === Some(memAndDisk)) - assert(status.rddStorageLevel(1) === Some(memAndDisk)) - assert(status.rddStorageLevel(2) === Some(memAndDisk)) - - // Verify default values for RDDs that don't exist - assert(status.rddBlocksById(10).isEmpty) - assert(status.memUsedByRdd(10) === 0L) - assert(status.diskUsedByRdd(10) === 0L) - assert(status.rddStorageLevel(10) === None) - } - - test("storage status update RDD blocks") { - val status = storageStatus2 - status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 5000L, 0L)) - status.updateBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 0L, 0L)) - status.updateBlock(RDDBlockId(2, 2), BlockStatus(memAndDisk, 0L, 1000L)) - assert(status.blocks.size === 7) - assert(status.rddBlocks.size === 5) - assert(status.rddBlocksById(0).size === 1) - assert(status.rddBlocksById(1).size === 1) - assert(status.rddBlocksById(2).size === 3) - assert(status.memUsedByRdd(0) === 0L) - assert(status.memUsedByRdd(1) === 100L) - assert(status.memUsedByRdd(2) === 20L) - assert(status.diskUsedByRdd(0) === 0L) - assert(status.diskUsedByRdd(1) === 200L) - assert(status.diskUsedByRdd(2) === 1060L) - } - - test("storage status remove RDD blocks") { - val status = storageStatus2 - status.removeBlock(TestBlockId("man")) - status.removeBlock(RDDBlockId(1, 1)) - status.removeBlock(RDDBlockId(2, 2)) - status.removeBlock(RDDBlockId(2, 4)) - assert(status.blocks.size === 3) - assert(status.rddBlocks.size === 2) - assert(status.rddBlocks.contains(RDDBlockId(0, 0))) - assert(status.rddBlocks.contains(RDDBlockId(2, 3))) - assert(status.rddBlocksById(0).size === 1) - assert(status.rddBlocksById(0).contains(RDDBlockId(0, 0))) - assert(status.rddBlocksById(1).size === 0) - assert(status.rddBlocksById(2).size === 1) - assert(status.rddBlocksById(2).contains(RDDBlockId(2, 3))) - assert(status.memUsedByRdd(0) === 10L) - assert(status.memUsedByRdd(1) === 0L) - assert(status.memUsedByRdd(2) === 10L) - assert(status.diskUsedByRdd(0) === 20L) - assert(status.diskUsedByRdd(1) === 0L) - assert(status.diskUsedByRdd(2) === 20L) - } - - test("storage status containsBlock") { - val status = storageStatus2 - // blocks that actually exist - assert(status.blocks.contains(TestBlockId("dan")) === status.containsBlock(TestBlockId("dan"))) - assert(status.blocks.contains(TestBlockId("man")) === status.containsBlock(TestBlockId("man"))) - assert(status.blocks.contains(RDDBlockId(0, 0)) === status.containsBlock(RDDBlockId(0, 0))) - assert(status.blocks.contains(RDDBlockId(1, 1)) === status.containsBlock(RDDBlockId(1, 1))) - assert(status.blocks.contains(RDDBlockId(2, 2)) === status.containsBlock(RDDBlockId(2, 2))) - assert(status.blocks.contains(RDDBlockId(2, 3)) === status.containsBlock(RDDBlockId(2, 3))) - assert(status.blocks.contains(RDDBlockId(2, 4)) === status.containsBlock(RDDBlockId(2, 4))) - // blocks that don't exist - assert(status.blocks.contains(TestBlockId("fan")) === status.containsBlock(TestBlockId("fan"))) - assert(status.blocks.contains(RDDBlockId(100, 0)) === status.containsBlock(RDDBlockId(100, 0))) } test("storage status getBlock") { @@ -191,40 +91,6 @@ class StorageSuite extends SparkFunSuite { assert(status.blocks.get(RDDBlockId(100, 0)) === status.getBlock(RDDBlockId(100, 0))) } - test("storage status num[Rdd]Blocks") { - val status = storageStatus2 - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - status.addBlock(TestBlockId("Foo"), BlockStatus(memAndDisk, 0L, 0L)) - status.addBlock(RDDBlockId(4, 4), BlockStatus(memAndDisk, 0L, 0L)) - status.addBlock(RDDBlockId(4, 8), BlockStatus(memAndDisk, 0L, 0L)) - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) - assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) - status.updateBlock(TestBlockId("Foo"), BlockStatus(memAndDisk, 0L, 10L)) - status.updateBlock(RDDBlockId(4, 0), BlockStatus(memAndDisk, 0L, 0L)) - status.updateBlock(RDDBlockId(4, 8), BlockStatus(memAndDisk, 0L, 0L)) - status.updateBlock(RDDBlockId(10, 10), BlockStatus(memAndDisk, 0L, 0L)) - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) - assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) - assert(status.rddBlocksById(100).size === status.numRddBlocksById(100)) - status.removeBlock(RDDBlockId(4, 0)) - status.removeBlock(RDDBlockId(10, 10)) - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) - assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) - // remove a block that doesn't exist - status.removeBlock(RDDBlockId(1000, 999)) - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) - assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) - assert(status.rddBlocksById(1000).size === status.numRddBlocksById(1000)) - } test("storage status memUsed, diskUsed, externalBlockStoreUsed") { val status = storageStatus2 @@ -237,17 +103,6 @@ class StorageSuite extends SparkFunSuite { status.addBlock(RDDBlockId(25, 25), BlockStatus(memAndDisk, 40L, 50L)) assert(status.memUsed === actualMemUsed) assert(status.diskUsed === actualDiskUsed) - status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 4L, 5L)) - status.updateBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 4L, 5L)) - status.updateBlock(RDDBlockId(1, 1), BlockStatus(memAndDisk, 4L, 5L)) - assert(status.memUsed === actualMemUsed) - assert(status.diskUsed === actualDiskUsed) - status.removeBlock(TestBlockId("fire")) - status.removeBlock(TestBlockId("man")) - status.removeBlock(RDDBlockId(2, 2)) - status.removeBlock(RDDBlockId(2, 3)) - assert(status.memUsed === actualMemUsed) - assert(status.diskUsed === actualDiskUsed) } // For testing StorageUtils.updateRddInfo and StorageUtils.getRddBlockLocations @@ -273,65 +128,6 @@ class StorageSuite extends SparkFunSuite { Seq(info0, info1) } - test("StorageUtils.updateRddInfo") { - val storageStatuses = stockStorageStatuses - val rddInfos = stockRDDInfos - StorageUtils.updateRddInfo(rddInfos, storageStatuses) - assert(rddInfos(0).storageLevel === memAndDisk) - assert(rddInfos(0).numCachedPartitions === 5) - assert(rddInfos(0).memSize === 5L) - assert(rddInfos(0).diskSize === 10L) - assert(rddInfos(0).externalBlockStoreSize === 0L) - assert(rddInfos(1).storageLevel === memAndDisk) - assert(rddInfos(1).numCachedPartitions === 3) - assert(rddInfos(1).memSize === 3L) - assert(rddInfos(1).diskSize === 6L) - assert(rddInfos(1).externalBlockStoreSize === 0L) - } - - test("StorageUtils.getRddBlockLocations") { - val storageStatuses = stockStorageStatuses - val blockLocations0 = StorageUtils.getRddBlockLocations(0, storageStatuses) - val blockLocations1 = StorageUtils.getRddBlockLocations(1, storageStatuses) - assert(blockLocations0.size === 5) - assert(blockLocations1.size === 3) - assert(blockLocations0.contains(RDDBlockId(0, 0))) - assert(blockLocations0.contains(RDDBlockId(0, 1))) - assert(blockLocations0.contains(RDDBlockId(0, 2))) - assert(blockLocations0.contains(RDDBlockId(0, 3))) - assert(blockLocations0.contains(RDDBlockId(0, 4))) - assert(blockLocations1.contains(RDDBlockId(1, 0))) - assert(blockLocations1.contains(RDDBlockId(1, 1))) - assert(blockLocations1.contains(RDDBlockId(1, 2))) - assert(blockLocations0(RDDBlockId(0, 0)) === Seq("dog:1")) - assert(blockLocations0(RDDBlockId(0, 1)) === Seq("dog:1")) - assert(blockLocations0(RDDBlockId(0, 2)) === Seq("duck:2")) - assert(blockLocations0(RDDBlockId(0, 3)) === Seq("duck:2")) - assert(blockLocations0(RDDBlockId(0, 4)) === Seq("cat:3")) - assert(blockLocations1(RDDBlockId(1, 0)) === Seq("duck:2")) - assert(blockLocations1(RDDBlockId(1, 1)) === Seq("duck:2")) - assert(blockLocations1(RDDBlockId(1, 2)) === Seq("cat:3")) - } - - test("StorageUtils.getRddBlockLocations with multiple locations") { - val storageStatuses = stockStorageStatuses - storageStatuses(0).addBlock(RDDBlockId(1, 0), BlockStatus(memAndDisk, 1L, 2L)) - storageStatuses(0).addBlock(RDDBlockId(0, 4), BlockStatus(memAndDisk, 1L, 2L)) - storageStatuses(2).addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 1L, 2L)) - val blockLocations0 = StorageUtils.getRddBlockLocations(0, storageStatuses) - val blockLocations1 = StorageUtils.getRddBlockLocations(1, storageStatuses) - assert(blockLocations0.size === 5) - assert(blockLocations1.size === 3) - assert(blockLocations0(RDDBlockId(0, 0)) === Seq("dog:1", "cat:3")) - assert(blockLocations0(RDDBlockId(0, 1)) === Seq("dog:1")) - assert(blockLocations0(RDDBlockId(0, 2)) === Seq("duck:2")) - assert(blockLocations0(RDDBlockId(0, 3)) === Seq("duck:2")) - assert(blockLocations0(RDDBlockId(0, 4)) === Seq("dog:1", "cat:3")) - assert(blockLocations1(RDDBlockId(1, 0)) === Seq("dog:1", "duck:2")) - assert(blockLocations1(RDDBlockId(1, 1)) === Seq("duck:2")) - assert(blockLocations1(RDDBlockId(1, 2)) === Seq("cat:3")) - } - private val offheap = StorageLevel.OFF_HEAP // For testing add, update, remove, get, and contains etc. for both RDD and non-RDD onheap // and offheap blocks @@ -373,21 +169,6 @@ class StorageSuite extends SparkFunSuite { status.addBlock(RDDBlockId(25, 25), BlockStatus(memAndDisk, 40L, 50L)) assert(status.memUsed === actualMemUsed) assert(status.diskUsed === actualDiskUsed) - - status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 4L, 5L)) - status.updateBlock(RDDBlockId(0, 0), BlockStatus(offheap, 4L, 0L)) - status.updateBlock(RDDBlockId(1, 1), BlockStatus(offheap, 4L, 0L)) - assert(status.memUsed === actualMemUsed) - assert(status.diskUsed === actualDiskUsed) - assert(status.onHeapMemUsed.get === actualOnHeapMemUsed) - assert(status.offHeapMemUsed.get === actualOffHeapMemUsed) - - status.removeBlock(TestBlockId("fire")) - status.removeBlock(TestBlockId("man")) - status.removeBlock(RDDBlockId(2, 2)) - status.removeBlock(RDDBlockId(2, 3)) - assert(status.memUsed === actualMemUsed) - assert(status.diskUsed === actualDiskUsed) } private def storageStatus4: StorageStatus = { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d35c50e1d00fe..381f7b5be1ddf 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,20 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-20659] Remove StorageStatus, or make it private + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOffHeapStorageMemory"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOffHeapStorageMemory"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOnHeapStorageMemory"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOnHeapStorageMemory"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.getExecutorStorageStatus"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numBlocks"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numRddBlocks"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.containsBlock"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddBlocksById"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numRddBlocksById"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.memUsedByRdd"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.cacheSize"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel") ) // Exclude rules for 2.3.x diff --git a/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala index ec3d790255ad3..d49e0fd85229f 100644 --- a/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala @@ -350,7 +350,7 @@ class SingletonReplSuite extends SparkFunSuite { """ |val timeout = 60000 // 60 seconds |val start = System.currentTimeMillis - |while(sc.getExecutorStorageStatus.size != 3 && + |while(sc.statusTracker.getExecutorInfos.size != 3 && | (System.currentTimeMillis - start) < timeout) { | Thread.sleep(10) |} @@ -361,11 +361,11 @@ class SingletonReplSuite extends SparkFunSuite { |case class Foo(i: Int) |val ret = sc.parallelize((1 to 100).map(Foo), 10).persist(MEMORY_AND_DISK_2) |ret.count() - |val res = sc.getExecutorStorageStatus.map(s => s.rddBlocksById(ret.id).size).sum + |val res = sc.getRDDStorageInfo.filter(_.id == ret.id).map(_.numCachedPartitions).sum """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) - assertContains("res: Int = 20", output) + assertContains("res: Int = 10", output) } test("should clone and clean line object in ClosureCleaner") { From d6e1958a2472898e60bd013902c2f35111596e40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Tue, 13 Feb 2018 09:54:52 -0600 Subject: [PATCH 341/774] [SPARK-23189][CORE][WEB UI] Reflect stage level blacklisting on executor tab MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? The purpose of this PR to reflect the stage level blacklisting on the executor tab for the currently active stages. After this change in the executor tab at the Status column one of the following label will be: - "Blacklisted" when the executor is blacklisted application level (old flag) - "Dead" when the executor is not Blacklisted and not Active - "Blacklisted in Stages: [...]" when the executor is Active but the there are active blacklisted stages for the executor. Within the [] coma separated active stageIDs are listed. - "Active" when the executor is Active and there is no active blacklisted stages for the executor ## How was this patch tested? Both with unit tests and manually. #### Manual test Spark was started as: ```bash bin/spark-shell --master "local-cluster[2,1,1024]" --conf "spark.blacklist.enabled=true" --conf "spark.blacklist.stage.maxFailedTasksPerExecutor=1" --conf "spark.blacklist.application.maxFailedTasksPerExecutor=10" ``` And the job was: ```scala import org.apache.spark.SparkEnv val pairs = sc.parallelize(1 to 10000, 10).map { x => if (SparkEnv.get.executorId.toInt == 0) throw new RuntimeException("Bad executor") else { Thread.sleep(10) (x % 10, x) } } val all = pairs.cogroup(pairs) all.collect() ``` UI screenshots about the running: - One executor is blacklisted in the two stages: ![One executor is blacklisted in two stages](https://issues.apache.org/jira/secure/attachment/12908314/multiple_stages_1.png) - One stage completes the other one is still running: ![One stage completes the other is still running](https://issues.apache.org/jira/secure/attachment/12908315/multiple_stages_2.png) - Both stages are completed: ![Both stages are completed](https://issues.apache.org/jira/secure/attachment/12908316/multiple_stages_3.png) ### Unit tests In AppStatusListenerSuite.scala both the node blacklisting for a stage and the executor blacklisting for stage are tested. Author: “attilapiros” Closes #20408 from attilapiros/SPARK-23189. --- .../apache/spark/ui/static/executorspage.js | 21 +++++--- .../spark/status/AppStatusListener.scala | 49 ++++++++++++++----- .../org/apache/spark/status/LiveEntity.scala | 7 ++- .../org/apache/spark/status/api/v1/api.scala | 3 +- .../executor_list_json_expectation.json | 3 +- .../executor_memory_usage_expectation.json | 15 ++++-- ...xecutor_node_blacklisting_expectation.json | 15 ++++-- ...acklisting_unblacklisting_expectation.json | 15 ++++-- .../spark/status/AppStatusListenerSuite.scala | 21 ++++++++ 9 files changed, 113 insertions(+), 36 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js index d430d8c5fb35a..6717af3ac4daf 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js @@ -25,12 +25,18 @@ function getThreadDumpEnabled() { return threadDumpEnabled; } -function formatStatus(status, type) { +function formatStatus(status, type, row) { + if (row.isBlacklisted) { + return "Blacklisted"; + } + if (status) { - return "Active" - } else { - return "Dead" + if (row.blacklistedInStages.length == 0) { + return "Active" + } + return "Active (Blacklisted in Stages: [" + row.blacklistedInStages.join(", ") + "])"; } + return "Dead" } jQuery.extend(jQuery.fn.dataTableExt.oSort, { @@ -415,9 +421,10 @@ $(document).ready(function () { } }, {data: 'hostPort'}, - {data: 'isActive', render: function (data, type, row) { - if (row.isBlacklisted) return "Blacklisted"; - else return formatStatus (data, type); + { + data: 'isActive', + render: function (data, type, row) { + return formatStatus (data, type, row); } }, {data: 'rddBlocks'}, diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index ab01cddfca5b0..79a17e26665fd 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -213,11 +213,13 @@ private[spark] class AppStatusListener( override def onExecutorBlacklistedForStage( event: SparkListenerExecutorBlacklistedForStage): Unit = { + val now = System.nanoTime() + Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => - val now = System.nanoTime() - val esummary = stage.executorSummary(event.executorId) - esummary.isBlacklisted = true - maybeUpdate(esummary, now) + setStageBlackListStatus(stage, now, event.executorId) + } + liveExecutors.get(event.executorId).foreach { exec => + addBlackListedStageTo(exec, event.stageId, now) } } @@ -226,16 +228,29 @@ private[spark] class AppStatusListener( // Implicitly blacklist every available executor for the stage associated with this node Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => - liveExecutors.values.foreach { exec => - if (exec.hostname == event.hostId) { - val esummary = stage.executorSummary(exec.executorId) - esummary.isBlacklisted = true - maybeUpdate(esummary, now) - } - } + val executorIds = liveExecutors.values.filter(_.host == event.hostId).map(_.executorId).toSeq + setStageBlackListStatus(stage, now, executorIds: _*) + } + liveExecutors.values.filter(_.hostname == event.hostId).foreach { exec => + addBlackListedStageTo(exec, event.stageId, now) } } + private def addBlackListedStageTo(exec: LiveExecutor, stageId: Int, now: Long): Unit = { + exec.blacklistedInStages += stageId + liveUpdate(exec, now) + } + + private def setStageBlackListStatus(stage: LiveStage, now: Long, executorIds: String*): Unit = { + executorIds.foreach { executorId => + val executorStageSummary = stage.executorSummary(executorId) + executorStageSummary.isBlacklisted = true + maybeUpdate(executorStageSummary, now) + } + stage.blackListedExecutors ++= executorIds + maybeUpdate(stage, now) + } + override def onExecutorUnblacklisted(event: SparkListenerExecutorUnblacklisted): Unit = { updateBlackListStatus(event.executorId, false) } @@ -594,12 +609,24 @@ private[spark] class AppStatusListener( stage.executorSummaries.values.foreach(update(_, now)) update(stage, now, last = true) + + val executorIdsForStage = stage.blackListedExecutors + executorIdsForStage.foreach { executorId => + liveExecutors.get(executorId).foreach { exec => + removeBlackListedStageFrom(exec, event.stageInfo.stageId, now) + } + } } appSummary = new AppSummary(appSummary.numCompletedJobs, appSummary.numCompletedStages + 1) kvstore.write(appSummary) } + private def removeBlackListedStageFrom(exec: LiveExecutor, stageId: Int, now: Long) = { + exec.blacklistedInStages -= stageId + liveUpdate(exec, now) + } + override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded): Unit = { // This needs to set fields that are already set by onExecutorAdded because the driver is // considered an "executor" in the UI, but does not have a SparkListenerExecutorAdded event. diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index d5f9e19ffdcd0..79e3f13b826ce 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -20,6 +20,7 @@ package org.apache.spark.status import java.util.Date import java.util.concurrent.atomic.AtomicInteger +import scala.collection.immutable.{HashSet, TreeSet} import scala.collection.mutable.HashMap import com.google.common.collect.Interners @@ -254,6 +255,7 @@ private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveE var totalShuffleRead = 0L var totalShuffleWrite = 0L var isBlacklisted = false + var blacklistedInStages: Set[Int] = TreeSet() var executorLogs = Map[String, String]() @@ -299,7 +301,8 @@ private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveE Option(removeTime), Option(removeReason), executorLogs, - memoryMetrics) + memoryMetrics, + blacklistedInStages) new ExecutorSummaryWrapper(info) } @@ -371,6 +374,8 @@ private class LiveStage extends LiveEntity { val executorSummaries = new HashMap[String, LiveExecutorStageSummary]() + var blackListedExecutors = new HashSet[String]() + // Used for cleanup of tasks after they reach the configured limit. Not written to the store. @volatile var cleaning = false var savedTasks = new AtomicInteger(0) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 550eac3952bbb..a333f1aaf6325 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -95,7 +95,8 @@ class ExecutorSummary private[spark]( val removeTime: Option[Date], val removeReason: Option[String], val executorLogs: Map[String, String], - val memoryMetrics: Option[MemoryMetrics]) + val memoryMetrics: Option[MemoryMetrics], + val blacklistedInStages: Set[Int]) class MemoryMetrics private[spark]( val usedOnHeapStorageMemory: Long, diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json index 942e6d8f04363..7bb8fe8fd8f98 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json @@ -19,5 +19,6 @@ "isBlacklisted" : false, "maxMemory" : 278302556, "addTime" : "2015-02-03T16:43:00.906GMT", - "executorLogs" : { } + "executorLogs" : { }, + "blacklistedInStages" : [ ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json index ed33c90dd39ba..dd5b1dcb7372b 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json @@ -25,7 +25,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "3", "hostPort" : "172.22.0.167:51485", @@ -56,7 +57,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] } ,{ "id" : "2", "hostPort" : "172.22.0.167:51487", @@ -87,7 +89,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -118,7 +121,8 @@ "usedOffHeapStorageMemory": 0, "totalOnHeapStorageMemory": 384093388, "totalOffHeapStorageMemory": 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "0", "hostPort" : "172.22.0.167:51491", @@ -149,5 +153,6 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json index 73519f1d9e2e4..3e55d3d9d7eb9 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json @@ -25,7 +25,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "3", "hostPort" : "172.22.0.167:51485", @@ -56,7 +57,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "2", "hostPort" : "172.22.0.167:51487", @@ -87,7 +89,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -118,7 +121,8 @@ "usedOffHeapStorageMemory": 0, "totalOnHeapStorageMemory": 384093388, "totalOffHeapStorageMemory": 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "0", "hostPort" : "172.22.0.167:51491", @@ -149,5 +153,6 @@ "usedOffHeapStorageMemory": 0, "totalOnHeapStorageMemory": 384093388, "totalOffHeapStorageMemory": 524288000 - } + }, + "blacklistedInStages" : [ ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json index 6931fead3d2ff..e87f3e78f2dc8 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json @@ -19,7 +19,8 @@ "isBlacklisted" : false, "maxMemory" : 384093388, "addTime" : "2016-11-15T23:20:38.836GMT", - "executorLogs" : { } + "executorLogs" : { }, + "blacklistedInStages" : [ ] }, { "id" : "3", "hostPort" : "172.22.0.111:64543", @@ -44,7 +45,8 @@ "executorLogs" : { "stdout" : "http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stdout", "stderr" : "http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stderr" - } + }, + "blacklistedInStages" : [ ] }, { "id" : "2", "hostPort" : "172.22.0.111:64539", @@ -69,7 +71,8 @@ "executorLogs" : { "stdout" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stdout", "stderr" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stderr" - } + }, + "blacklistedInStages" : [ ] }, { "id" : "1", "hostPort" : "172.22.0.111:64541", @@ -94,7 +97,8 @@ "executorLogs" : { "stdout" : "http://172.22.0.111:64518/logPage/?appId=app-20161115172038-0000&executorId=1&logType=stdout", "stderr" : "http://172.22.0.111:64518/logPage/?appId=app-20161115172038-0000&executorId=1&logType=stderr" - } + }, + "blacklistedInStages" : [ ] }, { "id" : "0", "hostPort" : "172.22.0.111:64540", @@ -119,5 +123,6 @@ "executorLogs" : { "stdout" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stdout", "stderr" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stderr" - } + }, + "blacklistedInStages" : [ ] } ] diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index b74d6ee2ec836..749502709b5c8 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -273,6 +273,10 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(exec.info.isBlacklistedForStage === expectedBlacklistedFlag) } + check[ExecutorSummaryWrapper](execIds.head) { exec => + assert(exec.info.blacklistedInStages === Set(stages.head.stageId)) + } + // Blacklisting node for stage time += 1 listener.onNodeBlacklistedForStage(SparkListenerNodeBlacklistedForStage( @@ -439,6 +443,10 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(stage.info.numCompleteTasks === pending.size) } + check[ExecutorSummaryWrapper](execIds.head) { exec => + assert(exec.info.blacklistedInStages === Set()) + } + // Submit stage 2. time += 1 stages.last.submissionTime = Some(time) @@ -453,6 +461,19 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(stage.info.submissionTime === Some(new Date(stages.last.submissionTime.get))) } + // Blacklisting node for stage + time += 1 + listener.onNodeBlacklistedForStage(SparkListenerNodeBlacklistedForStage( + time = time, + hostId = "1.example.com", + executorFailures = 1, + stageId = stages.last.stageId, + stageAttemptId = stages.last.attemptId)) + + check[ExecutorSummaryWrapper](execIds.head) { exec => + assert(exec.info.blacklistedInStages === Set(stages.last.stageId)) + } + // Start and fail all tasks of stage 2. time += 1 val s2Tasks = createTasks(4, execIds) From 091a000d27f324de8c5c527880854ecfcf5de9a4 Mon Sep 17 00:00:00 2001 From: huangtengfei Date: Tue, 13 Feb 2018 09:59:21 -0600 Subject: [PATCH 342/774] [SPARK-23053][CORE] taskBinarySerialization and task partitions calculate in DagScheduler.submitMissingTasks should keep the same RDD checkpoint status ## What changes were proposed in this pull request? When we run concurrent jobs using the same rdd which is marked to do checkpoint. If one job has finished running the job, and start the process of RDD.doCheckpoint, while another job is submitted, then submitStage and submitMissingTasks will be called. In [submitMissingTasks](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L961), will serialize taskBinaryBytes and calculate task partitions which are both affected by the status of checkpoint, if the former is calculated before doCheckpoint finished, while the latter is calculated after doCheckpoint finished, when run task, rdd.compute will be called, for some rdds with particular partition type such as [UnionRDD](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala) who will do partition type cast, will get a ClassCastException because the part params is actually a CheckpointRDDPartition. This error occurs because rdd.doCheckpoint occurs in the same thread that called sc.runJob, while the task serialization occurs in the DAGSchedulers event loop. ## How was this patch tested? the exist uts and also add a test case in DAGScheduerSuite to show the exception case. Author: huangtengfei Closes #20244 from ivoson/branch-taskpart-mistype. --- .../apache/spark/scheduler/DAGScheduler.scala | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 199937b8c27af..8c46a84323392 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -39,7 +39,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{RDD, RDDCheckpointData} import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -1016,15 +1016,24 @@ class DAGScheduler( // might modify state of objects referenced in their closures. This is necessary in Hadoop // where the JobConf/Configuration object is not thread-safe. var taskBinary: Broadcast[Array[Byte]] = null + var partitions: Array[Partition] = null try { // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). // For ResultTask, serialize and broadcast (rdd, func). - val taskBinaryBytes: Array[Byte] = stage match { - case stage: ShuffleMapStage => - JavaUtils.bufferToArray( - closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef)) - case stage: ResultStage => - JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef)) + var taskBinaryBytes: Array[Byte] = null + // taskBinaryBytes and partitions are both effected by the checkpoint status. We need + // this synchronization in case another concurrent job is checkpointing this RDD, so we get a + // consistent view of both variables. + RDDCheckpointData.synchronized { + taskBinaryBytes = stage match { + case stage: ShuffleMapStage => + JavaUtils.bufferToArray( + closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef)) + case stage: ResultStage => + JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef)) + } + + partitions = stage.rdd.partitions } taskBinary = sc.broadcast(taskBinaryBytes) @@ -1049,7 +1058,7 @@ class DAGScheduler( stage.pendingPartitions.clear() partitionsToCompute.map { id => val locs = taskIdToLocations(id) - val part = stage.rdd.partitions(id) + val part = partitions(id) stage.pendingPartitions += id new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId), @@ -1059,7 +1068,7 @@ class DAGScheduler( case stage: ResultStage => partitionsToCompute.map { id => val p: Int = stage.partitions(id) - val part = stage.rdd.partitions(p) + val part = partitions(p) val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, id, properties, serializedTaskMetrics, From bd24731722a9142c90cf3d76008115f308203844 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Tue, 13 Feb 2018 11:39:33 -0600 Subject: [PATCH 343/774] [SPARK-23382][WEB-UI] Spark Streaming ui about the contents of the for need to have hidden and show features, when the table records very much. ## What changes were proposed in this pull request? Spark Streaming ui about the contents of the for need to have hidden and show features, when the table records very much. please refer to https://github.com/apache/spark/pull/20216 fix after: ![1](https://user-images.githubusercontent.com/26266482/36068644-df029328-0f14-11e8-8350-cfdde9733ffc.png) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #20570 from guoxiaolongzte/SPARK-23382. --- .../org/apache/spark/ui/static/webui.js | 2 + .../spark/streaming/ui/StreamingPage.scala | 37 ++++++++++++++++--- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index e575c4c78970d..83009df91d30a 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -80,4 +80,6 @@ $(function() { collapseTablePageLoad('collapse-aggregated-poolActiveStages','aggregated-poolActiveStages'); collapseTablePageLoad('collapse-aggregated-tasks','aggregated-tasks'); collapseTablePageLoad('collapse-aggregated-rdds','aggregated-rdds'); + collapseTablePageLoad('collapse-aggregated-activeBatches','aggregated-activeBatches'); + collapseTablePageLoad('collapse-aggregated-completedBatches','aggregated-completedBatches'); }); \ No newline at end of file diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 7abafd6ba7908..3a176f64cdd60 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -490,15 +490,40 @@ private[ui] class StreamingPage(parent: StreamingTab) sortBy(_.batchTime.milliseconds).reverse val activeBatchesContent = { -

    Active Batches ({runningBatches.size + waitingBatches.size})

    ++ - new ActiveBatchTable(runningBatches, waitingBatches, listener.batchDuration).toNodeSeq +
    +
    + +

    + + Active Batches ({runningBatches.size + waitingBatches.size}) +

    +
    +
    + {new ActiveBatchTable(runningBatches, waitingBatches, listener.batchDuration).toNodeSeq} +
    +
    +
    } val completedBatchesContent = { -

    - Completed Batches (last {completedBatches.size} out of {listener.numTotalCompletedBatches}) -

    ++ - new CompletedBatchTable(completedBatches, listener.batchDuration).toNodeSeq +
    +
    + +

    + + Completed Batches (last {completedBatches.size} + out of {listener.numTotalCompletedBatches}) +

    +
    +
    + {new CompletedBatchTable(completedBatches, listener.batchDuration).toNodeSeq} +
    +
    +
    } activeBatchesContent ++ completedBatchesContent From 263531466f4a7e223c94caa8705e6e8394a12054 Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Tue, 13 Feb 2018 11:45:20 -0600 Subject: [PATCH 344/774] [SPARK-23392][TEST] Add some test cases for images feature ## What changes were proposed in this pull request? Add some test cases for images feature ## How was this patch tested? Add some test cases in ImageSchemaSuite Author: xubo245 <601450868@qq.com> Closes #20583 from xubo245/CARBONDATA23392_AddTestForImage. --- .../spark/ml/image/ImageSchemaSuite.scala | 62 ++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala index a8833c615865d..527b3f8955968 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala @@ -65,11 +65,71 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(count50 > 0 && count50 < countTotal) } + test("readImages test: recursive = false") { + val df = readImages(imagePath, null, false, 3, true, 1.0, 0) + assert(df.count() === 0) + } + + test("readImages test: read jpg image") { + val df = readImages(imagePath + "/kittens/DP153539.jpg", null, false, 3, true, 1.0, 0) + assert(df.count() === 1) + } + + test("readImages test: read png image") { + val df = readImages(imagePath + "/multi-channel/BGRA.png", null, false, 3, true, 1.0, 0) + assert(df.count() === 1) + } + + test("readImages test: read non image") { + val df = readImages(imagePath + "/kittens/not-image.txt", null, false, 3, true, 1.0, 0) + assert(df.schema("image").dataType == columnSchema, "data do not fit ImageSchema") + assert(df.count() === 0) + } + + test("readImages test: read non image and dropImageFailures is false") { + val df = readImages(imagePath + "/kittens/not-image.txt", null, false, 3, false, 1.0, 0) + assert(df.count() === 1) + } + + test("readImages test: sampleRatio > 1") { + val e = intercept[IllegalArgumentException] { + readImages(imagePath, null, true, 3, true, 1.1, 0) + } + assert(e.getMessage.contains("sampleRatio")) + } + + test("readImages test: sampleRatio < 0") { + val e = intercept[IllegalArgumentException] { + readImages(imagePath, null, true, 3, true, -0.1, 0) + } + assert(e.getMessage.contains("sampleRatio")) + } + + test("readImages test: sampleRatio = 0") { + val df = readImages(imagePath, null, true, 3, true, 0.0, 0) + assert(df.count() === 0) + } + + test("readImages test: with sparkSession") { + val df = readImages(imagePath, sparkSession = spark, true, 3, true, 1.0, 0) + assert(df.count() === 8) + } + test("readImages partition test") { val df = readImages(imagePath, null, true, 3, true, 1.0, 0) assert(df.rdd.getNumPartitions === 3) } + test("readImages partition test: < 0") { + val df = readImages(imagePath, null, true, -3, true, 1.0, 0) + assert(df.rdd.getNumPartitions === spark.sparkContext.defaultParallelism) + } + + test("readImages partition test: = 0") { + val df = readImages(imagePath, null, true, 0, true, 1.0, 0) + assert(df.rdd.getNumPartitions === spark.sparkContext.defaultParallelism) + } + // Images with the different number of channels test("readImages pixel values test") { @@ -93,7 +153,7 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { // - default representation for 3-channel RGB images is BGR row-wise: // (B00, G00, R00, B10, G10, R10, ...) // - default representation for 4-channel RGB images is BGRA row-wise: - // (B00, G00, R00, A00, B10, G10, R10, A00, ...) + // (B00, G00, R00, A00, B10, G10, R10, A10, ...) private val firstBytes20 = Map( "grayscale.jpg" -> (("CV_8UC1", Array[Byte](-2, -33, -61, -60, -59, -59, -64, -59, -66, -67, -73, -73, -62, From 05d051293fe46938e9cb012342fea6e8a3715cd4 Mon Sep 17 00:00:00 2001 From: Bogdan Raducanu Date: Tue, 13 Feb 2018 09:49:52 -0800 Subject: [PATCH 345/774] [SPARK-23316][SQL] AnalysisException after max iteration reached for IN query ## What changes were proposed in this pull request? Added flag ignoreNullability to DataType.equalsStructurally. The previous semantic is for ignoreNullability=false. When ignoreNullability=true equalsStructurally ignores nullability of contained types (map key types, value types, array element types, structure field types). In.checkInputTypes calls equalsStructurally to check if the children types match. They should match regardless of nullability (which is just a hint), so it is now called with ignoreNullability=true. ## How was this patch tested? New test in SubquerySuite Author: Bogdan Raducanu Closes #20548 from bogdanrdc/SPARK-23316. --- .../sql/catalyst/expressions/predicates.scala | 3 ++- .../org/apache/spark/sql/types/DataType.scala | 18 ++++++++++++------ .../org/apache/spark/sql/SubquerySuite.scala | 5 +++++ 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index b469f5cb7586a..a6d41ea7d00d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -157,7 +157,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { require(list != null, "list should not be null") override def checkInputDataTypes(): TypeCheckResult = { - val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType)) + val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType, + ignoreNullability = true)) if (mismatchOpt.isDefined) { list match { case ListQuery(_, _, _, childOutputs) :: Nil => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index d6e0df12218ad..0bef11659fc9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -295,25 +295,31 @@ object DataType { } /** - * Returns true if the two data types share the same "shape", i.e. the types (including - * nullability) are the same, but the field names don't need to be the same. + * Returns true if the two data types share the same "shape", i.e. the types + * are the same, but the field names don't need to be the same. + * + * @param ignoreNullability whether to ignore nullability when comparing the types */ - def equalsStructurally(from: DataType, to: DataType): Boolean = { + def equalsStructurally( + from: DataType, + to: DataType, + ignoreNullability: Boolean = false): Boolean = { (from, to) match { case (left: ArrayType, right: ArrayType) => equalsStructurally(left.elementType, right.elementType) && - left.containsNull == right.containsNull + (ignoreNullability || left.containsNull == right.containsNull) case (left: MapType, right: MapType) => equalsStructurally(left.keyType, right.keyType) && equalsStructurally(left.valueType, right.valueType) && - left.valueContainsNull == right.valueContainsNull + (ignoreNullability || left.valueContainsNull == right.valueContainsNull) case (StructType(fromFields), StructType(toFields)) => fromFields.length == toFields.length && fromFields.zip(toFields) .forall { case (l, r) => - equalsStructurally(l.dataType, r.dataType) && l.nullable == r.nullable + equalsStructurally(l.dataType, r.dataType) && + (ignoreNullability || l.nullable == r.nullable) } case (fromDataType, toDataType) => fromDataType == toDataType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 8673dc14f7597..31e8b0e8dede0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -950,4 +950,9 @@ class SubquerySuite extends QueryTest with SharedSQLContext { assert(join.duplicateResolved) assert(optimizedPlan.resolved) } + + test("SPARK-23316: AnalysisException after max iteration reached for IN query") { + // before the fix this would throw AnalysisException + spark.range(10).where("(id,id) in (select id, null from range(3))").count + } } From 4e0fb010ccdf13fe411f2a4796bbadc385b01520 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 13 Feb 2018 11:51:19 -0600 Subject: [PATCH 346/774] [SPARK-23217][ML] Add cosine distance measure to ClusteringEvaluator ## What changes were proposed in this pull request? The PR provided an implementation of ClusteringEvaluator using the cosine distance measure. This allows to evaluate clustering results created using the cosine distance, introduced in SPARK-22119. In the corresponding JIRA, there is a design document for the algorithm implemented here. ## How was this patch tested? Added UT which compares the result to the one provided by python sklearn. Author: Marco Gaido Closes #20396 from mgaido91/SPARK-23217. --- .../ml/evaluation/ClusteringEvaluator.scala | 334 ++++++++++++++---- .../evaluation/ClusteringEvaluatorSuite.scala | 32 +- 2 files changed, 300 insertions(+), 66 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala index d6ec5223237bb..8d4ae562b3d2b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala @@ -20,11 +20,12 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors, VectorUDT} +import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, + SchemaUtils} +import org.apache.spark.sql.{Column, DataFrame, Dataset} import org.apache.spark.sql.functions.{avg, col, udf} import org.apache.spark.sql.types.DoubleType @@ -32,15 +33,11 @@ import org.apache.spark.sql.types.DoubleType * :: Experimental :: * * Evaluator for clustering results. - * The metric computes the Silhouette measure - * using the squared Euclidean distance. - * - * The Silhouette is a measure for the validation - * of the consistency within clusters. It ranges - * between 1 and -1, where a value close to 1 - * means that the points in a cluster are close - * to the other points in the same cluster and - * far from the points of the other clusters. + * The metric computes the Silhouette measure using the specified distance measure. + * + * The Silhouette is a measure for the validation of the consistency within clusters. It ranges + * between 1 and -1, where a value close to 1 means that the points in a cluster are close to the + * other points in the same cluster and far from the points of the other clusters. */ @Experimental @Since("2.3.0") @@ -84,18 +81,40 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str @Since("2.3.0") def setMetricName(value: String): this.type = set(metricName, value) - setDefault(metricName -> "silhouette") + /** + * param for distance measure to be used in evaluation + * (supports `"squaredEuclidean"` (default), `"cosine"`) + * @group param + */ + @Since("2.4.0") + val distanceMeasure: Param[String] = { + val availableValues = Array("squaredEuclidean", "cosine") + val allowedParams = ParamValidators.inArray(availableValues) + new Param(this, "distanceMeasure", "distance measure in evaluation. Supported options: " + + availableValues.mkString("'", "', '", "'"), allowedParams) + } + + /** @group getParam */ + @Since("2.4.0") + def getDistanceMeasure: String = $(distanceMeasure) + + /** @group setParam */ + @Since("2.4.0") + def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) + + setDefault(metricName -> "silhouette", distanceMeasure -> "squaredEuclidean") @Since("2.3.0") override def evaluate(dataset: Dataset[_]): Double = { SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) SchemaUtils.checkNumericType(dataset.schema, $(predictionCol)) - $(metricName) match { - case "silhouette" => + ($(metricName), $(distanceMeasure)) match { + case ("silhouette", "squaredEuclidean") => SquaredEuclideanSilhouette.computeSilhouetteScore( - dataset, $(predictionCol), $(featuresCol) - ) + dataset, $(predictionCol), $(featuresCol)) + case ("silhouette", "cosine") => + CosineSilhouette.computeSilhouetteScore(dataset, $(predictionCol), $(featuresCol)) } } } @@ -111,6 +130,48 @@ object ClusteringEvaluator } +private[evaluation] abstract class Silhouette { + + /** + * It computes the Silhouette coefficient for a point. + */ + def pointSilhouetteCoefficient( + clusterIds: Set[Double], + pointClusterId: Double, + pointClusterNumOfPoints: Long, + averageDistanceToCluster: (Double) => Double): Double = { + // Here we compute the average dissimilarity of the current point to any cluster of which the + // point is not a member. + // The cluster with the lowest average dissimilarity - i.e. the nearest cluster to the current + // point - is said to be the "neighboring cluster". + val otherClusterIds = clusterIds.filter(_ != pointClusterId) + val neighboringClusterDissimilarity = otherClusterIds.map(averageDistanceToCluster).min + + // adjustment for excluding the node itself from the computation of the average dissimilarity + val currentClusterDissimilarity = if (pointClusterNumOfPoints == 1) { + 0.0 + } else { + averageDistanceToCluster(pointClusterId) * pointClusterNumOfPoints / + (pointClusterNumOfPoints - 1) + } + + if (currentClusterDissimilarity < neighboringClusterDissimilarity) { + 1 - (currentClusterDissimilarity / neighboringClusterDissimilarity) + } else if (currentClusterDissimilarity > neighboringClusterDissimilarity) { + (neighboringClusterDissimilarity / currentClusterDissimilarity) - 1 + } else { + 0.0 + } + } + + /** + * Compute the mean Silhouette values of all samples. + */ + def overallScore(df: DataFrame, scoreColumn: Column): Double = { + df.select(avg(scoreColumn)).collect()(0).getDouble(0) + } +} + /** * SquaredEuclideanSilhouette computes the average of the * Silhouette over all the data of the dataset, which is @@ -259,7 +320,7 @@ object ClusteringEvaluator * `N` is the number of points in the dataset and `W` is the number * of worker nodes. */ -private[evaluation] object SquaredEuclideanSilhouette { +private[evaluation] object SquaredEuclideanSilhouette extends Silhouette { private[this] var kryoRegistrationPerformed: Boolean = false @@ -336,18 +397,19 @@ private[evaluation] object SquaredEuclideanSilhouette { * It computes the Silhouette coefficient for a point. * * @param broadcastedClustersMap A map of the precomputed values for each cluster. - * @param features The [[org.apache.spark.ml.linalg.Vector]] representing the current point. + * @param point The [[org.apache.spark.ml.linalg.Vector]] representing the current point. * @param clusterId The id of the cluster the current point belongs to. * @param squaredNorm The `$\Xi_{X}$` (which is the squared norm) precomputed for the point. * @return The Silhouette for the point. */ def computeSilhouetteCoefficient( broadcastedClustersMap: Broadcast[Map[Double, ClusterStats]], - features: Vector, + point: Vector, clusterId: Double, squaredNorm: Double): Double = { - def compute(squaredNorm: Double, point: Vector, clusterStats: ClusterStats): Double = { + def compute(targetClusterId: Double): Double = { + val clusterStats = broadcastedClustersMap.value(targetClusterId) val pointDotClusterFeaturesSum = BLAS.dot(point, clusterStats.featureSum) squaredNorm + @@ -355,41 +417,14 @@ private[evaluation] object SquaredEuclideanSilhouette { 2 * pointDotClusterFeaturesSum / clusterStats.numOfPoints } - // Here we compute the average dissimilarity of the - // current point to any cluster of which the point - // is not a member. - // The cluster with the lowest average dissimilarity - // - i.e. the nearest cluster to the current point - - // is said to be the "neighboring cluster". - var neighboringClusterDissimilarity = Double.MaxValue - broadcastedClustersMap.value.keySet.foreach { - c => - if (c != clusterId) { - val dissimilarity = compute(squaredNorm, features, broadcastedClustersMap.value(c)) - if(dissimilarity < neighboringClusterDissimilarity) { - neighboringClusterDissimilarity = dissimilarity - } - } - } - val currentCluster = broadcastedClustersMap.value(clusterId) - // adjustment for excluding the node itself from - // the computation of the average dissimilarity - val currentClusterDissimilarity = if (currentCluster.numOfPoints == 1) { - 0 - } else { - compute(squaredNorm, features, currentCluster) * currentCluster.numOfPoints / - (currentCluster.numOfPoints - 1) - } - - (currentClusterDissimilarity compare neighboringClusterDissimilarity).signum match { - case -1 => 1 - (currentClusterDissimilarity / neighboringClusterDissimilarity) - case 1 => (neighboringClusterDissimilarity / currentClusterDissimilarity) - 1 - case 0 => 0.0 - } + pointSilhouetteCoefficient(broadcastedClustersMap.value.keySet, + clusterId, + broadcastedClustersMap.value(clusterId).numOfPoints, + compute) } /** - * Compute the mean Silhouette values of all samples. + * Compute the Silhouette score of the dataset using squared Euclidean distance measure. * * @param dataset The input dataset (previously clustered) on which compute the Silhouette. * @param predictionCol The name of the column which contains the predicted cluster id @@ -412,7 +447,7 @@ private[evaluation] object SquaredEuclideanSilhouette { val clustersStatsMap = SquaredEuclideanSilhouette .computeClusterStats(dfWithSquaredNorm, predictionCol, featuresCol) - // Silhouette is reasonable only when the number of clusters is grater then 1 + // Silhouette is reasonable only when the number of clusters is greater then 1 assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.") val bClustersStatsMap = dataset.sparkSession.sparkContext.broadcast(clustersStatsMap) @@ -421,13 +456,190 @@ private[evaluation] object SquaredEuclideanSilhouette { computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double, _: Double) } - val silhouetteScore = dfWithSquaredNorm - .select(avg( - computeSilhouetteCoefficientUDF( - col(featuresCol), col(predictionCol).cast(DoubleType), col("squaredNorm")) - )) - .collect()(0) - .getDouble(0) + val silhouetteScore = overallScore(dfWithSquaredNorm, + computeSilhouetteCoefficientUDF(col(featuresCol), col(predictionCol).cast(DoubleType), + col("squaredNorm"))) + + bClustersStatsMap.destroy() + + silhouetteScore + } +} + + +/** + * The algorithm which is implemented in this object, instead, is an efficient and parallel + * implementation of the Silhouette using the cosine distance measure. The cosine distance + * measure is defined as `1 - s` where `s` is the cosine similarity between two points. + * + * The total distance of the point `X` to the points `$C_{i}$` belonging to the cluster `$\Gamma$` + * is: + * + *
    + * $$ + * \sum\limits_{i=1}^N d(X, C_{i} ) = + * \sum\limits_{i=1}^N \Big( 1 - \frac{\sum\limits_{j=1}^D x_{j}c_{ij} }{ \|X\|\|C_{i}\|} \Big) + * = \sum\limits_{i=1}^N 1 - \sum\limits_{i=1}^N \sum\limits_{j=1}^D \frac{x_{j}}{\|X\|} + * \frac{c_{ij}}{\|C_{i}\|} + * = N - \sum\limits_{j=1}^D \frac{x_{j}}{\|X\|} \Big( \sum\limits_{i=1}^N + * \frac{c_{ij}}{\|C_{i}\|} \Big) + * $$ + *
    + * + * where `$x_{j}$` is the `j`-th dimension of the point `X` and `$c_{ij}$` is the `j`-th dimension + * of the `i`-th point in cluster `$\Gamma$`. + * + * Then, we can define the vector: + * + *
    + * $$ + * \xi_{X} : \xi_{X i} = \frac{x_{i}}{\|X\|}, i = 1, ..., D + * $$ + *
    + * + * which can be precomputed for each point and the vector + * + *
    + * $$ + * \Omega_{\Gamma} : \Omega_{\Gamma i} = \sum\limits_{j=1}^N \xi_{C_{j}i}, i = 1, ..., D + * $$ + *
    + * + * which can be precomputed too for each cluster `$\Gamma$` by its points `$C_{i}$`. + * + * With these definitions, the numerator becomes: + * + *
    + * $$ + * N - \sum\limits_{j=1}^D \xi_{X j} \Omega_{\Gamma j} + * $$ + *
    + * + * Thus the average distance of a point `X` to the points of the cluster `$\Gamma$` is: + * + *
    + * $$ + * 1 - \frac{\sum\limits_{j=1}^D \xi_{X j} \Omega_{\Gamma j}}{N} + * $$ + *
    + * + * In the implementation, the precomputed values for the clusters are distributed among the worker + * nodes via broadcasted variables, because we can assume that the clusters are limited in number. + * + * The main strengths of this algorithm are the low computational complexity and the intrinsic + * parallelism. The precomputed information for each point and for each cluster can be computed + * with a computational complexity which is `O(N/W)`, where `N` is the number of points in the + * dataset and `W` is the number of worker nodes. After that, every point can be analyzed + * independently from the others. + * + * For every point we need to compute the average distance to all the clusters. Since the formula + * above requires `O(D)` operations, this phase has a computational complexity which is + * `O(C*D*N/W)` where `C` is the number of clusters (which we assume quite low), `D` is the number + * of dimensions, `N` is the number of points in the dataset and `W` is the number of worker + * nodes. + */ +private[evaluation] object CosineSilhouette extends Silhouette { + + private[this] val normalizedFeaturesColName = "normalizedFeatures" + + /** + * The method takes the input dataset and computes the aggregated values + * about a cluster which are needed by the algorithm. + * + * @param df The DataFrame which contains the input data + * @param predictionCol The name of the column which contains the predicted cluster id + * for the point. + * @return A [[scala.collection.immutable.Map]] which associates each cluster id to a + * its statistics (ie. the precomputed values `N` and `$\Omega_{\Gamma}$`). + */ + def computeClusterStats(df: DataFrame, predictionCol: String): Map[Double, (Vector, Long)] = { + val numFeatures = df.select(col(normalizedFeaturesColName)).first().getAs[Vector](0).size + val clustersStatsRDD = df.select( + col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName)) + .rdd + .map { row => (row.getDouble(0), row.getAs[Vector](1)) } + .aggregateByKey[(DenseVector, Long)]((Vectors.zeros(numFeatures).toDense, 0L))( + seqOp = { + case ((normalizedFeaturesSum: DenseVector, numOfPoints: Long), (normalizedFeatures)) => + BLAS.axpy(1.0, normalizedFeatures, normalizedFeaturesSum) + (normalizedFeaturesSum, numOfPoints + 1) + }, + combOp = { + case ((normalizedFeaturesSum1, numOfPoints1), (normalizedFeaturesSum2, numOfPoints2)) => + BLAS.axpy(1.0, normalizedFeaturesSum2, normalizedFeaturesSum1) + (normalizedFeaturesSum1, numOfPoints1 + numOfPoints2) + } + ) + + clustersStatsRDD + .collectAsMap() + .toMap + } + + /** + * It computes the Silhouette coefficient for a point. + * + * @param broadcastedClustersMap A map of the precomputed values for each cluster. + * @param normalizedFeatures The [[org.apache.spark.ml.linalg.Vector]] representing the + * normalized features of the current point. + * @param clusterId The id of the cluster the current point belongs to. + */ + def computeSilhouetteCoefficient( + broadcastedClustersMap: Broadcast[Map[Double, (Vector, Long)]], + normalizedFeatures: Vector, + clusterId: Double): Double = { + + def compute(targetClusterId: Double): Double = { + val (normalizedFeatureSum, numOfPoints) = broadcastedClustersMap.value(targetClusterId) + 1 - BLAS.dot(normalizedFeatures, normalizedFeatureSum) / numOfPoints + } + + pointSilhouetteCoefficient(broadcastedClustersMap.value.keySet, + clusterId, + broadcastedClustersMap.value(clusterId)._2, + compute) + } + + /** + * Compute the Silhouette score of the dataset using the cosine distance measure. + * + * @param dataset The input dataset (previously clustered) on which compute the Silhouette. + * @param predictionCol The name of the column which contains the predicted cluster id + * for the point. + * @param featuresCol The name of the column which contains the feature vector of the point. + * @return The average of the Silhouette values of the clustered data. + */ + def computeSilhouetteScore( + dataset: Dataset[_], + predictionCol: String, + featuresCol: String): Double = { + val normalizeFeatureUDF = udf { + features: Vector => { + val norm = Vectors.norm(features, 2.0) + features match { + case d: DenseVector => Vectors.dense(d.values.map(_ / norm)) + case s: SparseVector => Vectors.sparse(s.size, s.indices, s.values.map(_ / norm)) + } + } + } + val dfWithNormalizedFeatures = dataset.withColumn(normalizedFeaturesColName, + normalizeFeatureUDF(col(featuresCol))) + + // compute aggregate values for clusters needed by the algorithm + val clustersStatsMap = computeClusterStats(dfWithNormalizedFeatures, predictionCol) + + // Silhouette is reasonable only when the number of clusters is greater then 1 + assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.") + + val bClustersStatsMap = dataset.sparkSession.sparkContext.broadcast(clustersStatsMap) + + val computeSilhouetteCoefficientUDF = udf { + computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double) + } + + val silhouetteScore = overallScore(dfWithNormalizedFeatures, + computeSilhouetteCoefficientUDF(col(normalizedFeaturesColName), + col(predictionCol).cast(DoubleType))) bClustersStatsMap.destroy() diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala index 677ce49a903ab..3bf34770f5687 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala @@ -66,16 +66,38 @@ class ClusteringEvaluatorSuite assert(evaluator.evaluate(irisDataset) ~== 0.6564679231 relTol 1e-5) } - test("number of clusters must be greater than one") { - val singleClusterDataset = irisDataset.where($"label" === 0.0) + /* + Use the following python code to load the data and evaluate it using scikit-learn package. + + from sklearn import datasets + from sklearn.metrics import silhouette_score + iris = datasets.load_iris() + round(silhouette_score(iris.data, iris.target, metric='cosine'), 10) + + 0.7222369298 + */ + test("cosine Silhouette") { val evaluator = new ClusteringEvaluator() .setFeaturesCol("features") .setPredictionCol("label") + .setDistanceMeasure("cosine") + + assert(evaluator.evaluate(irisDataset) ~== 0.7222369298 relTol 1e-5) + } + + test("number of clusters must be greater than one") { + val singleClusterDataset = irisDataset.where($"label" === 0.0) + Seq("squaredEuclidean", "cosine").foreach { distanceMeasure => + val evaluator = new ClusteringEvaluator() + .setFeaturesCol("features") + .setPredictionCol("label") + .setDistanceMeasure(distanceMeasure) - val e = intercept[AssertionError]{ - evaluator.evaluate(singleClusterDataset) + val e = intercept[AssertionError] { + evaluator.evaluate(singleClusterDataset) + } + assert(e.getMessage.contains("Number of clusters must be greater than one")) } - assert(e.getMessage.contains("Number of clusters must be greater than one")) } } From d58fe28836639e68e262812d911f167cb071007b Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 13 Feb 2018 11:18:45 -0800 Subject: [PATCH 347/774] [SPARK-23154][ML][DOC] Document backwards compatibility guarantees for ML persistence ## What changes were proposed in this pull request? Added documentation about what MLlib guarantees in terms of loading ML models and Pipelines from old Spark versions. Discussed & confirmed on linked JIRA. Author: Joseph K. Bradley Closes #20592 from jkbradley/SPARK-23154-backwards-compat-doc. --- docs/ml-pipeline.md | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/docs/ml-pipeline.md b/docs/ml-pipeline.md index aa92c0a37c0f4..e22e9003c30f6 100644 --- a/docs/ml-pipeline.md +++ b/docs/ml-pipeline.md @@ -188,9 +188,36 @@ Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. -## Saving and Loading Pipelines +## ML persistence: Saving and Loading Pipelines -Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API. Most basic transformers are supported as well as some of the more basic ML models. Please refer to the algorithm's API documentation to see if saving and loading is supported. +Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API. +As of Spark 2.3, the DataFrame-based API in `spark.ml` and `pyspark.ml` has complete coverage. + +ML persistence works across Scala, Java and Python. However, R currently uses a modified format, +so models saved in R can only be loaded back in R; this should be fixed in the future and is +tracked in [SPARK-15572](https://issues.apache.org/jira/browse/SPARK-15572). + +### Backwards compatibility for ML persistence + +In general, MLlib maintains backwards compatibility for ML persistence. I.e., if you save an ML +model or Pipeline in one version of Spark, then you should be able to load it back and use it in a +future version of Spark. However, there are rare exceptions, described below. + +Model persistence: Is a model or Pipeline saved using Apache Spark ML persistence in Spark +version X loadable by Spark version Y? + +* Major versions: No guarantees, but best-effort. +* Minor and patch versions: Yes; these are backwards compatible. +* Note about the format: There are no guarantees for a stable persistence format, but model loading itself is designed to be backwards compatible. + +Model behavior: Does a model or Pipeline in Spark version X behave identically in Spark version Y? + +* Major versions: No guarantees, but best-effort. +* Minor and patch versions: Identical behavior, except for bug fixes. + +For both model persistence and model behavior, any breaking changes across a minor version or patch +version are reported in the Spark version release notes. If a breakage is not reported in release +notes, then it should be treated as a bug to be fixed. # Code examples From 2ee76c22b6e48e643694c9475e5f0d37124215e7 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 13 Feb 2018 11:56:49 -0800 Subject: [PATCH 348/774] [SPARK-23400][SQL] Add a constructors for ScalaUDF ## What changes were proposed in this pull request? In this upcoming 2.3 release, we changed the interface of `ScalaUDF`. Unfortunately, some Spark packages (e.g., spark-deep-learning) are using our internal class `ScalaUDF`. In the release 2.3, we added new parameters into this class. The users hit the binary compatibility issues and got the exception: ``` > java.lang.NoSuchMethodError: org.apache.spark.sql.catalyst.expressions.ScalaUDF.<init>(Ljava/lang/Object;Lorg/apache/spark/sql/types/DataType;Lscala/collection/Seq;Lscala/collection/Seq;Lscala/Option;)V ``` This PR is to improve the backward compatibility. However, we definitely should not encourage the external packages to use our internal classes. This might make us hard to maintain/develop the codes in Spark. ## How was this patch tested? N/A Author: gatorsmile Closes #20591 from gatorsmile/scalaUDF. --- .../spark/sql/catalyst/expressions/ScalaUDF.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 388ef42883ad3..989c02305620a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -49,6 +49,17 @@ case class ScalaUDF( udfDeterministic: Boolean = true) extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression { + // The constructor for SPARK 2.1 and 2.2 + def this( + function: AnyRef, + dataType: DataType, + children: Seq[Expression], + inputTypes: Seq[DataType], + udfName: Option[String]) = { + this( + function, dataType, children, inputTypes, udfName, nullable = true, udfDeterministic = true) + } + override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) override def toString: String = From a5a4b83501526e02d0e3cd0056e4a5c0e1c8284f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Tue, 13 Feb 2018 16:46:43 -0600 Subject: [PATCH 349/774] [SPARK-23235][CORE] Add executor Threaddump to api MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Extending api with the executor thread dump data. For this new REST URL is introduced: - GET http://localhost:4040/api/v1/applications/{applicationId}/executors/{executorId}/threads
    Example response: ``` javascript [ { "threadId" : 52, "threadName" : "context-cleaner-periodic-gc", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:1093)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:809)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1385411893})", "holdingLocks" : [ ] }, { "threadId" : 48, "threadName" : "dag-scheduler-event-loop", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingDeque.takeFirst(LinkedBlockingDeque.java:492)\njava.util.concurrent.LinkedBlockingDeque.take(LinkedBlockingDeque.java:680)\norg.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:46)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1138053349})", "holdingLocks" : [ ] }, { "threadId" : 17, "threadName" : "dispatcher-event-loop-0", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker832743930})" ] }, { "threadId" : 18, "threadName" : "dispatcher-event-loop-1", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker834153999})" ] }, { "threadId" : 19, "threadName" : "dispatcher-event-loop-2", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker664836465})" ] }, { "threadId" : 20, "threadName" : "dispatcher-event-loop-3", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1645557354})" ] }, { "threadId" : 21, "threadName" : "dispatcher-event-loop-4", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1188871851})" ] }, { "threadId" : 22, "threadName" : "dispatcher-event-loop-5", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker920926249})" ] }, { "threadId" : 23, "threadName" : "dispatcher-event-loop-6", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker355222677})" ] }, { "threadId" : 24, "threadName" : "dispatcher-event-loop-7", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1589745212})" ] }, { "threadId" : 49, "threadName" : "driver-heartbeater", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:1093)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:809)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1602885835})", "holdingLocks" : [ ] }, { "threadId" : 53, "threadName" : "element-tracking-store-worker", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1439439099})", "holdingLocks" : [ ] }, { "threadId" : 3, "threadName" : "Finalizer", "threadState" : "WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.lang.ref.ReferenceQueue.remove(ReferenceQueue.java:143)\njava.lang.ref.ReferenceQueue.remove(ReferenceQueue.java:164)\njava.lang.ref.Finalizer$FinalizerThread.run(Finalizer.java:209)", "blockedByLock" : "Lock(java.lang.ref.ReferenceQueue$Lock1213098236})", "holdingLocks" : [ ] }, { "threadId" : 15, "threadName" : "ForkJoinPool-1-worker-13", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\nscala.concurrent.forkjoin.ForkJoinPool.scan(ForkJoinPool.java:2075)\nscala.concurrent.forkjoin.ForkJoinPool.runWorker(ForkJoinPool.java:1979)\nscala.concurrent.forkjoin.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:107)", "blockedByLock" : "Lock(scala.concurrent.forkjoin.ForkJoinPool380286413})", "holdingLocks" : [ ] }, { "threadId" : 45, "threadName" : "heartbeat-receiver-event-loop-thread", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:1093)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:809)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject715135812})", "holdingLocks" : [ ] }, { "threadId" : 1, "threadName" : "main", "threadState" : "RUNNABLE", "stackTrace" : "java.io.FileInputStream.read0(Native Method)\njava.io.FileInputStream.read(FileInputStream.java:207)\nscala.tools.jline_embedded.internal.NonBlockingInputStream.read(NonBlockingInputStream.java:169) => holding Monitor(scala.tools.jline_embedded.internal.NonBlockingInputStream46248392})\nscala.tools.jline_embedded.internal.NonBlockingInputStream.read(NonBlockingInputStream.java:137)\nscala.tools.jline_embedded.internal.NonBlockingInputStream.read(NonBlockingInputStream.java:246)\nscala.tools.jline_embedded.internal.InputStreamReader.read(InputStreamReader.java:261) => holding Monitor(scala.tools.jline_embedded.internal.NonBlockingInputStream46248392})\nscala.tools.jline_embedded.internal.InputStreamReader.read(InputStreamReader.java:198) => holding Monitor(scala.tools.jline_embedded.internal.NonBlockingInputStream46248392})\nscala.tools.jline_embedded.console.ConsoleReader.readCharacter(ConsoleReader.java:2145)\nscala.tools.jline_embedded.console.ConsoleReader.readLine(ConsoleReader.java:2349)\nscala.tools.jline_embedded.console.ConsoleReader.readLine(ConsoleReader.java:2269)\nscala.tools.nsc.interpreter.jline_embedded.InteractiveReader.readOneLine(JLineReader.scala:57)\nscala.tools.nsc.interpreter.InteractiveReader$$anonfun$readLine$2.apply(InteractiveReader.scala:37)\nscala.tools.nsc.interpreter.InteractiveReader$$anonfun$readLine$2.apply(InteractiveReader.scala:37)\nscala.tools.nsc.interpreter.InteractiveReader$.restartSysCalls(InteractiveReader.scala:44)\nscala.tools.nsc.interpreter.InteractiveReader$class.readLine(InteractiveReader.scala:37)\nscala.tools.nsc.interpreter.jline_embedded.InteractiveReader.readLine(JLineReader.scala:28)\nscala.tools.nsc.interpreter.ILoop.readOneLine(ILoop.scala:404)\nscala.tools.nsc.interpreter.ILoop.loop(ILoop.scala:413)\nscala.tools.nsc.interpreter.ILoop$$anonfun$process$1.apply$mcZ$sp(ILoop.scala:923)\nscala.tools.nsc.interpreter.ILoop$$anonfun$process$1.apply(ILoop.scala:909)\nscala.tools.nsc.interpreter.ILoop$$anonfun$process$1.apply(ILoop.scala:909)\nscala.reflect.internal.util.ScalaClassLoader$.savingContextLoader(ScalaClassLoader.scala:97)\nscala.tools.nsc.interpreter.ILoop.process(ILoop.scala:909)\norg.apache.spark.repl.Main$.doMain(Main.scala:76)\norg.apache.spark.repl.Main$.main(Main.scala:56)\norg.apache.spark.repl.Main.main(Main.scala)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\norg.apache.spark.deploy.JavaMainApplication.start(SparkApplication.scala:52)\norg.apache.spark.deploy.SparkSubmit$.org$apache$spark$deploy$SparkSubmit$$runMain(SparkSubmit.scala:879)\norg.apache.spark.deploy.SparkSubmit$.doRunMain$1(SparkSubmit.scala:197)\norg.apache.spark.deploy.SparkSubmit$.submit(SparkSubmit.scala:227)\norg.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:136)\norg.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(scala.tools.jline_embedded.internal.NonBlockingInputStream46248392})" ] }, { "threadId" : 26, "threadName" : "map-output-dispatcher-0", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1791280119})" ] }, { "threadId" : 27, "threadName" : "map-output-dispatcher-1", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1947378744})" ] }, { "threadId" : 28, "threadName" : "map-output-dispatcher-2", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker507507251})" ] }, { "threadId" : 29, "threadName" : "map-output-dispatcher-3", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1016408627})" ] }, { "threadId" : 30, "threadName" : "map-output-dispatcher-4", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1879219501})" ] }, { "threadId" : 31, "threadName" : "map-output-dispatcher-5", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker290509937})" ] }, { "threadId" : 32, "threadName" : "map-output-dispatcher-6", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1889468930})" ] }, { "threadId" : 33, "threadName" : "map-output-dispatcher-7", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1699637904})" ] }, { "threadId" : 47, "threadName" : "netty-rpc-env-timeout", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:1093)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:809)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject977194847})", "holdingLocks" : [ ] }, { "threadId" : 14, "threadName" : "NonBlockingInputStreamThread", "threadState" : "WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\nscala.tools.jline_embedded.internal.NonBlockingInputStream.run(NonBlockingInputStream.java:278)\njava.lang.Thread.run(Thread.java:748)", "blockedByThreadId" : 1, "blockedByLock" : "Lock(scala.tools.jline_embedded.internal.NonBlockingInputStream46248392})", "holdingLocks" : [ ] }, { "threadId" : 2, "threadName" : "Reference Handler", "threadState" : "WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.lang.Object.wait(Object.java:502)\njava.lang.ref.Reference.tryHandlePending(Reference.java:191)\njava.lang.ref.Reference$ReferenceHandler.run(Reference.java:153)", "blockedByLock" : "Lock(java.lang.ref.Reference$Lock1359433302})", "holdingLocks" : [ ] }, { "threadId" : 35, "threadName" : "refresh progress", "threadState" : "TIMED_WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.util.TimerThread.mainLoop(Timer.java:552)\njava.util.TimerThread.run(Timer.java:505)", "blockedByLock" : "Lock(java.util.TaskQueue44276328})", "holdingLocks" : [ ] }, { "threadId" : 34, "threadName" : "RemoteBlock-temp-file-clean-thread", "threadState" : "TIMED_WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.lang.ref.ReferenceQueue.remove(ReferenceQueue.java:143)\norg.apache.spark.storage.BlockManager$RemoteBlockTempFileManager.org$apache$spark$storage$BlockManager$RemoteBlockTempFileManager$$keepCleaning(BlockManager.scala:1630)\norg.apache.spark.storage.BlockManager$RemoteBlockTempFileManager$$anon$1.run(BlockManager.scala:1608)", "blockedByLock" : "Lock(java.lang.ref.ReferenceQueue$Lock391748181})", "holdingLocks" : [ ] }, { "threadId" : 25, "threadName" : "rpc-server-3-1", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl2057702496})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nio.netty.channel.nio.SelectedSelectionKeySetSelector.select(SelectedSelectionKeySetSelector.java:62)\nio.netty.channel.nio.NioEventLoop.select(NioEventLoop.java:753)\nio.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:409)\nio.netty.util.concurrent.SingleThreadEventExecutor$5.run(SingleThreadEventExecutor.java:858)\nio.netty.util.concurrent.DefaultThreadFactory$DefaultRunnableDecorator.run(DefaultThreadFactory.java:138)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(io.netty.channel.nio.SelectedSelectionKeySet1066929256})", "Monitor(java.util.Collections$UnmodifiableSet561426729})", "Monitor(sun.nio.ch.KQueueSelectorImpl2057702496})" ] }, { "threadId" : 50, "threadName" : "shuffle-server-5-1", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl1401522546})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nio.netty.channel.nio.SelectedSelectionKeySetSelector.select(SelectedSelectionKeySetSelector.java:62)\nio.netty.channel.nio.NioEventLoop.select(NioEventLoop.java:753)\nio.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:409)\nio.netty.util.concurrent.SingleThreadEventExecutor$5.run(SingleThreadEventExecutor.java:858)\nio.netty.util.concurrent.DefaultThreadFactory$DefaultRunnableDecorator.run(DefaultThreadFactory.java:138)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(io.netty.channel.nio.SelectedSelectionKeySet385972319})", "Monitor(java.util.Collections$UnmodifiableSet477937109})", "Monitor(sun.nio.ch.KQueueSelectorImpl1401522546})" ] }, { "threadId" : 4, "threadName" : "Signal Dispatcher", "threadState" : "RUNNABLE", "stackTrace" : "", "blockedByLock" : "", "holdingLocks" : [ ] }, { "threadId" : 51, "threadName" : "Spark Context Cleaner", "threadState" : "TIMED_WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.lang.ref.ReferenceQueue.remove(ReferenceQueue.java:143)\norg.apache.spark.ContextCleaner$$anonfun$org$apache$spark$ContextCleaner$$keepCleaning$1.apply$mcV$sp(ContextCleaner.scala:181)\norg.apache.spark.util.Utils$.tryOrStopSparkContext(Utils.scala:1319)\norg.apache.spark.ContextCleaner.org$apache$spark$ContextCleaner$$keepCleaning(ContextCleaner.scala:178)\norg.apache.spark.ContextCleaner$$anon$1.run(ContextCleaner.scala:73)", "blockedByLock" : "Lock(java.lang.ref.ReferenceQueue$Lock1739420764})", "holdingLocks" : [ ] }, { "threadId" : 16, "threadName" : "spark-listener-group-appStatus", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.scheduler.AsyncEventQueue$$anonfun$org$apache$spark$scheduler$AsyncEventQueue$$dispatch$1.apply(AsyncEventQueue.scala:94)\nscala.util.DynamicVariable.withValue(DynamicVariable.scala:58)\norg.apache.spark.scheduler.AsyncEventQueue.org$apache$spark$scheduler$AsyncEventQueue$$dispatch(AsyncEventQueue.scala:83)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1$$anonfun$run$1.apply$mcV$sp(AsyncEventQueue.scala:79)\norg.apache.spark.util.Utils$.tryOrStopSparkContext(Utils.scala:1319)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1.run(AsyncEventQueue.scala:78)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1287190987})", "holdingLocks" : [ ] }, { "threadId" : 44, "threadName" : "spark-listener-group-executorManagement", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.scheduler.AsyncEventQueue$$anonfun$org$apache$spark$scheduler$AsyncEventQueue$$dispatch$1.apply(AsyncEventQueue.scala:94)\nscala.util.DynamicVariable.withValue(DynamicVariable.scala:58)\norg.apache.spark.scheduler.AsyncEventQueue.org$apache$spark$scheduler$AsyncEventQueue$$dispatch(AsyncEventQueue.scala:83)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1$$anonfun$run$1.apply$mcV$sp(AsyncEventQueue.scala:79)\norg.apache.spark.util.Utils$.tryOrStopSparkContext(Utils.scala:1319)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1.run(AsyncEventQueue.scala:78)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject943262890})", "holdingLocks" : [ ] }, { "threadId" : 54, "threadName" : "spark-listener-group-shared", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.scheduler.AsyncEventQueue$$anonfun$org$apache$spark$scheduler$AsyncEventQueue$$dispatch$1.apply(AsyncEventQueue.scala:94)\nscala.util.DynamicVariable.withValue(DynamicVariable.scala:58)\norg.apache.spark.scheduler.AsyncEventQueue.org$apache$spark$scheduler$AsyncEventQueue$$dispatch(AsyncEventQueue.scala:83)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1$$anonfun$run$1.apply$mcV$sp(AsyncEventQueue.scala:79)\norg.apache.spark.util.Utils$.tryOrStopSparkContext(Utils.scala:1319)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1.run(AsyncEventQueue.scala:78)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject334604425})", "holdingLocks" : [ ] }, { "threadId" : 37, "threadName" : "SparkUI-37", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\norg.spark_project.jetty.util.BlockingArrayQueue.poll(BlockingArrayQueue.java:392)\norg.spark_project.jetty.util.thread.QueuedThreadPool.idleJobPoll(QueuedThreadPool.java:563)\norg.spark_project.jetty.util.thread.QueuedThreadPool.access$800(QueuedThreadPool.java:48)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:626)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1503479572})", "holdingLocks" : [ ] }, { "threadId" : 38, "threadName" : "SparkUI-38", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl841741934})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:101)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.select(ManagedSelector.java:243)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.produce(ManagedSelector.java:191)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:249)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(sun.nio.ch.Util$3873523986})", "Monitor(java.util.Collections$UnmodifiableSet1769333189})", "Monitor(sun.nio.ch.KQueueSelectorImpl841741934})" ] }, { "threadId" : 40, "threadName" : "SparkUI-40-acceptor-034929380-Spark3a557b62{HTTP/1.1,[http/1.1]}{0.0.0.0:4040}", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.ServerSocketChannelImpl.accept0(Native Method)\nsun.nio.ch.ServerSocketChannelImpl.accept(ServerSocketChannelImpl.java:422)\nsun.nio.ch.ServerSocketChannelImpl.accept(ServerSocketChannelImpl.java:250) => holding Monitor(java.lang.Object1134240909})\norg.spark_project.jetty.server.ServerConnector.accept(ServerConnector.java:371)\norg.spark_project.jetty.server.AbstractConnector$Acceptor.run(AbstractConnector.java:601)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(java.lang.Object1134240909})" ] }, { "threadId" : 43, "threadName" : "SparkUI-43", "threadState" : "RUNNABLE", "stackTrace" : "sun.management.ThreadImpl.dumpThreads0(Native Method)\nsun.management.ThreadImpl.dumpAllThreads(ThreadImpl.java:454)\norg.apache.spark.util.Utils$.getThreadDump(Utils.scala:2170)\norg.apache.spark.SparkContext.getExecutorThreadDump(SparkContext.scala:596)\norg.apache.spark.status.api.v1.AbstractApplicationResource$$anonfun$threadDump$1$$anonfun$apply$1.apply(OneApplicationResource.scala:66)\norg.apache.spark.status.api.v1.AbstractApplicationResource$$anonfun$threadDump$1$$anonfun$apply$1.apply(OneApplicationResource.scala:65)\nscala.Option.flatMap(Option.scala:171)\norg.apache.spark.status.api.v1.AbstractApplicationResource$$anonfun$threadDump$1.apply(OneApplicationResource.scala:65)\norg.apache.spark.status.api.v1.AbstractApplicationResource$$anonfun$threadDump$1.apply(OneApplicationResource.scala:58)\norg.apache.spark.status.api.v1.BaseAppResource$$anonfun$withUI$1.apply(ApiRootResource.scala:139)\norg.apache.spark.status.api.v1.BaseAppResource$$anonfun$withUI$1.apply(ApiRootResource.scala:134)\norg.apache.spark.ui.SparkUI.withSparkUI(SparkUI.scala:106)\norg.apache.spark.status.api.v1.BaseAppResource$class.withUI(ApiRootResource.scala:134)\norg.apache.spark.status.api.v1.AbstractApplicationResource.withUI(OneApplicationResource.scala:32)\norg.apache.spark.status.api.v1.AbstractApplicationResource.threadDump(OneApplicationResource.scala:58)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\norg.glassfish.jersey.server.model.internal.ResourceMethodInvocationHandlerFactory$1.invoke(ResourceMethodInvocationHandlerFactory.java:81)\norg.glassfish.jersey.server.model.internal.AbstractJavaResourceMethodDispatcher$1.run(AbstractJavaResourceMethodDispatcher.java:144)\norg.glassfish.jersey.server.model.internal.AbstractJavaResourceMethodDispatcher.invoke(AbstractJavaResourceMethodDispatcher.java:161)\norg.glassfish.jersey.server.model.internal.JavaResourceMethodDispatcherProvider$TypeOutInvoker.doDispatch(JavaResourceMethodDispatcherProvider.java:205)\norg.glassfish.jersey.server.model.internal.AbstractJavaResourceMethodDispatcher.dispatch(AbstractJavaResourceMethodDispatcher.java:99)\norg.glassfish.jersey.server.model.ResourceMethodInvoker.invoke(ResourceMethodInvoker.java:389)\norg.glassfish.jersey.server.model.ResourceMethodInvoker.apply(ResourceMethodInvoker.java:347)\norg.glassfish.jersey.server.model.ResourceMethodInvoker.apply(ResourceMethodInvoker.java:102)\norg.glassfish.jersey.server.ServerRuntime$2.run(ServerRuntime.java:326)\norg.glassfish.jersey.internal.Errors$1.call(Errors.java:271)\norg.glassfish.jersey.internal.Errors$1.call(Errors.java:267)\norg.glassfish.jersey.internal.Errors.process(Errors.java:315)\norg.glassfish.jersey.internal.Errors.process(Errors.java:297)\norg.glassfish.jersey.internal.Errors.process(Errors.java:267)\norg.glassfish.jersey.process.internal.RequestScope.runInScope(RequestScope.java:317)\norg.glassfish.jersey.server.ServerRuntime.process(ServerRuntime.java:305)\norg.glassfish.jersey.server.ApplicationHandler.handle(ApplicationHandler.java:1154)\norg.glassfish.jersey.servlet.WebComponent.serviceImpl(WebComponent.java:473)\norg.glassfish.jersey.servlet.WebComponent.service(WebComponent.java:427)\norg.glassfish.jersey.servlet.ServletContainer.service(ServletContainer.java:388)\norg.glassfish.jersey.servlet.ServletContainer.service(ServletContainer.java:341)\norg.glassfish.jersey.servlet.ServletContainer.service(ServletContainer.java:228)\norg.spark_project.jetty.servlet.ServletHolder.handle(ServletHolder.java:848)\norg.spark_project.jetty.servlet.ServletHandler.doHandle(ServletHandler.java:584)\norg.spark_project.jetty.server.handler.ContextHandler.doHandle(ContextHandler.java:1180)\norg.spark_project.jetty.servlet.ServletHandler.doScope(ServletHandler.java:512)\norg.spark_project.jetty.server.handler.ContextHandler.doScope(ContextHandler.java:1112)\norg.spark_project.jetty.server.handler.ScopedHandler.handle(ScopedHandler.java:141)\norg.spark_project.jetty.server.handler.gzip.GzipHandler.handle(GzipHandler.java:493)\norg.spark_project.jetty.server.handler.ContextHandlerCollection.handle(ContextHandlerCollection.java:213)\norg.spark_project.jetty.server.handler.HandlerWrapper.handle(HandlerWrapper.java:134)\norg.spark_project.jetty.server.Server.handle(Server.java:534)\norg.spark_project.jetty.server.HttpChannel.handle(HttpChannel.java:320)\norg.spark_project.jetty.server.HttpConnection.onFillable(HttpConnection.java:251)\norg.spark_project.jetty.io.AbstractConnection$ReadCallback.succeeded(AbstractConnection.java:283)\norg.spark_project.jetty.io.FillInterest.fillable(FillInterest.java:108)\norg.spark_project.jetty.io.SelectChannelEndPoint$2.run(SelectChannelEndPoint.java:93)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:303)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ ] }, { "threadId" : 67, "threadName" : "SparkUI-67", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl1837806480})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:101)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.select(ManagedSelector.java:243)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.produce(ManagedSelector.java:191)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:249)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(sun.nio.ch.Util$3881415814})", "Monitor(java.util.Collections$UnmodifiableSet62050480})", "Monitor(sun.nio.ch.KQueueSelectorImpl1837806480})" ] }, { "threadId" : 68, "threadName" : "SparkUI-68", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl223607814})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:101)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.select(ManagedSelector.java:243)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.produce(ManagedSelector.java:191)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:249)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(sun.nio.ch.Util$3543145185})", "Monitor(java.util.Collections$UnmodifiableSet897441546})", "Monitor(sun.nio.ch.KQueueSelectorImpl223607814})" ] }, { "threadId" : 71, "threadName" : "SparkUI-71", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\norg.spark_project.jetty.util.BlockingArrayQueue.poll(BlockingArrayQueue.java:392)\norg.spark_project.jetty.util.thread.QueuedThreadPool.idleJobPoll(QueuedThreadPool.java:563)\norg.spark_project.jetty.util.thread.QueuedThreadPool.access$800(QueuedThreadPool.java:48)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:626)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1503479572})", "holdingLocks" : [ ] }, { "threadId" : 77, "threadName" : "SparkUI-77", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\norg.spark_project.jetty.util.BlockingArrayQueue.poll(BlockingArrayQueue.java:392)\norg.spark_project.jetty.util.thread.QueuedThreadPool.idleJobPoll(QueuedThreadPool.java:563)\norg.spark_project.jetty.util.thread.QueuedThreadPool.access$800(QueuedThreadPool.java:48)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:626)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1503479572})", "holdingLocks" : [ ] }, { "threadId" : 78, "threadName" : "SparkUI-78", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl403077801})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:101)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.select(ManagedSelector.java:243)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.produce(ManagedSelector.java:191)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:249)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(sun.nio.ch.Util$3261312406})", "Monitor(java.util.Collections$UnmodifiableSet852901260})", "Monitor(sun.nio.ch.KQueueSelectorImpl403077801})" ] }, { "threadId" : 72, "threadName" : "SparkUI-JettyScheduler", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:1093)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:809)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1587346642})", "holdingLocks" : [ ] }, { "threadId" : 63, "threadName" : "task-result-getter-0", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject537563105})", "holdingLocks" : [ ] }, { "threadId" : 64, "threadName" : "task-result-getter-1", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject537563105})", "holdingLocks" : [ ] }, { "threadId" : 65, "threadName" : "task-result-getter-2", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject537563105})", "holdingLocks" : [ ] }, { "threadId" : 66, "threadName" : "task-result-getter-3", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject537563105})", "holdingLocks" : [ ] }, { "threadId" : 46, "threadName" : "Timer-0", "threadState" : "WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.lang.Object.wait(Object.java:502)\njava.util.TimerThread.mainLoop(Timer.java:526)\njava.util.TimerThread.run(Timer.java:505)", "blockedByLock" : "Lock(java.util.TaskQueue635634547})", "holdingLocks" : [ ] } ] ```
    ## How was this patch tested? It was tested manually. Old executor page with thread dumps: screen shot 2018-02-01 at 14 31 19 New api: screen shot 2018-02-01 at 14 31 56 Testing error cases. Initial state: ![screen shot 2018-02-06 at 13 05 05](https://user-images.githubusercontent.com/2017933/35858990-ad2982be-0b3e-11e8-879b-656112065c7f.png) Dead executor: ```bash $ curl -o - -s -w "\n%{http_code}\n" http://localhost:4040/api/v1/applications/app-20180206122543-0000/executors/1/threads Executor is not active. 400 ``` Never existed (but well formatted: number) executor ID: ```bash $ curl -o - -s -w "\n%{http_code}\n" http://localhost:4040/api/v1/applications/app-20180206122543-0000/executors/42/threads Executor does not exist. 404 ``` Not available stacktrace (dead executor but UI has not registered as dead yet): ```bash $ kill -9 ; curl -o - -s -w "\n%{http_code}\n" http://localhost:4040/api/v1/applications/app-20180206122543-0000/executors/2/threads No thread dump is available. 404 ``` Invalid executor ID format: ```bash $ curl -o - -s -w "\n%{http_code}\n" http://localhost:4040/api/v1/applications/app-20180206122543-0000/executors/something6/threads Invalid executorId: neither 'driver' nor number. 400 ``` Author: “attilapiros” Closes #20474 from attilapiros/SPARK-23235. --- .../scala/org/apache/spark/SparkContext.scala | 1 + .../spark/status/api/v1/ApiRootResource.scala | 8 +++++ .../api/v1/OneApplicationResource.scala | 29 +++++++++++++++-- .../org/apache/spark/status/api/v1/api.scala | 9 ++++++ .../ui/exec/ExecutorThreadDumpPage.scala | 13 +------- .../apache/spark/util/ThreadStackTrace.scala | 31 ------------------- .../scala/org/apache/spark/util/Utils.scala | 18 ++++++++++- docs/monitoring.md | 7 +++++ 8 files changed, 69 insertions(+), 47 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c4f74c4f1f9c2..dc531e3337014 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -54,6 +54,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, StandaloneSchedulerBackend} import org.apache.spark.scheduler.local.LocalSchedulerBackend import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1.ThreadStackTrace import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.TriggerThreadDump import org.apache.spark.ui.{ConsoleProgressBar, SparkUI} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index ed9bdc6e1e3c2..7127397f6205c 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -157,6 +157,14 @@ private[v1] class NotFoundException(msg: String) extends WebApplicationException .build() ) +private[v1] class ServiceUnavailable(msg: String) extends WebApplicationException( + new ServiceUnavailableException(msg), + Response + .status(Response.Status.SERVICE_UNAVAILABLE) + .entity(ErrorWrapper(msg)) + .build() +) + private[v1] class BadParameterException(msg: String) extends WebApplicationException( new IllegalArgumentException(msg), Response diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala index bd4df07e7afc6..974697890dd03 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala @@ -19,13 +19,13 @@ package org.apache.spark.status.api.v1 import java.io.OutputStream import java.util.{List => JList} import java.util.zip.ZipOutputStream -import javax.ws.rs.{GET, Path, PathParam, Produces, QueryParam} +import javax.ws.rs._ import javax.ws.rs.core.{MediaType, Response, StreamingOutput} import scala.util.control.NonFatal -import org.apache.spark.JobExecutionStatus -import org.apache.spark.ui.SparkUI +import org.apache.spark.{JobExecutionStatus, SparkContext} +import org.apache.spark.ui.UIUtils @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class AbstractApplicationResource extends BaseAppResource { @@ -51,6 +51,29 @@ private[v1] class AbstractApplicationResource extends BaseAppResource { @Path("executors") def executorList(): Seq[ExecutorSummary] = withUI(_.store.executorList(true)) + @GET + @Path("executors/{executorId}/threads") + def threadDump(@PathParam("executorId") execId: String): Array[ThreadStackTrace] = withUI { ui => + if (execId != SparkContext.DRIVER_IDENTIFIER && !execId.forall(Character.isDigit)) { + throw new BadParameterException( + s"Invalid executorId: neither '${SparkContext.DRIVER_IDENTIFIER}' nor number.") + } + + val safeSparkContext = ui.sc.getOrElse { + throw new ServiceUnavailable("Thread dumps not available through the history server.") + } + + ui.store.asOption(ui.store.executorSummary(execId)) match { + case Some(executorSummary) if executorSummary.isActive => + val safeThreadDump = safeSparkContext.getExecutorThreadDump(execId).getOrElse { + throw new NotFoundException("No thread dump is available.") + } + safeThreadDump + case Some(_) => throw new BadParameterException("Executor is not active.") + case _ => throw new NotFoundException("Executor does not exist.") + } + } + @GET @Path("allexecutors") def allExecutorList(): Seq[ExecutorSummary] = withUI(_.store.executorList(false)) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index a333f1aaf6325..369e98b683b1a 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -316,3 +316,12 @@ class RuntimeInfo private[spark]( val javaVersion: String, val javaHome: String, val scalaVersion: String) + +case class ThreadStackTrace( + val threadId: Long, + val threadName: String, + val threadState: Thread.State, + val stackTrace: String, + val blockedByThreadId: Option[Long], + val blockedByLock: String, + val holdingLocks: Seq[String]) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index f4686ea3cf91f..7a9aaf29a8b05 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -17,7 +17,6 @@ package org.apache.spark.ui.exec -import java.util.Locale import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Text} @@ -41,17 +40,7 @@ private[ui] class ExecutorThreadDumpPage( val maybeThreadDump = sc.get.getExecutorThreadDump(executorId) val content = maybeThreadDump.map { threadDump => - val dumpRows = threadDump.sortWith { - case (threadTrace1, threadTrace2) => - val v1 = if (threadTrace1.threadName.contains("Executor task launch")) 1 else 0 - val v2 = if (threadTrace2.threadName.contains("Executor task launch")) 1 else 0 - if (v1 == v2) { - threadTrace1.threadName.toLowerCase(Locale.ROOT) < - threadTrace2.threadName.toLowerCase(Locale.ROOT) - } else { - v1 > v2 - } - }.map { thread => + val dumpRows = threadDump.map { thread => val threadId = thread.threadId val blockedBy = thread.blockedByThreadId match { case Some(_) => diff --git a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala deleted file mode 100644 index b1217980faf1f..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -/** - * Used for shipping per-thread stacktraces from the executors to driver. - */ -private[spark] case class ThreadStackTrace( - threadId: Long, - threadName: String, - threadState: Thread.State, - stackTrace: String, - blockedByThreadId: Option[Long], - blockedByLock: String, - holdingLocks: Seq[String]) - diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 5853302973140..d493663f0b168 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -63,6 +63,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} +import org.apache.spark.status.api.v1.ThreadStackTrace /** CallSite represents a place in user code. It can have a short and a long form. */ private[spark] case class CallSite(shortForm: String, longForm: String) @@ -2168,7 +2169,22 @@ private[spark] object Utils extends Logging { // We need to filter out null values here because dumpAllThreads() may return null array // elements for threads that are dead / don't exist. val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null) - threadInfos.sortBy(_.getThreadId).map(threadInfoToThreadStackTrace) + threadInfos.sortWith { case (threadTrace1, threadTrace2) => + val v1 = if (threadTrace1.getThreadName.contains("Executor task launch")) 1 else 0 + val v2 = if (threadTrace2.getThreadName.contains("Executor task launch")) 1 else 0 + if (v1 == v2) { + val name1 = threadTrace1.getThreadName().toLowerCase(Locale.ROOT) + val name2 = threadTrace2.getThreadName().toLowerCase(Locale.ROOT) + val nameCmpRes = name1.compareTo(name2) + if (nameCmpRes == 0) { + threadTrace1.getThreadId < threadTrace2.getThreadId + } else { + nameCmpRes < 0 + } + } else { + v1 > v2 + } + }.map(threadInfoToThreadStackTrace) } def getThreadDumpForThread(threadId: Long): Option[ThreadStackTrace] = { diff --git a/docs/monitoring.md b/docs/monitoring.md index 6f6cfc1288d73..d5f7ffcc260a1 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -347,6 +347,13 @@ can be identified by their `[attempt-id]`. In the API listed below, when running /applications/[app-id]/executors A list of all active executors for the given application. + + /applications/[app-id]/executors/[executor-id]/threads + + Stack traces of all the threads running within the given active executor. + Not available via the history server. + + /applications/[app-id]/allexecutors A list of all(active and dead) executors for the given application. From d6f5e172b480c62165be168deae0deff8062f476 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 13 Feb 2018 16:21:17 -0800 Subject: [PATCH 350/774] Revert "[SPARK-23303][SQL] improve the explain result for data source v2 relations" This reverts commit f17b936f0ddb7d46d1349bd42f9a64c84c06e48d. --- .../kafka010/KafkaContinuousSourceSuite.scala | 18 +++- .../sql/kafka010/KafkaContinuousTest.scala | 3 +- .../spark/sql/kafka010/KafkaSourceSuite.scala | 3 +- .../apache/spark/sql/DataFrameReader.scala | 8 +- .../v2/DataSourceReaderHolder.scala | 64 +++++++++++++ .../v2/DataSourceV2QueryPlan.scala | 96 ------------------- .../datasources/v2/DataSourceV2Relation.scala | 26 ++--- .../datasources/v2/DataSourceV2ScanExec.scala | 6 +- .../datasources/v2/DataSourceV2Strategy.scala | 4 +- .../v2/PushDownOperatorsToDataSource.scala | 4 +- .../streaming/MicroBatchExecution.scala | 22 ++--- .../continuous/ContinuousExecution.scala | 9 +- .../spark/sql/streaming/StreamSuite.scala | 8 +- .../spark/sql/streaming/StreamTest.scala | 2 +- .../continuous/ContinuousSuite.scala | 11 ++- 15 files changed, 127 insertions(+), 157 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index 72ee0c551ec3d..a7083fa4e3417 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -17,9 +17,20 @@ package org.apache.spark.sql.kafka010 -import org.apache.spark.sql.Dataset +import java.util.Properties +import java.util.concurrent.atomic.AtomicInteger + +import org.scalatest.time.SpanSugar._ +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.streaming.Trigger +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} // Run tests in KafkaSourceSuiteBase in continuous execution mode. class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest @@ -60,8 +71,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => - r.reader.asInstanceOf[KafkaContinuousReader] + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r }.exists { r => // Ensure the new topic is present and the old topic is gone. r.knownPartitions.exists(_.topic == topic2) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index d34458ac81014..5a1a14f7a307a 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -47,8 +47,7 @@ trait KafkaContinuousTest extends KafkaSourceTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => - r.reader.asInstanceOf[KafkaContinuousReader] + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index cb09cce75ff6f..02c87643568bd 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -117,8 +117,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => - r.reader.asInstanceOf[KafkaContinuousReader] + case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader } }) if (sources.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 984b6510f2dbe..fcaf8d618c168 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -189,9 +189,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance().asInstanceOf[DataSourceV2] + val ds = cls.newInstance() val options = new DataSourceOptions((extraOptions ++ - DataSourceV2Utils.extractSessionConfigs(ds, sparkSession.sessionState.conf)).asJava) + DataSourceV2Utils.extractSessionConfigs( + ds = ds.asInstanceOf[DataSourceV2], + conf = sparkSession.sessionState.conf)).asJava) // Streaming also uses the data source V2 API. So it may be that the data source implements // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading @@ -219,7 +221,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { if (reader == null) { loadV1Source(paths: _*) } else { - Dataset.ofRows(sparkSession, DataSourceV2Relation(ds, reader)) + Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) } } else { loadV1Source(paths: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala new file mode 100644 index 0000000000000..81219e9771bd8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.Objects + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.sources.v2.reader._ + +/** + * A base class for data source reader holder with customized equals/hashCode methods. + */ +trait DataSourceReaderHolder { + + /** + * The output of the data source reader, w.r.t. column pruning. + */ + def output: Seq[Attribute] + + /** + * The held data source reader. + */ + def reader: DataSourceReader + + /** + * The metadata of this data source reader that can be used for equality test. + */ + private def metadata: Seq[Any] = { + val filters: Any = reader match { + case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet + case s: SupportsPushDownFilters => s.pushedFilters().toSet + case _ => Nil + } + Seq(output, reader.getClass, filters) + } + + def canEqual(other: Any): Boolean + + override def equals(other: Any): Boolean = other match { + case other: DataSourceReaderHolder => + canEqual(other) && metadata.length == other.metadata.length && + metadata.zip(other.metadata).forall { case (l, r) => l == r } + case _ => false + } + + override def hashCode(): Int = { + metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala deleted file mode 100644 index 1e0d088f3a57c..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import java.util.Objects - -import org.apache.commons.lang3.StringUtils - -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.DataSourceV2 -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.util.Utils - -/** - * A base class for data source v2 related query plan(both logical and physical). It defines the - * equals/hashCode methods, and provides a string representation of the query plan, according to - * some common information. - */ -trait DataSourceV2QueryPlan { - - /** - * The output of the data source reader, w.r.t. column pruning. - */ - def output: Seq[Attribute] - - /** - * The instance of this data source implementation. Note that we only consider its class in - * equals/hashCode, not the instance itself. - */ - def source: DataSourceV2 - - /** - * The created data source reader. Here we use it to get the filters that has been pushed down - * so far, itself doesn't take part in the equals/hashCode. - */ - def reader: DataSourceReader - - private lazy val filters = reader match { - case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet - case s: SupportsPushDownFilters => s.pushedFilters().toSet - case _ => Set.empty - } - - /** - * The metadata of this data source query plan that can be used for equality check. - */ - private def metadata: Seq[Any] = Seq(output, source.getClass, filters) - - def canEqual(other: Any): Boolean - - override def equals(other: Any): Boolean = other match { - case other: DataSourceV2QueryPlan => canEqual(other) && metadata == other.metadata - case _ => false - } - - override def hashCode(): Int = { - metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) - } - - def metadataString: String = { - val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] - if (filters.nonEmpty) entries += "PushedFilter" -> filters.mkString("[", ", ", "]") - - val outputStr = Utils.truncatedString(output, "[", ", ", "]") - - val entriesStr = if (entries.nonEmpty) { - Utils.truncatedString(entries.map { - case (key, value) => key + ": " + StringUtils.abbreviate(redact(value), 100) - }, " (", ", ", ")") - } else { - "" - } - - s"${source.getClass.getSimpleName}$outputStr$entriesStr" - } - - private def redact(text: String): String = { - Utils.redact(SQLConf.get.stringRedationPattern, text) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index cd97e0cab6b5c..38f6b15224788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -20,23 +20,15 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} -import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( output: Seq[AttributeReference], - source: DataSourceV2, - reader: DataSourceReader, - override val isStreaming: Boolean) - extends LeafNode with MultiInstanceRelation with DataSourceV2QueryPlan { + reader: DataSourceReader) + extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] - override def simpleString: String = { - val streamingHeader = if (isStreaming) "Streaming " else "" - s"${streamingHeader}Relation $metadataString" - } - override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) @@ -49,8 +41,18 @@ case class DataSourceV2Relation( } } +/** + * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical + * to the non-streaming relation. + */ +class StreamingDataSourceV2Relation( + output: Seq[AttributeReference], + reader: DataSourceReader) extends DataSourceV2Relation(output, reader) { + override def isStreaming: Boolean = true +} + object DataSourceV2Relation { - def apply(source: DataSourceV2, reader: DataSourceReader): DataSourceV2Relation = { - new DataSourceV2Relation(reader.readSchema().toAttributes, source, reader, isStreaming = false) + def apply(reader: DataSourceReader): DataSourceV2Relation = { + new DataSourceV2Relation(reader.readSchema().toAttributes, reader) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index c99d535efcf81..7d9581be4db89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types.StructType @@ -37,14 +36,11 @@ import org.apache.spark.sql.types.StructType */ case class DataSourceV2ScanExec( output: Seq[AttributeReference], - @transient source: DataSourceV2, @transient reader: DataSourceReader) - extends LeafExecNode with DataSourceV2QueryPlan with ColumnarBatchScan { + extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] - override def simpleString: String = s"Scan $metadataString" - override def outputPartitioning: physical.Partitioning = reader match { case s: SupportsReportPartitioning => new DataSourcePartitioning( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index fb61e6f32b1f4..df5b524485f54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case r: DataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.reader) :: Nil + case DataSourceV2Relation(output, reader) => + DataSourceV2ScanExec(output, reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 4cfdd50e8f46b..1ca6cbf061b4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -39,11 +39,11 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel // TODO: Ideally column pruning should be implemented via a plan property that is propagated // top-down, then we can simplify the logic here and only collect target operators. val filterPushed = plan transformUp { - case FilterAndProject(fields, condition, r: DataSourceV2Relation) => + case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => val (candidates, nonDeterministic) = splitConjunctivePredicates(condition).partition(_.deterministic) - val stayUpFilters: Seq[Expression] = r.reader match { + val stayUpFilters: Seq[Expression] = reader match { case r: SupportsPushDownCatalystFilters => r.pushCatalystFilters(candidates.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 84564b6639ac9..812533313332e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -27,9 +27,9 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} @@ -52,8 +52,6 @@ class MicroBatchExecution( @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty - private val readerToDataSourceMap = MutableMap.empty[MicroBatchReader, DataSourceV2] - private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) case OneTimeTrigger => OneTimeExecutor() @@ -92,7 +90,6 @@ class MicroBatchExecution( metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 - readerToDataSourceMap(reader) = source StreamingExecutionRelation(reader, output)(sparkSession) }) case s @ StreamingRelationV2(_, sourceName, _, output, v1Relation) => @@ -408,15 +405,12 @@ class MicroBatchExecution( case v1: SerializedOffset => reader.deserializeOffset(v1.json) case v2: OffsetV2 => v2 } - reader.setOffsetRange(toJava(current), Optional.of(availableV2)) + reader.setOffsetRange( + toJava(current), + Optional.of(availableV2)) logDebug(s"Retrieving data from $reader: $current -> $availableV2") - Some(reader -> new DataSourceV2Relation( - reader.readSchema().toAttributes, - // Provide a fake value here just in case something went wrong, e.g. the reader gives - // a wrong `equals` implementation. - readerToDataSourceMap.getOrElse(reader, FakeDataSourceV2), - reader, - isStreaming = true)) + Some(reader -> + new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader)) case _ => None } } @@ -506,5 +500,3 @@ class MicroBatchExecution( Optional.ofNullable(scalaOption.orNull) } } - -object FakeDataSourceV2 extends DataSourceV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index f87d57d0b3209..c3294d64b10cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} @@ -167,7 +167,7 @@ class ContinuousExecution( var insertedSourceId = 0 val withNewSources = logicalPlan transform { - case ContinuousExecutionRelation(ds, _, output) => + case ContinuousExecutionRelation(_, _, output) => val reader = continuousSources(insertedSourceId) insertedSourceId += 1 val newOutput = reader.readSchema().toAttributes @@ -180,7 +180,7 @@ class ContinuousExecution( val loggedOffset = offsets.offsets(0) val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) - new DataSourceV2Relation(newOutput, ds, reader, isStreaming = true) + new StreamingDataSourceV2Relation(newOutput, reader) } // Rewire the plan to use the new attributes that were returned by the source. @@ -201,8 +201,7 @@ class ContinuousExecution( val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { - case r: DataSourceV2Relation if r.reader.isInstanceOf[ContinuousReader] => - r.reader.asInstanceOf[ContinuousReader] + case DataSourceV2Relation(_, r: ContinuousReader) => r }.head reportTimeTaken("queryPlanning") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 70eb9f0ac66d5..d1a04833390f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -492,16 +492,16 @@ class StreamSuite extends StreamTest { val explainWithoutExtended = q.explainInternal(false) // `extended = false` only displays the physical plan. - assert("Streaming Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) - assert("Scan FakeDataSourceV2".r.findAllMatchIn(explainWithoutExtended).size === 1) + assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) + assert("DataSourceV2Scan".r.findAllMatchIn(explainWithoutExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithoutExtended.contains("StateStoreRestore")) val explainWithExtended = q.explainInternal(true) // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical // plan. - assert("Streaming Relation".r.findAllMatchIn(explainWithExtended).size === 3) - assert("Scan FakeDataSourceV2".r.findAllMatchIn(explainWithExtended).size === 1) + assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithExtended).size === 3) + assert("DataSourceV2Scan".r.findAllMatchIn(explainWithExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithExtended.contains("StateStoreRestore")) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 254394685857b..37fe595529baf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -605,7 +605,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be plan .collect { case StreamingExecutionRelation(s, _) => s - case d: DataSourceV2Relation => d.reader + case DataSourceV2Relation(_, r) => r } .zipWithIndex .find(_._1 == source) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 9ee9aaf87f87c..4b4ed82dc6520 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.streaming.continuous -import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} +import java.util.UUID + +import org.apache.spark.{SparkContext, SparkEnv, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.test.TestSparkSession @@ -40,7 +43,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, _, r: RateStreamContinuousReader) => r + case DataSourceV2ScanExec(_, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 From 357babde5a8eb9710de7016d7ae82dee21fa4ef3 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 14 Feb 2018 10:55:24 +0800 Subject: [PATCH 351/774] [SPARK-23399][SQL] Register a task completion listener first for OrcColumnarBatchReader ## What changes were proposed in this pull request? This PR aims to resolve an open file leakage issue reported at [SPARK-23390](https://issues.apache.org/jira/browse/SPARK-23390) by moving the listener registration position. Currently, the sequence is like the following. 1. Create `batchReader` 2. `batchReader.initialize` opens a ORC file. 3. `batchReader.initBatch` may take a long time to alloc memory in some environment and cause errors. 4. `Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))` This PR moves 4 before 2 and 3. To sum up, the new sequence is 1 -> 4 -> 2 -> 3. ## How was this patch tested? Manual. The following test case makes OOM intentionally to cause leaked filesystem connection in the current code base. With this patch, leakage doesn't occurs. ```scala // This should be tested manually because it raises OOM intentionally // in order to cause `Leaked filesystem connection`. test("SPARK-23399 Register a task completion listener first for OrcColumnarBatchReader") { withSQLConf(SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE.key -> s"${Int.MaxValue}") { withTempDir { dir => val basePath = dir.getCanonicalPath Seq(0).toDF("a").write.format("orc").save(new Path(basePath, "first").toString) Seq(1).toDF("a").write.format("orc").save(new Path(basePath, "second").toString) val df = spark.read.orc( new Path(basePath, "first").toString, new Path(basePath, "second").toString) val e = intercept[SparkException] { df.collect() } assert(e.getCause.isInstanceOf[OutOfMemoryError]) } } } ``` Author: Dongjoon Hyun Closes #20590 from dongjoon-hyun/SPARK-23399. --- .../sql/execution/datasources/orc/OrcFileFormat.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index dbf3bc6f0ee6c..1de2ca2914c44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -188,6 +188,12 @@ class OrcFileFormat if (enableVectorizedReader) { val batchReader = new OrcColumnarBatchReader( enableOffHeapColumnVector && taskContext.isDefined, copyToSpark, capacity) + // SPARK-23399 Register a task completion listener first to call `close()` in all cases. + // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM) + // after opening a file. + val iter = new RecordReaderIterator(batchReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + batchReader.initialize(fileSplit, taskAttemptContext) batchReader.initBatch( reader.getSchema, @@ -196,8 +202,6 @@ class OrcFileFormat partitionSchema, file.partitionValues) - val iter = new RecordReaderIterator(batchReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) iter.asInstanceOf[Iterator[InternalRow]] } else { val orcRecordReader = new OrcInputFormat[OrcStruct] From 140f87533a468b1046504fc3ff01fbe1637e41cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Wed, 14 Feb 2018 06:45:54 -0800 Subject: [PATCH 352/774] [SPARK-23394][UI] In RDD storage page show the executor addresses instead of the IDs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Extending RDD storage page to show executor addresses in the block table. ## How was this patch tested? Manually: ![screen shot 2018-02-13 at 10 30 59](https://user-images.githubusercontent.com/2017933/36142668-0b3578f8-10a9-11e8-95ea-2f57703ee4af.png) Author: “attilapiros” Closes #20589 from attilapiros/SPARK-23394. --- .../org/apache/spark/ui/storage/RDDPage.scala | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 02cee7f8c5b33..2674b9291203a 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -23,7 +23,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Unparsed} import org.apache.spark.status.AppStatusStore -import org.apache.spark.status.api.v1.{RDDDataDistribution, RDDPartitionInfo} +import org.apache.spark.status.api.v1.{ExecutorSummary, RDDDataDistribution, RDDPartitionInfo} import org.apache.spark.ui._ import org.apache.spark.util.Utils @@ -76,7 +76,8 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web rddStorageInfo.partitions.get, blockPageSize, blockSortColumn, - blockSortDesc) + blockSortDesc, + store.executorList(true)) _blockTable.table(page) } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => @@ -182,7 +183,8 @@ private[ui] class BlockDataSource( rddPartitions: Seq[RDDPartitionInfo], pageSize: Int, sortColumn: String, - desc: Boolean) extends PagedDataSource[BlockTableRowData](pageSize) { + desc: Boolean, + executorIdToAddress: Map[String, String]) extends PagedDataSource[BlockTableRowData](pageSize) { private val data = rddPartitions.map(blockRow).sorted(ordering(sortColumn, desc)) @@ -198,7 +200,10 @@ private[ui] class BlockDataSource( rddPartition.storageLevel, rddPartition.memoryUsed, rddPartition.diskUsed, - rddPartition.executors.mkString(" ")) + rddPartition.executors + .map { id => executorIdToAddress.get(id).getOrElse(id) } + .sorted + .mkString(" ")) } /** @@ -226,7 +231,8 @@ private[ui] class BlockPagedTable( rddPartitions: Seq[RDDPartitionInfo], pageSize: Int, sortColumn: String, - desc: Boolean) extends PagedTable[BlockTableRowData] { + desc: Boolean, + executorSummaries: Seq[ExecutorSummary]) extends PagedTable[BlockTableRowData] { override def tableId: String = "rdd-storage-by-block-table" @@ -243,7 +249,8 @@ private[ui] class BlockPagedTable( rddPartitions, pageSize, sortColumn, - desc) + desc, + executorSummaries.map { ex => (ex.id, ex.hostPort) }.toMap) override def pageLink(page: Int): String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") From 400a1d9e25c1196f0be87323bd89fb3af0660166 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 14 Feb 2018 10:57:12 -0800 Subject: [PATCH 353/774] Revert "[SPARK-23249][SQL] Improved block merging logic for partitions" This reverts commit 8c21170decfb9ca4d3233e1ea13bd1b6e3199ed9. --- .../sql/execution/DataSourceScanExec.scala | 29 +++++-------------- .../datasources/FileSourceStrategySuite.scala | 15 ++++++---- 2 files changed, 17 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index ba1157d5b6a49..08ff33afbba3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -444,29 +444,16 @@ case class FileSourceScanExec( currentSize = 0 } - def addFile(file: PartitionedFile): Unit = { - currentFiles += file - currentSize += file.length + openCostInBytes - } - - var frontIndex = 0 - var backIndex = splitFiles.length - 1 - - while (frontIndex <= backIndex) { - addFile(splitFiles(frontIndex)) - frontIndex += 1 - while (frontIndex <= backIndex && - currentSize + splitFiles(frontIndex).length <= maxSplitBytes) { - addFile(splitFiles(frontIndex)) - frontIndex += 1 - } - while (backIndex > frontIndex && - currentSize + splitFiles(backIndex).length <= maxSplitBytes) { - addFile(splitFiles(backIndex)) - backIndex -= 1 + // Assign files to partitions using "Next Fit Decreasing" + splitFiles.foreach { file => + if (currentSize + file.length > maxSplitBytes) { + closePartition() } - closePartition() + // Add the given file to the current partition. + currentSize += file.length + openCostInBytes + currentFiles += file } + closePartition() new FileScanRDD(fsRelation.sparkSession, readFile, partitions) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index bfccc9335b361..c1d61b843d899 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -141,17 +141,16 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "4", SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") { checkScan(table.select('c1)) { partitions => - // Files should be laid out [(file1, file6), (file2, file3), (file4, file5)] - assert(partitions.size == 3, "when checking partitions") - assert(partitions(0).files.size == 2, "when checking partition 1") + // Files should be laid out [(file1), (file2, file3), (file4, file5), (file6)] + assert(partitions.size == 4, "when checking partitions") + assert(partitions(0).files.size == 1, "when checking partition 1") assert(partitions(1).files.size == 2, "when checking partition 2") assert(partitions(2).files.size == 2, "when checking partition 3") + assert(partitions(3).files.size == 1, "when checking partition 4") - // First partition reads (file1, file6) + // First partition reads (file1) assert(partitions(0).files(0).start == 0) assert(partitions(0).files(0).length == 2) - assert(partitions(0).files(1).start == 0) - assert(partitions(0).files(1).length == 1) // Second partition reads (file2, file3) assert(partitions(1).files(0).start == 0) @@ -164,6 +163,10 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi assert(partitions(2).files(0).length == 1) assert(partitions(2).files(1).start == 0) assert(partitions(2).files(1).length == 1) + + // Final partition reads (file6) + assert(partitions(3).files(0).start == 0) + assert(partitions(3).files(0).length == 1) } checkPartitionSchema(StructType(Nil)) From 658d9d9d785a30857bf35d164e6cbbd9799d6959 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 14 Feb 2018 14:27:02 -0800 Subject: [PATCH 354/774] [SPARK-23406][SS] Enable stream-stream self-joins ## What changes were proposed in this pull request? Solved two bugs to enable stream-stream self joins. ### Incorrect analysis due to missing MultiInstanceRelation trait Streaming leaf nodes did not extend MultiInstanceRelation, which is necessary for the catalyst analyzer to convert the self-join logical plan DAG into a tree (by creating new instances of the leaf relations). This was causing the error `Failure when resolving conflicting references in Join:` (see JIRA for details). ### Incorrect attribute rewrite when splicing batch plans in MicroBatchExecution When splicing the source's batch plan into the streaming plan (by replacing the StreamingExecutionPlan), we were rewriting the attribute reference in the streaming plan with the new attribute references from the batch plan. This was incorrectly handling the scenario when multiple StreamingExecutionRelation point to the same source, and therefore eventually point to the same batch plan returned by the source. Here is an example query, and its corresponding plan transformations. ``` val df = input.toDF val join = df.select('value % 5 as "key", 'value).join( df.select('value % 5 as "key", 'value), "key") ``` Streaming logical plan before splicing the batch plan ``` Project [key#6, value#1, value#12] +- Join Inner, (key#6 = key#9) :- Project [(value#1 % 5) AS key#6, value#1] : +- StreamingExecutionRelation Memory[#1], value#1 +- Project [(value#12 % 5) AS key#9, value#12] +- StreamingExecutionRelation Memory[#1], value#12 // two different leaves pointing to same source ``` Batch logical plan after splicing the batch plan and before rewriting ``` Project [key#6, value#1, value#12] +- Join Inner, (key#6 = key#9) :- Project [(value#1 % 5) AS key#6, value#1] : +- LocalRelation [value#66] // replaces StreamingExecutionRelation Memory[#1], value#1 +- Project [(value#12 % 5) AS key#9, value#12] +- LocalRelation [value#66] // replaces StreamingExecutionRelation Memory[#1], value#12 ``` Batch logical plan after rewriting the attributes. Specifically, for spliced, the new output attributes (value#66) replace the earlier output attributes (value#12, and value#1, one for each StreamingExecutionRelation). ``` Project [key#6, value#66, value#66] // both value#1 and value#12 replaces by value#66 +- Join Inner, (key#6 = key#9) :- Project [(value#66 % 5) AS key#6, value#66] : +- LocalRelation [value#66] +- Project [(value#66 % 5) AS key#9, value#66] +- LocalRelation [value#66] ``` This causes the optimizer to eliminate value#66 from one side of the join. ``` Project [key#6, value#66, value#66] +- Join Inner, (key#6 = key#9) :- Project [(value#66 % 5) AS key#6, value#66] : +- LocalRelation [value#66] +- Project [(value#66 % 5) AS key#9] // this does not generate value, incorrect join results +- LocalRelation [value#66] ``` **Solution**: Instead of rewriting attributes, use a Project to introduce aliases between the output attribute references and the new reference generated by the spliced plans. The analyzer and optimizer will take care of the rest. ``` Project [key#6, value#1, value#12] +- Join Inner, (key#6 = key#9) :- Project [(value#1 % 5) AS key#6, value#1] : +- Project [value#66 AS value#1] // solution: project with aliases : +- LocalRelation [value#66] +- Project [(value#12 % 5) AS key#9, value#12] +- Project [value#66 AS value#12] // solution: project with aliases +- LocalRelation [value#66] ``` ## How was this patch tested? New unit test Author: Tathagata Das Closes #20598 from tdas/SPARK-23406. --- .../streaming/MicroBatchExecution.scala | 16 ++++++------ .../streaming/StreamingRelation.scala | 20 ++++++++++----- .../sql/streaming/StreamingJoinSuite.scala | 25 ++++++++++++++++++- 3 files changed, 45 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 812533313332e..ac73ba3417904 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -24,8 +24,8 @@ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} @@ -415,8 +415,6 @@ class MicroBatchExecution( } } - // A list of attributes that will need to be updated. - val replacements = new ArrayBuffer[(Attribute, Attribute)] // Replace sources in the logical plan with data that has arrived since the last batch. val newBatchesPlan = logicalPlan transform { case StreamingExecutionRelation(source, output) => @@ -424,18 +422,18 @@ class MicroBatchExecution( assert(output.size == dataPlan.output.size, s"Invalid batch: ${Utils.truncatedString(output, ",")} != " + s"${Utils.truncatedString(dataPlan.output, ",")}") - replacements ++= output.zip(dataPlan.output) - dataPlan + + val aliases = output.zip(dataPlan.output).map { case (to, from) => + Alias(from, to.name)(exprId = to.exprId, explicitMetadata = Some(from.metadata)) + } + Project(aliases, dataPlan) }.getOrElse { LocalRelation(output, isStreaming = true) } } // Rewire the plan to use the new attributes that were returned by the source. - val replacementMap = AttributeMap(replacements) val newAttributePlan = newBatchesPlan transformAllExpressions { - case a: Attribute if replacementMap.contains(a) => - replacementMap(a).withMetadata(a.metadata) case ct: CurrentTimestamp => CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs, ct.dataType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 7146190645b37..f02d3a2c3733f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.LeafNode -import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2} @@ -42,7 +42,7 @@ object StreamingRelation { * passing to [[StreamExecution]] to run a query. */ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: Seq[Attribute]) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { override def isStreaming: Boolean = true override def toString: String = sourceName @@ -53,6 +53,8 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: override def computeStats(): Statistics = Statistics( sizeInBytes = BigInt(dataSource.sparkSession.sessionState.conf.defaultSizeInBytes) ) + + override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance())) } /** @@ -62,7 +64,7 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: case class StreamingExecutionRelation( source: BaseStreamingSource, output: Seq[Attribute])(session: SparkSession) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { override def isStreaming: Boolean = true override def toString: String = source.toString @@ -74,6 +76,8 @@ case class StreamingExecutionRelation( override def computeStats(): Statistics = Statistics( sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) ) + + override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session) } // We have to pack in the V1 data source as a shim, for the case when a source implements @@ -92,13 +96,15 @@ case class StreamingRelationV2( extraOptions: Map[String, String], output: Seq[Attribute], v1Relation: Option[StreamingRelation])(session: SparkSession) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { override def isStreaming: Boolean = true override def toString: String = sourceName override def computeStats(): Statistics = Statistics( sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) ) + + override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session) } /** @@ -108,7 +114,7 @@ case class ContinuousExecutionRelation( source: ContinuousReadSupport, extraOptions: Map[String, String], output: Seq[Attribute])(session: SparkSession) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { override def isStreaming: Boolean = true override def toString: String = source.toString @@ -120,6 +126,8 @@ case class ContinuousExecutionRelation( override def computeStats(): Statistics = Statistics( sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) ) + + override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 54eb863dacc83..92087f68ad74a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -28,7 +28,9 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Literal} import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter} -import org.apache.spark.sql.execution.LogicalRDD +import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.execution.{FileSourceScanExec, LogicalRDD} +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreProviderId} import org.apache.spark.sql.functions._ @@ -323,6 +325,27 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with assert(e.toString.contains("Stream stream joins without equality predicate is not supported")) } + test("stream stream self join") { + val input = MemoryStream[Int] + val df = input.toDF + val join = + df.select('value % 5 as "key", 'value).join( + df.select('value % 5 as "key", 'value), "key") + + testStream(join)( + AddData(input, 1, 2), + CheckAnswer((1, 1, 1), (2, 2, 2)), + StopStream, + StartStream(), + AddData(input, 3, 6), + /* + (1, 1) (1, 1) + (2, 2) x (2, 2) = (1, 1, 1), (1, 1, 6), (2, 2, 2), (1, 6, 1), (1, 6, 6) + (1, 6) (1, 6) + */ + CheckAnswer((3, 3, 3), (1, 1, 1), (1, 1, 6), (2, 2, 2), (1, 6, 1), (1, 6, 6))) + } + test("locality preferences of StateStoreAwareZippedRDD") { import StreamingSymmetricHashJoinHelper._ From a77ebb0921e390cf4fc6279a8c0a92868ad7e69b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 14 Feb 2018 23:52:59 -0800 Subject: [PATCH 355/774] [SPARK-23421][SPARK-22356][SQL] Document the behavior change in ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/19579 introduces a behavior change. We need to document it in the migration guide. ## How was this patch tested? Also update the HiveExternalCatalogVersionsSuite to verify it. Author: gatorsmile Closes #20606 from gatorsmile/addMigrationGuide. --- docs/sql-programming-guide.md | 2 ++ .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 0f9f01e18682f..cf9529a79f4f9 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1963,6 +1963,8 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.1 to 2.2 - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. + + - Since Spark 2.2.1 and 2.3.0, the schema is always inferred at runtime when the data source tables have the columns that exist in both partition schema and data schema. The inferred schema does not have the partitioned columns. When reading the table, Spark respects the partition values of these overlapping columns instead of the values stored in the data source files. In 2.2.0 and 2.1.x release, the inferred schema is partitioned but the data of the table is invisible to users (i.e., the result set is empty). ## Upgrading From Spark SQL 2.0 to 2.1 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index ae4aeb7b4ce4a..c13a750dbb270 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -195,7 +195,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0") + val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0", "2.2.1") protected var spark: SparkSession = _ @@ -249,7 +249,7 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { // SPARK-22356: overlapped columns between data and partition schema in data source tables val tbl_with_col_overlap = s"tbl_with_col_overlap_$index" - // For Spark 2.2.0 and 2.1.x, the behavior is different from Spark 2.0. + // For Spark 2.2.0 and 2.1.x, the behavior is different from Spark 2.0, 2.2.1, 2.3+ if (testingVersions(index).startsWith("2.1") || testingVersions(index) == "2.2.0") { spark.sql("msck repair table " + tbl_with_col_overlap) assert(spark.table(tbl_with_col_overlap).columns === Array("i", "j", "p")) From 95e4b4916065e66a4f8dba57e98e725796f75e04 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 14 Feb 2018 23:56:02 -0800 Subject: [PATCH 356/774] [SPARK-23094] Revert [] Fix invalid character handling in JsonDataSource ## What changes were proposed in this pull request? This PR is to revert the PR https://github.com/apache/spark/pull/20302, because it causes a regression. ## How was this patch tested? N/A Author: gatorsmile Closes #20614 from gatorsmile/revertJsonFix. --- .../catalyst/json/CreateJacksonParser.scala | 5 ++- .../sources/JsonHadoopFsRelationSuite.scala | 34 ------------------- 2 files changed, 2 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala index b1672e7e2fca2..025a388aacaa5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala @@ -40,11 +40,10 @@ private[sql] object CreateJacksonParser extends Serializable { } def text(jsonFactory: JsonFactory, record: Text): JsonParser = { - val bain = new ByteArrayInputStream(record.getBytes, 0, record.getLength) - jsonFactory.createParser(new InputStreamReader(bain, "UTF-8")) + jsonFactory.createParser(record.getBytes, 0, record.getLength) } def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = { - jsonFactory.createParser(new InputStreamReader(record, "UTF-8")) + jsonFactory.createParser(record) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala index 27f398ebf301a..49be30435ad2f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -28,8 +28,6 @@ import org.apache.spark.sql.types._ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = "json" - private val badJson = "\u0000\u0000\u0000A\u0001AAA" - // JSON does not write data of NullType and does not play well with BinaryType. override protected def supportsDataType(dataType: DataType): Boolean = dataType match { case _: NullType => false @@ -107,36 +105,4 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { ) } } - - test("invalid json with leading nulls - from file (multiLine=true)") { - import testImplicits._ - withTempDir { tempDir => - val path = tempDir.getAbsolutePath - Seq(badJson, """{"a":1}""").toDS().write.mode("overwrite").text(path) - val expected = s"""$badJson\n{"a":1}\n""" - val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) - val df = - spark.read.format(dataSourceName).option("multiLine", true).schema(schema).load(path) - checkAnswer(df, Row(null, expected)) - } - } - - test("invalid json with leading nulls - from file (multiLine=false)") { - import testImplicits._ - withTempDir { tempDir => - val path = tempDir.getAbsolutePath - Seq(badJson, """{"a":1}""").toDS().write.mode("overwrite").text(path) - val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) - val df = - spark.read.format(dataSourceName).option("multiLine", false).schema(schema).load(path) - checkAnswer(df, Seq(Row(1, null), Row(null, badJson))) - } - } - - test("invalid json with leading nulls - from dataset") { - import testImplicits._ - checkAnswer( - spark.read.json(Seq(badJson).toDS()), - Row(badJson)) - } } From f38c760638063f1fb45e9ee2c772090fb203a4a0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 15 Feb 2018 16:59:44 +0800 Subject: [PATCH 357/774] [SPARK-23419][SPARK-23416][SS] data source v2 write path should re-throw interruption exceptions directly ## What changes were proposed in this pull request? Streaming execution has a list of exceptions that means interruption, and handle them specially. `WriteToDataSourceV2Exec` should also respect this list and not wrap them with `SparkException`. ## How was this patch tested? existing test. Author: Wenchen Fan Closes #20605 from cloud-fan/write. --- .../datasources/v2/WriteToDataSourceV2.scala | 11 ++++- .../execution/streaming/StreamExecution.scala | 40 ++++++++++--------- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 535e7962d7439..41cdfc80d8a19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.v2 +import scala.util.control.NonFatal + import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging @@ -27,6 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -107,7 +110,13 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e throw new SparkException("Writing job failed.", cause) } logError(s"Data source writer $writer aborted.") - throw new SparkException("Writing job aborted.", cause) + cause match { + // Do not wrap interruption exceptions that will be handled by streaming specially. + case _ if StreamExecution.isInterruptionException(cause) => throw cause + // Only wrap non fatal exceptions. + case NonFatal(e) => throw new SparkException("Writing job aborted.", e) + case _ => throw cause + } } sparkContext.emptyRDD diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index e7982d7880ceb..3fc8c7887896a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -356,25 +356,7 @@ abstract class StreamExecution( private def isInterruptedByStop(e: Throwable): Boolean = { if (state.get == TERMINATED) { - e match { - // InterruptedIOException - thrown when an I/O operation is interrupted - // ClosedByInterruptException - thrown when an I/O operation upon a channel is interrupted - case _: InterruptedException | _: InterruptedIOException | _: ClosedByInterruptException => - true - // The cause of the following exceptions may be one of the above exceptions: - // - // UncheckedIOException - thrown by codes that cannot throw a checked IOException, such as - // BiFunction.apply - // ExecutionException - thrown by codes running in a thread pool and these codes throw an - // exception - // UncheckedExecutionException - thrown by codes that cannot throw a checked - // ExecutionException, such as BiFunction.apply - case e2 @ (_: UncheckedIOException | _: ExecutionException | _: UncheckedExecutionException) - if e2.getCause != null => - isInterruptedByStop(e2.getCause) - case _ => - false - } + StreamExecution.isInterruptionException(e) } else { false } @@ -565,6 +547,26 @@ abstract class StreamExecution( object StreamExecution { val QUERY_ID_KEY = "sql.streaming.queryId" + + def isInterruptionException(e: Throwable): Boolean = e match { + // InterruptedIOException - thrown when an I/O operation is interrupted + // ClosedByInterruptException - thrown when an I/O operation upon a channel is interrupted + case _: InterruptedException | _: InterruptedIOException | _: ClosedByInterruptException => + true + // The cause of the following exceptions may be one of the above exceptions: + // + // UncheckedIOException - thrown by codes that cannot throw a checked IOException, such as + // BiFunction.apply + // ExecutionException - thrown by codes running in a thread pool and these codes throw an + // exception + // UncheckedExecutionException - thrown by codes that cannot throw a checked + // ExecutionException, such as BiFunction.apply + case e2 @ (_: UncheckedIOException | _: ExecutionException | _: UncheckedExecutionException) + if e2.getCause != null => + isInterruptionException(e2.getCause) + case _ => + false + } } /** From 7539ae59d6c354c95c50528abe9ddff6972e960f Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Thu, 15 Feb 2018 17:09:06 +0800 Subject: [PATCH 358/774] [SPARK-23366] Improve hot reading path in ReadAheadInputStream MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? `ReadAheadInputStream` was introduced in https://github.com/apache/spark/pull/18317/ to optimize reading spill files from disk. However, from the profiles it seems that the hot path of reading small amounts of data (like readInt) is inefficient - it involves taking locks, and multiple checks. Optimize locking: Lock is not needed when simply accessing the active buffer. Only lock when needing to swap buffers or trigger async reading, or get information about the async state. Optimize short-path single byte reads, that are used e.g. by Java library DataInputStream.readInt. The asyncReader used to call "read" only once on the underlying stream, that never filled the underlying buffer when it was wrapping an LZ4BlockInputStream. If the buffer was returned unfilled, that would trigger the async reader to be triggered to fill the read ahead buffer on each call, because the reader would see that the active buffer is below the refill threshold all the time. However, filling the full buffer all the time could introduce increased latency, so also add an `AtomicBoolean` flag for the async reader to return earlier if there is a reader waiting for data. Remove `readAheadThresholdInBytes` and instead immediately trigger async read when switching the buffers. It allows to simplify code paths, especially the hot one that then only has to check if there is available data in the active buffer, without worrying if it needs to retrigger async read. It seems to have positive effect on perf. ## How was this patch tested? It was noticed as a regression in some workloads after upgrading to Spark 2.3.  It was particularly visible on TPCDS Q95 running on instances with fast disk (i3 AWS instances). Running with profiling: * Spark 2.2 - 5.2-5.3 minutes 9.5% in LZ4BlockInputStream.read * Spark 2.3 - 6.4-6.6 minutes 31.1% in ReadAheadInputStream.read * Spark 2.3 + fix - 5.3-5.4 minutes 13.3% in ReadAheadInputStream.read - very slightly slower, practically within noise. We didn't see other regressions, and many workloads in general seem to be faster with Spark 2.3 (not investigated if thanks to async readed, or unrelated). Author: Juliusz Sompolski Closes #20555 from juliuszsompolski/SPARK-23366. --- .../apache/spark/io/ReadAheadInputStream.java | 119 +++++++++--------- .../unsafe/sort/UnsafeSorterSpillReader.java | 10 +- .../spark/io/GenericFileInputStreamSuite.java | 98 ++++++++------- .../spark/io/NioBufferedInputStreamSuite.java | 6 +- .../spark/io/ReadAheadInputStreamSuite.java | 17 ++- 5 files changed, 133 insertions(+), 117 deletions(-) diff --git a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java index 5b45d268ace8d..0cced9e222952 100644 --- a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java +++ b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java @@ -27,6 +27,7 @@ import java.nio.ByteBuffer; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.ReentrantLock; @@ -78,9 +79,8 @@ public class ReadAheadInputStream extends InputStream { // whether there is a read ahead task running, private boolean isReading; - // If the remaining data size in the current buffer is below this threshold, - // we issue an async read from the underlying input stream. - private final int readAheadThresholdInBytes; + // whether there is a reader waiting for data. + private AtomicBoolean isWaiting = new AtomicBoolean(false); private final InputStream underlyingInputStream; @@ -97,20 +97,13 @@ public class ReadAheadInputStream extends InputStream { * * @param inputStream The underlying input stream. * @param bufferSizeInBytes The buffer size. - * @param readAheadThresholdInBytes If the active buffer has less data than the read-ahead - * threshold, an async read is triggered. */ public ReadAheadInputStream( - InputStream inputStream, int bufferSizeInBytes, int readAheadThresholdInBytes) { + InputStream inputStream, int bufferSizeInBytes) { Preconditions.checkArgument(bufferSizeInBytes > 0, "bufferSizeInBytes should be greater than 0, but the value is " + bufferSizeInBytes); - Preconditions.checkArgument(readAheadThresholdInBytes > 0 && - readAheadThresholdInBytes < bufferSizeInBytes, - "readAheadThresholdInBytes should be greater than 0 and less than bufferSizeInBytes, " + - "but the value is " + readAheadThresholdInBytes); activeBuffer = ByteBuffer.allocate(bufferSizeInBytes); readAheadBuffer = ByteBuffer.allocate(bufferSizeInBytes); - this.readAheadThresholdInBytes = readAheadThresholdInBytes; this.underlyingInputStream = inputStream; activeBuffer.flip(); readAheadBuffer.flip(); @@ -166,12 +159,17 @@ public void run() { // in that case the reader waits for this async read to complete. // So there is no race condition in both the situations. int read = 0; + int off = 0, len = arr.length; Throwable exception = null; try { - while (true) { - read = underlyingInputStream.read(arr); - if (0 != read) break; - } + // try to fill the read ahead buffer. + // if a reader is waiting, possibly return early. + do { + read = underlyingInputStream.read(arr, off, len); + if (read <= 0) break; + off += read; + len -= read; + } while (len > 0 && !isWaiting.get()); } catch (Throwable ex) { exception = ex; if (ex instanceof Error) { @@ -181,13 +179,12 @@ public void run() { } } finally { stateChangeLock.lock(); + readAheadBuffer.limit(off); if (read < 0 || (exception instanceof EOFException)) { endOfStream = true; } else if (exception != null) { readAborted = true; readException = exception; - } else { - readAheadBuffer.limit(read); } readInProgress = false; signalAsyncReadComplete(); @@ -230,7 +227,10 @@ private void signalAsyncReadComplete() { private void waitForAsyncReadComplete() throws IOException { stateChangeLock.lock(); + isWaiting.set(true); try { + // There is only one reader, and one writer, so the writer should signal only once, + // but a while loop checking the wake up condition is still needed to avoid spurious wakeups. while (readInProgress) { asyncReadComplete.await(); } @@ -239,6 +239,7 @@ private void waitForAsyncReadComplete() throws IOException { iio.initCause(e); throw iio; } finally { + isWaiting.set(false); stateChangeLock.unlock(); } checkReadException(); @@ -246,8 +247,13 @@ private void waitForAsyncReadComplete() throws IOException { @Override public int read() throws IOException { - byte[] oneByteArray = oneByte.get(); - return read(oneByteArray, 0, 1) == -1 ? -1 : oneByteArray[0] & 0xFF; + if (activeBuffer.hasRemaining()) { + // short path - just get one byte. + return activeBuffer.get() & 0xFF; + } else { + byte[] oneByteArray = oneByte.get(); + return read(oneByteArray, 0, 1) == -1 ? -1 : oneByteArray[0] & 0xFF; + } } @Override @@ -258,54 +264,43 @@ public int read(byte[] b, int offset, int len) throws IOException { if (len == 0) { return 0; } - stateChangeLock.lock(); - try { - return readInternal(b, offset, len); - } finally { - stateChangeLock.unlock(); - } - } - /** - * flip the active and read ahead buffer - */ - private void swapBuffers() { - ByteBuffer temp = activeBuffer; - activeBuffer = readAheadBuffer; - readAheadBuffer = temp; - } - - /** - * Internal read function which should be called only from read() api. The assumption is that - * the stateChangeLock is already acquired in the caller before calling this function. - */ - private int readInternal(byte[] b, int offset, int len) throws IOException { - assert (stateChangeLock.isLocked()); if (!activeBuffer.hasRemaining()) { - waitForAsyncReadComplete(); - if (readAheadBuffer.hasRemaining()) { - swapBuffers(); - } else { - // The first read or activeBuffer is skipped. - readAsync(); + // No remaining in active buffer - lock and switch to write ahead buffer. + stateChangeLock.lock(); + try { waitForAsyncReadComplete(); - if (isEndOfStream()) { - return -1; + if (!readAheadBuffer.hasRemaining()) { + // The first read. + readAsync(); + waitForAsyncReadComplete(); + if (isEndOfStream()) { + return -1; + } } + // Swap the newly read read ahead buffer in place of empty active buffer. swapBuffers(); + // After swapping buffers, trigger another async read for read ahead buffer. + readAsync(); + } finally { + stateChangeLock.unlock(); } - } else { - checkReadException(); } len = Math.min(len, activeBuffer.remaining()); activeBuffer.get(b, offset, len); - if (activeBuffer.remaining() <= readAheadThresholdInBytes && !readAheadBuffer.hasRemaining()) { - readAsync(); - } return len; } + /** + * flip the active and read ahead buffer + */ + private void swapBuffers() { + ByteBuffer temp = activeBuffer; + activeBuffer = readAheadBuffer; + readAheadBuffer = temp; + } + @Override public int available() throws IOException { stateChangeLock.lock(); @@ -323,6 +318,11 @@ public long skip(long n) throws IOException { if (n <= 0L) { return 0L; } + if (n <= activeBuffer.remaining()) { + // Only skipping from active buffer is sufficient + activeBuffer.position((int) n + activeBuffer.position()); + return n; + } stateChangeLock.lock(); long skipped; try { @@ -346,21 +346,14 @@ private long skipInternal(long n) throws IOException { if (available() >= n) { // we can skip from the internal buffers int toSkip = (int) n; - if (toSkip <= activeBuffer.remaining()) { - // Only skipping from active buffer is sufficient - activeBuffer.position(toSkip + activeBuffer.position()); - if (activeBuffer.remaining() <= readAheadThresholdInBytes - && !readAheadBuffer.hasRemaining()) { - readAsync(); - } - return n; - } // We need to skip from both active buffer and read ahead buffer toSkip -= activeBuffer.remaining(); + assert(toSkip > 0); // skipping from activeBuffer already handled. activeBuffer.position(0); activeBuffer.flip(); readAheadBuffer.position(toSkip + readAheadBuffer.position()); swapBuffers(); + // Trigger async read to emptied read ahead buffer. readAsync(); return n; } else { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 2c53c8d809d2e..fb179d07edebc 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -72,21 +72,15 @@ public UnsafeSorterSpillReader( bufferSizeBytes = DEFAULT_BUFFER_SIZE_BYTES; } - final double readAheadFraction = - SparkEnv.get() == null ? 0.5 : - SparkEnv.get().conf().getDouble("spark.unsafe.sorter.spill.read.ahead.fraction", 0.5); - - // SPARK-23310: Disable read-ahead input stream, because it is causing lock contention and perf - // regression for TPC-DS queries. final boolean readAheadEnabled = SparkEnv.get() != null && - SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", false); + SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", true); final InputStream bs = new NioBufferedFileInputStream(file, (int) bufferSizeBytes); try { if (readAheadEnabled) { this.in = new ReadAheadInputStream(serializerManager.wrapStream(blockId, bs), - (int) bufferSizeBytes, (int) (bufferSizeBytes * readAheadFraction)); + (int) bufferSizeBytes); } else { this.in = serializerManager.wrapStream(blockId, bs); } diff --git a/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java index 3440e1aea2f46..22db3592ecc96 100644 --- a/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java +++ b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java @@ -37,7 +37,7 @@ public abstract class GenericFileInputStreamSuite { protected File inputFile; - protected InputStream inputStream; + protected InputStream[] inputStreams; @Before public void setUp() throws IOException { @@ -54,77 +54,91 @@ public void tearDown() { @Test public void testReadOneByte() throws IOException { - for (int i = 0; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + for (int i = 0; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testReadMultipleBytes() throws IOException { - byte[] readBytes = new byte[8 * 1024]; - int i = 0; - while (i < randomBytes.length) { - int read = inputStream.read(readBytes, 0, 8 * 1024); - for (int j = 0; j < read; j++) { - assertEquals(randomBytes[i], readBytes[j]); - i++; + for (InputStream inputStream: inputStreams) { + byte[] readBytes = new byte[8 * 1024]; + int i = 0; + while (i < randomBytes.length) { + int read = inputStream.read(readBytes, 0, 8 * 1024); + for (int j = 0; j < read; j++) { + assertEquals(randomBytes[i], readBytes[j]); + i++; + } } } } @Test public void testBytesSkipped() throws IOException { - assertEquals(1024, inputStream.skip(1024)); - for (int i = 1024; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + assertEquals(1024, inputStream.skip(1024)); + for (int i = 1024; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testBytesSkippedAfterRead() throws IOException { - for (int i = 0; i < 1024; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); - } - assertEquals(1024, inputStream.skip(1024)); - for (int i = 2048; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + for (int i = 0; i < 1024; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + assertEquals(1024, inputStream.skip(1024)); + for (int i = 2048; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testNegativeBytesSkippedAfterRead() throws IOException { - for (int i = 0; i < 1024; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); - } - // Skipping negative bytes should essential be a no-op - assertEquals(0, inputStream.skip(-1)); - assertEquals(0, inputStream.skip(-1024)); - assertEquals(0, inputStream.skip(Long.MIN_VALUE)); - assertEquals(1024, inputStream.skip(1024)); - for (int i = 2048; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + for (int i = 0; i < 1024; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + // Skipping negative bytes should essential be a no-op + assertEquals(0, inputStream.skip(-1)); + assertEquals(0, inputStream.skip(-1024)); + assertEquals(0, inputStream.skip(Long.MIN_VALUE)); + assertEquals(1024, inputStream.skip(1024)); + for (int i = 2048; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testSkipFromFileChannel() throws IOException { - // Since the buffer is smaller than the skipped bytes, this will guarantee - // we skip from underlying file channel. - assertEquals(1024, inputStream.skip(1024)); - for (int i = 1024; i < 2048; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); - } - assertEquals(256, inputStream.skip(256)); - assertEquals(256, inputStream.skip(256)); - assertEquals(512, inputStream.skip(512)); - for (int i = 3072; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + // Since the buffer is smaller than the skipped bytes, this will guarantee + // we skip from underlying file channel. + assertEquals(1024, inputStream.skip(1024)); + for (int i = 1024; i < 2048; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + assertEquals(256, inputStream.skip(256)); + assertEquals(256, inputStream.skip(256)); + assertEquals(512, inputStream.skip(512)); + for (int i = 3072; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testBytesSkippedAfterEOF() throws IOException { - assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1)); - assertEquals(-1, inputStream.read()); + for (InputStream inputStream: inputStreams) { + assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1)); + assertEquals(-1, inputStream.read()); + } } } diff --git a/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java index 211b33a1a9fb0..a320f8662f707 100644 --- a/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java +++ b/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java @@ -18,6 +18,7 @@ import org.junit.Before; +import java.io.InputStream; import java.io.IOException; /** @@ -28,6 +29,9 @@ public class NioBufferedInputStreamSuite extends GenericFileInputStreamSuite { @Before public void setUp() throws IOException { super.setUp(); - inputStream = new NioBufferedFileInputStream(inputFile); + inputStreams = new InputStream[] { + new NioBufferedFileInputStream(inputFile), // default + new NioBufferedFileInputStream(inputFile, 123) // small, unaligned buffer + }; } } diff --git a/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java index 918ddc4517ec4..bfa1e0b908824 100644 --- a/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java +++ b/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java @@ -19,16 +19,27 @@ import org.junit.Before; import java.io.IOException; +import java.io.InputStream; /** - * Tests functionality of {@link NioBufferedFileInputStream} + * Tests functionality of {@link ReadAheadInputStreamSuite} */ public class ReadAheadInputStreamSuite extends GenericFileInputStreamSuite { @Before public void setUp() throws IOException { super.setUp(); - inputStream = new ReadAheadInputStream( - new NioBufferedFileInputStream(inputFile), 8 * 1024, 4 * 1024); + inputStreams = new InputStream[] { + // Tests equal and aligned buffers of wrapped an outer stream. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 8 * 1024), 8 * 1024), + // Tests aligned buffers, wrapped bigger than outer. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 3 * 1024), 2 * 1024), + // Tests aligned buffers, wrapped smaller than outer. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 2 * 1024), 3 * 1024), + // Tests unaligned buffers, wrapped bigger than outer. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 321), 123), + // Tests unaligned buffers, wrapped smaller than outer. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 123), 321) + }; } } From ed8647609883fcef16be5d24c2cb4ebda25bd6f0 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 15 Feb 2018 17:13:05 +0800 Subject: [PATCH 359/774] [SPARK-23359][SQL] Adds an alias 'names' of 'fieldNames' in Scala's StructType ## What changes were proposed in this pull request? This PR proposes to add an alias 'names' of 'fieldNames' in Scala. Please see the discussion in [SPARK-20090](https://issues.apache.org/jira/browse/SPARK-20090). ## How was this patch tested? Unit tests added in `DataTypeSuite.scala`. Author: hyukjinkwon Closes #20545 from HyukjinKwon/SPARK-23359. --- .../scala/org/apache/spark/sql/types/StructType.scala | 7 +++++++ .../scala/org/apache/spark/sql/types/DataTypeSuite.scala | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index e3b0969283a84..d5011c3cb87e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -104,6 +104,13 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru /** Returns all field names in an array. */ def fieldNames: Array[String] = fields.map(_.name) + /** + * Returns all field names in an array. This is an alias of `fieldNames`. + * + * @since 2.4.0 + */ + def names: Array[String] = fieldNames + private lazy val fieldNamesSet: Set[String] = fieldNames.toSet private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 8e2b32c2b9a08..5a86f4055dce7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -134,6 +134,14 @@ class DataTypeSuite extends SparkFunSuite { assert(mapped === expected) } + test("fieldNames and names returns field names") { + val struct = StructType( + StructField("a", LongType) :: StructField("b", FloatType) :: Nil) + + assert(struct.fieldNames === Seq("a", "b")) + assert(struct.names === Seq("a", "b")) + } + test("merge where right contains type conflict") { val left = StructType( StructField("a", LongType) :: From 44e20c42254bc6591b594f54cd94ced5fcfadae3 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Thu, 15 Feb 2018 03:52:40 -0800 Subject: [PATCH 360/774] =?UTF-8?q?[SPARK-23422][CORE]=20YarnShuffleIntegr?= =?UTF-8?q?ationSuite=20fix=20when=20SPARK=5FPREPEN=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …D_CLASSES set to 1 ## What changes were proposed in this pull request? YarnShuffleIntegrationSuite fails when SPARK_PREPEND_CLASSES set to 1. Normally mllib built before yarn module. When SPARK_PREPEND_CLASSES used mllib classes are on yarn test classpath. Before 2.3 that did not cause issues. But 2.3 has SPARK-22450, which registered some mllib classes with the kryo serializer. Now it dies with the following error: ` 18/02/13 07:33:29 INFO SparkContext: Starting job: collect at YarnShuffleIntegrationSuite.scala:143 Exception in thread "dag-scheduler-event-loop" java.lang.NoClassDefFoundError: breeze/linalg/DenseMatrix ` In this PR NoClassDefFoundError caught only in case of testing and then do nothing. ## How was this patch tested? Automated: Pass the Jenkins. Author: Gabor Somogyi Closes #20608 from gaborgsomogyi/SPARK-23422. --- .../main/scala/org/apache/spark/serializer/KryoSerializer.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 538ae05e4eea1..72427dd6ce4d4 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -206,6 +206,7 @@ class KryoSerializer(conf: SparkConf) kryo.register(clazz) } catch { case NonFatal(_) => // do nothing + case _: NoClassDefFoundError if Utils.isTesting => // See SPARK-23422. } } From f217d7d9b22c4b9c947fc5467379af17f036ee61 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 15 Feb 2018 07:47:40 -0800 Subject: [PATCH 361/774] [INFRA] Close stale PRs. Closes #20587 Closes #20586 From 2f0498d1e85a53b60da6a47d20bbdf56b42b7dcb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 15 Feb 2018 08:55:39 -0800 Subject: [PATCH 362/774] [SPARK-23426][SQL] Use `hive` ORC impl and disable PPD for Spark 2.3.0 ## What changes were proposed in this pull request? To prevent any regressions, this PR changes ORC implementation to `hive` by default like Spark 2.2.X. Users can enable `native` ORC. Also, ORC PPD is also restored to `false` like Spark 2.2.X. ![orc_section](https://user-images.githubusercontent.com/9700541/36221575-57a1d702-1173-11e8-89fe-dca5842f4ca7.png) ## How was this patch tested? Pass all test cases. Author: Dongjoon Hyun Closes #20610 from dongjoon-hyun/SPARK-ORC-DISABLE. --- docs/sql-programming-guide.md | 52 ++++++++----------- .../apache/spark/sql/internal/SQLConf.scala | 6 +-- .../spark/sql/FileBasedDataSourceSuite.scala | 17 +++++- .../sql/streaming/FileStreamSinkSuite.scala | 13 +++++ .../sql/streaming/FileStreamSourceSuite.scala | 13 +++++ 5 files changed, 68 insertions(+), 33 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index cf9529a79f4f9..91e43678481d6 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1004,6 +1004,29 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession +## ORC Files + +Since Spark 2.3, Spark supports a vectorized ORC reader with a new ORC file format for ORC files. +To do that, the following configurations are newly added. The vectorized reader is used for the +native ORC tables (e.g., the ones created using the clause `USING ORC`) when `spark.sql.orc.impl` +is set to `native` and `spark.sql.orc.enableVectorizedReader` is set to `true`. For the Hive ORC +serde tables (e.g., the ones created using the clause `USING HIVE OPTIONS (fileFormat 'ORC')`), +the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is also set to `true`. + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.sql.orc.implhiveThe name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4.1. `hive` means the ORC library in Hive 1.2.1.
    spark.sql.orc.enableVectorizedReadertrueEnables vectorized orc decoding in native implementation. If false, a new non-vectorized ORC reader is used in native implementation. For hive implementation, this is ignored.
    + ## JSON Datasets
    @@ -1776,35 +1799,6 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.2 to 2.3 - - Since Spark 2.3, Spark supports a vectorized ORC reader with a new ORC file format for ORC files. To do that, the following configurations are newly added or change their default values. The vectorized reader is used for the native ORC tables (e.g., the ones created using the clause `USING ORC`) when `spark.sql.orc.impl` is set to `native` and `spark.sql.orc.enableVectorizedReader` is set to `true`. For the Hive ORC serde table (e.g., the ones created using the clause `USING HIVE OPTIONS (fileFormat 'ORC')`), the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is set to `true`. - - - New configurations - - - - - - - - - - - - - -
    Property NameDefaultMeaning
    spark.sql.orc.implnativeThe name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4.1. `hive` means the ORC library in Hive 1.2.1 which is used prior to Spark 2.3.
    spark.sql.orc.enableVectorizedReadertrueEnables vectorized orc decoding in native implementation. If false, a new non-vectorized ORC reader is used in native implementation. For hive implementation, this is ignored.
    - - - Changed configurations - - - - - - - - -
    Property NameDefaultMeaning
    spark.sql.orc.filterPushdowntrueEnables filter pushdown for ORC files. It is false by default prior to Spark 2.3.
    - - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. - The `percentile_approx` function previously accepted numeric type input and output double type results. Now it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles. - Since Spark 2.3, the Join/Filter's deterministic predicates that are after the first non-deterministic predicates are also pushed down/through the child operators, if possible. In prior Spark versions, these filters are not eligible for predicate pushdown. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7835dbaa58439..f24fd7ff74d3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -399,11 +399,11 @@ object SQLConf { val ORC_IMPLEMENTATION = buildConf("spark.sql.orc.impl") .doc("When native, use the native version of ORC support instead of the ORC library in Hive " + - "1.2.1. It is 'hive' by default prior to Spark 2.3.") + "1.2.1. It is 'hive' by default.") .internal() .stringConf .checkValues(Set("hive", "native")) - .createWithDefault("native") + .createWithDefault("hive") val ORC_VECTORIZED_READER_ENABLED = buildConf("spark.sql.orc.enableVectorizedReader") .doc("Enables vectorized orc decoding.") @@ -426,7 +426,7 @@ object SQLConf { val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val HIVE_VERIFY_PARTITION_PATH = buildConf("spark.sql.hive.verifyPartitionPath") .doc("When true, check all the partition paths under the table\'s root directory " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 2e332362ea644..b5d4c558f0d3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -20,14 +20,29 @@ package org.apache.spark.sql import java.io.FileNotFoundException import org.apache.hadoop.fs.Path +import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { + +class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with BeforeAndAfterAll { import testImplicits._ + override def beforeAll(): Unit = { + super.beforeAll() + spark.sessionState.conf.setConf(SQLConf.ORC_IMPLEMENTATION, "native") + } + + override def afterAll(): Unit = { + try { + spark.sessionState.conf.unsetConf(SQLConf.ORC_IMPLEMENTATION) + } finally { + super.afterAll() + } + } + private val allFileBasedDataSources = Seq("orc", "parquet", "csv", "json", "text") private val nameWithSpecialChars = "sp&cial%c hars" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 8c4e1fd00b0a2..ba48bc1ce0c4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -33,6 +33,19 @@ import org.apache.spark.util.Utils class FileStreamSinkSuite extends StreamTest { import testImplicits._ + override def beforeAll(): Unit = { + super.beforeAll() + spark.sessionState.conf.setConf(SQLConf.ORC_IMPLEMENTATION, "native") + } + + override def afterAll(): Unit = { + try { + spark.sessionState.conf.unsetConf(SQLConf.ORC_IMPLEMENTATION) + } finally { + super.afterAll() + } + } + test("unpartitioned writing and batch reading") { val inputData = MemoryStream[Int] val df = inputData.toDF() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 5bb0f4d643bbe..d4bd9c7987f2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -207,6 +207,19 @@ class FileStreamSourceSuite extends FileStreamSourceTest { .collect { case s @ StreamingRelation(dataSource, _, _) => s.schema }.head } + override def beforeAll(): Unit = { + super.beforeAll() + spark.sessionState.conf.setConf(SQLConf.ORC_IMPLEMENTATION, "native") + } + + override def afterAll(): Unit = { + try { + spark.sessionState.conf.unsetConf(SQLConf.ORC_IMPLEMENTATION) + } finally { + super.afterAll() + } + } + // ============= Basic parameter exists tests ================ test("FileStreamSource schema: no path") { From 6968c3cfd70961c4e86daffd6a156d0a9c1d7a2a Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 15 Feb 2018 09:40:08 -0800 Subject: [PATCH 363/774] [MINOR][SQL] Fix an error message about inserting into bucketed tables ## What changes were proposed in this pull request? This replaces `Sparkcurrently` to `Spark currently` in the following error message. ```scala scala> sql("insert into t2 select * from v1") org.apache.spark.sql.AnalysisException: Output Hive table `default`.`t2` is bucketed but Sparkcurrently does NOT populate bucketed ... ``` ## How was this patch tested? Manual. Author: Dongjoon Hyun Closes #20617 from dongjoon-hyun/SPARK-ERROR-MSG. --- .../apache/spark/sql/hive/execution/InsertIntoHiveTable.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 3ce5b8469d6fc..02a60f16b3b3a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -172,7 +172,7 @@ case class InsertIntoHiveTable( val enforceBucketingConfig = "hive.enforce.bucketing" val enforceSortingConfig = "hive.enforce.sorting" - val message = s"Output Hive table ${table.identifier} is bucketed but Spark" + + val message = s"Output Hive table ${table.identifier} is bucketed but Spark " + "currently does NOT populate bucketed output which is compatible with Hive." if (hadoopConf.get(enforceBucketingConfig, "true").toBoolean || From db45daab90ede4c03c1abc9096f4eac584e9db17 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 15 Feb 2018 09:54:39 -0800 Subject: [PATCH 364/774] [SPARK-23377][ML] Fixes Bucketizer with multiple columns persistence bug ## What changes were proposed in this pull request? #### Problem: Since 2.3, `Bucketizer` supports multiple input/output columns. We will check if exclusive params are set during transformation. E.g., if `inputCols` and `outputCol` are both set, an error will be thrown. However, when we write `Bucketizer`, looks like the default params and user-supplied params are merged during writing. All saved params are loaded back and set to created model instance. So the default `outputCol` param in `HasOutputCol` trait will be set in `paramMap` and become an user-supplied param. That makes the check of exclusive params failed. #### Fix: This changes the saving logic of Bucketizer to handle this case. This is a quick fix to catch the time of 2.3. We should consider modify the persistence mechanism later. Please see the discussion in the JIRA. Note: The multi-column `QuantileDiscretizer` also has the same issue. ## How was this patch tested? Modified tests. Author: Liang-Chi Hsieh Closes #20594 from viirya/SPARK-23377-2. --- .../apache/spark/ml/feature/Bucketizer.scala | 28 +++++++++++++++++++ .../ml/feature/QuantileDiscretizer.scala | 28 +++++++++++++++++++ .../spark/ml/feature/BucketizerSuite.scala | 12 ++++++-- .../ml/feature/QuantileDiscretizerSuite.scala | 14 ++++++++-- 4 files changed, 78 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index c13bf47eacb94..f49c410cbcfe2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -19,6 +19,10 @@ package org.apache.spark.ml.feature import java.{util => ju} +import org.json4s.JsonDSL._ +import org.json4s.JValue +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.Model @@ -213,6 +217,8 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String override def copy(extra: ParamMap): Bucketizer = { defaultCopy[Bucketizer](extra).setParent(parent) } + + override def write: MLWriter = new Bucketizer.BucketizerWriter(this) } @Since("1.6.0") @@ -290,6 +296,28 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { } } + + private[Bucketizer] class BucketizerWriter(instance: Bucketizer) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + // SPARK-23377: The default params will be saved and loaded as user-supplied params. + // Once `inputCols` is set, the default value of `outputCol` param causes the error + // when checking exclusive params. As a temporary to fix it, we skip the default value + // of `outputCol` if `inputCols` is set when saving the metadata. + // TODO: If we modify the persistence mechanism later to better handle default params, + // we can get rid of this. + var paramWithoutOutputCol: Option[JValue] = None + if (instance.isSet(instance.inputCols)) { + val params = instance.extractParamMap().toSeq + val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList + paramWithoutOutputCol = Some(render(jsonParams)) + } + DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol) + } + } + @Since("1.6.0") override def load(path: String): Bucketizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 1ec5f8cb6139b..3b4c25478fb1d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -17,6 +17,10 @@ package org.apache.spark.ml.feature +import org.json4s.JsonDSL._ +import org.json4s.JValue +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml._ @@ -249,11 +253,35 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("1.6.0") override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra) + + override def write: MLWriter = new QuantileDiscretizer.QuantileDiscretizerWriter(this) } @Since("1.6.0") object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { + private[QuantileDiscretizer] + class QuantileDiscretizerWriter(instance: QuantileDiscretizer) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + // SPARK-23377: The default params will be saved and loaded as user-supplied params. + // Once `inputCols` is set, the default value of `outputCol` param causes the error + // when checking exclusive params. As a temporary to fix it, we skip the default value + // of `outputCol` if `inputCols` is set when saving the metadata. + // TODO: If we modify the persistence mechanism later to better handle default params, + // we can get rid of this. + var paramWithoutOutputCol: Option[JValue] = None + if (instance.isSet(instance.inputCols)) { + val params = instance.extractParamMap().toSeq + val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList + paramWithoutOutputCol = Some(render(jsonParams)) + } + DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol) + } + } + @Since("1.6.0") override def load(path: String): QuantileDiscretizer = super.load(path) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 7403680ae3fdc..41cf72fe3470a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -172,7 +172,10 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setInputCol("myInputCol") .setOutputCol("myOutputCol") .setSplits(Array(0.1, 0.8, 0.9)) - testDefaultReadWrite(t) + + val bucketizer = testDefaultReadWrite(t) + val data = Seq((1.0, 2.0), (10.0, 100.0), (101.0, -1.0)).toDF("myInputCol", "myInputCol2") + bucketizer.transform(data) } test("Bucket numeric features") { @@ -327,7 +330,12 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setInputCols(Array("myInputCol")) .setOutputCols(Array("myOutputCol")) .setSplitsArray(Array(Array(0.1, 0.8, 0.9))) - testDefaultReadWrite(t) + + val bucketizer = testDefaultReadWrite(t) + val data = Seq((1.0, 2.0), (10.0, 100.0), (101.0, -1.0)).toDF("myInputCol", "myInputCol2") + bucketizer.transform(data) + assert(t.hasDefault(t.outputCol)) + assert(bucketizer.hasDefault(bucketizer.outputCol)) } test("Bucketizer in a pipeline") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index e9a75e931e6a8..6c363799dd300 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.functions.udf class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("Test observed number of buckets and their sizes match expected values") { val spark = this.spark import spark.implicits._ @@ -132,7 +134,10 @@ class QuantileDiscretizerSuite .setInputCol("myInputCol") .setOutputCol("myOutputCol") .setNumBuckets(6) - testDefaultReadWrite(t) + + val readDiscretizer = testDefaultReadWrite(t) + val data = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("myInputCol") + readDiscretizer.fit(data) } test("Verify resulting model has parent") { @@ -379,7 +384,12 @@ class QuantileDiscretizerSuite .setInputCols(Array("input1", "input2")) .setOutputCols(Array("result1", "result2")) .setNumBucketsArray(Array(5, 10)) - testDefaultReadWrite(discretizer) + + val readDiscretizer = testDefaultReadWrite(discretizer) + val data = Seq((1.0, 2.0), (2.0, 3.0), (3.0, 4.0)).toDF("input1", "input2") + readDiscretizer.fit(data) + assert(discretizer.hasDefault(discretizer.outputCol)) + assert(readDiscretizer.hasDefault(readDiscretizer.outputCol)) } test("Multiple Columns: Both inputCol and inputCols are set") { From 1dc2c1d5e85c5f404f470aeb44c1f3c22786bdea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Thu, 15 Feb 2018 13:51:24 -0600 Subject: [PATCH 365/774] [SPARK-23413][UI] Fix sorting tasks by Host / Executor ID at the Stage page MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Fixing exception got at sorting tasks by Host / Executor ID: ``` java.lang.IllegalArgumentException: Invalid sort column: Host at org.apache.spark.ui.jobs.ApiHelper$.indexName(StagePage.scala:1017) at org.apache.spark.ui.jobs.TaskDataSource.sliceData(StagePage.scala:694) at org.apache.spark.ui.PagedDataSource.pageData(PagedTable.scala:61) at org.apache.spark.ui.PagedTable$class.table(PagedTable.scala:96) at org.apache.spark.ui.jobs.TaskPagedTable.table(StagePage.scala:708) at org.apache.spark.ui.jobs.StagePage.liftedTree1$1(StagePage.scala:293) at org.apache.spark.ui.jobs.StagePage.render(StagePage.scala:282) at org.apache.spark.ui.WebUI$$anonfun$2.apply(WebUI.scala:82) at org.apache.spark.ui.WebUI$$anonfun$2.apply(WebUI.scala:82) at org.apache.spark.ui.JettyUtils$$anon$3.doGet(JettyUtils.scala:90) at javax.servlet.http.HttpServlet.service(HttpServlet.java:687) at javax.servlet.http.HttpServlet.service(HttpServlet.java:790) at org.spark_project.jetty.servlet.ServletHolder.handle(ServletHolder.java:848) at org.spark_project.jetty.servlet.ServletHandler.doHandle(ServletHandler.java:584) ``` Moreover some refactoring to avoid similar problems by introducing constants for each header name and reusing them at the identification of the corresponding sorting index. ## How was this patch tested? Manually: ![screen shot 2018-02-13 at 18 57 10](https://user-images.githubusercontent.com/2017933/36166532-1cfdf3b8-10f3-11e8-8d32-5fcaad2af214.png) Author: “attilapiros” Closes #20601 from attilapiros/SPARK-23413. --- .../org/apache/spark/status/storeTypes.scala | 2 + .../org/apache/spark/ui/jobs/StagePage.scala | 121 +++++++++++------- .../org/apache/spark/ui/StagePageSuite.scala | 63 ++++++++- 3 files changed, 139 insertions(+), 47 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index 412644d3657b5..646cf25880e37 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -109,6 +109,7 @@ private[spark] object TaskIndexNames { final val DURATION = "dur" final val ERROR = "err" final val EXECUTOR = "exe" + final val HOST = "hst" final val EXEC_CPU_TIME = "ect" final val EXEC_RUN_TIME = "ert" final val GC_TIME = "gc" @@ -165,6 +166,7 @@ private[spark] class TaskDataWrapper( val duration: Long, @KVIndexParam(value = TaskIndexNames.EXECUTOR, parent = TaskIndexNames.STAGE) val executorId: String, + @KVIndexParam(value = TaskIndexNames.HOST, parent = TaskIndexNames.STAGE) val host: String, @KVIndexParam(value = TaskIndexNames.STATUS, parent = TaskIndexNames.STAGE) val status: String, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 5c2b0c3a19996..a9265d4dbcdfb 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -750,37 +750,39 @@ private[ui] class TaskPagedTable( } def headers: Seq[Node] = { + import ApiHelper._ + val taskHeadersAndCssClasses: Seq[(String, String)] = Seq( - ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), - ("Executor ID", ""), ("Host", ""), ("Launch Time", ""), ("Duration", ""), - ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), - ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), - ("GC Time", ""), - ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), - ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME), - ("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) ++ - {if (hasAccumulators(stage)) Seq(("Accumulators", "")) else Nil} ++ - {if (hasInput(stage)) Seq(("Input Size / Records", "")) else Nil} ++ - {if (hasOutput(stage)) Seq(("Output Size / Records", "")) else Nil} ++ + (HEADER_TASK_INDEX, ""), (HEADER_ID, ""), (HEADER_ATTEMPT, ""), (HEADER_STATUS, ""), + (HEADER_LOCALITY, ""), (HEADER_EXECUTOR, ""), (HEADER_HOST, ""), (HEADER_LAUNCH_TIME, ""), + (HEADER_DURATION, ""), (HEADER_SCHEDULER_DELAY, TaskDetailsClassNames.SCHEDULER_DELAY), + (HEADER_DESER_TIME, TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), + (HEADER_GC_TIME, ""), + (HEADER_SER_TIME, TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), + (HEADER_GETTING_RESULT_TIME, TaskDetailsClassNames.GETTING_RESULT_TIME), + (HEADER_PEAK_MEM, TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) ++ + {if (hasAccumulators(stage)) Seq((HEADER_ACCUMULATORS, "")) else Nil} ++ + {if (hasInput(stage)) Seq((HEADER_INPUT_SIZE, "")) else Nil} ++ + {if (hasOutput(stage)) Seq((HEADER_OUTPUT_SIZE, "")) else Nil} ++ {if (hasShuffleRead(stage)) { - Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), - ("Shuffle Read Size / Records", ""), - ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) + Seq((HEADER_SHUFFLE_READ_TIME, TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), + (HEADER_SHUFFLE_TOTAL_READS, ""), + (HEADER_SHUFFLE_REMOTE_READS, TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) } else { Nil }} ++ {if (hasShuffleWrite(stage)) { - Seq(("Write Time", ""), ("Shuffle Write Size / Records", "")) + Seq((HEADER_SHUFFLE_WRITE_TIME, ""), (HEADER_SHUFFLE_WRITE_SIZE, "")) } else { Nil }} ++ {if (hasBytesSpilled(stage)) { - Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) + Seq((HEADER_MEM_SPILL, ""), (HEADER_DISK_SPILL, "")) } else { Nil }} ++ - Seq(("Errors", "")) + Seq((HEADER_ERROR, "")) if (!taskHeadersAndCssClasses.map(_._1).contains(sortColumn)) { throw new IllegalArgumentException(s"Unknown column: $sortColumn") @@ -961,35 +963,62 @@ private[ui] class TaskPagedTable( } } -private object ApiHelper { - - - private val COLUMN_TO_INDEX = Map( - "ID" -> null.asInstanceOf[String], - "Index" -> TaskIndexNames.TASK_INDEX, - "Attempt" -> TaskIndexNames.ATTEMPT, - "Status" -> TaskIndexNames.STATUS, - "Locality Level" -> TaskIndexNames.LOCALITY, - "Executor ID / Host" -> TaskIndexNames.EXECUTOR, - "Launch Time" -> TaskIndexNames.LAUNCH_TIME, - "Duration" -> TaskIndexNames.DURATION, - "Scheduler Delay" -> TaskIndexNames.SCHEDULER_DELAY, - "Task Deserialization Time" -> TaskIndexNames.DESER_TIME, - "GC Time" -> TaskIndexNames.GC_TIME, - "Result Serialization Time" -> TaskIndexNames.SER_TIME, - "Getting Result Time" -> TaskIndexNames.GETTING_RESULT_TIME, - "Peak Execution Memory" -> TaskIndexNames.PEAK_MEM, - "Accumulators" -> TaskIndexNames.ACCUMULATORS, - "Input Size / Records" -> TaskIndexNames.INPUT_SIZE, - "Output Size / Records" -> TaskIndexNames.OUTPUT_SIZE, - "Shuffle Read Blocked Time" -> TaskIndexNames.SHUFFLE_READ_TIME, - "Shuffle Read Size / Records" -> TaskIndexNames.SHUFFLE_TOTAL_READS, - "Shuffle Remote Reads" -> TaskIndexNames.SHUFFLE_REMOTE_READS, - "Write Time" -> TaskIndexNames.SHUFFLE_WRITE_TIME, - "Shuffle Write Size / Records" -> TaskIndexNames.SHUFFLE_WRITE_SIZE, - "Shuffle Spill (Memory)" -> TaskIndexNames.MEM_SPILL, - "Shuffle Spill (Disk)" -> TaskIndexNames.DISK_SPILL, - "Errors" -> TaskIndexNames.ERROR) +private[ui] object ApiHelper { + + val HEADER_ID = "ID" + val HEADER_TASK_INDEX = "Index" + val HEADER_ATTEMPT = "Attempt" + val HEADER_STATUS = "Status" + val HEADER_LOCALITY = "Locality Level" + val HEADER_EXECUTOR = "Executor ID" + val HEADER_HOST = "Host" + val HEADER_LAUNCH_TIME = "Launch Time" + val HEADER_DURATION = "Duration" + val HEADER_SCHEDULER_DELAY = "Scheduler Delay" + val HEADER_DESER_TIME = "Task Deserialization Time" + val HEADER_GC_TIME = "GC Time" + val HEADER_SER_TIME = "Result Serialization Time" + val HEADER_GETTING_RESULT_TIME = "Getting Result Time" + val HEADER_PEAK_MEM = "Peak Execution Memory" + val HEADER_ACCUMULATORS = "Accumulators" + val HEADER_INPUT_SIZE = "Input Size / Records" + val HEADER_OUTPUT_SIZE = "Output Size / Records" + val HEADER_SHUFFLE_READ_TIME = "Shuffle Read Blocked Time" + val HEADER_SHUFFLE_TOTAL_READS = "Shuffle Read Size / Records" + val HEADER_SHUFFLE_REMOTE_READS = "Shuffle Remote Reads" + val HEADER_SHUFFLE_WRITE_TIME = "Write Time" + val HEADER_SHUFFLE_WRITE_SIZE = "Shuffle Write Size / Records" + val HEADER_MEM_SPILL = "Shuffle Spill (Memory)" + val HEADER_DISK_SPILL = "Shuffle Spill (Disk)" + val HEADER_ERROR = "Errors" + + private[ui] val COLUMN_TO_INDEX = Map( + HEADER_ID -> null.asInstanceOf[String], + HEADER_TASK_INDEX -> TaskIndexNames.TASK_INDEX, + HEADER_ATTEMPT -> TaskIndexNames.ATTEMPT, + HEADER_STATUS -> TaskIndexNames.STATUS, + HEADER_LOCALITY -> TaskIndexNames.LOCALITY, + HEADER_EXECUTOR -> TaskIndexNames.EXECUTOR, + HEADER_HOST -> TaskIndexNames.HOST, + HEADER_LAUNCH_TIME -> TaskIndexNames.LAUNCH_TIME, + HEADER_DURATION -> TaskIndexNames.DURATION, + HEADER_SCHEDULER_DELAY -> TaskIndexNames.SCHEDULER_DELAY, + HEADER_DESER_TIME -> TaskIndexNames.DESER_TIME, + HEADER_GC_TIME -> TaskIndexNames.GC_TIME, + HEADER_SER_TIME -> TaskIndexNames.SER_TIME, + HEADER_GETTING_RESULT_TIME -> TaskIndexNames.GETTING_RESULT_TIME, + HEADER_PEAK_MEM -> TaskIndexNames.PEAK_MEM, + HEADER_ACCUMULATORS -> TaskIndexNames.ACCUMULATORS, + HEADER_INPUT_SIZE -> TaskIndexNames.INPUT_SIZE, + HEADER_OUTPUT_SIZE -> TaskIndexNames.OUTPUT_SIZE, + HEADER_SHUFFLE_READ_TIME -> TaskIndexNames.SHUFFLE_READ_TIME, + HEADER_SHUFFLE_TOTAL_READS -> TaskIndexNames.SHUFFLE_TOTAL_READS, + HEADER_SHUFFLE_REMOTE_READS -> TaskIndexNames.SHUFFLE_REMOTE_READS, + HEADER_SHUFFLE_WRITE_TIME -> TaskIndexNames.SHUFFLE_WRITE_TIME, + HEADER_SHUFFLE_WRITE_SIZE -> TaskIndexNames.SHUFFLE_WRITE_SIZE, + HEADER_MEM_SPILL -> TaskIndexNames.MEM_SPILL, + HEADER_DISK_SPILL -> TaskIndexNames.DISK_SPILL, + HEADER_ERROR -> TaskIndexNames.ERROR) def hasAccumulators(stageData: StageData): Boolean = { stageData.accumulatorUpdates.exists { acc => acc.name != null && acc.value != null } diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 0aeddf730cd35..6044563f7dde7 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -28,13 +28,74 @@ import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1.{AccumulableInfo => UIAccumulableInfo, StageData, StageStatus} import org.apache.spark.status.config._ -import org.apache.spark.ui.jobs.{StagePage, StagesTab} +import org.apache.spark.ui.jobs.{ApiHelper, StagePage, StagesTab, TaskPagedTable} class StagePageSuite extends SparkFunSuite with LocalSparkContext { private val peakExecutionMemory = 10 + test("ApiHelper.COLUMN_TO_INDEX should match headers of the task table") { + val conf = new SparkConf(false).set(LIVE_ENTITY_UPDATE_PERIOD, 0L) + val statusStore = AppStatusStore.createLiveStore(conf) + try { + val stageData = new StageData( + status = StageStatus.ACTIVE, + stageId = 1, + attemptId = 1, + numTasks = 1, + numActiveTasks = 1, + numCompleteTasks = 1, + numFailedTasks = 1, + numKilledTasks = 1, + numCompletedIndices = 1, + + executorRunTime = 1L, + executorCpuTime = 1L, + submissionTime = None, + firstTaskLaunchedTime = None, + completionTime = None, + failureReason = None, + + inputBytes = 1L, + inputRecords = 1L, + outputBytes = 1L, + outputRecords = 1L, + shuffleReadBytes = 1L, + shuffleReadRecords = 1L, + shuffleWriteBytes = 1L, + shuffleWriteRecords = 1L, + memoryBytesSpilled = 1L, + diskBytesSpilled = 1L, + + name = "stage1", + description = Some("description"), + details = "detail", + schedulingPool = "pool1", + + rddIds = Seq(1), + accumulatorUpdates = Seq(new UIAccumulableInfo(0L, "acc", None, "value")), + tasks = None, + executorSummary = None, + killedTasksSummary = Map.empty + ) + val taskTable = new TaskPagedTable( + stageData, + basePath = "/a/b/c", + currentTime = 0, + pageSize = 10, + sortColumn = "Index", + desc = false, + store = statusStore + ) + val columnNames = (taskTable.headers \ "th" \ "a").map(_.child(1).text).toSet + assert(columnNames === ApiHelper.COLUMN_TO_INDEX.keySet) + } finally { + statusStore.close() + } + } + test("peak execution memory should displayed") { val html = renderStagePage().toString().toLowerCase(Locale.ROOT) val targetString = "peak execution memory" From c5857e496ff0d170ed0339f14afc7d36b192da6d Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 16 Feb 2018 09:41:17 -0800 Subject: [PATCH 366/774] [SPARK-23446][PYTHON] Explicitly check supported types in toPandas ## What changes were proposed in this pull request? This PR explicitly specifies and checks the types we supported in `toPandas`. This was a hole. For example, we haven't finished the binary type support in Python side yet but now it allows as below: ```python spark.conf.set("spark.sql.execution.arrow.enabled", "false") df = spark.createDataFrame([[bytearray("a")]]) df.toPandas() spark.conf.set("spark.sql.execution.arrow.enabled", "true") df.toPandas() ``` ``` _1 0 [97] _1 0 a ``` This should be disallowed. I think the same things also apply to nested timestamps too. I also added some nicer message about `spark.sql.execution.arrow.enabled` in the error message. ## How was this patch tested? Manually tested and tests added in `python/pyspark/sql/tests.py`. Author: hyukjinkwon Closes #20625 from HyukjinKwon/pandas_convertion_supported_type. --- python/pyspark/sql/dataframe.py | 15 +++++++++------ python/pyspark/sql/tests.py | 9 ++++++++- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5cc8b63cdfadf..f37777e13ee12 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1988,10 +1988,11 @@ def toPandas(self): if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: from pyspark.sql.types import _check_dataframe_convert_date, \ - _check_dataframe_localize_timestamps + _check_dataframe_localize_timestamps, to_arrow_schema from pyspark.sql.utils import require_minimum_pyarrow_version - import pyarrow require_minimum_pyarrow_version() + import pyarrow + to_arrow_schema(self.schema) tables = self._collectAsArrow() if tables: table = pyarrow.concat_tables(tables) @@ -2000,10 +2001,12 @@ def toPandas(self): return _check_dataframe_localize_timestamps(pdf, timezone) else: return pd.DataFrame.from_records([], columns=self.columns) - except ImportError as e: - msg = "note: pyarrow must be installed and available on calling Python process " \ - "if using spark.sql.execution.arrow.enabled=true" - raise ImportError("%s\n%s" % (_exception_message(e), msg)) + except Exception as e: + msg = ( + "Note: toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true. Please set it to false " + "to disable this.") + raise RuntimeError("%s\n%s" % (_exception_message(e), msg)) else: pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2af218a691026..19653072ea316 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3497,7 +3497,14 @@ def test_unsupported_datatype(self): schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported data type'): + with self.assertRaisesRegexp(Exception, 'Unsupported type'): + df.toPandas() + + df = self.spark.createDataFrame([(None,)], schema="a binary") + with QuietTest(self.sc): + with self.assertRaisesRegexp( + Exception, + 'Unsupported type.*\nNote: toPandas attempted Arrow optimization because'): df.toPandas() def test_null_conversion(self): From 0a73aa31f41c83503d5d99eff3c9d7b406014ab3 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 16 Feb 2018 14:30:19 -0800 Subject: [PATCH 367/774] [SPARK-23362][SS] Migrate Kafka Microbatch source to v2 ## What changes were proposed in this pull request? Migrating KafkaSource (with data source v1) to KafkaMicroBatchReader (with data source v2). Performance comparison: In a unit test with in-process Kafka broker, I tested the read throughput of V1 and V2 using 20M records in a single partition. They were comparable. ## How was this patch tested? Existing tests, few modified to be better tests than the existing ones. Author: Tathagata Das Closes #20554 from tdas/SPARK-23362. --- dev/.rat-excludes | 1 + .../sql/kafka010/CachedKafkaConsumer.scala | 2 +- .../sql/kafka010/KafkaContinuousReader.scala | 29 +- .../sql/kafka010/KafkaMicroBatchReader.scala | 403 ++++++++++++++++++ .../KafkaRecordToUnsafeRowConverter.scala | 52 +++ .../spark/sql/kafka010/KafkaSource.scala | 19 +- .../sql/kafka010/KafkaSourceProvider.scala | 70 ++- ...a-source-initial-offset-future-version.bin | 2 + ...ka-source-initial-offset-version-2.1.0.bin | 2 +- ...scala => KafkaMicroBatchSourceSuite.scala} | 254 +++++++---- .../apache/spark/sql/internal/SQLConf.scala | 15 +- .../streaming/MicroBatchExecution.scala | 20 +- 12 files changed, 741 insertions(+), 128 deletions(-) create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala create mode 100644 external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-future-version.bin rename external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/{KafkaSourceSuite.scala => KafkaMicroBatchSourceSuite.scala} (85%) diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 243fbe3e1bc24..9552d001a079c 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -105,3 +105,4 @@ META-INF/* spark-warehouse structured-streaming/* kafka-source-initial-offset-version-2.1.0.bin +kafka-source-initial-offset-future-version.bin diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala index 90ed7b1fba2f8..e97881cb0a163 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala @@ -27,7 +27,7 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.kafka010.KafkaSource._ +import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.util.UninterruptibleThread diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index b049a054cb40e..97a0f66e1880d 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType @@ -187,13 +187,9 @@ class KafkaContinuousDataReader( kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { - private val topic = topicPartition.topic - private val kafkaPartition = topicPartition.partition - private val consumer = CachedKafkaConsumer.createUncached(topic, kafkaPartition, kafkaParams) - - private val sharedRow = new UnsafeRow(7) - private val bufferHolder = new BufferHolder(sharedRow) - private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) + private val consumer = + CachedKafkaConsumer.createUncached(topicPartition.topic, topicPartition.partition, kafkaParams) + private val converter = new KafkaRecordToUnsafeRowConverter private var nextKafkaOffset = startOffset private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _ @@ -232,22 +228,7 @@ class KafkaContinuousDataReader( } override def get(): UnsafeRow = { - bufferHolder.reset() - - if (currentRecord.key == null) { - rowWriter.setNullAt(0) - } else { - rowWriter.write(0, currentRecord.key) - } - rowWriter.write(1, currentRecord.value) - rowWriter.write(2, UTF8String.fromString(currentRecord.topic)) - rowWriter.write(3, currentRecord.partition) - rowWriter.write(4, currentRecord.offset) - rowWriter.write(5, - DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(currentRecord.timestamp))) - rowWriter.write(6, currentRecord.timestampType.id) - sharedRow.setTotalSize(bufferHolder.totalSize) - sharedRow + converter.toUnsafeRow(currentRecord) } override def getOffset(): KafkaSourcePartitionOffset = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala new file mode 100644 index 0000000000000..fb647ca7e70dd --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -0,0 +1,403 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} +import java.io._ +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.IOUtils +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.UninterruptibleThread + +/** + * A [[MicroBatchReader]] that reads data from Kafka. + * + * The [[KafkaSourceOffset]] is the custom [[Offset]] defined for this source that contains + * a map of TopicPartition -> offset. Note that this offset is 1 + (available offset). For + * example if the last record in a Kafka topic "t", partition 2 is offset 5, then + * KafkaSourceOffset will contain TopicPartition("t", 2) -> 6. This is done keep it consistent + * with the semantics of `KafkaConsumer.position()`. + * + * Zero data lost is not guaranteed when topics are deleted. If zero data lost is critical, the user + * must make sure all messages in a topic have been processed when deleting a topic. + * + * There is a known issue caused by KAFKA-1894: the query using Kafka maybe cannot be stopped. + * To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers + * and not use wrong broker addresses. + */ +private[kafka010] class KafkaMicroBatchReader( + kafkaOffsetReader: KafkaOffsetReader, + executorKafkaParams: ju.Map[String, Object], + options: DataSourceOptions, + metadataPath: String, + startingOffsets: KafkaOffsetRangeLimit, + failOnDataLoss: Boolean) + extends MicroBatchReader with SupportsScanUnsafeRow with Logging { + + type PartitionOffsetMap = Map[TopicPartition, Long] + + private var startPartitionOffsets: PartitionOffsetMap = _ + private var endPartitionOffsets: PartitionOffsetMap = _ + + private val pollTimeoutMs = options.getLong( + "kafkaConsumer.pollTimeoutMs", + SparkEnv.get.conf.getTimeAsMs("spark.network.timeout", "120s")) + + private val maxOffsetsPerTrigger = + Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong) + + /** + * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only + * called in StreamExecutionThread. Otherwise, interrupting a thread while running + * `KafkaConsumer.poll` may hang forever (KAFKA-1894). + */ + private lazy val initialPartitionOffsets = getOrCreateInitialPartitionOffsets() + + override def setOffsetRange(start: ju.Optional[Offset], end: ju.Optional[Offset]): Unit = { + // Make sure initialPartitionOffsets is initialized + initialPartitionOffsets + + startPartitionOffsets = Option(start.orElse(null)) + .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets) + .getOrElse(initialPartitionOffsets) + + endPartitionOffsets = Option(end.orElse(null)) + .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets) + .getOrElse { + val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets() + maxOffsetsPerTrigger.map { maxOffsets => + rateLimit(maxOffsets, startPartitionOffsets, latestPartitionOffsets) + }.getOrElse { + latestPartitionOffsets + } + } + } + + override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + // Find the new partitions, and get their earliest offsets + val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) + val newPartitionOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) + if (newPartitionOffsets.keySet != newPartitions) { + // We cannot get from offsets for some partitions. It means they got deleted. + val deletedPartitions = newPartitions.diff(newPartitionOffsets.keySet) + reportDataLoss( + s"Cannot find earliest offsets of ${deletedPartitions}. Some data may have been missed") + } + logInfo(s"Partitions added: $newPartitionOffsets") + newPartitionOffsets.filter(_._2 != 0).foreach { case (p, o) => + reportDataLoss( + s"Added partition $p starts from $o instead of 0. Some data may have been missed") + } + + // Find deleted partitions, and report data loss if required + val deletedPartitions = startPartitionOffsets.keySet.diff(endPartitionOffsets.keySet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"$deletedPartitions are gone. Some data may have been missed") + } + + // Use the until partitions to calculate offset ranges to ignore partitions that have + // been deleted + val topicPartitions = endPartitionOffsets.keySet.filter { tp => + // Ignore partitions that we don't know the from offsets. + newPartitionOffsets.contains(tp) || startPartitionOffsets.contains(tp) + }.toSeq + logDebug("TopicPartitions: " + topicPartitions.mkString(", ")) + + val sortedExecutors = getSortedExecutorList() + val numExecutors = sortedExecutors.length + logDebug("Sorted executors: " + sortedExecutors.mkString(", ")) + + // Calculate offset ranges + val factories = topicPartitions.flatMap { tp => + val fromOffset = startPartitionOffsets.get(tp).getOrElse { + newPartitionOffsets.getOrElse( + tp, { + // This should not happen since newPartitionOffsets contains all partitions not in + // fromPartitionOffsets + throw new IllegalStateException(s"$tp doesn't have a from offset") + }) + } + val untilOffset = endPartitionOffsets(tp) + + if (untilOffset >= fromOffset) { + // This allows cached KafkaConsumers in the executors to be re-used to read the same + // partition in every batch. + val preferredLoc = if (numExecutors > 0) { + Some(sortedExecutors(Math.floorMod(tp.hashCode, numExecutors))) + } else None + val range = KafkaOffsetRange(tp, fromOffset, untilOffset) + Some( + new KafkaMicroBatchDataReaderFactory( + range, preferredLoc, executorKafkaParams, pollTimeoutMs, failOnDataLoss)) + } else { + reportDataLoss( + s"Partition $tp's offset was changed from " + + s"$fromOffset to $untilOffset, some data may have been missed") + None + } + } + factories.map(_.asInstanceOf[DataReaderFactory[UnsafeRow]]).asJava + } + + override def getStartOffset: Offset = { + KafkaSourceOffset(startPartitionOffsets) + } + + override def getEndOffset: Offset = { + KafkaSourceOffset(endPartitionOffsets) + } + + override def deserializeOffset(json: String): Offset = { + KafkaSourceOffset(JsonUtils.partitionOffsets(json)) + } + + override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema + + override def commit(end: Offset): Unit = {} + + override def stop(): Unit = { + kafkaOffsetReader.close() + } + + override def toString(): String = s"Kafka[$kafkaOffsetReader]" + + /** + * Read initial partition offsets from the checkpoint, or decide the offsets and write them to + * the checkpoint. + */ + private def getOrCreateInitialPartitionOffsets(): PartitionOffsetMap = { + // Make sure that `KafkaConsumer.poll` is only called in StreamExecutionThread. + // Otherwise, interrupting a thread while running `KafkaConsumer.poll` may hang forever + // (KAFKA-1894). + assert(Thread.currentThread().isInstanceOf[UninterruptibleThread]) + + // SparkSession is required for getting Hadoop configuration for writing to checkpoints + assert(SparkSession.getActiveSession.nonEmpty) + + val metadataLog = + new KafkaSourceInitialOffsetWriter(SparkSession.getActiveSession.get, metadataPath) + metadataLog.get(0).getOrElse { + val offsets = startingOffsets match { + case EarliestOffsetRangeLimit => + KafkaSourceOffset(kafkaOffsetReader.fetchEarliestOffsets()) + case LatestOffsetRangeLimit => + KafkaSourceOffset(kafkaOffsetReader.fetchLatestOffsets()) + case SpecificOffsetRangeLimit(p) => + kafkaOffsetReader.fetchSpecificOffsets(p, reportDataLoss) + } + metadataLog.add(0, offsets) + logInfo(s"Initial offsets: $offsets") + offsets + }.partitionToOffsets + } + + /** Proportionally distribute limit number of offsets among topicpartitions */ + private def rateLimit( + limit: Long, + from: PartitionOffsetMap, + until: PartitionOffsetMap): PartitionOffsetMap = { + val fromNew = kafkaOffsetReader.fetchEarliestOffsets(until.keySet.diff(from.keySet).toSeq) + val sizes = until.flatMap { + case (tp, end) => + // If begin isn't defined, something's wrong, but let alert logic in getBatch handle it + from.get(tp).orElse(fromNew.get(tp)).flatMap { begin => + val size = end - begin + logDebug(s"rateLimit $tp size is $size") + if (size > 0) Some(tp -> size) else None + } + } + val total = sizes.values.sum.toDouble + if (total < 1) { + until + } else { + until.map { + case (tp, end) => + tp -> sizes.get(tp).map { size => + val begin = from.get(tp).getOrElse(fromNew(tp)) + val prorate = limit * (size / total) + // Don't completely starve small topicpartitions + val off = begin + (if (prorate < 1) Math.ceil(prorate) else Math.floor(prorate)).toLong + // Paranoia, make sure not to return an offset that's past end + Math.min(end, off) + }.getOrElse(end) + } + } + } + + private def getSortedExecutorList(): Array[String] = { + + def compare(a: ExecutorCacheTaskLocation, b: ExecutorCacheTaskLocation): Boolean = { + if (a.host == b.host) { + a.executorId > b.executorId + } else { + a.host > b.host + } + } + + val bm = SparkEnv.get.blockManager + bm.master.getPeers(bm.blockManagerId).toArray + .map(x => ExecutorCacheTaskLocation(x.host, x.executorId)) + .sortWith(compare) + .map(_.toString) + } + + /** + * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. + * Otherwise, just log a warning. + */ + private def reportDataLoss(message: String): Unit = { + if (failOnDataLoss) { + throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE") + } else { + logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") + } + } + + /** A version of [[HDFSMetadataLog]] specialized for saving the initial offsets. */ + class KafkaSourceInitialOffsetWriter(sparkSession: SparkSession, metadataPath: String) + extends HDFSMetadataLog[KafkaSourceOffset](sparkSession, metadataPath) { + + val VERSION = 1 + + override def serialize(metadata: KafkaSourceOffset, out: OutputStream): Unit = { + out.write(0) // A zero byte is written to support Spark 2.1.0 (SPARK-19517) + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): KafkaSourceOffset = { + in.read() // A zero byte is read to support Spark 2.1.0 (SPARK-19517) + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) + KafkaSourceOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + // The log was generated by Spark 2.1.0 + KafkaSourceOffset(SerializedOffset(content)) + } + } + } +} + +/** A [[DataReaderFactory]] for reading Kafka data in a micro-batch streaming query. */ +private[kafka010] class KafkaMicroBatchDataReaderFactory( + range: KafkaOffsetRange, + preferredLoc: Option[String], + executorKafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean) extends DataReaderFactory[UnsafeRow] { + + override def preferredLocations(): Array[String] = preferredLoc.toArray + + override def createDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader( + range, executorKafkaParams, pollTimeoutMs, failOnDataLoss) +} + +/** A [[DataReader]] for reading Kafka data in a micro-batch streaming query. */ +private[kafka010] class KafkaMicroBatchDataReader( + offsetRange: KafkaOffsetRange, + executorKafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean) extends DataReader[UnsafeRow] with Logging { + + private val consumer = CachedKafkaConsumer.getOrCreate( + offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + private val rangeToRead = resolveRange(offsetRange) + private val converter = new KafkaRecordToUnsafeRowConverter + + private var nextOffset = rangeToRead.fromOffset + private var nextRow: UnsafeRow = _ + + override def next(): Boolean = { + if (nextOffset < rangeToRead.untilOffset) { + val record = consumer.get(nextOffset, rangeToRead.untilOffset, pollTimeoutMs, failOnDataLoss) + if (record != null) { + nextRow = converter.toUnsafeRow(record) + true + } else { + false + } + } else { + false + } + } + + override def get(): UnsafeRow = { + assert(nextRow != null) + nextOffset += 1 + nextRow + } + + override def close(): Unit = { + // Indicate that we're no longer using this consumer + CachedKafkaConsumer.releaseKafkaConsumer( + offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + } + + private def resolveRange(range: KafkaOffsetRange): KafkaOffsetRange = { + if (range.fromOffset < 0 || range.untilOffset < 0) { + // Late bind the offset range + val availableOffsetRange = consumer.getAvailableOffsetRange() + val fromOffset = if (range.fromOffset < 0) { + assert(range.fromOffset == KafkaOffsetRangeLimit.EARLIEST, + s"earliest offset ${range.fromOffset} does not equal ${KafkaOffsetRangeLimit.EARLIEST}") + availableOffsetRange.earliest + } else { + range.fromOffset + } + val untilOffset = if (range.untilOffset < 0) { + assert(range.untilOffset == KafkaOffsetRangeLimit.LATEST, + s"latest offset ${range.untilOffset} does not equal ${KafkaOffsetRangeLimit.LATEST}") + availableOffsetRange.latest + } else { + range.untilOffset + } + KafkaOffsetRange(range.topicPartition, fromOffset, untilOffset) + } else { + range + } + } +} + +private[kafka010] case class KafkaOffsetRange( + topicPartition: TopicPartition, fromOffset: Long, untilOffset: Long) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala new file mode 100644 index 0000000000000..1acdd56125741 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.clients.consumer.ConsumerRecord + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.unsafe.types.UTF8String + +/** A simple class for converting Kafka ConsumerRecord to UnsafeRow */ +private[kafka010] class KafkaRecordToUnsafeRowConverter { + private val sharedRow = new UnsafeRow(7) + private val bufferHolder = new BufferHolder(sharedRow) + private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) + + def toUnsafeRow(record: ConsumerRecord[Array[Byte], Array[Byte]]): UnsafeRow = { + bufferHolder.reset() + + if (record.key == null) { + rowWriter.setNullAt(0) + } else { + rowWriter.write(0, record.key) + } + rowWriter.write(1, record.value) + rowWriter.write(2, UTF8String.fromString(record.topic)) + rowWriter.write(3, record.partition) + rowWriter.write(4, record.offset) + rowWriter.write( + 5, + DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(record.timestamp))) + rowWriter.write(6, record.timestampType.id) + sharedRow.setTotalSize(bufferHolder.totalSize) + sharedRow + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 169a5d006fb04..1c7b3a29a861f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.kafka010.KafkaSource._ +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -306,7 +307,7 @@ private[kafka010] class KafkaSource( kafkaReader.close() } - override def toString(): String = s"KafkaSource[$kafkaReader]" + override def toString(): String = s"KafkaSourceV1[$kafkaReader]" /** * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. @@ -323,22 +324,6 @@ private[kafka010] class KafkaSource( /** Companion object for the [[KafkaSource]]. */ private[kafka010] object KafkaSource { - val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE = - """ - |Some data may have been lost because they are not available in Kafka any more; either the - | data was aged out by Kafka or the topic may have been deleted before all the data in the - | topic was processed. If you want your streaming query to fail on such cases, set the source - | option "failOnDataLoss" to "true". - """.stripMargin - - val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE = - """ - |Some data may have been lost because they are not available in Kafka any more; either the - | data was aged out by Kafka or the topic may have been deleted before all the data in the - | topic was processed. If you don't want your streaming query to fail on such cases, set the - | source option "failOnDataLoss" to "false". - """.stripMargin - private[kafka010] val VERSION = 1 def getSortedExecutorList(sc: SparkContext): Array[String] = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index d4fa0359c12d6..0aa64a6a9cf90 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -30,13 +30,13 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType /** - * The provider class for the [[KafkaSource]]. This provider is designed such that it throws + * The provider class for all Kafka readers and writers. It is designed such that it throws * IllegalArgumentException when the Kafka Dataset is created, so that it can catch * missing options even before the query is started. */ @@ -47,6 +47,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with CreatableRelationProvider with StreamWriteSupport with ContinuousReadSupport + with MicroBatchReadSupport with Logging { import KafkaSourceProvider._ @@ -105,6 +106,52 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParams)) } + /** + * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader]] to read batches + * of Kafka data in a micro-batch streaming query. + */ + override def createMicroBatchReader( + schema: Optional[StructType], + metadataPath: String, + options: DataSourceOptions): KafkaMicroBatchReader = { + + val parameters = options.asMap().asScala.toMap + validateStreamOptions(parameters) + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + val specifiedKafkaParams = + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + + val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, + STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + + val kafkaOffsetReader = new KafkaOffsetReader( + strategy(caseInsensitiveParams), + kafkaParamsForDriver(specifiedKafkaParams), + parameters, + driverGroupIdPrefix = s"$uniqueGroupId-driver") + + new KafkaMicroBatchReader( + kafkaOffsetReader, + kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), + options, + metadataPath, + startingStreamOffsets, + failOnDataLoss(caseInsensitiveParams)) + } + + /** + * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.ContinuousDataReader]] to read + * Kafka data in a continuous streaming query. + */ override def createContinuousReader( schema: Optional[StructType], metadataPath: String, @@ -408,8 +455,27 @@ private[kafka010] object KafkaSourceProvider extends Logging { private[kafka010] val STARTING_OFFSETS_OPTION_KEY = "startingoffsets" private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets" private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" + val TOPIC_OPTION_KEY = "topic" + val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE = + """ + |Some data may have been lost because they are not available in Kafka any more; either the + | data was aged out by Kafka or the topic may have been deleted before all the data in the + | topic was processed. If you want your streaming query to fail on such cases, set the source + | option "failOnDataLoss" to "true". + """.stripMargin + + val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE = + """ + |Some data may have been lost because they are not available in Kafka any more; either the + | data was aged out by Kafka or the topic may have been deleted before all the data in the + | topic was processed. If you don't want your streaming query to fail on such cases, set the + | source option "failOnDataLoss" to "false". + """.stripMargin + + + private val deserClassName = classOf[ByteArrayDeserializer].getName def getKafkaOffsetRangeLimit( diff --git a/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-future-version.bin b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-future-version.bin new file mode 100644 index 0000000000000..d530773f57327 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-future-version.bin @@ -0,0 +1,2 @@ +0v99999 +{"kafka-initial-offset-future-version":{"2":2,"1":1,"0":0}} \ No newline at end of file diff --git a/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin index ae928e724967d..8c78d9e390a0e 100644 --- a/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin +++ b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin @@ -1 +1 @@ -2{"kafka-initial-offset-2-1-0":{"2":0,"1":0,"0":0}} \ No newline at end of file +2{"kafka-initial-offset-2-1-0":{"2":2,"1":1,"0":0}} \ No newline at end of file diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala similarity index 85% rename from external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala rename to external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 02c87643568bd..ed4ecfeafa972 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -25,6 +25,7 @@ import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable +import scala.io.Source import scala.util.Random import org.apache.kafka.clients.producer.RecordMetadata @@ -42,7 +43,6 @@ import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} -import org.apache.spark.util.Utils abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { @@ -112,14 +112,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") - val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: KafkaSource, _) => source - } ++ (query.get.lastExecution match { - case null => Seq() - case e => e.logical.collect { - case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader - } - }) + val sources = { + query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: KafkaSource, _) => source + case StreamingExecutionRelation(source: KafkaMicroBatchReader, _) => source + } ++ (query.get.lastExecution match { + case null => Seq() + case e => e.logical.collect { + case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + } + }) + }.distinct + if (sources.isEmpty) { throw new Exception( "Could not find Kafka source in the StreamExecution logical plan to add data to") @@ -155,7 +159,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" } -class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { +abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { import testImplicits._ @@ -303,94 +307,105 @@ class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { ) } - testWithUninterruptibleThread( - "deserialization of initial offset with Spark 2.1.0") { + test("ensure that initial offset are written with an extra byte in the beginning (SPARK-19517)") { withTempDir { metadataPath => - val topic = newTopic - testUtils.createTopic(topic, partitions = 3) + val topic = "kafka-initial-offset-current" + testUtils.createTopic(topic, partitions = 1) - val provider = new KafkaSourceProvider - val parameters = Map( - "kafka.bootstrap.servers" -> testUtils.brokerAddress, - "subscribe" -> topic - ) - val source = provider.createSource(spark.sqlContext, metadataPath.getAbsolutePath, None, - "", parameters) - source.getOffset.get // Write initial offset - - // Make sure Spark 2.1.0 will throw an exception when reading the new log - intercept[java.lang.IllegalArgumentException] { - // Simulate how Spark 2.1.0 reads the log - Utils.tryWithResource(new FileInputStream(metadataPath.getAbsolutePath + "/0")) { in => - val length = in.read() - val bytes = new Array[Byte](length) - in.read(bytes) - KafkaSourceOffset(SerializedOffset(new String(bytes, UTF_8))) - } + val initialOffsetFile = Paths.get(s"${metadataPath.getAbsolutePath}/sources/0/0").toFile + + val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .option("startingOffsets", s"earliest") + .load() + + // Test the written initial offset file has 0 byte in the beginning, so that + // Spark 2.1.0 can read the offsets (see SPARK-19517) + testStream(df)( + StartStream(checkpointLocation = metadataPath.getAbsolutePath), + makeSureGetOffsetCalled) + + val binarySource = Source.fromFile(initialOffsetFile) + try { + assert(binarySource.next().toInt == 0) // first byte is binary 0 + } finally { + binarySource.close() } } } - testWithUninterruptibleThread("deserialization of initial offset written by Spark 2.1.0") { + test("deserialization of initial offset written by Spark 2.1.0 (SPARK-19517)") { withTempDir { metadataPath => val topic = "kafka-initial-offset-2-1-0" testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessages(topic, Array("0", "1", "2"), Some(0)) + testUtils.sendMessages(topic, Array("0", "10", "20"), Some(1)) + testUtils.sendMessages(topic, Array("0", "100", "200"), Some(2)) - val provider = new KafkaSourceProvider - val parameters = Map( - "kafka.bootstrap.servers" -> testUtils.brokerAddress, - "subscribe" -> topic - ) - + // Copy the initial offset file into the right location inside the checkpoint root directory + // such that the Kafka source can read it for initial offsets. val from = new File( getClass.getResource("/kafka-source-initial-offset-version-2.1.0.bin").toURI).toPath - val to = Paths.get(s"${metadataPath.getAbsolutePath}/0") + val to = Paths.get(s"${metadataPath.getAbsolutePath}/sources/0/0") + Files.createDirectories(to.getParent) Files.copy(from, to) - val source = provider.createSource( - spark.sqlContext, metadataPath.toURI.toString, None, "", parameters) - val deserializedOffset = source.getOffset.get - val referenceOffset = KafkaSourceOffset((topic, 0, 0L), (topic, 1, 0L), (topic, 2, 0L)) - assert(referenceOffset == deserializedOffset) + val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .option("startingOffsets", s"earliest") + .load() + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(_.toInt) + + // Test that the query starts from the expected initial offset (i.e. read older offsets, + // even though startingOffsets is latest). + testStream(df)( + StartStream(checkpointLocation = metadataPath.getAbsolutePath), + AddKafkaData(Set(topic), 1000), + CheckAnswer(0, 1, 2, 10, 20, 200, 1000)) } } - testWithUninterruptibleThread("deserialization of initial offset written by future version") { + test("deserialization of initial offset written by future version") { withTempDir { metadataPath => - val futureMetadataLog = - new HDFSMetadataLog[KafkaSourceOffset](sqlContext.sparkSession, - metadataPath.getAbsolutePath) { - override def serialize(metadata: KafkaSourceOffset, out: OutputStream): Unit = { - out.write(0) - val writer = new BufferedWriter(new OutputStreamWriter(out, UTF_8)) - writer.write(s"v99999\n${metadata.json}") - writer.flush - } - } - - val topic = newTopic + val topic = "kafka-initial-offset-future-version" testUtils.createTopic(topic, partitions = 3) - val offset = KafkaSourceOffset((topic, 0, 0L), (topic, 1, 0L), (topic, 2, 0L)) - futureMetadataLog.add(0, offset) - - val provider = new KafkaSourceProvider - val parameters = Map( - "kafka.bootstrap.servers" -> testUtils.brokerAddress, - "subscribe" -> topic - ) - val source = provider.createSource(spark.sqlContext, metadataPath.getAbsolutePath, None, - "", parameters) - val e = intercept[java.lang.IllegalStateException] { - source.getOffset.get // Read initial offset - } + // Copy the initial offset file into the right location inside the checkpoint root directory + // such that the Kafka source can read it for initial offsets. + val from = new File( + getClass.getResource("/kafka-source-initial-offset-future-version.bin").toURI).toPath + val to = Paths.get(s"${metadataPath.getAbsolutePath}/sources/0/0") + Files.createDirectories(to.getParent) + Files.copy(from, to) - Seq( - s"maximum supported log version is v${KafkaSource.VERSION}, but encountered v99999", - "produced by a newer version of Spark and cannot be read by this version" - ).foreach { message => - assert(e.getMessage.contains(message)) - } + val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .load() + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(_.toInt) + + testStream(df)( + StartStream(checkpointLocation = metadataPath.getAbsolutePath), + ExpectFailure[IllegalStateException](e => { + Seq( + s"maximum supported log version is v1, but encountered v99999", + "produced by a newer version of Spark and cannot be read by this version" + ).foreach { message => + assert(e.toString.contains(message)) + } + })) } } @@ -542,6 +557,91 @@ class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { CheckLastBatch(120 to 124: _*) ) } + + test("ensure stream-stream self-join generates only one offset in offset log") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 2) + require(testUtils.getLatestOffsets(Set(topic)).size === 2) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .load() + + val values = kafka + .selectExpr("CAST(CAST(value AS STRING) AS INT) AS value", + "CAST(CAST(value AS STRING) AS INT) % 5 AS key") + + val join = values.join(values, "key") + + testStream(join)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2), + CheckAnswer((1, 1, 1), (2, 2, 2)), + AddKafkaData(Set(topic), 6, 3), + CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 6, 6)) + ) + } +} + + +class KafkaMicroBatchV1SourceSuite extends KafkaMicroBatchSourceSuiteBase { + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set( + "spark.sql.streaming.disabledV2MicroBatchReaders", + classOf[KafkaSourceProvider].getCanonicalName) + } + + test("V1 Source is used when disabled through SQLConf") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + + val kafka = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topic.*") + .load() + + testStream(kafka)( + makeSureGetOffsetCalled, + AssertOnQuery { query => + query.logicalPlan.collect { + case StreamingExecutionRelation(_: KafkaSource, _) => true + }.nonEmpty + } + ) + } +} + +class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { + + test("V2 Source is used by default") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + + val kafka = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topic.*") + .load() + + testStream(kafka)( + makeSureGetOffsetCalled, + AssertOnQuery { query => + query.logicalPlan.collect { + case StreamingExecutionRelation(_: KafkaMicroBatchReader, _) => true + }.nonEmpty + } + ) + } } abstract class KafkaSourceSuiteBase extends KafkaSourceTest { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f24fd7ff74d3f..e75e1d66ebcf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1146,10 +1146,20 @@ object SQLConf { val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers") .internal() .doc("A comma-separated list of fully qualified data source register class names for which" + - " StreamWriteSupport is disabled. Writes to these sources will fail back to the V1 Sink.") + " StreamWriteSupport is disabled. Writes to these sources will fall back to the V1 Sinks.") .stringConf .createWithDefault("") + val DISABLED_V2_STREAMING_MICROBATCH_READERS = + buildConf("spark.sql.streaming.disabledV2MicroBatchReaders") + .internal() + .doc( + "A comma-separated list of fully qualified data source register class names for which " + + "MicroBatchReadSupport is disabled. Reads from these sources will fall back to the " + + "V1 Sources.") + .stringConf + .createWithDefault("") + object PartitionOverwriteMode extends Enumeration { val STATIC, DYNAMIC = Value } @@ -1525,6 +1535,9 @@ class SQLConf extends Serializable with Logging { def disabledV2StreamingWriters: String = getConf(DISABLED_V2_STREAMING_WRITERS) + def disabledV2StreamingMicroBatchReaders: String = + getConf(DISABLED_V2_STREAMING_MICROBATCH_READERS) + def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index ac73ba3417904..84655013ba957 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -72,27 +72,36 @@ class MicroBatchExecution( // Note that we have to use the previous `output` as attributes in StreamingExecutionRelation, // since the existing logical plan has already used those attributes. The per-microbatch // transformation is responsible for replacing attributes with their final values. + + val disabledSources = + sparkSession.sqlContext.conf.disabledV2StreamingMicroBatchReaders.split(",") + val _logicalPlan = analyzedPlan.transform { - case streamingRelation@StreamingRelation(dataSource, _, output) => + case streamingRelation@StreamingRelation(dataSourceV1, sourceName, output) => toExecutionRelationMap.getOrElseUpdate(streamingRelation, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val source = dataSource.createSource(metadataPath) + val source = dataSourceV1.createSource(metadataPath) nextSourceId += 1 + logInfo(s"Using Source [$source] from DataSourceV1 named '$sourceName' [$dataSourceV1]") StreamingExecutionRelation(source, output)(sparkSession) }) - case s @ StreamingRelationV2(source: MicroBatchReadSupport, _, options, output, _) => + case s @ StreamingRelationV2( + dataSourceV2: MicroBatchReadSupport, sourceName, options, output, _) if + !disabledSources.contains(dataSourceV2.getClass.getCanonicalName) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val reader = source.createMicroBatchReader( + val reader = dataSourceV2.createMicroBatchReader( Optional.empty(), // user specified schema metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 + logInfo(s"Using MicroBatchReader [$reader] from " + + s"DataSourceV2 named '$sourceName' [$dataSourceV2]") StreamingExecutionRelation(reader, output)(sparkSession) }) - case s @ StreamingRelationV2(_, sourceName, _, output, v1Relation) => + case s @ StreamingRelationV2(dataSourceV2, sourceName, _, output, v1Relation) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" @@ -102,6 +111,7 @@ class MicroBatchExecution( } val source = v1Relation.get.dataSource.createSource(metadataPath) nextSourceId += 1 + logInfo(s"Using Source [$source] from DataSourceV2 named '$sourceName' [$dataSourceV2]") StreamingExecutionRelation(source, output)(sparkSession) }) } From d5ed2108d32e1d95b26ee7fed39e8a733e935e2c Mon Sep 17 00:00:00 2001 From: Shintaro Murakami Date: Fri, 16 Feb 2018 17:17:55 -0800 Subject: [PATCH 368/774] [SPARK-23381][CORE] Murmur3 hash generates a different value from other implementations ## What changes were proposed in this pull request? Murmur3 hash generates a different value from the original and other implementations (like Scala standard library and Guava or so) when the length of a bytes array is not multiple of 4. ## How was this patch tested? Added a unit test. **Note: When we merge this PR, please give all the credits to Shintaro Murakami.** Author: Shintaro Murakami Author: gatorsmile Author: Shintaro Murakami Closes #20630 from gatorsmile/pr-20568. --- .../spark/util/sketch/Murmur3_x86_32.java | 16 +++++++++ .../spark/unsafe/hash/Murmur3_x86_32.java | 16 +++++++++ .../unsafe/hash/Murmur3_x86_32Suite.java | 19 +++++++++++ .../spark/ml/feature/FeatureHasher.scala | 33 ++++++++++++++++++- .../spark/mllib/feature/HashingTF.scala | 2 +- .../spark/ml/feature/FeatureHasherSuite.scala | 11 ++++++- python/pyspark/ml/feature.py | 4 +-- 7 files changed, 96 insertions(+), 5 deletions(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java index a61ce4fb7241d..e83b331391e39 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java @@ -60,6 +60,8 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i } public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + // This is not compatible with original and another implementations. + // But remain it for backward compatibility for the components existing before 2.3. assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; int h1 = hashBytesByInt(base, offset, lengthAligned, seed); @@ -71,6 +73,20 @@ public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, i return fmix(h1, lengthInBytes); } + public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) { + // This is compatible with original and another implementations. + // Use this method for new components after Spark 2.3. + assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; + int lengthAligned = lengthInBytes - lengthInBytes % 4; + int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + int k1 = 0; + for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) { + k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift; + } + h1 ^= mixK1(k1); + return fmix(h1, lengthInBytes); + } + private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { assert (lengthInBytes % 4 == 0); int h1 = seed; diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index 5e7ee480cafd1..d239de6083ad0 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -60,6 +60,8 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i } public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + // This is not compatible with original and another implementations. + // But remain it for backward compatibility for the components existing before 2.3. assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; int h1 = hashBytesByInt(base, offset, lengthAligned, seed); @@ -71,6 +73,20 @@ public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, i return fmix(h1, lengthInBytes); } + public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) { + // This is compatible with original and another implementations. + // Use this method for new components after Spark 2.3. + assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; + int lengthAligned = lengthInBytes - lengthInBytes % 4; + int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + int k1 = 0; + for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) { + k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift; + } + h1 ^= mixK1(k1); + return fmix(h1, lengthInBytes); + } + private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { assert (lengthInBytes % 4 == 0); int h1 = seed; diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java index e759cb33b3e6a..6348a73bf3895 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -22,6 +22,8 @@ import java.util.Random; import java.util.Set; +import scala.util.hashing.MurmurHash3$; + import org.apache.spark.unsafe.Platform; import org.junit.Assert; import org.junit.Test; @@ -51,6 +53,23 @@ public void testKnownLongInputs() { Assert.assertEquals(-2106506049, hasher.hashLong(Long.MAX_VALUE)); } + // SPARK-23381 Check whether the hash of the byte array is the same as another implementations + @Test + public void testKnownBytesInputs() { + byte[] test = "test".getBytes(StandardCharsets.UTF_8); + Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(test, 0), + Murmur3_x86_32.hashUnsafeBytes2(test, Platform.BYTE_ARRAY_OFFSET, test.length, 0)); + byte[] test1 = "test1".getBytes(StandardCharsets.UTF_8); + Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(test1, 0), + Murmur3_x86_32.hashUnsafeBytes2(test1, Platform.BYTE_ARRAY_OFFSET, test1.length, 0)); + byte[] te = "te".getBytes(StandardCharsets.UTF_8); + Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(te, 0), + Murmur3_x86_32.hashUnsafeBytes2(te, Platform.BYTE_ARRAY_OFFSET, te.length, 0)); + byte[] tes = "tes".getBytes(StandardCharsets.UTF_8); + Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(tes, 0), + Murmur3_x86_32.hashUnsafeBytes2(tes, Platform.BYTE_ARRAY_OFFSET, tes.length, 0)); + } + @Test public void randomizedStressTest() { int size = 65536; diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index a918dd4c075da..c78f61ac3ef71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup @@ -28,6 +29,8 @@ import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashMap @@ -138,7 +141,7 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme @Since("2.3.0") override def transform(dataset: Dataset[_]): DataFrame = { - val hashFunc: Any => Int = OldHashingTF.murmur3Hash + val hashFunc: Any => Int = FeatureHasher.murmur3Hash val n = $(numFeatures) val localInputCols = $(inputCols) val catCols = if (isSet(categoricalCols)) { @@ -218,4 +221,32 @@ object FeatureHasher extends DefaultParamsReadable[FeatureHasher] { @Since("2.3.0") override def load(path: String): FeatureHasher = super.load(path) + + private val seed = OldHashingTF.seed + + /** + * Calculate a hash code value for the term object using + * Austin Appleby's MurmurHash 3 algorithm (MurmurHash3_x86_32). + * This is the default hash algorithm used from Spark 2.0 onwards. + * Use hashUnsafeBytes2 to match the original algorithm with the value. + * See SPARK-23381. + */ + @Since("2.3.0") + private[feature] def murmur3Hash(term: Any): Int = { + term match { + case null => seed + case b: Boolean => hashInt(if (b) 1 else 0, seed) + case b: Byte => hashInt(b, seed) + case s: Short => hashInt(s, seed) + case i: Int => hashInt(i, seed) + case l: Long => hashLong(l, seed) + case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed) + case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) + case s: String => + val utf8 = UTF8String.fromString(s) + hashUnsafeBytes2(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed) + case _ => throw new SparkException("FeatureHasher with murmur3 algorithm does not " + + s"support type ${term.getClass.getCanonicalName} of input data.") + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index 9abdd44a635d1..8935c8496cdbb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -135,7 +135,7 @@ object HashingTF { private[HashingTF] val Murmur3: String = "murmur3" - private val seed = 42 + private[spark] val seed = 42 /** * Calculate a hash code value for the term object using the native Scala implementation. diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala index 3fc3cbb62d5b5..7bc1825b69c43 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils class FeatureHasherSuite extends SparkFunSuite with MLlibTestSparkContext @@ -34,7 +35,7 @@ class FeatureHasherSuite extends SparkFunSuite import testImplicits._ - import HashingTFSuite.murmur3FeatureIdx + import FeatureHasherSuite.murmur3FeatureIdx implicit private val vectorEncoder = ExpressionEncoder[Vector]() @@ -216,3 +217,11 @@ class FeatureHasherSuite extends SparkFunSuite testDefaultReadWrite(t) } } + +object FeatureHasherSuite { + + private[feature] def murmur3FeatureIdx(numFeatures: Int)(term: Any): Int = { + Utils.nonNegativeMod(FeatureHasher.murmur3Hash(term), numFeatures) + } + +} diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index da85ba761a145..04b07e6a05481 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -741,9 +741,9 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, >>> df = spark.createDataFrame(data, cols) >>> hasher = FeatureHasher(inputCols=cols, outputCol="features") >>> hasher.transform(df).head().features - SparseVector(262144, {51871: 1.0, 63643: 1.0, 174475: 2.0, 253195: 1.0}) + SparseVector(262144, {174475: 2.0, 247670: 1.0, 257907: 1.0, 262126: 1.0}) >>> hasher.setCategoricalCols(["real"]).transform(df).head().features - SparseVector(262144, {51871: 1.0, 63643: 1.0, 171257: 1.0, 253195: 1.0}) + SparseVector(262144, {171257: 1.0, 247670: 1.0, 257907: 1.0, 262126: 1.0}) >>> hasherPath = temp_path + "/hasher" >>> hasher.save(hasherPath) >>> loadedHasher = FeatureHasher.load(hasherPath) From 15ad4a7f1000c83cefbecd41e315c964caa3c39f Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Sat, 17 Feb 2018 10:54:14 +0800 Subject: [PATCH 369/774] [SPARK-23447][SQL] Cleanup codegen template for Literal ## What changes were proposed in this pull request? Cleaned up the codegen templates for `Literal`s, to make sure that the `ExprCode` returned from `Literal.doGenCode()` has: 1. an empty `code` field; 2. an `isNull` field of either literal `true` or `false`; 3. a `value` field that is just a simple literal/constant. Before this PR, there are a couple of paths that would return a non-trivial `code` and all of them are actually unnecessary. The `NaN` and `Infinity` constants for `double` and `float` can be accessed through constants directly available so there's no need to add a reference for them. Also took the opportunity to add a new util method for ease of creating `ExprCode` for inline-able non-null values. ## How was this patch tested? Existing tests. Author: Kris Mok Closes #20626 from rednaxelafx/codegen-literal. --- .../expressions/codegen/CodeGenerator.scala | 6 +++ .../sql/catalyst/expressions/literals.scala | 51 ++++++++++--------- 2 files changed, 34 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4dcbb702893da..31ba29ae8d8ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -58,6 +58,12 @@ import org.apache.spark.util.{ParentClassLoader, Utils} */ case class ExprCode(var code: String, var isNull: String, var value: String) +object ExprCode { + def forNonNullValue(value: String): ExprCode = { + ExprCode(code = "", isNull = "false", value = value) + } +} + /** * State used for subexpression elimination. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index cd176d941819f..c1e65e34c2ea6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -278,40 +278,45 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) - // change the isNull and primitive to consts, to inline them if (value == null) { - ev.isNull = "true" - ev.copy(s"final $javaType ${ev.value} = ${ctx.defaultValue(dataType)};") + val defaultValueLiteral = ctx.defaultValue(javaType) match { + case "null" => s"(($javaType)null)" + case lit => lit + } + ExprCode(code = "", isNull = "true", value = defaultValueLiteral) } else { - ev.isNull = "false" dataType match { case BooleanType | IntegerType | DateType => - ev.copy(code = "", value = value.toString) + ExprCode.forNonNullValue(value.toString) case FloatType => - val v = value.asInstanceOf[Float] - if (v.isNaN || v.isInfinite) { - val boxedValue = ctx.addReferenceObj("boxedValue", v) - val code = s"final $javaType ${ev.value} = ($javaType) $boxedValue;" - ev.copy(code = code) - } else { - ev.copy(code = "", value = s"${value}f") + value.asInstanceOf[Float] match { + case v if v.isNaN => + ExprCode.forNonNullValue("Float.NaN") + case Float.PositiveInfinity => + ExprCode.forNonNullValue("Float.POSITIVE_INFINITY") + case Float.NegativeInfinity => + ExprCode.forNonNullValue("Float.NEGATIVE_INFINITY") + case _ => + ExprCode.forNonNullValue(s"${value}F") } case DoubleType => - val v = value.asInstanceOf[Double] - if (v.isNaN || v.isInfinite) { - val boxedValue = ctx.addReferenceObj("boxedValue", v) - val code = s"final $javaType ${ev.value} = ($javaType) $boxedValue;" - ev.copy(code = code) - } else { - ev.copy(code = "", value = s"${value}D") + value.asInstanceOf[Double] match { + case v if v.isNaN => + ExprCode.forNonNullValue("Double.NaN") + case Double.PositiveInfinity => + ExprCode.forNonNullValue("Double.POSITIVE_INFINITY") + case Double.NegativeInfinity => + ExprCode.forNonNullValue("Double.NEGATIVE_INFINITY") + case _ => + ExprCode.forNonNullValue(s"${value}D") } case ByteType | ShortType => - ev.copy(code = "", value = s"($javaType)$value") + ExprCode.forNonNullValue(s"($javaType)$value") case TimestampType | LongType => - ev.copy(code = "", value = s"${value}L") + ExprCode.forNonNullValue(s"${value}L") case _ => - ev.copy(code = "", value = ctx.addReferenceObj("literal", value, - ctx.javaType(dataType))) + val constRef = ctx.addReferenceObj("literal", value, javaType) + ExprCode.forNonNullValue(constRef) } } } From 3ee3b2ae1ff8fbeb43a08becef43a9bd763b06bb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 17 Feb 2018 00:25:36 -0800 Subject: [PATCH 370/774] [SPARK-23340][SQL] Upgrade Apache ORC to 1.4.3 ## What changes were proposed in this pull request? This PR updates Apache ORC dependencies to 1.4.3 released on February 9th. Apache ORC 1.4.2 release removes unnecessary dependencies and 1.4.3 has 5 more patches (https://s.apache.org/Fll8). Especially, the following ORC-285 is fixed at 1.4.3. ```scala scala> val df = Seq(Array.empty[Float]).toDF() scala> df.write.format("orc").save("/tmp/floatarray") scala> spark.read.orc("/tmp/floatarray") res1: org.apache.spark.sql.DataFrame = [value: array] scala> spark.read.orc("/tmp/floatarray").show() 18/02/12 22:09:10 ERROR Executor: Exception in task 0.0 in stage 1.0 (TID 1) java.io.IOException: Error reading file: file:/tmp/floatarray/part-00000-9c0b461b-4df1-4c23-aac1-3e4f349ac7d6-c000.snappy.orc at org.apache.orc.impl.RecordReaderImpl.nextBatch(RecordReaderImpl.java:1191) at org.apache.orc.mapreduce.OrcMapreduceRecordReader.ensureBatch(OrcMapreduceRecordReader.java:78) ... Caused by: java.io.EOFException: Read past EOF for compressed stream Stream for column 2 kind DATA position: 0 length: 0 range: 0 offset: 0 limit: 0 ``` ## How was this patch tested? Pass the Jenkins test. Author: Dongjoon Hyun Closes #20511 from dongjoon-hyun/SPARK-23340. --- dev/deps/spark-deps-hadoop-2.6 | 4 ++-- dev/deps/spark-deps-hadoop-2.7 | 4 ++-- pom.xml | 6 +----- .../sql/execution/datasources/orc/OrcSourceSuite.scala | 9 +++++++++ .../apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala | 10 ++++++++++ 5 files changed, 24 insertions(+), 9 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 99031384aa22e..ed310507d14ed 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -157,8 +157,8 @@ objenesis-2.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.1-nohive.jar -orc-mapreduce-1.4.1-nohive.jar +orc-core-1.4.3-nohive.jar +orc-mapreduce-1.4.3-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index cf8d2789b7ee9..04dec04796af4 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -158,8 +158,8 @@ objenesis-2.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.1-nohive.jar -orc-mapreduce-1.4.1-nohive.jar +orc-core-1.4.3-nohive.jar +orc-mapreduce-1.4.3-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/pom.xml b/pom.xml index de949b94d676c..ac30107066389 100644 --- a/pom.xml +++ b/pom.xml @@ -130,7 +130,7 @@ 1.2.1 10.12.1.1 1.8.2 - 1.4.1 + 1.4.3 nohive 1.6.0 9.3.20.v20170531 @@ -1740,10 +1740,6 @@ org.apache.hive hive-storage-api - - io.airlift - slice - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 6f5f2fd795f74..523f7cf77e103 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -160,6 +160,15 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { } } } + + test("SPARK-23340 Empty float/double array columns raise EOFException") { + Seq(Seq(Array.empty[Float]).toDF(), Seq(Array.empty[Double]).toDF()).foreach { df => + withTempPath { path => + df.write.format("orc").save(path.getCanonicalPath) + checkAnswer(spark.read.orc(path.getCanonicalPath), df) + } + } + } } class OrcSourceSuite extends OrcSuite with SharedSQLContext { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala index 92b2f069cacd6..597b0f56a55e4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala @@ -208,4 +208,14 @@ class HiveOrcQuerySuite extends OrcQueryTest with TestHiveSingleton { } } } + + test("SPARK-23340 Empty float/double array columns raise EOFException") { + withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "false") { + withTable("spark_23340") { + sql("CREATE TABLE spark_23340(a array, b array) STORED AS ORC") + sql("INSERT INTO spark_23340 VALUES (array(), array())") + checkAnswer(spark.table("spark_23340"), Seq(Row(Array.empty[Float], Array.empty[Double]))) + } + } + } } From f5850e78924d03448ad243cdd32b24c3fe0ea8af Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 20 Feb 2018 13:33:03 +0800 Subject: [PATCH 371/774] [SPARK-23457][SQL] Register task completion listeners first in ParquetFileFormat ## What changes were proposed in this pull request? ParquetFileFormat leaks opened files in some cases. This PR prevents that by registering task completion listers first before initialization. - [spark-branch-2.3-test-sbt-hadoop-2.7](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-branch-2.3-test-sbt-hadoop-2.7/205/testReport/org.apache.spark.sql/FileBasedDataSourceSuite/_It_is_not_a_test_it_is_a_sbt_testing_SuiteSelector_/) - [spark-master-test-sbt-hadoop-2.6](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.6/4228/testReport/junit/org.apache.spark.sql.execution.datasources.parquet/ParquetQuerySuite/_It_is_not_a_test_it_is_a_sbt_testing_SuiteSelector_/) ``` Caused by: sbt.ForkMain$ForkError: java.lang.Throwable: null at org.apache.spark.DebugFilesystem$.addOpenStream(DebugFilesystem.scala:36) at org.apache.spark.DebugFilesystem.open(DebugFilesystem.scala:70) at org.apache.hadoop.fs.FileSystem.open(FileSystem.java:769) at org.apache.parquet.hadoop.ParquetFileReader.(ParquetFileReader.java:538) at org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.initialize(SpecificParquetRecordReaderBase.java:149) at org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader.initialize(VectorizedParquetRecordReader.java:133) at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$$anonfun$buildReaderWithPartitionValues$1.apply(ParquetFileFormat.scala:400) at ``` ## How was this patch tested? Manual. The following test case generates the same leakage. ```scala test("SPARK-23457 Register task completion listeners first in ParquetFileFormat") { withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE.key -> s"${Int.MaxValue}") { withTempDir { dir => val basePath = dir.getCanonicalPath Seq(0).toDF("a").write.format("parquet").save(new Path(basePath, "first").toString) Seq(1).toDF("a").write.format("parquet").save(new Path(basePath, "second").toString) val df = spark.read.parquet( new Path(basePath, "first").toString, new Path(basePath, "second").toString) val e = intercept[SparkException] { df.collect() } assert(e.getCause.isInstanceOf[OutOfMemoryError]) } } } ``` Author: Dongjoon Hyun Closes #20619 from dongjoon-hyun/SPARK-23390. --- .../parquet/ParquetFileFormat.scala | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index ba69f9a26c968..476bd02374364 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -395,16 +395,21 @@ class ParquetFileFormat ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get) } val taskContext = Option(TaskContext.get()) - val parquetReader = if (enableVectorizedReader) { + if (enableVectorizedReader) { val vectorizedReader = new VectorizedParquetRecordReader( convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined, capacity) + val iter = new RecordReaderIterator(vectorizedReader) + // SPARK-23457 Register a task completion lister before `initialization`. + taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) vectorizedReader.initialize(split, hadoopAttemptContext) logDebug(s"Appending $partitionSchema ${file.partitionValues}") vectorizedReader.initBatch(partitionSchema, file.partitionValues) if (returningBatch) { vectorizedReader.enableReturningBatches() } - vectorizedReader + + // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. + iter.asInstanceOf[Iterator[InternalRow]] } else { logDebug(s"Falling back to parquet-mr") // ParquetRecordReader returns UnsafeRow @@ -414,18 +419,11 @@ class ParquetFileFormat } else { new ParquetRecordReader[UnsafeRow](new ParquetReadSupport(convertTz)) } + val iter = new RecordReaderIterator(reader) + // SPARK-23457 Register a task completion lister before `initialization`. + taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) reader.initialize(split, hadoopAttemptContext) - reader - } - val iter = new RecordReaderIterator(parquetReader) - taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) - - // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. - if (parquetReader.isInstanceOf[VectorizedParquetRecordReader] && - enableVectorizedReader) { - iter.asInstanceOf[Iterator[InternalRow]] - } else { val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes val joinedRow = new JoinedRow() val appendPartitionColumns = GenerateUnsafeProjection.generate(fullSchema, fullSchema) From 651b0277fe989119932d5ae1ef729c9768aa018d Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 20 Feb 2018 13:56:38 +0800 Subject: [PATCH 372/774] [SPARK-23436][SQL] Infer partition as Date only if it can be casted to Date ## What changes were proposed in this pull request? Before the patch, Spark could infer as Date a partition value which cannot be casted to Date (this can happen when there are extra characters after a valid date, like `2018-02-15AAA`). When this happens and the input format has metadata which define the schema of the table, then `null` is returned as a value for the partition column, because the `cast` operator used in (`PartitioningAwareFileIndex.inferPartitioning`) is unable to convert the value. The PR checks in the partition inference that values can be casted to Date and Timestamp, in order to infer that datatype to them. ## How was this patch tested? added UT Author: Marco Gaido Closes #20621 from mgaido91/SPARK-23436. --- .../datasources/PartitioningUtils.scala | 40 ++++++++++++++----- .../ParquetPartitionDiscoverySuite.scala | 14 +++++++ 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 472bf82d3604d..379acb67f7c71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -407,6 +407,34 @@ object PartitioningUtils { Literal(bigDecimal) } + val dateTry = Try { + // try and parse the date, if no exception occurs this is a candidate to be resolved as + // DateType + DateTimeUtils.getThreadLocalDateFormat.parse(raw) + // SPARK-23436: Casting the string to date may still return null if a bad Date is provided. + // This can happen since DateFormat.parse may not use the entire text of the given string: + // so if there are extra-characters after the date, it returns correctly. + // We need to check that we can cast the raw string since we later can use Cast to get + // the partition values with the right DataType (see + // org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex.inferPartitioning) + val dateValue = Cast(Literal(raw), DateType).eval() + // Disallow DateType if the cast returned null + require(dateValue != null) + Literal.create(dateValue, DateType) + } + + val timestampTry = Try { + val unescapedRaw = unescapePathName(raw) + // try and parse the date, if no exception occurs this is a candidate to be resolved as + // TimestampType + DateTimeUtils.getThreadLocalTimestampFormat(timeZone).parse(unescapedRaw) + // SPARK-23436: see comment for date + val timestampValue = Cast(Literal(unescapedRaw), TimestampType, Some(timeZone.getID)).eval() + // Disallow TimestampType if the cast returned null + require(timestampValue != null) + Literal.create(timestampValue, TimestampType) + } + if (typeInference) { // First tries integral types Try(Literal.create(Integer.parseInt(raw), IntegerType)) @@ -415,16 +443,8 @@ object PartitioningUtils { // Then falls back to fractional types .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) // Then falls back to date/timestamp types - .orElse(Try( - Literal.create( - DateTimeUtils.getThreadLocalTimestampFormat(timeZone) - .parse(unescapePathName(raw)).getTime * 1000L, - TimestampType))) - .orElse(Try( - Literal.create( - DateTimeUtils.millisToDays( - DateTimeUtils.getThreadLocalDateFormat.parse(raw).getTime), - DateType))) + .orElse(timestampTry) + .orElse(dateTry) // Then falls back to string .getOrElse { if (raw == DEFAULT_PARTITION_NAME) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index d4902641e335f..edb3da904d10d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -1120,4 +1120,18 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha Row(3, BigDecimal("2" * 30)) :: Nil) } } + + test("SPARK-23436: invalid Dates should be inferred as String in partition inference") { + withTempPath { path => + val data = Seq(("1", "2018-01", "2018-01-01-04", "test")) + .toDF("id", "date_month", "date_hour", "data") + + data.write.partitionBy("date_month", "date_hour").parquet(path.getAbsolutePath) + val input = spark.read.parquet(path.getAbsolutePath).select("id", + "date_month", "date_hour", "data") + + assert(input.schema.sameType(input.schema)) + checkAnswer(input, data) + } + } } From aadf9535b4a11b42fd9d72f636576d2da0766199 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Tue, 20 Feb 2018 16:04:22 +0800 Subject: [PATCH 373/774] [SPARK-23203][SQL] DataSourceV2: Use immutable logical plans. ## What changes were proposed in this pull request? SPARK-23203: DataSourceV2 should use immutable catalyst trees instead of wrapping a mutable DataSourceV2Reader. This commit updates DataSourceV2Relation and consolidates much of the DataSourceV2 API requirements for the read path in it. Instead of wrapping a reader that changes, the relation lazily produces a reader from its configuration. This commit also updates the predicate and projection push-down. Instead of the implementation from SPARK-22197, this reuses the rule matching from the Hive and DataSource read paths (using `PhysicalOperation`) and copies most of the implementation of `SparkPlanner.pruneFilterProject`, with updates for DataSourceV2. By reusing the implementation from other read paths, this should have fewer regressions from other read paths and is less code to maintain. The new push-down rules also supports the following edge cases: * The output of DataSourceV2Relation should be what is returned by the reader, in case the reader can only partially satisfy the requested schema projection * The requested projection passed to the DataSourceV2Reader should include filter columns * The push-down rule may be run more than once if filters are not pushed through projections ## How was this patch tested? Existing push-down and read tests. Author: Ryan Blue Closes #20387 from rdblue/SPARK-22386-push-down-immutable-trees. --- .../kafka010/KafkaContinuousSourceSuite.scala | 19 +- .../sql/kafka010/KafkaContinuousTest.scala | 4 +- .../kafka010/KafkaMicroBatchSourceSuite.scala | 4 +- .../apache/spark/sql/DataFrameReader.scala | 41 +--- .../datasources/v2/DataSourceV2Relation.scala | 212 +++++++++++++++++- .../datasources/v2/DataSourceV2Strategy.scala | 7 +- .../v2/PushDownOperatorsToDataSource.scala | 159 ++++--------- .../continuous/ContinuousExecution.scala | 2 +- .../sql/sources/v2/DataSourceV2Suite.scala | 2 +- .../spark/sql/streaming/StreamTest.scala | 6 +- 10 files changed, 269 insertions(+), 187 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index a7083fa4e3417..f679e9bfc0450 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -17,20 +17,9 @@ package org.apache.spark.sql.kafka010 -import java.util.Properties -import java.util.concurrent.atomic.AtomicInteger - -import org.scalatest.time.SpanSugar._ -import scala.collection.mutable -import scala.util.Random - -import org.apache.spark.SparkContext -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.streaming.{StreamTest, Trigger} -import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation +import org.apache.spark.sql.streaming.Trigger // Run tests in KafkaSourceSuiteBase in continuous execution mode. class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest @@ -71,7 +60,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r }.exists { r => // Ensure the new topic is present and the old topic is gone. r.knownPartitions.exists(_.topic == topic2) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index 5a1a14f7a307a..48ac3fc1e8f9d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.SparkContext import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.streaming.Trigger @@ -47,7 +47,7 @@ trait KafkaContinuousTest extends KafkaSourceTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index ed4ecfeafa972..89c9ef4cc73b5 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -35,7 +35,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext import org.apache.spark.sql.{Dataset, ForeachWriter} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} @@ -119,7 +119,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + case StreamingDataSourceV2Relation(_, reader: KafkaContinuousReader) => reader } }) }.distinct diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index fcaf8d618c168..4274f120a375a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.{DataSourceV2, ReadSupport, ReadSupportWithSchema} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -189,39 +189,16 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance() - val options = new DataSourceOptions((extraOptions ++ - DataSourceV2Utils.extractSessionConfigs( - ds = ds.asInstanceOf[DataSourceV2], - conf = sparkSession.sessionState.conf)).asJava) - - // Streaming also uses the data source V2 API. So it may be that the data source implements - // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading - // the dataframe as a v1 source. - val reader = (ds, userSpecifiedSchema) match { - case (ds: ReadSupportWithSchema, Some(schema)) => - ds.createReader(schema, options) - - case (ds: ReadSupport, None) => - ds.createReader(options) - - case (ds: ReadSupportWithSchema, None) => - throw new AnalysisException(s"A schema needs to be specified when using $ds.") - - case (ds: ReadSupport, Some(schema)) => - val reader = ds.createReader(options) - if (reader.readSchema() != schema) { - throw new AnalysisException(s"$ds does not allow user-specified schemas.") - } - reader - - case _ => null // fall back to v1 - } + val ds = cls.newInstance().asInstanceOf[DataSourceV2] + if (ds.isInstanceOf[ReadSupport] || ds.isInstanceOf[ReadSupportWithSchema]) { + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + ds = ds, conf = sparkSession.sessionState.conf) + Dataset.ofRows(sparkSession, DataSourceV2Relation.create( + ds, extraOptions.toMap ++ sessionOptions, + userSpecifiedSchema = userSpecifiedSchema)) - if (reader == null) { - loadV1Source(paths: _*) } else { - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + loadV1Source(paths: _*) } } else { loadV1Source(paths: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 38f6b15224788..a98dd4866f82a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -17,17 +17,80 @@ package org.apache.spark.sql.execution.datasources.v2 +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} -import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsReportStatistics} +import org.apache.spark.sql.types.StructType case class DataSourceV2Relation( - output: Seq[AttributeReference], - reader: DataSourceReader) - extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder { + source: DataSourceV2, + options: Map[String, String], + projection: Seq[AttributeReference], + filters: Option[Seq[Expression]] = None, + userSpecifiedSchema: Option[StructType] = None) extends LeafNode with MultiInstanceRelation { + + import DataSourceV2Relation._ + + override def simpleString: String = { + s"DataSourceV2Relation(source=${source.name}, " + + s"schema=[${output.map(a => s"$a ${a.dataType.simpleString}").mkString(", ")}], " + + s"filters=[${pushedFilters.mkString(", ")}], options=$options)" + } + + override lazy val schema: StructType = reader.readSchema() + + override lazy val output: Seq[AttributeReference] = { + // use the projection attributes to avoid assigning new ids. fields that are not projected + // will be assigned new ids, which is okay because they are not projected. + val attrMap = projection.map(a => a.name -> a).toMap + schema.map(f => attrMap.getOrElse(f.name, + AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())) + } + + private lazy val v2Options: DataSourceOptions = makeV2Options(options) + + lazy val ( + reader: DataSourceReader, + unsupportedFilters: Seq[Expression], + pushedFilters: Seq[Expression]) = { + val newReader = userSpecifiedSchema match { + case Some(s) => + source.asReadSupportWithSchema.createReader(s, v2Options) + case _ => + source.asReadSupport.createReader(v2Options) + } + + DataSourceV2Relation.pushRequiredColumns(newReader, projection.toStructType) - override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] + val (remainingFilters, pushedFilters) = filters match { + case Some(filterSeq) => + DataSourceV2Relation.pushFilters(newReader, filterSeq) + case _ => + (Nil, Nil) + } + + (newReader, remainingFilters, pushedFilters) + } + + override def doCanonicalize(): LogicalPlan = { + val c = super.doCanonicalize().asInstanceOf[DataSourceV2Relation] + + // override output with canonicalized output to avoid attempting to configure a reader + val canonicalOutput: Seq[AttributeReference] = this.output + .map(a => QueryPlan.normalizeExprId(a, projection)) + + new DataSourceV2Relation(c.source, c.options, c.projection) { + override lazy val output: Seq[AttributeReference] = canonicalOutput + } + } override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => @@ -37,7 +100,9 @@ case class DataSourceV2Relation( } override def newInstance(): DataSourceV2Relation = { - copy(output = output.map(_.newInstance())) + // projection is used to maintain id assignment. + // if projection is not set, use output so the copy is not equal to the original + copy(projection = projection.map(_.newInstance())) } } @@ -45,14 +110,137 @@ case class DataSourceV2Relation( * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical * to the non-streaming relation. */ -class StreamingDataSourceV2Relation( +case class StreamingDataSourceV2Relation( output: Seq[AttributeReference], - reader: DataSourceReader) extends DataSourceV2Relation(output, reader) { + reader: DataSourceReader) + extends LeafNode with DataSourceReaderHolder with MultiInstanceRelation { override def isStreaming: Boolean = true + + override def canEqual(other: Any): Boolean = other.isInstanceOf[StreamingDataSourceV2Relation] + + override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance())) + + override def computeStats(): Statistics = reader match { + case r: SupportsReportStatistics => + Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + case _ => + Statistics(sizeInBytes = conf.defaultSizeInBytes) + } } object DataSourceV2Relation { - def apply(reader: DataSourceReader): DataSourceV2Relation = { - new DataSourceV2Relation(reader.readSchema().toAttributes, reader) + private implicit class SourceHelpers(source: DataSourceV2) { + def asReadSupport: ReadSupport = { + source match { + case support: ReadSupport => + support + case _: ReadSupportWithSchema => + // this method is only called if there is no user-supplied schema. if there is no + // user-supplied schema and ReadSupport was not implemented, throw a helpful exception. + throw new AnalysisException(s"Data source requires a user-supplied schema: $name") + case _ => + throw new AnalysisException(s"Data source is not readable: $name") + } + } + + def asReadSupportWithSchema: ReadSupportWithSchema = { + source match { + case support: ReadSupportWithSchema => + support + case _: ReadSupport => + throw new AnalysisException( + s"Data source does not support user-supplied schema: $name") + case _ => + throw new AnalysisException(s"Data source is not readable: $name") + } + } + + def name: String = { + source match { + case registered: DataSourceRegister => + registered.shortName() + case _ => + source.getClass.getSimpleName + } + } + } + + private def makeV2Options(options: Map[String, String]): DataSourceOptions = { + new DataSourceOptions(options.asJava) + } + + private def schema( + source: DataSourceV2, + v2Options: DataSourceOptions, + userSchema: Option[StructType]): StructType = { + val reader = userSchema match { + // TODO: remove this case because it is confusing for users + case Some(s) if !source.isInstanceOf[ReadSupportWithSchema] => + val reader = source.asReadSupport.createReader(v2Options) + if (reader.readSchema() != s) { + throw new AnalysisException(s"${source.name} does not allow user-specified schemas.") + } + reader + case Some(s) => + source.asReadSupportWithSchema.createReader(s, v2Options) + case _ => + source.asReadSupport.createReader(v2Options) + } + reader.readSchema() + } + + def create( + source: DataSourceV2, + options: Map[String, String], + filters: Option[Seq[Expression]] = None, + userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { + val projection = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes + DataSourceV2Relation(source, options, projection, filters, + // if the source does not implement ReadSupportWithSchema, then the userSpecifiedSchema must + // be equal to the reader's schema. the schema method enforces this. because the user schema + // and the reader's schema are identical, drop the user schema. + if (source.isInstanceOf[ReadSupportWithSchema]) userSpecifiedSchema else None) + } + + private def pushRequiredColumns(reader: DataSourceReader, struct: StructType): Unit = { + reader match { + case projectionSupport: SupportsPushDownRequiredColumns => + projectionSupport.pruneColumns(struct) + case _ => + } + } + + private def pushFilters( + reader: DataSourceReader, + filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + reader match { + case catalystFilterSupport: SupportsPushDownCatalystFilters => + ( + catalystFilterSupport.pushCatalystFilters(filters.toArray), + catalystFilterSupport.pushedCatalystFilters() + ) + + case filterSupport: SupportsPushDownFilters => + // A map from original Catalyst expressions to corresponding translated data source + // filters. If a predicate is not in this map, it means it cannot be pushed down. + val translatedMap: Map[Expression, Filter] = filters.flatMap { p => + DataSourceStrategy.translateFilter(p).map(f => p -> f) + }.toMap + + // Catalyst predicate expressions that cannot be converted to data source filters. + val nonConvertiblePredicates = filters.filterNot(translatedMap.contains) + + // Data source filters that cannot be pushed down. An unhandled filter means + // the data source cannot guarantee the rows returned can pass the filter. + // As a result we must return it so Spark can plan an extra filter operator. + val unhandledFilters = filterSupport.pushFilters(translatedMap.values.toArray).toSet + val (unhandledPredicates, pushedPredicates) = translatedMap.partition { case (_, f) => + unhandledFilters.contains(f) + } + + (nonConvertiblePredicates ++ unhandledPredicates.keys, pushedPredicates.keys.toSeq) + + case _ => (filters, Nil) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index df5b524485f54..c4e7644683c36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,8 +23,11 @@ import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case DataSourceV2Relation(output, reader) => - DataSourceV2ScanExec(output, reader) :: Nil + case relation: DataSourceV2Relation => + DataSourceV2ScanExec(relation.output, relation.reader) :: Nil + + case relation: StreamingDataSourceV2Relation => + DataSourceV2ScanExec(relation.output, relation.reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 1ca6cbf061b4e..f23d228567241 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -17,130 +17,55 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, AttributeSet, Expression, NamedExpression, PredicateHelper} -import org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.sources -import org.apache.spark.sql.sources.v2.reader._ -/** - * Pushes down various operators to the underlying data source for better performance. Operators are - * being pushed down with a specific order. As an example, given a LIMIT has a FILTER child, you - * can't push down LIMIT if FILTER is not completely pushed down. When both are pushed down, the - * data source should execute FILTER before LIMIT. And required columns are calculated at the end, - * because when more operators are pushed down, we may need less columns at Spark side. - */ -object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHelper { - override def apply(plan: LogicalPlan): LogicalPlan = { - // Note that, we need to collect the target operator along with PROJECT node, as PROJECT may - // appear in many places for column pruning. - // TODO: Ideally column pruning should be implemented via a plan property that is propagated - // top-down, then we can simplify the logic here and only collect target operators. - val filterPushed = plan transformUp { - case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => - val (candidates, nonDeterministic) = - splitConjunctivePredicates(condition).partition(_.deterministic) - - val stayUpFilters: Seq[Expression] = reader match { - case r: SupportsPushDownCatalystFilters => - r.pushCatalystFilters(candidates.toArray) - - case r: SupportsPushDownFilters => - // A map from original Catalyst expressions to corresponding translated data source - // filters. If a predicate is not in this map, it means it cannot be pushed down. - val translatedMap: Map[Expression, sources.Filter] = candidates.flatMap { p => - DataSourceStrategy.translateFilter(p).map(f => p -> f) - }.toMap - - // Catalyst predicate expressions that cannot be converted to data source filters. - val nonConvertiblePredicates = candidates.filterNot(translatedMap.contains) - - // Data source filters that cannot be pushed down. An unhandled filter means - // the data source cannot guarantee the rows returned can pass the filter. - // As a result we must return it so Spark can plan an extra filter operator. - val unhandledFilters = r.pushFilters(translatedMap.values.toArray).toSet - val unhandledPredicates = translatedMap.filter { case (_, f) => - unhandledFilters.contains(f) - }.keys - - nonConvertiblePredicates ++ unhandledPredicates - - case _ => candidates - } - - val filterCondition = (stayUpFilters ++ nonDeterministic).reduceLeftOption(And) - val withFilter = filterCondition.map(Filter(_, r)).getOrElse(r) - if (withFilter.output == fields) { - withFilter - } else { - Project(fields, withFilter) - } - } - - // TODO: add more push down rules. - - val columnPruned = pushDownRequiredColumns(filterPushed, filterPushed.outputSet) - // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. - RemoveRedundantProject(columnPruned) - } - - // TODO: nested fields pruning - private def pushDownRequiredColumns( - plan: LogicalPlan, requiredByParent: AttributeSet): LogicalPlan = { - plan match { - case p @ Project(projectList, child) => - val required = projectList.flatMap(_.references) - p.copy(child = pushDownRequiredColumns(child, AttributeSet(required))) - - case f @ Filter(condition, child) => - val required = requiredByParent ++ condition.references - f.copy(child = pushDownRequiredColumns(child, required)) - - case relation: DataSourceV2Relation => relation.reader match { - case reader: SupportsPushDownRequiredColumns => - // TODO: Enable the below assert after we make `DataSourceV2Relation` immutable. Fow now - // it's possible that the mutable reader being updated by someone else, and we need to - // always call `reader.pruneColumns` here to correct it. - // assert(relation.output.toStructType == reader.readSchema(), - // "Schema of data source reader does not match the relation plan.") - - val requiredColumns = relation.output.filter(requiredByParent.contains) - reader.pruneColumns(requiredColumns.toStructType) - - val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap - val newOutput = reader.readSchema().map(_.name).map(nameToAttr) - relation.copy(output = newOutput) - - case _ => relation +object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { + override def apply( + plan: LogicalPlan): LogicalPlan = plan transformUp { + // PhysicalOperation guarantees that filters are deterministic; no need to check + case PhysicalOperation(project, newFilters, relation : DataSourceV2Relation) => + // merge the filters + val filters = relation.filters match { + case Some(existing) => + existing ++ newFilters + case _ => + newFilters } - // TODO: there may be more operators that can be used to calculate the required columns. We - // can add more and more in the future. - case _ => plan.mapChildren(c => pushDownRequiredColumns(c, c.outputSet)) - } - } - - /** - * Finds a Filter node(with an optional Project child) above data source relation. - */ - object FilterAndProject { - // returns the project list, the filter condition and the data source relation. - def unapply(plan: LogicalPlan) - : Option[(Seq[NamedExpression], Expression, DataSourceV2Relation)] = plan match { + val projectAttrs = project.map(_.toAttribute) + val projectSet = AttributeSet(project.flatMap(_.references)) + val filterSet = AttributeSet(filters.flatMap(_.references)) + + val projection = if (filterSet.subsetOf(projectSet) && + AttributeSet(projectAttrs) == projectSet) { + // When the required projection contains all of the filter columns and column pruning alone + // can produce the required projection, push the required projection. + // A final projection may still be needed if the data source produces a different column + // order or if it cannot prune all of the nested columns. + projectAttrs + } else { + // When there are filter columns not already in the required projection or when the required + // projection is more complicated than column pruning, base column pruning on the set of + // all columns needed by both. + (projectSet ++ filterSet).toSeq + } - case Filter(condition, r: DataSourceV2Relation) => Some((r.output, condition, r)) + val newRelation = relation.copy( + projection = projection.asInstanceOf[Seq[AttributeReference]], + filters = Some(filters)) - case Filter(condition, Project(fields, r: DataSourceV2Relation)) - if fields.forall(_.deterministic) => - val attributeMap = AttributeMap(fields.map(e => e.toAttribute -> e)) - val substituted = condition.transform { - case a: Attribute => attributeMap.getOrElse(a, a) - } - Some((fields, substituted, r)) + // Add a Filter for any filters that could not be pushed + val unpushedFilter = newRelation.unsupportedFilters.reduceLeftOption(And) + val filtered = unpushedFilter.map(Filter(_, newRelation)).getOrElse(newRelation) - case _ => None - } + // Add a Project to ensure the output matches the required projection + if (newRelation.output != projectAttrs) { + Project(project, filtered) + } else { + filtered + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index c3294d64b10cd..2c1d6c509d21b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -201,7 +201,7 @@ class ContinuousExecution( val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { - case DataSourceV2Relation(_, r: ContinuousReader) => r + case StreamingDataSourceV2Relation(_, r: ContinuousReader) => r }.head reportTimeTaken("queryPlanning") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index a1c87fb15542c..1157a350461d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -146,7 +146,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { Seq(classOf[SchemaRequiredDataSource], classOf[JavaSchemaRequiredDataSource]).foreach { cls => withClue(cls.getName) { val e = intercept[AnalysisException](spark.read.format(cls.getName).load()) - assert(e.message.contains("A schema needs to be specified")) + assert(e.message.contains("requires a user-supplied schema")) val schema = new StructType().add("i", "int").add("s", "string") val df = spark.read.format(cls.getName).schema(schema).load() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 37fe595529baf..159dd0ecb5902 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -38,9 +38,9 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch} +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ @@ -605,7 +605,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be plan .collect { case StreamingExecutionRelation(s, _) => s - case DataSourceV2Relation(_, r) => r + case StreamingDataSourceV2Relation(_, r) => r } .zipWithIndex .find(_._1 == source) From 862fa697d829cdddf0f25e5613c91b040f9d9652 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Tue, 20 Feb 2018 20:26:26 +0900 Subject: [PATCH 374/774] [SPARK-23240][PYTHON] Better error message when extraneous data in pyspark.daemon's stdout ## What changes were proposed in this pull request? Print more helpful message when daemon module's stdout is empty or contains a bad port number. ## How was this patch tested? Manually recreated the environmental issues that caused the mysterious exceptions at one site. Tested that the expected messages are logged. Also, ran all scala unit tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bruce Robbins Closes #20424 from bersprockets/SPARK-23240_prop2. --- .../api/python/PythonWorkerFactory.scala | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 30976ac752a8a..2340580b54f67 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.python -import java.io.{DataInputStream, DataOutputStream, InputStream, OutputStreamWriter} +import java.io.{DataInputStream, DataOutputStream, EOFException, InputStream, OutputStreamWriter} import java.net.{InetAddress, ServerSocket, Socket, SocketException} import java.nio.charset.StandardCharsets import java.util.Arrays @@ -182,7 +182,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String try { // Create and start the daemon - val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", daemonModule)) + val command = Arrays.asList(pythonExec, "-m", daemonModule) + val pb = new ProcessBuilder(command) val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) @@ -191,7 +192,29 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String daemon = pb.start() val in = new DataInputStream(daemon.getInputStream) - daemonPort = in.readInt() + try { + daemonPort = in.readInt() + } catch { + case _: EOFException => + throw new SparkException(s"No port number in $daemonModule's stdout") + } + + // test that the returned port number is within a valid range. + // note: this does not cover the case where the port number + // is arbitrary data but is also coincidentally within range + if (daemonPort < 1 || daemonPort > 0xffff) { + val exceptionMessage = f""" + |Bad data in $daemonModule's standard output. Invalid port number: + | $daemonPort (0x$daemonPort%08x) + |Python command to execute the daemon was: + | ${command.asScala.mkString(" ")} + |Check that you don't have any unexpected modules or libraries in + |your PYTHONPATH: + | $pythonPath + |Also, check if you have a sitecustomize.py module in your python path, + |or in your python installation, that is printing to standard output""" + throw new SparkException(exceptionMessage.stripMargin) + } // Redirect daemon stdout and stderr redirectStreamsToStderr(in, daemon.getErrorStream) From 189f56f3dcdad4d997248c01aa5490617f018bd0 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 20 Feb 2018 07:51:30 -0600 Subject: [PATCH 375/774] [SPARK-23383][BUILD][MINOR] Make a distribution should exit with usage while detecting wrong options ## What changes were proposed in this pull request? ```shell ./dev/make-distribution.sh --name ne-1.0.0-SNAPSHOT xyz --tgz -Phadoop-2.7 +++ dirname ./dev/make-distribution.sh ++ cd ./dev/.. ++ pwd + SPARK_HOME=/Users/Kent/Documents/spark + DISTDIR=/Users/Kent/Documents/spark/dist + MAKE_TGZ=false + MAKE_PIP=false + MAKE_R=false + NAME=none + MVN=/Users/Kent/Documents/spark/build/mvn + (( 5 )) + case $1 in + NAME=ne-1.0.0-SNAPSHOT + shift + shift + (( 3 )) + case $1 in + break + '[' -z /Users/Kent/.jenv/candidates/java/current ']' + '[' -z /Users/Kent/.jenv/candidates/java/current ']' ++ command -v git + '[' /usr/local/bin/git ']' ++ git rev-parse --short HEAD + GITREV=98ea6a7 + '[' '!' -z 98ea6a7 ']' + GITREVSTRING=' (git revision 98ea6a7)' + unset GITREV ++ command -v /Users/Kent/Documents/spark/build/mvn + '[' '!' /Users/Kent/Documents/spark/build/mvn ']' ++ /Users/Kent/Documents/spark/build/mvn help:evaluate -Dexpression=project.version xyz --tgz -Phadoop-2.7 ++ grep -v INFO ++ tail -n 1 + VERSION=' -X,--debug Produce execution debug output' ``` It is better to declare the mistakes and exit with usage than `break` ## How was this patch tested? manually cc srowen Author: Kent Yao Closes #20571 from yaooqinn/SPARK-23383. --- dev/make-distribution.sh | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 8b02446b2f15f..84233c64caa9c 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -72,9 +72,17 @@ while (( "$#" )); do --help) exit_with_usage ;; - *) + --*) + echo "Error: $1 is not supported" + exit_with_usage + ;; + -*) break ;; + *) + echo "Error: $1 is not supported" + exit_with_usage + ;; esac shift done From 83c008762af444eef73d835eb6f506ecf5aebc17 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 20 Feb 2018 09:14:56 -0800 Subject: [PATCH 376/774] [SPARK-23456][SPARK-21783] Turn on `native` ORC impl and PPD by default ## What changes were proposed in this pull request? Apache Spark 2.3 introduced `native` ORC supports with vectorization and many fixes. However, it's shipped as a not-default option. This PR enables `native` ORC implementation and predicate-pushdown by default for Apache Spark 2.4. We will improve and stabilize ORC data source before Apache Spark 2.4. And, eventually, Apache Spark will drop old Hive-based ORC code. ## How was this patch tested? Pass the Jenkins with existing tests. Author: Dongjoon Hyun Closes #20634 from dongjoon-hyun/SPARK-23456. --- docs/sql-programming-guide.md | 6 +++++- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 6 +++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 91e43678481d6..c37c338a134f3 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1018,7 +1018,7 @@ the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is also spark.sql.orc.impl hive - The name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4.1. `hive` means the ORC library in Hive 1.2.1. + The name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4. `hive` means the ORC library in Hive 1.2.1. spark.sql.orc.enableVectorizedReader @@ -1797,6 +1797,10 @@ working with timestamps in `pandas_udf`s to get the best performance, see # Migration Guide +## Upgrading From Spark SQL 2.3 to 2.4 + + - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. + ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e75e1d66ebcf8..ce3f94618edeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -399,11 +399,11 @@ object SQLConf { val ORC_IMPLEMENTATION = buildConf("spark.sql.orc.impl") .doc("When native, use the native version of ORC support instead of the ORC library in Hive " + - "1.2.1. It is 'hive' by default.") + "1.2.1. It is 'hive' by default prior to Spark 2.4.") .internal() .stringConf .checkValues(Set("hive", "native")) - .createWithDefault("hive") + .createWithDefault("native") val ORC_VECTORIZED_READER_ENABLED = buildConf("spark.sql.orc.enableVectorizedReader") .doc("Enables vectorized orc decoding.") @@ -426,7 +426,7 @@ object SQLConf { val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val HIVE_VERIFY_PARTITION_PATH = buildConf("spark.sql.hive.verifyPartitionPath") .doc("When true, check all the partition paths under the table\'s root directory " + From 3e48f3b9ee7645e4218ad3ff7559e578d4bd9667 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 20 Feb 2018 16:02:44 -0800 Subject: [PATCH 377/774] [SPARK-23434][SQL] Spark should not warn `metadata directory` for a HDFS file path ## What changes were proposed in this pull request? In a kerberized cluster, when Spark reads a file path (e.g. `people.json`), it warns with a wrong warning message during looking up `people.json/_spark_metadata`. The root cause of this situation is the difference between `LocalFileSystem` and `DistributedFileSystem`. `LocalFileSystem.exists()` returns `false`, but `DistributedFileSystem.exists` raises `org.apache.hadoop.security.AccessControlException`. ```scala scala> spark.version res0: String = 2.4.0-SNAPSHOT scala> spark.read.json("file:///usr/hdp/current/spark-client/examples/src/main/resources/people.json").show +----+-------+ | age| name| +----+-------+ |null|Michael| | 30| Andy| | 19| Justin| +----+-------+ scala> spark.read.json("hdfs:///tmp/people.json") 18/02/15 05:00:48 WARN streaming.FileStreamSink: Error while looking for metadata directory. 18/02/15 05:00:48 WARN streaming.FileStreamSink: Error while looking for metadata directory. ``` After this PR, ```scala scala> spark.read.json("hdfs:///tmp/people.json").show +----+-------+ | age| name| +----+-------+ |null|Michael| | 30| Andy| | 19| Justin| +----+-------+ ``` ## How was this patch tested? Manual. Author: Dongjoon Hyun Closes #20616 from dongjoon-hyun/SPARK-23434. --- .../spark/sql/execution/streaming/FileStreamSink.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 2715fa93d0e98..87a17cebdc10c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -42,9 +42,11 @@ object FileStreamSink extends Logging { try { val hdfsPath = new Path(singlePath) val fs = hdfsPath.getFileSystem(hadoopConf) - val metadataPath = new Path(hdfsPath, metadataDir) - val res = fs.exists(metadataPath) - res + if (fs.isDirectory(hdfsPath)) { + fs.exists(new Path(hdfsPath, metadataDir)) + } else { + false + } } catch { case NonFatal(e) => logWarning(s"Error while looking for metadata directory.") From 2ba77ed9e51922303e3c3533e368b95788bd7de5 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 20 Feb 2018 17:54:06 -0800 Subject: [PATCH 378/774] [SPARK-23470][UI] Use first attempt of last stage to define job description. This is much faster than finding out what the last attempt is, and the data should be the same. There's room for improvement in this page (like only loading data for the jobs being shown, instead of loading all available jobs and sorting them), but this should bring performance on par with the 2.2 version. Author: Marcelo Vanzin Closes #20644 from vanzin/SPARK-23470. --- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index a9265d4dbcdfb..ac83de10f9237 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -1048,7 +1048,7 @@ private[ui] object ApiHelper { } def lastStageNameAndDescription(store: AppStatusStore, job: JobData): (String, String) = { - val stage = store.asOption(store.lastStageAttempt(job.stageIds.max)) + val stage = store.asOption(store.stageAttempt(job.stageIds.max, 0)) (stage.map(_.name).getOrElse(""), stage.flatMap(_.description).getOrElse(job.name)) } From 6d398c05cbad69aa9093429e04ae44c73b81cd5a Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 20 Feb 2018 18:06:21 -0800 Subject: [PATCH 379/774] [SPARK-23468][CORE] Stringify auth secret before storing it in credentials. The secret is used as a string in many parts of the code, so it has to be turned into a hex string to avoid issues such as the random byte sequence not containing a valid UTF8 sequence. Author: Marcelo Vanzin Closes #20643 from vanzin/SPARK-23468. --- core/src/main/scala/org/apache/spark/SecurityManager.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 4c1dbe3ffb4ad..5b15a1c57779d 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -541,7 +541,8 @@ private[spark] class SecurityManager( rnd.nextBytes(secretBytes) val creds = new Credentials() - creds.addSecretKey(SECRET_LOOKUP_KEY, secretBytes) + val secretStr = HashCodes.fromBytes(secretBytes).toString() + creds.addSecretKey(SECRET_LOOKUP_KEY, secretStr.getBytes(UTF_8)) UserGroupInformation.getCurrentUser().addCredentials(creds) } From 601d653bff9160db8477f86d961e609fc2190237 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 20 Feb 2018 18:16:10 -0800 Subject: [PATCH 380/774] [SPARK-23454][SS][DOCS] Added trigger information to the Structured Streaming programming guide ## What changes were proposed in this pull request? - Added clear information about triggers - Made the semantics guarantees of watermarks more clear for streaming aggregations and stream-stream joins. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Tathagata Das Closes #20631 from tdas/SPARK-23454. --- .../structured-streaming-programming-guide.md | 214 +++++++++++++++++- 1 file changed, 207 insertions(+), 7 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 48d6d0b542cc0..9a83f157452ad 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -904,7 +904,7 @@ windowedCounts <- count(
    -### Handling Late Data and Watermarking +#### Handling Late Data and Watermarking Now consider what happens if one of the events arrives late to the application. For example, say, a word generated at 12:04 (i.e. event time) could be received by the application at 12:11. The application should use the time 12:04 instead of 12:11 @@ -925,7 +925,9 @@ specifying the event time column and the threshold on how late the data is expec event time. For a specific window starting at time `T`, the engine will maintain state and allow late data to update the state until `(max event time seen by the engine - late threshold > T)`. In other words, late data within the threshold will be aggregated, -but data later than the threshold will be dropped. Let's understand this with an example. We can +but data later than the threshold will start getting dropped +(see [later]((#semantic-guarantees-of-aggregation-with-watermarking)) +in the section for the exact guarantees). Let's understand this with an example. We can easily define watermarking on the previous example using `withWatermark()` as shown below.
    @@ -1031,7 +1033,9 @@ then drops intermediate state of a window < watermark, and appends the final counts to the Result Table/sink. For example, the final counts of window `12:00 - 12:10` is appended to the Result Table only after the watermark is updated to `12:11`. -**Conditions for watermarking to clean aggregation state** +##### Conditions for watermarking to clean aggregation state +{:.no_toc} + It is important to note that the following conditions must be satisfied for the watermarking to clean the state in aggregation queries *(as of Spark 2.1.1, subject to change in the future)*. @@ -1051,6 +1055,16 @@ from the aggregation column. For example, `df.groupBy("time").count().withWatermark("time", "1 min")` is invalid in Append output mode. +##### Semantic Guarantees of Aggregation with Watermarking +{:.no_toc} + +- A watermark delay (set with `withWatermark`) of "2 hours" guarantees that the engine will never +drop any data that is less than 2 hours delayed. In other words, any data less than 2 hours behind +(in terms of event-time) the latest data processed till then is guaranteed to be aggregated. + +- However, the guarantee is strict only in one direction. Data delayed by more than 2 hours is +not guaranteed to be dropped; it may or may not get aggregated. More delayed is the data, less +likely is the engine going to process it. ### Join Operations Structured Streaming supports joining a streaming Dataset/DataFrame with a static Dataset/DataFrame @@ -1062,7 +1076,7 @@ Dataset/DataFrame will be the exactly the same as if it was with a static Datase containing the same data in the stream. -#### Stream-static joins +#### Stream-static Joins Since the introduction in Spark 2.0, Structured Streaming has supported joins (inner join and some type of outer joins) between a streaming and a static DataFrame/Dataset. Here is a simple example. @@ -1269,6 +1283,12 @@ joined <- join(
+###### Semantic Guarantees of Stream-stream Inner Joins with Watermarking +{:.no_toc} +This is similar to the [guarantees provided by watermarking on aggregations](#semantic-guarantees-of-aggregation-with-watermarking). +A watermark delay of "2 hours" guarantees that the engine will never drop any data that is less than + 2 hours delayed. But data delayed by more than 2 hours may or may not get processed. + ##### Outer Joins with Watermarking While the watermark + event-time constraints is optional for inner joins, for left and right outer joins they must be specified. This is because for generating the NULL results in outer join, the @@ -1347,7 +1367,14 @@ joined <- join(
-There are a few points to note regarding outer joins. +###### Semantic Guarantees of Stream-stream Outer Joins with Watermarking +{:.no_toc} +Outer joins have the same guarantees as [inner joins](#semantic-guarantees-of-stream-stream-inner-joins-with-watermarking) +regarding watermark delays and whether data will be dropped or not. + +###### Caveats +{:.no_toc} +There are a few important characteristics to note regarding how the outer results are generated. - *The outer NULL results will be generated with a delay that depends on the specified watermark delay and the time range condition.* This is because the engine has to wait for that long to ensure @@ -1962,7 +1989,7 @@ head(sql("select * from aggregates")) -#### Using Foreach +##### Using Foreach The `foreach` operation allows arbitrary operations to be computed on the output data. As of Spark 2.1, this is available only for Scala and Java. To use this, you will have to implement the interface `ForeachWriter` ([Scala](api/scala/index.html#org.apache.spark.sql.ForeachWriter)/[Java](api/java/org/apache/spark/sql/ForeachWriter.html) docs), which has methods that get called whenever there is a sequence of rows generated as output after a trigger. Note the following important points. @@ -1979,6 +2006,172 @@ which has methods that get called whenever there is a sequence of rows generated - Whenever `open` is called, `close` will also be called (unless the JVM exits due to some error). This is true even if `open` returns false. If there is any error in processing and writing the data, `close` will be called with the error. It is your responsibility to clean up state (e.g. connections, transactions, etc.) that have been created in `open` such that there are no resource leaks. +#### Triggers +The trigger settings of a streaming query defines the timing of streaming data processing, whether +the query is going to executed as micro-batch query with a fixed batch interval or as a continuous processing query. +Here are the different kinds of triggers that are supported. + + + + + + + + + + + + + + + + + + + + + + +
Trigger TypeDescription
unspecified (default) + If no trigger setting is explicitly specified, then by default, the query will be + executed in micro-batch mode, where micro-batches will be generated as soon as + the previous micro-batch has completed processing. +
Fixed interval micro-batches + The query will be executed with micro-batches mode, where micro-batches will be kicked off + at the user-specified intervals. +
    +
  • If the previous micro-batch completes within the interval, then the engine will wait until + the interval is over before kicking off the next micro-batch.
  • + +
  • If the previous micro-batch takes longer than the interval to complete (i.e. if an + interval boundary is missed), then the next micro-batch will start as soon as the + previous one completes (i.e., it will not wait for the next interval boundary).
  • + +
  • If no new data is available, then no micro-batch will be kicked off.
  • +
+
One-time micro-batch + The query will execute *only one* micro-batch to process all the available data and then + stop on its own. This is useful in scenarios you want to periodically spin up a cluster, + process everything that is available since the last period, and then shutdown the + cluster. In some case, this may lead to significant cost savings. +
Continuous with fixed checkpoint interval
(experimental)
+ The query will be executed in the new low-latency, continuous processing mode. Read more + about this in the Continuous Processing section below. +
+ +Here are a few code examples. + +
+
+ +{% highlight scala %} +import org.apache.spark.sql.streaming.Trigger + +// Default trigger (runs micro-batch as soon as it can) +df.writeStream + .format("console") + .start() + +// ProcessingTime trigger with two-seconds micro-batch interval +df.writeStream + .format("console") + .trigger(Trigger.ProcessingTime("2 seconds")) + .start() + +// One-time trigger +df.writeStream + .format("console") + .trigger(Trigger.Once()) + .start() + +// Continuous trigger with one-second checkpointing interval +df.writeStream + .format("console") + .trigger(Trigger.Continuous("1 second")) + .start() + +{% endhighlight %} + + +
+
+ +{% highlight java %} +import org.apache.spark.sql.streaming.Trigger + +// Default trigger (runs micro-batch as soon as it can) +df.writeStream + .format("console") + .start(); + +// ProcessingTime trigger with two-seconds micro-batch interval +df.writeStream + .format("console") + .trigger(Trigger.ProcessingTime("2 seconds")) + .start(); + +// One-time trigger +df.writeStream + .format("console") + .trigger(Trigger.Once()) + .start(); + +// Continuous trigger with one-second checkpointing interval +df.writeStream + .format("console") + .trigger(Trigger.Continuous("1 second")) + .start(); + +{% endhighlight %} + +
+
+ +{% highlight python %} + +# Default trigger (runs micro-batch as soon as it can) +df.writeStream \ + .format("console") \ + .start() + +# ProcessingTime trigger with two-seconds micro-batch interval +df.writeStream \ + .format("console") \ + .trigger(processingTime='2 seconds') \ + .start() + +# One-time trigger +df.writeStream \ + .format("console") \ + .trigger(once=True) \ + .start() + +# Continuous trigger with one-second checkpointing interval +df.writeStream + .format("console") + .trigger(continuous='1 second') + .start() + +{% endhighlight %} +
+
+ +{% highlight r %} +# Default trigger (runs micro-batch as soon as it can) +write.stream(df, "console") + +# ProcessingTime trigger with two-seconds micro-batch interval +write.stream(df, "console", trigger.processingTime = "2 seconds") + +# One-time trigger +write.stream(df, "console", trigger.once = TRUE) + +# Continuous trigger is not yet supported +{% endhighlight %} +
+
+ + ## Managing Streaming Queries The `StreamingQuery` object created when a query is started can be used to monitor and manage the query. @@ -2516,7 +2709,10 @@ write.stream(aggDF, "memory", outputMode = "complete", checkpointLocation = "pat -# Continuous Processing [Experimental] +# Continuous Processing +## [Experimental] +{:.no_toc} + **Continuous processing** is a new, experimental streaming execution mode introduced in Spark 2.3 that enables low (~1 ms) end-to-end latency with at-least-once fault-tolerance guarantees. Compare this with the default *micro-batch processing* engine which can achieve exactly-once guarantees but achieve latencies of ~100ms at best. For some types of queries (discussed below), you can choose which mode to execute them in without modifying the application logic (i.e. without changing the DataFrame/Dataset operations). To run a supported query in continuous processing mode, all you need to do is specify a **continuous trigger** with the desired checkpoint interval as a parameter. For example, @@ -2589,6 +2785,8 @@ spark \ A checkpoint interval of 1 second means that the continuous processing engine will records the progress of the query every second. The resulting checkpoints are in a format compatible with the micro-batch engine, hence any query can be restarted with any trigger. For example, a supported query started with the micro-batch mode can be restarted in continuous mode, and vice versa. Note that any time you switch to continuous mode, you will get at-least-once fault-tolerance guarantees. ## Supported Queries +{:.no_toc} + As of Spark 2.3, only the following type of queries are supported in the continuous processing mode. - *Operations*: Only map-like Dataset/DataFrame operations are supported in continuous mode, that is, only projections (`select`, `map`, `flatMap`, `mapPartitions`, etc.) and selections (`where`, `filter`, etc.). @@ -2606,6 +2804,8 @@ As of Spark 2.3, only the following type of queries are supported in the continu See [Input Sources](#input-sources) and [Output Sinks](#output-sinks) sections for more details on them. While the console sink is good for testing, the end-to-end low-latency processing can be best observed with Kafka as the source and sink, as this allows the engine to process the data and make the results available in the output topic within milliseconds of the input data being available in the input topic. ## Caveats +{:.no_toc} + - Continuous processing engine launches multiple long-running tasks that continuously read data from sources, process it and continuously write to sinks. The number of tasks required by the query depends on how many partitions the query can read from the sources in parallel. Therefore, before starting a continuous processing query, you must ensure there are enough cores in the cluster to all the tasks in parallel. For example, if you are reading from a Kafka topic that has 10 partitions, then the cluster must have at least 10 cores for the query to make progress. - Stopping a continuous processing stream may produce spurious task termination warnings. These can be safely ignored. - There are currently no automatic retries of failed tasks. Any failure will lead to the query being stopped and it needs to be manually restarted from the checkpoint. From 95e25ed1a8b56937345eff637c0032aea85a503d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 21 Feb 2018 11:26:06 +0800 Subject: [PATCH 381/774] [SPARK-23424][SQL] Add codegenStageId in comment ## What changes were proposed in this pull request? This PR always adds `codegenStageId` in comment of the generated class. This is a replication of #20419 for post-Spark 2.3. Closes #20419 ``` /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage1(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=1 /* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; ... ``` ## How was this patch tested? Existing tests Author: Kazuaki Ishizaki Closes #20612 from kiszk/SPARK-23424. --- .../expressions/codegen/CodeGenerator.scala | 21 ++++++++++++++++--- .../sql/execution/WholeStageCodegenExec.scala | 4 +++- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 31ba29ae8d8ce..60a6f50472504 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1232,14 +1232,29 @@ class CodegenContext { /** * Register a comment and return the corresponding place holder + * + * @param placeholderId an optionally specified identifier for the comment's placeholder. + * The caller should make sure this identifier is unique within the + * compilation unit. If this argument is not specified, a fresh identifier + * will be automatically created and used as the placeholder. + * @param force whether to force registering the comments */ - def registerComment(text: => String): String = { + def registerComment( + text: => String, + placeholderId: String = "", + force: Boolean = false): String = { // By default, disable comments in generated code because computing the comments themselves can // be extremely expensive in certain cases, such as deeply-nested expressions which operate over // inputs with wide schemas. For more details on the performance issues that motivated this // flat, see SPARK-15680. - if (SparkEnv.get != null && SparkEnv.get.conf.getBoolean("spark.sql.codegen.comments", false)) { - val name = freshName("c") + if (force || + SparkEnv.get != null && SparkEnv.get.conf.getBoolean("spark.sql.codegen.comments", false)) { + val name = if (placeholderId != "") { + assert(!placeHolderToComments.contains(placeholderId)) + placeholderId + } else { + freshName("c") + } val comment = if (text.contains("\n") || text.contains("\r")) { text.split("(\r\n)|\r|\n").mkString("/**\n * ", "\n * ", "\n */") } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 0e525b1e22eb9..deb0a044c2fb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -540,7 +540,9 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) ${ctx.registerComment( s"""Codegend pipeline for stage (id=$codegenStageId) - |${this.treeString.trim}""".stripMargin)} + |${this.treeString.trim}""".stripMargin, + "wsc_codegenPipeline")} + ${ctx.registerComment(s"codegenStageId=$codegenStageId", "wsc_codegenStageId", true)} final class $className extends ${classOf[BufferedRowIterator].getName} { private Object[] references; From c8c4441dfdfeda22f8d92e25aee1b6a6269752f9 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 21 Feb 2018 15:10:08 +0800 Subject: [PATCH 382/774] [SPARK-23418][SQL] Fail DataSourceV2 reads when user schema is passed, but not supported. ## What changes were proposed in this pull request? DataSourceV2 initially allowed user-supplied schemas when a source doesn't implement `ReadSupportWithSchema`, as long as the schema was identical to the source's schema. This is confusing behavior because changes to an underlying table can cause a previously working job to fail with an exception that user-supplied schemas are not allowed. This reverts commit adcb25a0624, which was added to #20387 so that it could be removed in a separate JIRA issue and PR. ## How was this patch tested? Existing tests. Author: Ryan Blue Closes #20603 from rdblue/SPARK-23418-revert-adcb25a0624. --- .../datasources/v2/DataSourceV2Relation.scala | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index a98dd4866f82a..cc6cb631e3f06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -174,13 +174,6 @@ object DataSourceV2Relation { v2Options: DataSourceOptions, userSchema: Option[StructType]): StructType = { val reader = userSchema match { - // TODO: remove this case because it is confusing for users - case Some(s) if !source.isInstanceOf[ReadSupportWithSchema] => - val reader = source.asReadSupport.createReader(v2Options) - if (reader.readSchema() != s) { - throw new AnalysisException(s"${source.name} does not allow user-specified schemas.") - } - reader case Some(s) => source.asReadSupportWithSchema.createReader(s, v2Options) case _ => @@ -195,11 +188,7 @@ object DataSourceV2Relation { filters: Option[Seq[Expression]] = None, userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { val projection = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes - DataSourceV2Relation(source, options, projection, filters, - // if the source does not implement ReadSupportWithSchema, then the userSpecifiedSchema must - // be equal to the reader's schema. the schema method enforces this. because the user schema - // and the reader's schema are identical, drop the user schema. - if (source.isInstanceOf[ReadSupportWithSchema]) userSpecifiedSchema else None) + DataSourceV2Relation(source, options, projection, filters, userSpecifiedSchema) } private def pushRequiredColumns(reader: DataSourceReader, struct: StructType): Unit = { From e836c27ce011ca9aef822bef6320b4a7059ec343 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 21 Feb 2018 12:39:36 -0600 Subject: [PATCH 383/774] [SPARK-23217][ML][PYTHON] Add distanceMeasure param to ClusteringEvaluator Python API ## What changes were proposed in this pull request? The PR adds the `distanceMeasure` param to ClusteringEvaluator in the Python API. This allows the user to specify `cosine` as distance measure in addition to the default `squaredEuclidean`. ## How was this patch tested? added UT Author: Marco Gaido Closes #20627 from mgaido91/SPARK-23217_python. --- python/pyspark/ml/evaluation.py | 28 +++++++++++++++++++++++----- python/pyspark/ml/tests.py | 16 ++++++++++++++-- 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 0cbce9b40048f..695d8ab27cc96 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -362,18 +362,21 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, metricName = Param(Params._dummy(), "metricName", "metric name in evaluation (silhouette)", typeConverter=TypeConverters.toString) + distanceMeasure = Param(Params._dummy(), "distanceMeasure", "The distance measure. " + + "Supported options: 'squaredEuclidean' and 'cosine'.", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, predictionCol="prediction", featuresCol="features", - metricName="silhouette"): + metricName="silhouette", distanceMeasure="squaredEuclidean"): """ __init__(self, predictionCol="prediction", featuresCol="features", \ - metricName="silhouette") + metricName="silhouette", distanceMeasure="squaredEuclidean") """ super(ClusteringEvaluator, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.ClusteringEvaluator", self.uid) - self._setDefault(metricName="silhouette") + self._setDefault(metricName="silhouette", distanceMeasure="squaredEuclidean") kwargs = self._input_kwargs self._set(**kwargs) @@ -394,15 +397,30 @@ def getMetricName(self): @keyword_only @since("2.3.0") def setParams(self, predictionCol="prediction", featuresCol="features", - metricName="silhouette"): + metricName="silhouette", distanceMeasure="squaredEuclidean"): """ setParams(self, predictionCol="prediction", featuresCol="features", \ - metricName="silhouette") + metricName="silhouette", distanceMeasure="squaredEuclidean") Sets params for clustering evaluator. """ kwargs = self._input_kwargs return self._set(**kwargs) + @since("2.4.0") + def setDistanceMeasure(self, value): + """ + Sets the value of :py:attr:`distanceMeasure`. + """ + return self._set(distanceMeasure=value) + + @since("2.4.0") + def getDistanceMeasure(self): + """ + Gets the value of `distanceMeasure` + """ + return self.getOrDefault(self.distanceMeasure) + + if __name__ == "__main__": import doctest import tempfile diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6d6737241e06e..116885969345c 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -51,7 +51,7 @@ from pyspark.ml.classification import * from pyspark.ml.clustering import * from pyspark.ml.common import _java2py, _py2java -from pyspark.ml.evaluation import BinaryClassificationEvaluator, \ +from pyspark.ml.evaluation import BinaryClassificationEvaluator, ClusteringEvaluator, \ MulticlassClassificationEvaluator, RegressionEvaluator from pyspark.ml.feature import * from pyspark.ml.fpm import FPGrowth, FPGrowthModel @@ -541,6 +541,15 @@ def test_java_params(self): self.assertEqual(evaluator._java_obj.getMetricName(), "r2") self.assertEqual(evaluatorCopy._java_obj.getMetricName(), "mae") + def test_clustering_evaluator_with_cosine_distance(self): + featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]), + [([1.0, 1.0], 1.0), ([10.0, 10.0], 1.0), ([1.0, 0.5], 2.0), + ([10.0, 4.4], 2.0), ([-1.0, 1.0], 3.0), ([-100.0, 90.0], 3.0)]) + dataset = self.spark.createDataFrame(featureAndPredictions, ["features", "prediction"]) + evaluator = ClusteringEvaluator(predictionCol="prediction", distanceMeasure="cosine") + self.assertEqual(evaluator.getDistanceMeasure(), "cosine") + self.assertTrue(np.isclose(evaluator.evaluate(dataset), 0.992671213, atol=1e-5)) + class FeatureTests(SparkSessionTestCase): @@ -1961,11 +1970,14 @@ def test_java_params(self): import pyspark.ml.feature import pyspark.ml.classification import pyspark.ml.clustering + import pyspark.ml.evaluation import pyspark.ml.pipeline import pyspark.ml.recommendation import pyspark.ml.regression + modules = [pyspark.ml.feature, pyspark.ml.classification, pyspark.ml.clustering, - pyspark.ml.pipeline, pyspark.ml.recommendation, pyspark.ml.regression] + pyspark.ml.evaluation, pyspark.ml.pipeline, pyspark.ml.recommendation, + pyspark.ml.regression] for module in modules: for name, cls in inspect.getmembers(module, inspect.isclass): if not name.endswith('Model') and issubclass(cls, JavaParams)\ From 3fd0ccb13fea44727d970479af1682ef00592147 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 21 Feb 2018 14:56:13 -0800 Subject: [PATCH 384/774] [SPARK-23484][SS] Fix possible race condition in KafkaContinuousReader ## What changes were proposed in this pull request? var `KafkaContinuousReader.knownPartitions` should be threadsafe as it is accessed from multiple threads - the query thread at the time of reader factory creation, and the epoch tracking thread at the time of `needsReconfiguration`. ## How was this patch tested? Existing tests. Author: Tathagata Das Closes #20655 from tdas/SPARK-23484. --- .../org/apache/spark/sql/kafka010/KafkaContinuousReader.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 97a0f66e1880d..ecd1170321f3f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -66,7 +66,7 @@ class KafkaContinuousReader( // Initialized when creating reader factories. If this diverges from the partitions at the latest // offsets, we need to reconfigure. // Exposed outside this object only for unit tests. - private[sql] var knownPartitions: Set[TopicPartition] = _ + @volatile private[sql] var knownPartitions: Set[TopicPartition] = _ override def readSchema: StructType = KafkaOffsetReader.kafkaSchema From 744d5af652ee8cece361cbca31e5201134e0fb42 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 21 Feb 2018 15:37:28 -0800 Subject: [PATCH 385/774] [SPARK-23481][WEBUI] lastStageAttempt should fail when a stage doesn't exist ## What changes were proposed in this pull request? The issue here is `AppStatusStore.lastStageAttempt` will return the next available stage in the store when a stage doesn't exist. This PR adds `last(stageId)` to ensure it returns a correct `StageData` ## How was this patch tested? The new unit test. Author: Shixiong Zhu Closes #20654 from zsxwing/SPARK-23481. --- .../apache/spark/status/AppStatusStore.scala | 6 +++- .../spark/status/AppStatusListenerSuite.scala | 33 +++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index efc28538a33db..688f25a9fdea1 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -95,7 +95,11 @@ private[spark] class AppStatusStore( } def lastStageAttempt(stageId: Int): v1.StageData = { - val it = store.view(classOf[StageDataWrapper]).index("stageId").reverse().first(stageId) + val it = store.view(classOf[StageDataWrapper]) + .index("stageId") + .reverse() + .first(stageId) + .last(stageId) .closeableIterator() try { if (it.hasNext()) { diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 749502709b5c8..673d191b5a4db 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -1121,6 +1121,39 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } + test("lastStageAttempt should fail when the stage doesn't exist") { + val testConf = conf.clone().set(MAX_RETAINED_STAGES, 1) + val listener = new AppStatusListener(store, testConf, true) + val appStore = new AppStatusStore(store) + + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") + val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2") + val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3") + + time += 1 + stage1.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties())) + stage1.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage1)) + + // Make stage 3 complete before stage 2 so that stage 3 will be evicted + time += 1 + stage3.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage3, new Properties())) + stage3.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage3)) + + time += 1 + stage2.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage2, new Properties())) + stage2.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage2)) + + assert(appStore.asOption(appStore.lastStageAttempt(1)) === None) + assert(appStore.asOption(appStore.lastStageAttempt(2)).map(_.stageId) === Some(2)) + assert(appStore.asOption(appStore.lastStageAttempt(3)) === None) + } + test("driver logs") { val listener = new AppStatusListener(store, conf, true) From 45cf714ee6d4eead2fe00794a0d754fa6d33d4a6 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 21 Feb 2018 19:43:11 -0800 Subject: [PATCH 386/774] [SPARK-23475][WEBUI] Skipped stages should be evicted before completed stages ## What changes were proposed in this pull request? The root cause of missing completed stages is because `cleanupStages` will never remove skipped stages. This PR changes the logic to always remove skipped stage first. This is safe since the job itself contains enough information to render skipped stages in the UI. ## How was this patch tested? The new unit tests. Author: Shixiong Zhu Closes #20656 from zsxwing/SPARK-23475. --- .../spark/status/AppStatusListener.scala | 5 ++- .../spark/status/AppStatusListenerSuite.scala | 36 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 79a17e26665fd..5ea161cd0d151 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -915,7 +915,10 @@ private[spark] class AppStatusListener( return } - val view = kvstore.view(classOf[StageDataWrapper]).index("completionTime").first(0L) + // As the completion time of a skipped stage is always -1, we will remove skipped stages first. + // This is safe since the job itself contains enough information to render skipped stages in the + // UI. + val view = kvstore.view(classOf[StageDataWrapper]).index("completionTime") val stages = KVUtils.viewToSeq(view, countToDelete.toInt) { s => s.info.status != v1.StageStatus.ACTIVE && s.info.status != v1.StageStatus.PENDING } diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 673d191b5a4db..1cd71955ad4d9 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -1089,6 +1089,42 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } + test("skipped stages should be evicted before completed stages") { + val testConf = conf.clone().set(MAX_RETAINED_STAGES, 2) + val listener = new AppStatusListener(store, testConf, true) + + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") + val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2") + + // Sart job 1 + time += 1 + listener.onJobStart(SparkListenerJobStart(1, time, Seq(stage1, stage2), null)) + + // Start and stop stage 1 + time += 1 + stage1.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties())) + + time += 1 + stage1.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage1)) + + // Stop job 1 and stage 2 will become SKIPPED + time += 1 + listener.onJobEnd(SparkListenerJobEnd(1, time, JobSucceeded)) + + // Submit stage 3 and verify stage 2 is evicted + val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3") + time += 1 + stage3.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage3, new Properties())) + + assert(store.count(classOf[StageDataWrapper]) === 2) + intercept[NoSuchElementException] { + store.read(classOf[StageDataWrapper], Array(2, 0)) + } + } + test("eviction should respect task completion time") { val testConf = conf.clone().set(MAX_RETAINED_TASKS_PER_STAGE, 2) val listener = new AppStatusListener(store, testConf, true) From 87293c746e19d66f475d506d0adb43421f496843 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 22 Feb 2018 11:00:12 -0800 Subject: [PATCH 387/774] [SPARK-23475][UI] Show also skipped stages ## What changes were proposed in this pull request? SPARK-20648 introduced the status `SKIPPED` for the stages. On the UI, previously, skipped stages were shown as `PENDING`; after this change, they are not shown on the UI. The PR introduce a new section in order to show also `SKIPPED` stages in a proper table. ## How was this patch tested? manual tests Author: Marco Gaido Closes #20651 from mgaido91/SPARK-23475. --- .../org/apache/spark/ui/static/webui.js | 1 + .../apache/spark/ui/jobs/AllStagesPage.scala | 27 +++++++++++++++++++ .../org/apache/spark/ui/UISeleniumSuite.scala | 17 ++++++++++++ 3 files changed, 45 insertions(+) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index 83009df91d30a..f01c567ba58ad 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -72,6 +72,7 @@ $(function() { collapseTablePageLoad('collapse-aggregated-allActiveStages','aggregated-allActiveStages'); collapseTablePageLoad('collapse-aggregated-allPendingStages','aggregated-allPendingStages'); collapseTablePageLoad('collapse-aggregated-allCompletedStages','aggregated-allCompletedStages'); + collapseTablePageLoad('collapse-aggregated-allSkippedStages','aggregated-allSkippedStages'); collapseTablePageLoad('collapse-aggregated-allFailedStages','aggregated-allFailedStages'); collapseTablePageLoad('collapse-aggregated-activeStages','aggregated-activeStages'); collapseTablePageLoad('collapse-aggregated-pendingOrSkippedStages','aggregated-pendingOrSkippedStages'); diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index 606dc1e180e5b..38450b9126ff0 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -36,6 +36,7 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { val activeStages = allStages.filter(_.status == StageStatus.ACTIVE) val pendingStages = allStages.filter(_.status == StageStatus.PENDING) + val skippedStages = allStages.filter(_.status == StageStatus.SKIPPED) val completedStages = allStages.filter(_.status == StageStatus.COMPLETE) val failedStages = allStages.filter(_.status == StageStatus.FAILED).reverse @@ -51,6 +52,9 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { val completedStagesTable = new StageTableBase(parent.store, request, completedStages, "completed", "completedStage", parent.basePath, subPath, parent.isFairScheduler, false, false) + val skippedStagesTable = + new StageTableBase(parent.store, request, skippedStages, "skipped", "skippedStage", + parent.basePath, subPath, parent.isFairScheduler, false, false) val failedStagesTable = new StageTableBase(parent.store, request, failedStages, "failed", "failedStage", parent.basePath, subPath, parent.isFairScheduler, false, true) @@ -66,6 +70,7 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { val shouldShowActiveStages = activeStages.nonEmpty val shouldShowPendingStages = pendingStages.nonEmpty val shouldShowCompletedStages = completedStages.nonEmpty + val shouldShowSkippedStages = skippedStages.nonEmpty val shouldShowFailedStages = failedStages.nonEmpty val appSummary = parent.store.appSummary() @@ -102,6 +107,14 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { } } + { + if (shouldShowSkippedStages) { +
  • + Skipped Stages: + {skippedStages.size} +
  • + } + } { if (shouldShowFailedStages) {
  • @@ -172,6 +185,20 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { {completedStagesTable.toNodeSeq} } + if (shouldShowSkippedStages) { + content ++= + +

    + + Skipped Stages ({skippedStages.size}) +

    +
    ++ +
    + {skippedStagesTable.toNodeSeq} +
    + } if (shouldShowFailedStages) { content ++= + val rdd = sc.parallelize(0 to 100, 100).repartition(10).cache() + rdd.count() + rdd.count() + + eventually(timeout(5 seconds), interval(50 milliseconds)) { + goToUi(sc, "/stages") + find(id("skipped")).get.text should be("Skipped Stages (1)") + } + val stagesJson = getJson(sc.ui.get, "stages") + stagesJson.children.size should be (4) + val stagesStatus = stagesJson.children.map(_ \ "status") + stagesStatus.count(_ == JString(StageStatus.SKIPPED.name())) should be (1) + } + } + def goToUi(sc: SparkContext, path: String): Unit = { goToUi(sc.ui.get, path) } From c5abb3c2d16f601d507bee3c53663d4e117eb8b5 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Thu, 22 Feb 2018 12:07:51 -0800 Subject: [PATCH 388/774] [SPARK-23476][CORE] Generate secret in local mode when authentication on ## What changes were proposed in this pull request? If spark is run with "spark.authenticate=true", then it will fail to start in local mode. This PR generates secret in local mode when authentication on. ## How was this patch tested? Modified existing unit test. Manually started spark-shell. Author: Gabor Somogyi Closes #20652 from gaborgsomogyi/SPARK-23476. --- .../org/apache/spark/SecurityManager.scala | 16 ++++-- .../apache/spark/SecurityManagerSuite.scala | 50 +++++++++++++------ docs/security.md | 2 +- 3 files changed, 46 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 5b15a1c57779d..2519d266879aa 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -520,19 +520,25 @@ private[spark] class SecurityManager( * * If authentication is disabled, do nothing. * - * In YARN mode, generate a new secret and store it in the current user's credentials. + * In YARN and local mode, generate a new secret and store it in the current user's credentials. * * In other modes, assert that the auth secret is set in the configuration. */ def initializeAuth(): Unit = { + import SparkMasterRegex._ + if (!sparkConf.get(NETWORK_AUTH_ENABLED)) { return } - if (sparkConf.get(SparkLauncher.SPARK_MASTER, null) != "yarn") { - require(sparkConf.contains(SPARK_AUTH_SECRET_CONF), - s"A secret key must be specified via the $SPARK_AUTH_SECRET_CONF config.") - return + val master = sparkConf.get(SparkLauncher.SPARK_MASTER, "") + master match { + case "yarn" | "local" | LOCAL_N_REGEX(_) | LOCAL_N_FAILURES_REGEX(_, _) => + // Secret generation allowed here + case _ => + require(sparkConf.contains(SPARK_AUTH_SECRET_CONF), + s"A secret key must be specified via the $SPARK_AUTH_SECRET_CONF config.") + return } val rnd = new SecureRandom() diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index cf59265dd646d..106ece7aed0a4 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -440,23 +440,41 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { assert(keyFromEnv === new SecurityManager(conf2).getSecretKey()) } - test("secret key generation in yarn mode") { - val conf = new SparkConf() - .set(NETWORK_AUTH_ENABLED, true) - .set(SparkLauncher.SPARK_MASTER, "yarn") - val mgr = new SecurityManager(conf) - - UserGroupInformation.createUserForTesting("authTest", Array()).doAs( - new PrivilegedExceptionAction[Unit]() { - override def run(): Unit = { - mgr.initializeAuth() - val creds = UserGroupInformation.getCurrentUser().getCredentials() - val secret = creds.getSecretKey(SecurityManager.SECRET_LOOKUP_KEY) - assert(secret != null) - assert(new String(secret, UTF_8) === mgr.getSecretKey()) + test("secret key generation") { + Seq( + ("yarn", true), + ("local", true), + ("local[*]", true), + ("local[1, 2]", true), + ("local-cluster[2, 1, 1024]", false), + ("invalid", false) + ).foreach { case (master, shouldGenerateSecret) => + val conf = new SparkConf() + .set(NETWORK_AUTH_ENABLED, true) + .set(SparkLauncher.SPARK_MASTER, master) + val mgr = new SecurityManager(conf) + + UserGroupInformation.createUserForTesting("authTest", Array()).doAs( + new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + if (shouldGenerateSecret) { + mgr.initializeAuth() + val creds = UserGroupInformation.getCurrentUser().getCredentials() + val secret = creds.getSecretKey(SecurityManager.SECRET_LOOKUP_KEY) + assert(secret != null) + assert(new String(secret, UTF_8) === mgr.getSecretKey()) + } else { + intercept[IllegalArgumentException] { + mgr.initializeAuth() + } + intercept[IllegalArgumentException] { + mgr.getSecretKey() + } + } + } } - } - ) + ) + } } } diff --git a/docs/security.md b/docs/security.md index bebc28ddbfb0e..0f384b411812a 100644 --- a/docs/security.md +++ b/docs/security.md @@ -6,7 +6,7 @@ title: Security Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the `spark.authenticate` configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate. The shared secret is created as follows: -* For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. +* For Spark on [YARN](running-on-yarn.html) and local deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. * For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. ## Web UI From 049f243c59737699fee54fdc9d65cbd7c788032a Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 22 Feb 2018 21:49:25 -0800 Subject: [PATCH 389/774] [SPARK-23490][SQL] Check storage.locationUri with existing table in CreateTable ## What changes were proposed in this pull request? For CreateTable with Append mode, we should check if `storage.locationUri` is the same with existing table in `PreprocessTableCreation` In the current code, there is only a simple exception if the `storage.locationUri` is different with existing table: `org.apache.spark.sql.AnalysisException: Table or view not found:` which can be improved. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #20660 from gengliangwang/locationUri. --- .../sql/execution/datasources/rules.scala | 8 +++++ .../sql/execution/command/DDLSuite.scala | 29 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 5cc21eeaeaa94..0dea767840ed3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -118,6 +118,14 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi s"`${existingProvider.getSimpleName}`. It doesn't match the specified format " + s"`${specifiedProvider.getSimpleName}`.") } + tableDesc.storage.locationUri match { + case Some(location) if location.getPath != existingTable.location.getPath => + throw new AnalysisException( + s"The location of the existing table ${tableIdentWithDB.quotedString} is " + + s"`${existingTable.location}`. It doesn't match the specified location " + + s"`${tableDesc.location}`.") + case _ => + } if (query.schema.length != existingTable.schema.length) { throw new AnalysisException( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index f76bfd2fda2b9..b800e6ff5b0ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -536,6 +536,35 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("create table - append to a non-partitioned table created with different paths") { + import testImplicits._ + withTempDir { dir1 => + withTempDir { dir2 => + withTable("path_test") { + Seq(1L -> "a").toDF("v1", "v2") + .write + .mode(SaveMode.Append) + .format("json") + .option("path", dir1.getCanonicalPath) + .saveAsTable("path_test") + + val ex = intercept[AnalysisException] { + Seq((3L, "c")).toDF("v1", "v2") + .write + .mode(SaveMode.Append) + .format("json") + .option("path", dir2.getCanonicalPath) + .saveAsTable("path_test") + }.getMessage + assert(ex.contains("The location of the existing table `default`.`path_test`")) + + checkAnswer( + spark.table("path_test"), Row(1L, "a") :: Nil) + } + } + } + } + test("Refresh table after changing the data source table partitioning") { import testImplicits._ From 855ce13d045569b7b16fdc7eee9c981f4ff3a545 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 23 Feb 2018 12:40:58 -0800 Subject: [PATCH 390/774] [SPARK-23408][SS] Synchronize successive AddData actions in Streaming*JoinSuite **The best way to review this PR is to ignore whitespace/indent changes. Use this link - https://github.com/apache/spark/pull/20650/files?w=1** ## What changes were proposed in this pull request? The stream-stream join tests add data to multiple sources and expect it all to show up in the next batch. But there's a race condition; the new batch might trigger when only one of the AddData actions has been reached. Prior attempt to solve this issue by jose-torres in #20646 attempted to simultaneously synchronize on all memory sources together when consecutive AddData was found in the actions. However, this carries the risk of deadlock as well as unintended modification of stress tests (see the above PR for a detailed explanation). Instead, this PR attempts the following. - A new action called `StreamProgressBlockedActions` that allows multiple actions to be executed while the streaming query is blocked from making progress. This allows data to be added to multiple sources that are made visible simultaneously in the next batch. - An alias of `StreamProgressBlockedActions` called `MultiAddData` is explicitly used in the `Streaming*JoinSuites` to add data to two memory sources simultaneously. This should avoid unintentional modification of the stress tests (or any other test for that matter) while making sure that the flaky tests are deterministic. ## How was this patch tested? Modified test cases in `Streaming*JoinSuites` where there are consecutive `AddData` actions. Author: Tathagata Das Closes #20650 from tdas/SPARK-23408. --- .../streaming/MicroBatchExecution.scala | 10 + .../spark/sql/streaming/StreamTest.scala | 472 ++++++++++-------- .../sql/streaming/StreamingJoinSuite.scala | 54 +- 3 files changed, 284 insertions(+), 252 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 84655013ba957..6bd03972c301d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -504,6 +504,16 @@ class MicroBatchExecution( } } + /** Execute a function while locking the stream from making an progress */ + private[sql] def withProgressLocked(f: => Unit): Unit = { + awaitProgressLock.lock() + try { + f + } finally { + awaitProgressLock.unlock() + } + } + private def toJava(scalaOption: Option[OffsetV2]): Optional[OffsetV2] = { Optional.ofNullable(scalaOption.orNull) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 159dd0ecb5902..08f722ecb10e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -102,6 +102,19 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be AddDataMemory(source, data) } + /** + * Adds data to multiple memory streams such that all the data will be made visible in the + * same batch. This is applicable only to MicroBatchExecution, as this coordination cannot be + * performed at the driver in ContinuousExecutions. + */ + object MultiAddData { + def apply[A] + (source1: MemoryStream[A], data1: A*)(source2: MemoryStream[A], data2: A*): StreamAction = { + val actions = Seq(AddDataMemory(source1, data1), AddDataMemory(source2, data2)) + StreamProgressLockedActions(actions, desc = actions.mkString("[ ", " | ", " ]")) + } + } + /** A trait that can be extended when testing a source. */ trait AddData extends StreamAction { /** @@ -217,6 +230,19 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be s"ExpectFailure[${causeClass.getName}, isFatalError: $isFatalError]" } + /** + * Performs multiple actions while locking the stream from progressing. + * This is applicable only to MicroBatchExecution, as progress of ContinuousExecution + * cannot be controlled from the driver. + */ + case class StreamProgressLockedActions(actions: Seq[StreamAction], desc: String = null) + extends StreamAction { + + override def toString(): String = { + if (desc != null) desc else super.toString + } + } + /** Assert that a body is true */ class Assert(condition: => Boolean, val message: String = "") extends StreamAction { def run(): Unit = { Assertions.assert(condition) } @@ -295,6 +321,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for val sink = if (useV2Sink) new MemorySinkV2 else new MemorySink(stream.schema, outputMode) val resetConfValues = mutable.Map[String, Option[String]]() + val defaultCheckpointLocation = + Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath + var manualClockExpectedTime = -1L @volatile var streamThreadDeathCause: Throwable = null @@ -425,243 +454,254 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } } - var manualClockExpectedTime = -1L - val defaultCheckpointLocation = - Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath - try { - startedTest.foreach { action => - logInfo(s"Processing test stream action: $action") - action match { - case StartStream(trigger, triggerClock, additionalConfs, checkpointLocation) => - verify(currentStream == null, "stream already running") - verify(triggerClock.isInstanceOf[SystemClock] - || triggerClock.isInstanceOf[StreamManualClock], - "Use either SystemClock or StreamManualClock to start the stream") - if (triggerClock.isInstanceOf[StreamManualClock]) { - manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis() + def executeAction(action: StreamAction): Unit = { + logInfo(s"Processing test stream action: $action") + action match { + case StartStream(trigger, triggerClock, additionalConfs, checkpointLocation) => + verify(currentStream == null, "stream already running") + verify(triggerClock.isInstanceOf[SystemClock] + || triggerClock.isInstanceOf[StreamManualClock], + "Use either SystemClock or StreamManualClock to start the stream") + if (triggerClock.isInstanceOf[StreamManualClock]) { + manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis() + } + val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation) + + additionalConfs.foreach(pair => { + val value = + if (sparkSession.conf.contains(pair._1)) { + Some(sparkSession.conf.get(pair._1)) + } else None + resetConfValues(pair._1) = value + sparkSession.conf.set(pair._1, pair._2) + }) + + lastStream = currentStream + currentStream = + sparkSession + .streams + .startQuery( + None, + Some(metadataRoot), + stream, + Map(), + sink, + outputMode, + trigger = trigger, + triggerClock = triggerClock) + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + // Wait until the initialization finishes, because some tests need to use `logicalPlan` + // after starting the query. + try { + currentStream.awaitInitialization(streamingTimeout.toMillis) + currentStream match { + case s: ContinuousExecution => eventually("IncrementalExecution was not created") { + assert(s.lastExecution != null) + } + case _ => } - val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation) + } catch { + case _: StreamingQueryException => + // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well. + } - additionalConfs.foreach(pair => { - val value = - if (sparkSession.conf.contains(pair._1)) { - Some(sparkSession.conf.get(pair._1)) - } else None - resetConfValues(pair._1) = value - sparkSession.conf.set(pair._1, pair._2) - }) + case AdvanceManualClock(timeToAdd) => + verify(currentStream != null, + "can not advance manual clock when a stream is not running") + verify(currentStream.triggerClock.isInstanceOf[StreamManualClock], + s"can not advance clock of type ${currentStream.triggerClock.getClass}") + val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] + assert(manualClockExpectedTime >= 0) + + // Make sure we don't advance ManualClock too early. See SPARK-16002. + eventually("StreamManualClock has not yet entered the waiting state") { + assert(clock.isStreamWaitingAt(manualClockExpectedTime)) + } + clock.advance(timeToAdd) + manualClockExpectedTime += timeToAdd + verify(clock.getTimeMillis() === manualClockExpectedTime, + s"Unexpected clock time after updating: " + + s"expecting $manualClockExpectedTime, current ${clock.getTimeMillis()}") + + case StopStream => + verify(currentStream != null, "can not stop a stream that is not running") + try failAfter(streamingTimeout) { + currentStream.stop() + verify(!currentStream.queryExecutionThread.isAlive, + s"microbatch thread not stopped") + verify(!currentStream.isActive, + "query.isActive() is false even after stopping") + verify(currentStream.exception.isEmpty, + s"query.exception() is not empty after clean stop: " + + currentStream.exception.map(_.toString()).getOrElse("")) + } catch { + case _: InterruptedException => + case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => + failTest( + "Timed out while stopping and waiting for microbatchthread to terminate.", e) + case t: Throwable => + failTest("Error while stopping stream", t) + } finally { lastStream = currentStream - currentStream = - sparkSession - .streams - .startQuery( - None, - Some(metadataRoot), - stream, - Map(), - sink, - outputMode, - trigger = trigger, - triggerClock = triggerClock) - .asInstanceOf[StreamingQueryWrapper] - .streamingQuery - // Wait until the initialization finishes, because some tests need to use `logicalPlan` - // after starting the query. - try { - currentStream.awaitInitialization(streamingTimeout.toMillis) - currentStream match { - case s: ContinuousExecution => eventually("IncrementalExecution was not created") { - assert(s.lastExecution != null) - } - case _ => - } - } catch { - case _: StreamingQueryException => - // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well. - } + currentStream = null + } - case AdvanceManualClock(timeToAdd) => - verify(currentStream != null, - "can not advance manual clock when a stream is not running") - verify(currentStream.triggerClock.isInstanceOf[StreamManualClock], - s"can not advance clock of type ${currentStream.triggerClock.getClass}") - val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] - assert(manualClockExpectedTime >= 0) - - // Make sure we don't advance ManualClock too early. See SPARK-16002. - eventually("StreamManualClock has not yet entered the waiting state") { - assert(clock.isStreamWaitingAt(manualClockExpectedTime)) + case ef: ExpectFailure[_] => + verify(currentStream != null, "can not expect failure when stream is not running") + try failAfter(streamingTimeout) { + val thrownException = intercept[StreamingQueryException] { + currentStream.awaitTermination() } - - clock.advance(timeToAdd) - manualClockExpectedTime += timeToAdd - verify(clock.getTimeMillis() === manualClockExpectedTime, - s"Unexpected clock time after updating: " + - s"expecting $manualClockExpectedTime, current ${clock.getTimeMillis()}") - - case StopStream => - verify(currentStream != null, "can not stop a stream that is not running") - try failAfter(streamingTimeout) { - currentStream.stop() - verify(!currentStream.queryExecutionThread.isAlive, - s"microbatch thread not stopped") - verify(!currentStream.isActive, - "query.isActive() is false even after stopping") - verify(currentStream.exception.isEmpty, - s"query.exception() is not empty after clean stop: " + - currentStream.exception.map(_.toString()).getOrElse("")) - } catch { - case _: InterruptedException => - case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => - failTest( - "Timed out while stopping and waiting for microbatchthread to terminate.", e) - case t: Throwable => - failTest("Error while stopping stream", t) - } finally { - lastStream = currentStream - currentStream = null + eventually("microbatch thread not stopped after termination with failure") { + assert(!currentStream.queryExecutionThread.isAlive) } + verify(currentStream.exception === Some(thrownException), + s"incorrect exception returned by query.exception()") + + val exception = currentStream.exception.get + verify(exception.cause.getClass === ef.causeClass, + "incorrect cause in exception returned by query.exception()\n" + + s"\tExpected: ${ef.causeClass}\n\tReturned: ${exception.cause.getClass}") + if (ef.isFatalError) { + // This is a fatal error, `streamThreadDeathCause` should be set to this error in + // UncaughtExceptionHandler. + verify(streamThreadDeathCause != null && + streamThreadDeathCause.getClass === ef.causeClass, + "UncaughtExceptionHandler didn't receive the correct error\n" + + s"\tExpected: ${ef.causeClass}\n\tReturned: $streamThreadDeathCause") + streamThreadDeathCause = null + } + ef.assertFailure(exception.getCause) + } catch { + case _: InterruptedException => + case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => + failTest("Timed out while waiting for failure", e) + case t: Throwable => + failTest("Error while checking stream failure", t) + } finally { + lastStream = currentStream + currentStream = null + } - case ef: ExpectFailure[_] => - verify(currentStream != null, "can not expect failure when stream is not running") - try failAfter(streamingTimeout) { - val thrownException = intercept[StreamingQueryException] { - currentStream.awaitTermination() - } - eventually("microbatch thread not stopped after termination with failure") { - assert(!currentStream.queryExecutionThread.isAlive) + case a: AssertOnQuery => + verify(currentStream != null || lastStream != null, + "cannot assert when no stream has been started") + val streamToAssert = Option(currentStream).getOrElse(lastStream) + try { + verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") + } catch { + case NonFatal(e) => + failTest(s"Assert on query failed: ${a.message}", e) + } + + case a: Assert => + val streamToAssert = Option(currentStream).getOrElse(lastStream) + verify({ a.run(); true }, s"Assert failed: ${a.message}") + + case a: AddData => + try { + + // If the query is running with manual clock, then wait for the stream execution + // thread to start waiting for the clock to increment. This is needed so that we + // are adding data when there is no trigger that is active. This would ensure that + // the data gets deterministically added to the next batch triggered after the manual + // clock is incremented in following AdvanceManualClock. This avoid race conditions + // between the test thread and the stream execution thread in tests using manual + // clock. + if (currentStream != null && + currentStream.triggerClock.isInstanceOf[StreamManualClock]) { + val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] + eventually("Error while synchronizing with manual clock before adding data") { + if (currentStream.isActive) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } } - verify(currentStream.exception === Some(thrownException), - s"incorrect exception returned by query.exception()") - - val exception = currentStream.exception.get - verify(exception.cause.getClass === ef.causeClass, - "incorrect cause in exception returned by query.exception()\n" + - s"\tExpected: ${ef.causeClass}\n\tReturned: ${exception.cause.getClass}") - if (ef.isFatalError) { - // This is a fatal error, `streamThreadDeathCause` should be set to this error in - // UncaughtExceptionHandler. - verify(streamThreadDeathCause != null && - streamThreadDeathCause.getClass === ef.causeClass, - "UncaughtExceptionHandler didn't receive the correct error\n" + - s"\tExpected: ${ef.causeClass}\n\tReturned: $streamThreadDeathCause") - streamThreadDeathCause = null + if (!currentStream.isActive) { + failTest("Query terminated while synchronizing with manual clock") } - ef.assertFailure(exception.getCause) - } catch { - case _: InterruptedException => - case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => - failTest("Timed out while waiting for failure", e) - case t: Throwable => - failTest("Error while checking stream failure", t) - } finally { - lastStream = currentStream - currentStream = null } - - case a: AssertOnQuery => - verify(currentStream != null || lastStream != null, - "cannot assert when no stream has been started") - val streamToAssert = Option(currentStream).getOrElse(lastStream) - try { - verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") - } catch { - case NonFatal(e) => - failTest(s"Assert on query failed: ${a.message}", e) + // Add data + val queryToUse = Option(currentStream).orElse(Option(lastStream)) + val (source, offset) = a.addData(queryToUse) + + def findSourceIndex(plan: LogicalPlan): Option[Int] = { + plan + .collect { + case StreamingExecutionRelation(s, _) => s + case StreamingDataSourceV2Relation(_, r) => r + } + .zipWithIndex + .find(_._1 == source) + .map(_._2) } - case a: Assert => - val streamToAssert = Option(currentStream).getOrElse(lastStream) - verify({ a.run(); true }, s"Assert failed: ${a.message}") - - case a: AddData => - try { - - // If the query is running with manual clock, then wait for the stream execution - // thread to start waiting for the clock to increment. This is needed so that we - // are adding data when there is no trigger that is active. This would ensure that - // the data gets deterministically added to the next batch triggered after the manual - // clock is incremented in following AdvanceManualClock. This avoid race conditions - // between the test thread and the stream execution thread in tests using manual - // clock. - if (currentStream != null && - currentStream.triggerClock.isInstanceOf[StreamManualClock]) { - val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] - eventually("Error while synchronizing with manual clock before adding data") { - if (currentStream.isActive) { - assert(clock.isStreamWaitingAt(clock.getTimeMillis())) - } + // Try to find the index of the source to which data was added. Either get the index + // from the current active query or the original input logical plan. + val sourceIndex = + queryToUse.flatMap { query => + findSourceIndex(query.logicalPlan) + }.orElse { + findSourceIndex(stream.logicalPlan) + }.orElse { + queryToUse.flatMap { q => + findSourceIndex(q.lastExecution.logical) } - if (!currentStream.isActive) { - failTest("Query terminated while synchronizing with manual clock") - } - } - // Add data - val queryToUse = Option(currentStream).orElse(Option(lastStream)) - val (source, offset) = a.addData(queryToUse) - - def findSourceIndex(plan: LogicalPlan): Option[Int] = { - plan - .collect { - case StreamingExecutionRelation(s, _) => s - case StreamingDataSourceV2Relation(_, r) => r - } - .zipWithIndex - .find(_._1 == source) - .map(_._2) + }.getOrElse { + throw new IllegalArgumentException( + "Could not find index of the source to which data was added") } - // Try to find the index of the source to which data was added. Either get the index - // from the current active query or the original input logical plan. - val sourceIndex = - queryToUse.flatMap { query => - findSourceIndex(query.logicalPlan) - }.orElse { - findSourceIndex(stream.logicalPlan) - }.orElse { - queryToUse.flatMap { q => - findSourceIndex(q.lastExecution.logical) - } - }.getOrElse { - throw new IllegalArgumentException( - "Could not find index of the source to which data was added") - } + // Store the expected offset of added data to wait for it later + awaiting.put(sourceIndex, offset) + } catch { + case NonFatal(e) => + failTest("Error adding data", e) + } - // Store the expected offset of added data to wait for it later - awaiting.put(sourceIndex, offset) - } catch { - case NonFatal(e) => - failTest("Error adding data", e) - } + case e: ExternalAction => + e.runAction() - case e: ExternalAction => - e.runAction() + case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) => + val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) + QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach { + error => failTest(error) + } - case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) => - val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) - QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach { - error => failTest(error) - } + case CheckAnswerRowsContains(expectedAnswer, lastOnly) => + val sparkAnswer = currentStream match { + case null => fetchStreamAnswer(lastStream, lastOnly) + case s => fetchStreamAnswer(s, lastOnly) + } + QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach { + error => failTest(error) + } - case CheckAnswerRowsContains(expectedAnswer, lastOnly) => - val sparkAnswer = currentStream match { - case null => fetchStreamAnswer(lastStream, lastOnly) - case s => fetchStreamAnswer(s, lastOnly) - } - QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach { - error => failTest(error) - } + case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) => + val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) + try { + globalCheckFunction(sparkAnswer) + } catch { + case e: Throwable => failTest(e.toString) + } + } + pos += 1 + } - case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) => - val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) - try { - globalCheckFunction(sparkAnswer) - } catch { - case e: Throwable => failTest(e.toString) - } - } - pos += 1 + try { + startedTest.foreach { + case StreamProgressLockedActions(actns, _) => + // Perform actions while holding the stream from progressing + assert(currentStream != null, + s"Cannot perform stream-progress-locked actions $actns when query is not active") + assert(currentStream.isInstanceOf[MicroBatchExecution], + s"Cannot perform stream-progress-locked actions on non-microbatch queries") + currentStream.asInstanceOf[MicroBatchExecution].withProgressLocked { + actns.foreach(executeAction) + } + + case action: StreamAction => executeAction(action) } if (streamThreadDeathCause != null) { failTest("Stream Thread Died", streamThreadDeathCause) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 92087f68ad74a..11bdd13942dcb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -462,15 +462,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) testStream(joined)( - AddData(leftInput, 1, 2, 3), - AddData(rightInput, 3, 4, 5), + MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), // The left rows with leftValue <= 4 should generate their outer join row now and // not get added to the state. CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)), assertNumStateRows(total = 4, updated = 4), // We shouldn't get more outer join rows when the watermark advances. - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), AddData(rightInput, 20), CheckLastBatch((20, 30, 40, "60")) @@ -493,15 +491,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) testStream(joined)( - AddData(leftInput, 3, 4, 5), - AddData(rightInput, 1, 2, 3), + MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), // The right rows with value <= 7 should never be added to the state. CheckLastBatch(Row(3, 10, 6, "9")), assertNumStateRows(total = 4, updated = 4), // When the watermark advances, we get the outer join rows just as we would if they // were added but didn't match the full join condition. - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), AddData(rightInput, 20), CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, 8, null), Row(5, 10, 10, null)) @@ -524,15 +520,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) testStream(joined)( - AddData(leftInput, 1, 2, 3), - AddData(rightInput, 3, 4, 5), + MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), // The left rows with value <= 4 should never be added to the state. CheckLastBatch(Row(3, 10, 6, "9")), assertNumStateRows(total = 4, updated = 4), // When the watermark advances, we get the outer join rows just as we would if they // were added but didn't match the full join condition. - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), AddData(rightInput, 20), CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, null, "12"), Row(5, 10, null, "15")) @@ -555,15 +549,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) testStream(joined)( - AddData(leftInput, 3, 4, 5), - AddData(rightInput, 1, 2, 3), + MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), // The right rows with rightValue <= 7 should generate their outer join row now and // not get added to the state. CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")), assertNumStateRows(total = 4, updated = 4), // We shouldn't get more outer join rows when the watermark advances. - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), AddData(rightInput, 20), CheckLastBatch((20, 30, 40, "60")) @@ -575,13 +567,11 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // Test inner part of the join. - AddData(leftInput, 1, 2, 3, 4, 5), - AddData(rightInput, 3, 4, 5, 6, 7), + MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), // Old state doesn't get dropped until the batch *after* it gets introduced, so the // nulls won't show up until the next batch after the watermark advances. - AddData(leftInput, 21), - AddData(rightInput, 22), + MultiAddData(leftInput, 21)(rightInput, 22), CheckLastBatch(), assertNumStateRows(total = 12, updated = 2), AddData(leftInput, 22), @@ -595,13 +585,11 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // Test inner part of the join. - AddData(leftInput, 1, 2, 3, 4, 5), - AddData(rightInput, 3, 4, 5, 6, 7), + MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), // Old state doesn't get dropped until the batch *after* it gets introduced, so the // nulls won't show up until the next batch after the watermark advances. - AddData(leftInput, 21), - AddData(rightInput, 22), + MultiAddData(leftInput, 21)(rightInput, 22), CheckLastBatch(), assertNumStateRows(total = 12, updated = 2), AddData(leftInput, 22), @@ -676,11 +664,9 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // leftValue <= 10 should generate outer join rows even though it matches right keys - AddData(leftInput, 1, 2, 3), - AddData(rightInput, 1, 2, 3), + MultiAddData(leftInput, 1, 2, 3)(rightInput, 1, 2, 3), CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), assertNumStateRows(total = 5, updated = 2), AddData(rightInput, 20), @@ -688,22 +674,18 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with Row(20, 30, 40, 60)), assertNumStateRows(total = 3, updated = 1), // leftValue and rightValue both satisfying condition should not generate outer join rows - AddData(leftInput, 40, 41), - AddData(rightInput, 40, 41), + MultiAddData(leftInput, 40, 41)(rightInput, 40, 41), CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)), - AddData(leftInput, 70), - AddData(rightInput, 71), + MultiAddData(leftInput, 70)(rightInput, 71), CheckLastBatch(), assertNumStateRows(total = 6, updated = 2), AddData(rightInput, 70), CheckLastBatch((70, 80, 140, 210)), assertNumStateRows(total = 3, updated = 1), // rightValue between 300 and 1000 should generate outer join rows even though it matches left - AddData(leftInput, 101, 102, 103), - AddData(rightInput, 101, 102, 103), + MultiAddData(leftInput, 101, 102, 103)(rightInput, 101, 102, 103), CheckLastBatch(), - AddData(leftInput, 1000), - AddData(rightInput, 1001), + MultiAddData(leftInput, 1000)(rightInput, 1001), CheckLastBatch(), assertNumStateRows(total = 8, updated = 2), AddData(rightInput, 1000), From 1a198ce8f580bcf35b9cbfab403fc40f821046a1 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 23 Feb 2018 16:30:32 -0800 Subject: [PATCH 391/774] [SPARK-23459][SQL] Improve the error message when unknown column is specified in partition columns ## What changes were proposed in this pull request? This PR avoids to print schema internal information when unknown column is specified in partition columns. This PR prints column names in the schema with more readable format. The following is an example. Source code ``` test("save with an unknown partition column") { withTempDir { dir => val path = dir.getCanonicalPath Seq(1L -> "a").toDF("i", "j").write .format("parquet") .partitionBy("unknownColumn") .save(path) } ``` Output without this PR ``` Partition column unknownColumn not found in schema StructType(StructField(i,LongType,false), StructField(j,StringType,true)); ``` Output with this PR ``` Partition column unknownColumn not found in schema struct; ``` ## How was this patch tested? Manually tested Author: Kazuaki Ishizaki Closes #20653 from kiszk/SPARK-23459. --- .../datasources/PartitioningUtils.scala | 3 ++- .../apache/spark/sql/sources/SaveLoadSuite.scala | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 379acb67f7c71..f9a24806953e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -486,7 +486,8 @@ object PartitioningUtils { val equality = columnNameEquality(caseSensitive) StructType(partitionColumns.map { col => schema.find(f => equality(f.name, col)).getOrElse { - throw new AnalysisException(s"Partition column $col not found in schema $schema") + val schemaCatalog = schema.catalogString + throw new AnalysisException(s"Partition column `$col` not found in schema $schemaCatalog") } }).asNullable } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 773d34dfaf9a8..12779b46bfe8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -126,4 +126,20 @@ class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndA checkLoad(df2, "jsonTable2") } + + test("SPARK-23459: Improve error message when specified unknown column in partition columns") { + withTempDir { dir => + val path = dir.getCanonicalPath + val unknown = "unknownColumn" + val df = Seq(1L -> "a").toDF("i", "j") + val schemaCatalog = df.schema.catalogString + val e = intercept[AnalysisException] { + df.write + .format("parquet") + .partitionBy(unknown) + .save(path) + }.getMessage + assert(e.contains(s"Partition column `$unknown` not found in schema $schemaCatalog")) + } + } } From 3ca9a2c56513444d7b233088b020d2d43fa6b77a Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Sun, 25 Feb 2018 09:29:59 -0600 Subject: [PATCH 392/774] =?UTF-8?q?[SPARK-22886][ML][TESTS]=20ML=20test=20?= =?UTF-8?q?for=20structured=20streaming:=20ml.recomme=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Converting spark.ml.recommendation tests to also check code with structured streaming, using the ML testing infrastructure implemented in SPARK-22882. ## How was this patch tested? Automated: Pass the Jenkins. Author: Gabor Somogyi Closes #20362 from gaborgsomogyi/SPARK-22886. --- .../spark/ml/recommendation/ALSSuite.scala | 213 ++++++++++++------ .../apache/spark/ml/util/MLTestingUtils.scala | 44 ---- 2 files changed, 143 insertions(+), 114 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index addcd21d50aac..e3dfe2faf5698 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -22,8 +22,7 @@ import java.util.Random import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.WrappedArray +import scala.collection.mutable.{ArrayBuffer, WrappedArray} import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} @@ -35,21 +34,20 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.recommendation.ALS._ -import org.apache.spark.ml.recommendation.ALS.Rating -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.recommendation.MatrixFactorizationModelSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} -import org.apache.spark.sql.{DataFrame, Row, SparkSession} -import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.{DataFrame, Encoder, Row, SparkSession} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.streaming.StreamingQueryException import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class ALSSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging { +class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { override def beforeAll(): Unit = { super.beforeAll() @@ -413,34 +411,36 @@ class ALSSuite .setSeed(0) val alpha = als.getAlpha val model = als.fit(training.toDF()) - val predictions = model.transform(test.toDF()).select("rating", "prediction").rdd.map { - case Row(rating: Float, prediction: Float) => - (rating.toDouble, prediction.toDouble) + testTransformerByGlobalCheckFunc[Rating[Int]](test.toDF(), model, "rating", "prediction") { + case rows: Seq[Row] => + val predictions = rows.map(row => (row.getFloat(0).toDouble, row.getFloat(1).toDouble)) + + val rmse = + if (implicitPrefs) { + // TODO: Use a better (rank-based?) evaluation metric for implicit feedback. + // We limit the ratings and the predictions to interval [0, 1] and compute the + // weighted RMSE with the confidence scores as weights. + val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) => + val confidence = 1.0 + alpha * math.abs(rating) + val rating01 = math.max(math.min(rating, 1.0), 0.0) + val prediction01 = math.max(math.min(prediction, 1.0), 0.0) + val err = prediction01 - rating01 + (confidence, confidence * err * err) + }.reduce[(Double, Double)] { case ((c0, e0), (c1, e1)) => + (c0 + c1, e0 + e1) + } + math.sqrt(weightedSumSq / totalWeight) + } else { + val errorSquares = predictions.map { case (rating, prediction) => + val err = rating - prediction + err * err + } + val mse = errorSquares.sum / errorSquares.length + math.sqrt(mse) + } + logInfo(s"Test RMSE is $rmse.") + assert(rmse < targetRMSE) } - val rmse = - if (implicitPrefs) { - // TODO: Use a better (rank-based?) evaluation metric for implicit feedback. - // We limit the ratings and the predictions to interval [0, 1] and compute the weighted RMSE - // with the confidence scores as weights. - val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) => - val confidence = 1.0 + alpha * math.abs(rating) - val rating01 = math.max(math.min(rating, 1.0), 0.0) - val prediction01 = math.max(math.min(prediction, 1.0), 0.0) - val err = prediction01 - rating01 - (confidence, confidence * err * err) - }.reduce { case ((c0, e0), (c1, e1)) => - (c0 + c1, e0 + e1) - } - math.sqrt(weightedSumSq / totalWeight) - } else { - val mse = predictions.map { case (rating, prediction) => - val err = rating - prediction - err * err - }.mean() - math.sqrt(mse) - } - logInfo(s"Test RMSE is $rmse.") - assert(rmse < targetRMSE) MLTestingUtils.checkCopyAndUids(als, model) } @@ -586,6 +586,68 @@ class ALSSuite allModelParamSettings, checkModelData) } + private def checkNumericTypesALS( + estimator: ALS, + spark: SparkSession, + column: String, + baseType: NumericType) + (check: (ALSModel, ALSModel) => Unit) + (check2: (ALSModel, ALSModel, DataFrame, Encoder[_]) => Unit): Unit = { + val dfs = genRatingsDFWithNumericCols(spark, column) + val df = dfs.find { + case (numericTypeWithEncoder, _) => numericTypeWithEncoder.numericType == baseType + } match { + case Some((_, df)) => df + } + val expected = estimator.fit(df) + val actuals = dfs.filter(_ != baseType).map(t => (t, estimator.fit(t._2))) + actuals.foreach { case (_, actual) => check(expected, actual) } + actuals.foreach { case (t, actual) => check2(expected, actual, t._2, t._1.encoder) } + + val baseDF = dfs.find(_._1.numericType == baseType).get._2 + val others = baseDF.columns.toSeq.diff(Seq(column)).map(col) + val cols = Seq(col(column).cast(StringType)) ++ others + val strDF = baseDF.select(cols: _*) + val thrown = intercept[IllegalArgumentException] { + estimator.fit(strDF) + } + assert(thrown.getMessage.contains( + s"$column must be of type NumericType but was actually of type StringType")) + } + + private class NumericTypeWithEncoder[A](val numericType: NumericType) + (implicit val encoder: Encoder[(A, Int, Double)]) + + private def genRatingsDFWithNumericCols( + spark: SparkSession, + column: String) = { + + import testImplicits._ + + val df = spark.createDataFrame(Seq( + (0, 10, 1.0), + (1, 20, 2.0), + (2, 30, 3.0), + (3, 40, 4.0), + (4, 50, 5.0) + )).toDF("user", "item", "rating") + + val others = df.columns.toSeq.diff(Seq(column)).map(col) + val types = + Seq(new NumericTypeWithEncoder[Short](ShortType), + new NumericTypeWithEncoder[Long](LongType), + new NumericTypeWithEncoder[Int](IntegerType), + new NumericTypeWithEncoder[Float](FloatType), + new NumericTypeWithEncoder[Byte](ByteType), + new NumericTypeWithEncoder[Double](DoubleType), + new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder()) + ) + types.map { t => + val cols = Seq(col(column).cast(t.numericType)) ++ others + t -> df.select(cols: _*) + } + } + test("input type validation") { val spark = this.spark import spark.implicits._ @@ -595,12 +657,16 @@ class ALSSuite val als = new ALS().setMaxIter(1).setRank(1) Seq(("user", IntegerType), ("item", IntegerType), ("rating", FloatType)).foreach { case (colName, sqlType) => - MLTestingUtils.checkNumericTypesALS(als, spark, colName, sqlType) { + checkNumericTypesALS(als, spark, colName, sqlType) { (ex, act) => - ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1) - } { (ex, act, _) => - ex.transform(_: DataFrame).select("prediction").first.getDouble(0) ~== - act.transform(_: DataFrame).select("prediction").first.getDouble(0) absTol 1e-6 + ex.userFactors.first().getSeq[Float](1) === act.userFactors.first().getSeq[Float](1) + } { (ex, act, df, enc) => + val expected = ex.transform(df).selectExpr("prediction") + .first().getFloat(0) + testTransformerByGlobalCheckFunc(df, act, "prediction") { + case rows: Seq[Row] => + expected ~== rows.head.getFloat(0) absTol 1e-6 + }(enc) } } // check user/item ids falling outside of Int range @@ -628,18 +694,22 @@ class ALSSuite } withClue("transform should fail when ids exceed integer range. ") { val model = als.fit(df) - assert(intercept[SparkException] { - model.transform(df.select(df("user_big").as("user"), df("item"))).first - }.getMessage.contains(msg)) - assert(intercept[SparkException] { - model.transform(df.select(df("user_small").as("user"), df("item"))).first - }.getMessage.contains(msg)) - assert(intercept[SparkException] { - model.transform(df.select(df("item_big").as("item"), df("user"))).first - }.getMessage.contains(msg)) - assert(intercept[SparkException] { - model.transform(df.select(df("item_small").as("item"), df("user"))).first - }.getMessage.contains(msg)) + def testTransformIdExceedsIntRange[A : Encoder](dataFrame: DataFrame): Unit = { + assert(intercept[SparkException] { + model.transform(dataFrame).first + }.getMessage.contains(msg)) + assert(intercept[StreamingQueryException] { + testTransformer[A](dataFrame, model, "prediction") { _ => } + }.getMessage.contains(msg)) + } + testTransformIdExceedsIntRange[(Long, Int)](df.select(df("user_big").as("user"), + df("item"))) + testTransformIdExceedsIntRange[(Double, Int)](df.select(df("user_small").as("user"), + df("item"))) + testTransformIdExceedsIntRange[(Long, Int)](df.select(df("item_big").as("item"), + df("user"))) + testTransformIdExceedsIntRange[(Double, Int)](df.select(df("item_small").as("item"), + df("user"))) } } @@ -662,28 +732,31 @@ class ALSSuite val knownItem = data.select(max("item")).as[Int].first() val unknownItem = knownItem + 20 val test = Seq( - (unknownUser, unknownItem), - (knownUser, unknownItem), - (unknownUser, knownItem), - (knownUser, knownItem) - ).toDF("user", "item") + (unknownUser, unknownItem, true), + (knownUser, unknownItem, true), + (unknownUser, knownItem, true), + (knownUser, knownItem, false) + ).toDF("user", "item", "expectedIsNaN") val als = new ALS().setMaxIter(1).setRank(1) // default is 'nan' val defaultModel = als.fit(data) - val defaultPredictions = defaultModel.transform(test).select("prediction").as[Float].collect() - assert(defaultPredictions.length == 4) - assert(defaultPredictions.slice(0, 3).forall(_.isNaN)) - assert(!defaultPredictions.last.isNaN) + testTransformer[(Int, Int, Boolean)](test, defaultModel, "expectedIsNaN", "prediction") { + case Row(expectedIsNaN: Boolean, prediction: Float) => + assert(prediction.isNaN === expectedIsNaN) + } // check 'drop' strategy should filter out rows with unknown users/items - val dropPredictions = defaultModel - .setColdStartStrategy("drop") - .transform(test) - .select("prediction").as[Float].collect() - assert(dropPredictions.length == 1) - assert(!dropPredictions.head.isNaN) - assert(dropPredictions.head ~== defaultPredictions.last relTol 1e-14) + val defaultPrediction = defaultModel.transform(test).select("prediction") + .as[Float].filter(!_.isNaN).first() + testTransformerByGlobalCheckFunc[(Int, Int, Boolean)](test, + defaultModel.setColdStartStrategy("drop"), "prediction") { + case rows: Seq[Row] => + val dropPredictions = rows.map(_.getFloat(0)) + assert(dropPredictions.length == 1) + assert(!dropPredictions.head.isNaN) + assert(dropPredictions.head ~== defaultPrediction relTol 1e-14) + } } test("case insensitive cold start param value") { @@ -693,7 +766,7 @@ class ALSSuite val data = ratings.toDF val model = new ALS().fit(data) Seq("nan", "NaN", "Nan", "drop", "DROP", "Drop").foreach { s => - model.setColdStartStrategy(s).transform(data) + testTransformer[Rating[Int]](data, model.setColdStartStrategy(s), "prediction") { _ => } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index aef81c8c173a0..c328d81b4bc3a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -91,30 +91,6 @@ object MLTestingUtils extends SparkFunSuite { } } - def checkNumericTypesALS( - estimator: ALS, - spark: SparkSession, - column: String, - baseType: NumericType) - (check: (ALSModel, ALSModel) => Unit) - (check2: (ALSModel, ALSModel, DataFrame) => Unit): Unit = { - val dfs = genRatingsDFWithNumericCols(spark, column) - val expected = estimator.fit(dfs(baseType)) - val actuals = dfs.keys.filter(_ != baseType).map(t => (t, estimator.fit(dfs(t)))) - actuals.foreach { case (_, actual) => check(expected, actual) } - actuals.foreach { case (t, actual) => check2(expected, actual, dfs(t)) } - - val baseDF = dfs(baseType) - val others = baseDF.columns.toSeq.diff(Seq(column)).map(col) - val cols = Seq(col(column).cast(StringType)) ++ others - val strDF = baseDF.select(cols: _*) - val thrown = intercept[IllegalArgumentException] { - estimator.fit(strDF) - } - assert(thrown.getMessage.contains( - s"$column must be of type NumericType but was actually of type StringType")) - } - def checkNumericTypes[T <: Evaluator](evaluator: T, spark: SparkSession): Unit = { val dfs = genEvaluatorDFWithNumericLabelCol(spark, "label", "prediction") val expected = evaluator.evaluate(dfs(DoubleType)) @@ -176,26 +152,6 @@ object MLTestingUtils extends SparkFunSuite { }.toMap } - def genRatingsDFWithNumericCols( - spark: SparkSession, - column: String): Map[NumericType, DataFrame] = { - val df = spark.createDataFrame(Seq( - (0, 10, 1.0), - (1, 20, 2.0), - (2, 30, 3.0), - (3, 40, 4.0), - (4, 50, 5.0) - )).toDF("user", "item", "rating") - - val others = df.columns.toSeq.diff(Seq(column)).map(col) - val types: Seq[NumericType] = - Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) - types.map { t => - val cols = Seq(col(column).cast(t)) ++ others - t -> df.select(cols: _*) - }.toMap - } - def genEvaluatorDFWithNumericLabelCol( spark: SparkSession, labelColName: String = "label", From b308182f233b8840dfe0e6b5736d2f2746f40757 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Mon, 26 Feb 2018 08:39:44 -0800 Subject: [PATCH 393/774] [SPARK-23438][DSTREAMS] Fix DStreams data loss with WAL when driver crashes ## What changes were proposed in this pull request? There is a race condition introduced in SPARK-11141 which could cause data loss. The problem is that ReceivedBlockTracker.insertAllocatedBatch function assumes that all blocks from streamIdToUnallocatedBlockQueues allocated to the batch and clears the queue. In this PR only the allocated blocks will be removed from the queue which will prevent data loss. ## How was this patch tested? Additional unit test + manually. Author: Gabor Somogyi Closes #20620 from gaborgsomogyi/SPARK-23438. --- .../scheduler/ReceivedBlockTracker.scala | 11 +++++---- .../streaming/ReceivedBlockTrackerSuite.scala | 23 ++++++++++++++++++- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 5d9a8ac0d9297..dacff69d55dd2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -193,12 +193,15 @@ private[streaming] class ReceivedBlockTracker( getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo } - // Insert the recovered block-to-batch allocations and clear the queue of received blocks - // (when the blocks were originally allocated to the batch, the queue must have been cleared). + // Insert the recovered block-to-batch allocations and removes them from queue of + // received blocks. def insertAllocatedBatch(batchTime: Time, allocatedBlocks: AllocatedBlocks) { logTrace(s"Recovery: Inserting allocated batch for time $batchTime to " + s"${allocatedBlocks.streamIdToAllocatedBlocks}") - streamIdToUnallocatedBlockQueues.values.foreach { _.clear() } + allocatedBlocks.streamIdToAllocatedBlocks.foreach { + case (streamId, allocatedBlocksInStream) => + getReceivedBlockQueue(streamId).dequeueAll(allocatedBlocksInStream.toSet) + } timeToAllocatedBlocks.put(batchTime, allocatedBlocks) lastAllocatedBatchTime = batchTime } @@ -227,7 +230,7 @@ private[streaming] class ReceivedBlockTracker( } /** Write an update to the tracker to the write ahead log */ - private def writeToLog(record: ReceivedBlockTrackerLogEvent): Boolean = { + private[streaming] def writeToLog(record: ReceivedBlockTrackerLogEvent): Boolean = { if (isWriteAheadLogEnabled) { logTrace(s"Writing record: $record") try { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index 107c3f5dcc08d..4fa236bd39663 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult -import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.scheduler.{AllocatedBlocks, _} import org.apache.spark.streaming.util._ import org.apache.spark.streaming.util.WriteAheadLogSuite._ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} @@ -94,6 +94,27 @@ class ReceivedBlockTrackerSuite receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos } + test("recovery with write ahead logs should remove only allocated blocks from received queue") { + val manualClock = new ManualClock + val batchTime = manualClock.getTimeMillis() + + val tracker1 = createTracker(clock = manualClock) + tracker1.isWriteAheadLogEnabled should be (true) + + val allocatedBlockInfos = generateBlockInfos() + val unallocatedBlockInfos = generateBlockInfos() + val receivedBlockInfos = allocatedBlockInfos ++ unallocatedBlockInfos + receivedBlockInfos.foreach { b => tracker1.writeToLog(BlockAdditionEvent(b)) } + val allocatedBlocks = AllocatedBlocks(Map(streamId -> allocatedBlockInfos)) + tracker1.writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks)) + tracker1.stop() + + val tracker2 = createTracker(clock = manualClock, recoverFromWriteAheadLog = true) + tracker2.getBlocksOfBatch(batchTime) shouldEqual allocatedBlocks.streamIdToAllocatedBlocks + tracker2.getUnallocatedBlocks(streamId) shouldEqual unallocatedBlockInfos + tracker2.stop() + } + test("recovery and cleanup with write ahead logs") { val manualClock = new ManualClock // Set the time increment level to twice the rotation interval so that every increment creates From 185f5bc7dd52cebe8fac9393ecb2bd0968bc5867 Mon Sep 17 00:00:00 2001 From: Andrew Korzhuev Date: Mon, 26 Feb 2018 10:28:45 -0800 Subject: [PATCH 394/774] [SPARK-23449][K8S] Preserve extraJavaOptions ordering For some JVM options, like `-XX:+UnlockExperimentalVMOptions` ordering is necessary. ## What changes were proposed in this pull request? Keep original `extraJavaOptions` ordering, when passing them through environment variables inside the Docker container. ## How was this patch tested? Ran base branch a couple of times and checked startup command in logs. Ordering differed every time. Added sorting, ordering was consistent to what user had in `extraJavaOptions`. Author: Andrew Korzhuev Closes #20628 from andrusha/patch-2. --- .../kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index b9090dc2852a5..3d67b0a702dd4 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -41,7 +41,7 @@ fi shift 1 SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" -env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt +env | grep SPARK_JAVA_OPT_ | sort -t_ -k4 -n | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt readarray -t SPARK_JAVA_OPTS < /tmp/java_opts.txt if [ -n "$SPARK_MOUNTED_CLASSPATH" ]; then SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_MOUNTED_CLASSPATH" From 7ec83658fbc88505dfc2d8a6f76e90db747f1292 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 26 Feb 2018 11:28:44 -0800 Subject: [PATCH 395/774] [SPARK-23491][SS] Remove explicit job cancellation from ContinuousExecution reconfiguring ## What changes were proposed in this pull request? Remove queryExecutionThread.interrupt() from ContinuousExecution. As detailed in the JIRA, interrupting the thread is only relevant in the microbatch case; for continuous processing the query execution can quickly clean itself up without. ## How was this patch tested? existing tests Author: Jose Torres Closes #20622 from jose-torres/SPARK-23441. --- .../streaming/continuous/ContinuousExecution.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 2c1d6c509d21b..daebd1dd010ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -236,9 +236,7 @@ class ContinuousExecution( startTrigger() if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { - stopSources() if (queryExecutionThread.isAlive) { - sparkSession.sparkContext.cancelJobGroup(runId.toString) queryExecutionThread.interrupt() } false @@ -266,12 +264,20 @@ class ContinuousExecution( SQLExecution.withNewExecutionId( sparkSessionForQuery, lastExecution)(lastExecution.toRdd) } + } catch { + case t: Throwable + if StreamExecution.isInterruptionException(t) && state.get() == RECONFIGURING => + logInfo(s"Query $id ignoring exception from reconfiguring: $t") + // interrupted by reconfiguration - swallow exception so we can restart the query } finally { epochEndpoint.askSync[Unit](StopContinuousExecutionWrites) SparkEnv.get.rpcEnv.stop(epochEndpoint) epochUpdateThread.interrupt() epochUpdateThread.join() + + stopSources() + sparkSession.sparkContext.cancelJobGroup(runId.toString) } } From 8077bb04f350fd35df83ef896135c0672dc3f7b0 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Mon, 26 Feb 2018 23:37:31 -0800 Subject: [PATCH 396/774] [SPARK-23445] ColumnStat refactoring ## What changes were proposed in this pull request? Refactor ColumnStat to be more flexible. * Split `ColumnStat` and `CatalogColumnStat` just like `CatalogStatistics` is split from `Statistics`. This detaches how the statistics are stored from how they are processed in the query plan. `CatalogColumnStat` keeps `min` and `max` as `String`, making it not depend on dataType information. * For `CatalogColumnStat`, parse column names from property names in the metastore (`KEY_VERSION` property), not from metastore schema. This means that `CatalogColumnStat`s can be created for columns even if the schema itself is not stored in the metastore. * Make all fields optional. `min`, `max` and `histogram` for columns were optional already. Having them all optional is more consistent, and gives flexibility to e.g. drop some of the fields through transformations if they are difficult / impossible to calculate. The added flexibility will make it possible to have alternative implementations for stats, and separates stats collection from stats and estimation processing in plans. ## How was this patch tested? Refactored existing tests to work with refactored `ColumnStat` and `CatalogColumnStat`. New tests added in `StatisticsSuite` checking that backwards / forwards compatibility is not broken. Author: Juliusz Sompolski Closes #20624 from juliuszsompolski/SPARK-23445. --- .../sql/catalyst/catalog/interface.scala | 146 ++++++++- .../optimizer/StarSchemaDetection.scala | 6 +- .../catalyst/plans/logical/Statistics.scala | 256 ++-------------- .../statsEstimation/AggregateEstimation.scala | 6 +- .../statsEstimation/EstimationUtils.scala | 20 +- .../statsEstimation/FilterEstimation.scala | 98 +++--- .../statsEstimation/JoinEstimation.scala | 55 ++-- .../catalyst/optimizer/JoinReorderSuite.scala | 25 +- .../StarJoinCostBasedReorderSuite.scala | 96 ++---- .../optimizer/StarJoinReorderSuite.scala | 77 ++--- .../AggregateEstimationSuite.scala | 24 +- .../BasicStatsEstimationSuite.scala | 12 +- .../FilterEstimationSuite.scala | 279 +++++++++--------- .../statsEstimation/JoinEstimationSuite.scala | 138 +++++---- .../ProjectEstimationSuite.scala | 70 +++-- .../StatsEstimationTestBase.scala | 10 +- .../command/AnalyzeColumnCommand.scala | 138 ++++++++- .../spark/sql/execution/command/tables.scala | 9 +- .../spark/sql/StatisticsCollectionSuite.scala | 9 +- .../sql/StatisticsCollectionTestBase.scala | 168 +++++++++-- .../spark/sql/hive/HiveExternalCatalog.scala | 63 ++-- .../spark/sql/hive/StatisticsSuite.scala | 162 +++------- 22 files changed, 995 insertions(+), 872 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 95b6fbb0cd61a..f3e67dc4e975c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -21,7 +21,9 @@ import java.net.URI import java.util.Date import scala.collection.mutable +import scala.util.control.NonFatal +import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation @@ -30,7 +32,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ /** @@ -361,7 +363,7 @@ object CatalogTable { case class CatalogStatistics( sizeInBytes: BigInt, rowCount: Option[BigInt] = None, - colStats: Map[String, ColumnStat] = Map.empty) { + colStats: Map[String, CatalogColumnStat] = Map.empty) { /** * Convert [[CatalogStatistics]] to [[Statistics]], and match column stats to attributes based @@ -369,7 +371,8 @@ case class CatalogStatistics( */ def toPlanStats(planOutput: Seq[Attribute], cboEnabled: Boolean): Statistics = { if (cboEnabled && rowCount.isDefined) { - val attrStats = AttributeMap(planOutput.flatMap(a => colStats.get(a.name).map(a -> _))) + val attrStats = AttributeMap(planOutput + .flatMap(a => colStats.get(a.name).map(a -> _.toPlanStat(a.name, a.dataType)))) // Estimate size as number of rows * row size. val size = EstimationUtils.getOutputSize(planOutput, rowCount.get, attrStats) Statistics(sizeInBytes = size, rowCount = rowCount, attributeStats = attrStats) @@ -387,6 +390,143 @@ case class CatalogStatistics( } } +/** + * This class of statistics for a column is used in [[CatalogTable]] to interact with metastore. + */ +case class CatalogColumnStat( + distinctCount: Option[BigInt] = None, + min: Option[String] = None, + max: Option[String] = None, + nullCount: Option[BigInt] = None, + avgLen: Option[Long] = None, + maxLen: Option[Long] = None, + histogram: Option[Histogram] = None) { + + /** + * Returns a map from string to string that can be used to serialize the column stats. + * The key is the name of the column and name of the field (e.g. "colName.distinctCount"), + * and the value is the string representation for the value. + * min/max values are stored as Strings. They can be deserialized using + * [[CatalogColumnStat.fromExternalString]]. + * + * As part of the protocol, the returned map always contains a key called "version". + * Any of the fields that are null (None) won't appear in the map. + */ + def toMap(colName: String): Map[String, String] = { + val map = new scala.collection.mutable.HashMap[String, String] + map.put(s"${colName}.${CatalogColumnStat.KEY_VERSION}", "1") + distinctCount.foreach { v => + map.put(s"${colName}.${CatalogColumnStat.KEY_DISTINCT_COUNT}", v.toString) + } + nullCount.foreach { v => + map.put(s"${colName}.${CatalogColumnStat.KEY_NULL_COUNT}", v.toString) + } + avgLen.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_AVG_LEN}", v.toString) } + maxLen.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_MAX_LEN}", v.toString) } + min.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_MIN_VALUE}", v) } + max.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_MAX_VALUE}", v) } + histogram.foreach { h => + map.put(s"${colName}.${CatalogColumnStat.KEY_HISTOGRAM}", HistogramSerializer.serialize(h)) + } + map.toMap + } + + /** Convert [[CatalogColumnStat]] to [[ColumnStat]]. */ + def toPlanStat( + colName: String, + dataType: DataType): ColumnStat = + ColumnStat( + distinctCount = distinctCount, + min = min.map(CatalogColumnStat.fromExternalString(_, colName, dataType)), + max = max.map(CatalogColumnStat.fromExternalString(_, colName, dataType)), + nullCount = nullCount, + avgLen = avgLen, + maxLen = maxLen, + histogram = histogram) +} + +object CatalogColumnStat extends Logging { + + // List of string keys used to serialize CatalogColumnStat + val KEY_VERSION = "version" + private val KEY_DISTINCT_COUNT = "distinctCount" + private val KEY_MIN_VALUE = "min" + private val KEY_MAX_VALUE = "max" + private val KEY_NULL_COUNT = "nullCount" + private val KEY_AVG_LEN = "avgLen" + private val KEY_MAX_LEN = "maxLen" + private val KEY_HISTOGRAM = "histogram" + + /** + * Converts from string representation of data type to the corresponding Catalyst data type. + */ + def fromExternalString(s: String, name: String, dataType: DataType): Any = { + dataType match { + case BooleanType => s.toBoolean + case DateType => DateTimeUtils.fromJavaDate(java.sql.Date.valueOf(s)) + case TimestampType => DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(s)) + case ByteType => s.toByte + case ShortType => s.toShort + case IntegerType => s.toInt + case LongType => s.toLong + case FloatType => s.toFloat + case DoubleType => s.toDouble + case _: DecimalType => Decimal(s) + // This version of Spark does not use min/max for binary/string types so we ignore it. + case BinaryType | StringType => null + case _ => + throw new AnalysisException("Column statistics deserialization is not supported for " + + s"column $name of data type: $dataType.") + } + } + + /** + * Converts the given value from Catalyst data type to string representation of external + * data type. + */ + def toExternalString(v: Any, colName: String, dataType: DataType): String = { + val externalValue = dataType match { + case DateType => DateTimeUtils.toJavaDate(v.asInstanceOf[Int]) + case TimestampType => DateTimeUtils.toJavaTimestamp(v.asInstanceOf[Long]) + case BooleanType | _: IntegralType | FloatType | DoubleType => v + case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal + // This version of Spark does not use min/max for binary/string types so we ignore it. + case _ => + throw new AnalysisException("Column statistics serialization is not supported for " + + s"column $colName of data type: $dataType.") + } + externalValue.toString + } + + + /** + * Creates a [[CatalogColumnStat]] object from the given map. + * This is used to deserialize column stats from some external storage. + * The serialization side is defined in [[CatalogColumnStat.toMap]]. + */ + def fromMap( + table: String, + colName: String, + map: Map[String, String]): Option[CatalogColumnStat] = { + + try { + Some(CatalogColumnStat( + distinctCount = map.get(s"${colName}.${KEY_DISTINCT_COUNT}").map(v => BigInt(v.toLong)), + min = map.get(s"${colName}.${KEY_MIN_VALUE}"), + max = map.get(s"${colName}.${KEY_MAX_VALUE}"), + nullCount = map.get(s"${colName}.${KEY_NULL_COUNT}").map(v => BigInt(v.toLong)), + avgLen = map.get(s"${colName}.${KEY_AVG_LEN}").map(_.toLong), + maxLen = map.get(s"${colName}.${KEY_MAX_LEN}").map(_.toLong), + histogram = map.get(s"${colName}.${KEY_HISTOGRAM}").map(HistogramSerializer.deserialize) + )) + } catch { + case NonFatal(e) => + logWarning(s"Failed to parse column statistics for column ${colName} in table $table", e) + None + } + } +} + case class CatalogTableType private(name: String) object CatalogTableType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala index 1f20b7661489e..2aa762e2595ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala @@ -187,11 +187,11 @@ object StarSchemaDetection extends PredicateHelper { stats.rowCount match { case Some(rowCount) if rowCount >= 0 => if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) { - val colStats = stats.attributeStats.get(col) - if (colStats.get.nullCount > 0) { + val colStats = stats.attributeStats.get(col).get + if (!colStats.hasCountStats || colStats.nullCount.get > 0) { false } else { - val distinctCount = colStats.get.distinctCount + val distinctCount = colStats.distinctCount.get val relDiff = math.abs((distinctCount.toDouble / rowCount.toDouble) - 1.0d) // ndvMaxErr adjusted based on TPCDS 1TB data results relDiff <= conf.ndvMaxError * 2 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 96b199d7f20b0..b3a48860aa63b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -27,6 +27,7 @@ import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.CatalogColumnStat import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils} @@ -79,11 +80,10 @@ case class Statistics( /** * Statistics collected for a column. * - * 1. Supported data types are defined in `ColumnStat.supportsType`. - * 2. The JVM data type stored in min/max is the internal data type for the corresponding + * 1. The JVM data type stored in min/max is the internal data type for the corresponding * Catalyst data type. For example, the internal type of DateType is Int, and that the internal * type of TimestampType is Long. - * 3. There is no guarantee that the statistics collected are accurate. Approximation algorithms + * 2. There is no guarantee that the statistics collected are accurate. Approximation algorithms * (sketches) might have been used, and the data collected can also be stale. * * @param distinctCount number of distinct values @@ -95,240 +95,32 @@ case class Statistics( * @param histogram histogram of the values */ case class ColumnStat( - distinctCount: BigInt, - min: Option[Any], - max: Option[Any], - nullCount: BigInt, - avgLen: Long, - maxLen: Long, + distinctCount: Option[BigInt] = None, + min: Option[Any] = None, + max: Option[Any] = None, + nullCount: Option[BigInt] = None, + avgLen: Option[Long] = None, + maxLen: Option[Long] = None, histogram: Option[Histogram] = None) { - // We currently don't store min/max for binary/string type. This can change in the future and - // then we need to remove this require. - require(min.isEmpty || (!min.get.isInstanceOf[Array[Byte]] && !min.get.isInstanceOf[String])) - require(max.isEmpty || (!max.get.isInstanceOf[Array[Byte]] && !max.get.isInstanceOf[String])) - - /** - * Returns a map from string to string that can be used to serialize the column stats. - * The key is the name of the field (e.g. "distinctCount" or "min"), and the value is the string - * representation for the value. min/max values are converted to the external data type. For - * example, for DateType we store java.sql.Date, and for TimestampType we store - * java.sql.Timestamp. The deserialization side is defined in [[ColumnStat.fromMap]]. - * - * As part of the protocol, the returned map always contains a key called "version". - * In the case min/max values are null (None), they won't appear in the map. - */ - def toMap(colName: String, dataType: DataType): Map[String, String] = { - val map = new scala.collection.mutable.HashMap[String, String] - map.put(ColumnStat.KEY_VERSION, "1") - map.put(ColumnStat.KEY_DISTINCT_COUNT, distinctCount.toString) - map.put(ColumnStat.KEY_NULL_COUNT, nullCount.toString) - map.put(ColumnStat.KEY_AVG_LEN, avgLen.toString) - map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString) - min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, toExternalString(v, colName, dataType)) } - max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, toExternalString(v, colName, dataType)) } - histogram.foreach { h => map.put(ColumnStat.KEY_HISTOGRAM, HistogramSerializer.serialize(h)) } - map.toMap - } - - /** - * Converts the given value from Catalyst data type to string representation of external - * data type. - */ - private def toExternalString(v: Any, colName: String, dataType: DataType): String = { - val externalValue = dataType match { - case DateType => DateTimeUtils.toJavaDate(v.asInstanceOf[Int]) - case TimestampType => DateTimeUtils.toJavaTimestamp(v.asInstanceOf[Long]) - case BooleanType | _: IntegralType | FloatType | DoubleType => v - case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal - // This version of Spark does not use min/max for binary/string types so we ignore it. - case _ => - throw new AnalysisException("Column statistics deserialization is not supported for " + - s"column $colName of data type: $dataType.") - } - externalValue.toString - } - -} + // Are distinctCount and nullCount statistics defined? + val hasCountStats = distinctCount.isDefined && nullCount.isDefined + // Are min and max statistics defined? + val hasMinMaxStats = min.isDefined && max.isDefined -object ColumnStat extends Logging { - - // List of string keys used to serialize ColumnStat - val KEY_VERSION = "version" - private val KEY_DISTINCT_COUNT = "distinctCount" - private val KEY_MIN_VALUE = "min" - private val KEY_MAX_VALUE = "max" - private val KEY_NULL_COUNT = "nullCount" - private val KEY_AVG_LEN = "avgLen" - private val KEY_MAX_LEN = "maxLen" - private val KEY_HISTOGRAM = "histogram" - - /** Returns true iff the we support gathering column statistics on column of the given type. */ - def supportsType(dataType: DataType): Boolean = dataType match { - case _: IntegralType => true - case _: DecimalType => true - case DoubleType | FloatType => true - case BooleanType => true - case DateType => true - case TimestampType => true - case BinaryType | StringType => true - case _ => false - } - - /** Returns true iff the we support gathering histogram on column of the given type. */ - def supportsHistogram(dataType: DataType): Boolean = dataType match { - case _: IntegralType => true - case _: DecimalType => true - case DoubleType | FloatType => true - case DateType => true - case TimestampType => true - case _ => false - } - - /** - * Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats - * from some external storage. The serialization side is defined in [[ColumnStat.toMap]]. - */ - def fromMap(table: String, field: StructField, map: Map[String, String]): Option[ColumnStat] = { - try { - Some(ColumnStat( - distinctCount = BigInt(map(KEY_DISTINCT_COUNT).toLong), - // Note that flatMap(Option.apply) turns Option(null) into None. - min = map.get(KEY_MIN_VALUE) - .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply), - max = map.get(KEY_MAX_VALUE) - .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply), - nullCount = BigInt(map(KEY_NULL_COUNT).toLong), - avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong, - maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong, - histogram = map.get(KEY_HISTOGRAM).map(HistogramSerializer.deserialize) - )) - } catch { - case NonFatal(e) => - logWarning(s"Failed to parse column statistics for column ${field.name} in table $table", e) - None - } - } - - /** - * Converts from string representation of external data type to the corresponding Catalyst data - * type. - */ - private def fromExternalString(s: String, name: String, dataType: DataType): Any = { - dataType match { - case BooleanType => s.toBoolean - case DateType => DateTimeUtils.fromJavaDate(java.sql.Date.valueOf(s)) - case TimestampType => DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(s)) - case ByteType => s.toByte - case ShortType => s.toShort - case IntegerType => s.toInt - case LongType => s.toLong - case FloatType => s.toFloat - case DoubleType => s.toDouble - case _: DecimalType => Decimal(s) - // This version of Spark does not use min/max for binary/string types so we ignore it. - case BinaryType | StringType => null - case _ => - throw new AnalysisException("Column statistics deserialization is not supported for " + - s"column $name of data type: $dataType.") - } - } - - /** - * Constructs an expression to compute column statistics for a given column. - * - * The expression should create a single struct column with the following schema: - * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long, - * distinctCountsForIntervals: Array[Long] - * - * Together with [[rowToColumnStat]], this function is used to create [[ColumnStat]] and - * as a result should stay in sync with it. - */ - def statExprs( - col: Attribute, - conf: SQLConf, - colPercentiles: AttributeMap[ArrayData]): CreateNamedStruct = { - def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr => - expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() } - }) - val one = Literal(1, LongType) - - // the approximate ndv (num distinct value) should never be larger than the number of rows - val numNonNulls = if (col.nullable) Count(col) else Count(one) - val ndv = Least(Seq(HyperLogLogPlusPlus(col, conf.ndvMaxError), numNonNulls)) - val numNulls = Subtract(Count(one), numNonNulls) - val defaultSize = Literal(col.dataType.defaultSize, LongType) - val nullArray = Literal(null, ArrayType(LongType)) - - def fixedLenTypeStruct: CreateNamedStruct = { - val genHistogram = - ColumnStat.supportsHistogram(col.dataType) && colPercentiles.contains(col) - val intervalNdvsExpr = if (genHistogram) { - ApproxCountDistinctForIntervals(col, - Literal(colPercentiles(col), ArrayType(col.dataType)), conf.ndvMaxError) - } else { - nullArray - } - // For fixed width types, avg size should be the same as max size. - struct(ndv, Cast(Min(col), col.dataType), Cast(Max(col), col.dataType), numNulls, - defaultSize, defaultSize, intervalNdvsExpr) - } - - col.dataType match { - case _: IntegralType => fixedLenTypeStruct - case _: DecimalType => fixedLenTypeStruct - case DoubleType | FloatType => fixedLenTypeStruct - case BooleanType => fixedLenTypeStruct - case DateType => fixedLenTypeStruct - case TimestampType => fixedLenTypeStruct - case BinaryType | StringType => - // For string and binary type, we don't compute min, max or histogram - val nullLit = Literal(null, col.dataType) - struct( - ndv, nullLit, nullLit, numNulls, - // Set avg/max size to default size if all the values are null or there is no value. - Coalesce(Seq(Ceil(Average(Length(col))), defaultSize)), - Coalesce(Seq(Cast(Max(Length(col)), LongType), defaultSize)), - nullArray) - case _ => - throw new AnalysisException("Analyzing column statistics is not supported for column " + - s"${col.name} of data type: ${col.dataType}.") - } - } - - /** Convert a struct for column stats (defined in `statExprs`) into [[ColumnStat]]. */ - def rowToColumnStat( - row: InternalRow, - attr: Attribute, - rowCount: Long, - percentiles: Option[ArrayData]): ColumnStat = { - // The first 6 fields are basic column stats, the 7th is ndvs for histogram bins. - val cs = ColumnStat( - distinctCount = BigInt(row.getLong(0)), - // for string/binary min/max, get should return null - min = Option(row.get(1, attr.dataType)), - max = Option(row.get(2, attr.dataType)), - nullCount = BigInt(row.getLong(3)), - avgLen = row.getLong(4), - maxLen = row.getLong(5) - ) - if (row.isNullAt(6)) { - cs - } else { - val ndvs = row.getArray(6).toLongArray() - assert(percentiles.get.numElements() == ndvs.length + 1) - val endpoints = percentiles.get.toArray[Any](attr.dataType).map(_.toString.toDouble) - // Construct equi-height histogram - val bins = ndvs.zipWithIndex.map { case (ndv, i) => - HistogramBin(endpoints(i), endpoints(i + 1), ndv) - } - val nonNullRows = rowCount - cs.nullCount - val histogram = Histogram(nonNullRows.toDouble / ndvs.length, bins) - cs.copy(histogram = Some(histogram)) - } - } + // Are avgLen and maxLen statistics defined? + val hasLenStats = avgLen.isDefined && maxLen.isDefined + def toCatalogColumnStat(colName: String, dataType: DataType): CatalogColumnStat = + CatalogColumnStat( + distinctCount = distinctCount, + min = min.map(CatalogColumnStat.toExternalString(_, colName, dataType)), + max = max.map(CatalogColumnStat.toExternalString(_, colName, dataType)), + nullCount = nullCount, + avgLen = avgLen, + maxLen = maxLen, + histogram = histogram) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index c41fac4015ec0..111c594a53e52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -32,13 +32,15 @@ object AggregateEstimation { val childStats = agg.child.stats // Check if we have column stats for all group-by columns. val colStatsExist = agg.groupingExpressions.forall { e => - e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute]) + e.isInstanceOf[Attribute] && + childStats.attributeStats.get(e.asInstanceOf[Attribute]).exists(_.hasCountStats) } if (rowCountsExist(agg.child) && colStatsExist) { // Multiply distinct counts of group-by columns. This is an upper bound, which assumes // the data contains all combinations of distinct values of group-by columns. var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))( - (res, expr) => res * childStats.attributeStats(expr.asInstanceOf[Attribute]).distinctCount) + (res, expr) => res * + childStats.attributeStats(expr.asInstanceOf[Attribute]).distinctCount.get) outputRows = if (agg.groupingExpressions.isEmpty) { // If there's no group-by columns, the output is a single row containing values of aggregate diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index d793f77413d18..0f147f0ffb135 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.math.BigDecimal.RoundingMode @@ -38,9 +39,18 @@ object EstimationUtils { } } + /** Check if each attribute has column stat containing distinct and null counts + * in the corresponding statistic. */ + def columnStatsWithCountsExist(statsAndAttr: (Statistics, Attribute)*): Boolean = { + statsAndAttr.forall { case (stats, attr) => + stats.attributeStats.get(attr).map(_.hasCountStats).getOrElse(false) + } + } + + /** Statistics for a Column containing only NULLs. */ def nullColumnStat(dataType: DataType, rowCount: BigInt): ColumnStat = { - ColumnStat(distinctCount = 0, min = None, max = None, nullCount = rowCount, - avgLen = dataType.defaultSize, maxLen = dataType.defaultSize) + ColumnStat(distinctCount = Some(0), min = None, max = None, nullCount = Some(rowCount), + avgLen = Some(dataType.defaultSize), maxLen = Some(dataType.defaultSize)) } /** @@ -70,13 +80,13 @@ object EstimationUtils { // We assign a generic overhead for a Row object, the actual overhead is different for different // Row format. val sizePerRow = 8 + attributes.map { attr => - if (attrStats.contains(attr)) { + if (attrStats.get(attr).map(_.avgLen.isDefined).getOrElse(false)) { attr.dataType match { case StringType => // UTF8String: base + offset + numBytes - attrStats(attr).avgLen + 8 + 4 + attrStats(attr).avgLen.get + 8 + 4 case _ => - attrStats(attr).avgLen + attrStats(attr).avgLen.get } } else { attr.dataType.defaultSize diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 4cc32de2d32d7..0538c9d88584b 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -225,7 +225,7 @@ case class FilterEstimation(plan: Filter) extends Logging { attr: Attribute, isNull: Boolean, update: Boolean): Option[Double] = { - if (!colStatsMap.contains(attr)) { + if (!colStatsMap.contains(attr) || !colStatsMap(attr).hasCountStats) { logDebug("[CBO] No statistics for " + attr) return None } @@ -234,14 +234,14 @@ case class FilterEstimation(plan: Filter) extends Logging { val nullPercent: Double = if (rowCountValue == 0) { 0 } else { - (BigDecimal(colStat.nullCount) / BigDecimal(rowCountValue)).toDouble + (BigDecimal(colStat.nullCount.get) / BigDecimal(rowCountValue)).toDouble } if (update) { val newStats = if (isNull) { - colStat.copy(distinctCount = 0, min = None, max = None) + colStat.copy(distinctCount = Some(0), min = None, max = None) } else { - colStat.copy(nullCount = 0) + colStat.copy(nullCount = Some(0)) } colStatsMap.update(attr, newStats) } @@ -322,17 +322,21 @@ case class FilterEstimation(plan: Filter) extends Logging { // value. val newStats = attr.dataType match { case StringType | BinaryType => - colStat.copy(distinctCount = 1, nullCount = 0) + colStat.copy(distinctCount = Some(1), nullCount = Some(0)) case _ => - colStat.copy(distinctCount = 1, min = Some(literal.value), - max = Some(literal.value), nullCount = 0) + colStat.copy(distinctCount = Some(1), min = Some(literal.value), + max = Some(literal.value), nullCount = Some(0)) } colStatsMap.update(attr, newStats) } if (colStat.histogram.isEmpty) { - // returns 1/ndv if there is no histogram - Some(1.0 / colStat.distinctCount.toDouble) + if (!colStat.distinctCount.isEmpty) { + // returns 1/ndv if there is no histogram + Some(1.0 / colStat.distinctCount.get.toDouble) + } else { + None + } } else { Some(computeEqualityPossibilityByHistogram(literal, colStat)) } @@ -378,13 +382,13 @@ case class FilterEstimation(plan: Filter) extends Logging { attr: Attribute, hSet: Set[Any], update: Boolean): Option[Double] = { - if (!colStatsMap.contains(attr)) { + if (!colStatsMap.hasDistinctCount(attr)) { logDebug("[CBO] No statistics for " + attr) return None } val colStat = colStatsMap(attr) - val ndv = colStat.distinctCount + val ndv = colStat.distinctCount.get val dataType = attr.dataType var newNdv = ndv @@ -407,8 +411,8 @@ case class FilterEstimation(plan: Filter) extends Logging { // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. newNdv = ndv.min(BigInt(validQuerySet.size)) if (update) { - val newStats = colStat.copy(distinctCount = newNdv, min = Some(newMin), - max = Some(newMax), nullCount = 0) + val newStats = colStat.copy(distinctCount = Some(newNdv), min = Some(newMin), + max = Some(newMax), nullCount = Some(0)) colStatsMap.update(attr, newStats) } @@ -416,7 +420,7 @@ case class FilterEstimation(plan: Filter) extends Logging { case StringType | BinaryType => newNdv = ndv.min(BigInt(hSet.size)) if (update) { - val newStats = colStat.copy(distinctCount = newNdv, nullCount = 0) + val newStats = colStat.copy(distinctCount = Some(newNdv), nullCount = Some(0)) colStatsMap.update(attr, newStats) } } @@ -443,12 +447,17 @@ case class FilterEstimation(plan: Filter) extends Logging { literal: Literal, update: Boolean): Option[Double] = { + if (!colStatsMap.hasMinMaxStats(attr) || !colStatsMap.hasDistinctCount(attr)) { + logDebug("[CBO] No statistics for " + attr) + return None + } + val colStat = colStatsMap(attr) val statsInterval = ValueInterval(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericValueInterval] val max = statsInterval.max val min = statsInterval.min - val ndv = colStat.distinctCount.toDouble + val ndv = colStat.distinctCount.get.toDouble // determine the overlapping degree between predicate interval and column's interval val numericLiteral = EstimationUtils.toDouble(literal.value, literal.dataType) @@ -520,8 +529,8 @@ case class FilterEstimation(plan: Filter) extends Logging { newMax = newValue } - val newStats = colStat.copy(distinctCount = ceil(ndv * percent), - min = newMin, max = newMax, nullCount = 0) + val newStats = colStat.copy(distinctCount = Some(ceil(ndv * percent)), + min = newMin, max = newMax, nullCount = Some(0)) colStatsMap.update(attr, newStats) } @@ -637,11 +646,11 @@ case class FilterEstimation(plan: Filter) extends Logging { attrRight: Attribute, update: Boolean): Option[Double] = { - if (!colStatsMap.contains(attrLeft)) { + if (!colStatsMap.hasCountStats(attrLeft)) { logDebug("[CBO] No statistics for " + attrLeft) return None } - if (!colStatsMap.contains(attrRight)) { + if (!colStatsMap.hasCountStats(attrRight)) { logDebug("[CBO] No statistics for " + attrRight) return None } @@ -668,7 +677,7 @@ case class FilterEstimation(plan: Filter) extends Logging { val minRight = statsIntervalRight.min // determine the overlapping degree between predicate interval and column's interval - val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0) + val allNotNull = (colStatLeft.nullCount.get == 0) && (colStatRight.nullCount.get == 0) val (noOverlap: Boolean, completeOverlap: Boolean) = op match { // Left < Right or Left <= Right // - no overlap: @@ -707,14 +716,14 @@ case class FilterEstimation(plan: Filter) extends Logging { case _: EqualTo => ((maxLeft < minRight) || (maxRight < minLeft), (minLeft == minRight) && (maxLeft == maxRight) && allNotNull - && (colStatLeft.distinctCount == colStatRight.distinctCount) + && (colStatLeft.distinctCount.get == colStatRight.distinctCount.get) ) case _: EqualNullSafe => // For null-safe equality, we use a very restrictive condition to evaluate its overlap. // If null values exists, we set it to partial overlap. (((maxLeft < minRight) || (maxRight < minLeft)) && allNotNull, (minLeft == minRight) && (maxLeft == maxRight) && allNotNull - && (colStatLeft.distinctCount == colStatRight.distinctCount) + && (colStatLeft.distinctCount.get == colStatRight.distinctCount.get) ) } @@ -731,9 +740,9 @@ case class FilterEstimation(plan: Filter) extends Logging { if (update) { // Need to adjust new min/max after the filter condition is applied - val ndvLeft = BigDecimal(colStatLeft.distinctCount) + val ndvLeft = BigDecimal(colStatLeft.distinctCount.get) val newNdvLeft = ceil(ndvLeft * percent) - val ndvRight = BigDecimal(colStatRight.distinctCount) + val ndvRight = BigDecimal(colStatRight.distinctCount.get) val newNdvRight = ceil(ndvRight * percent) var newMaxLeft = colStatLeft.max @@ -817,10 +826,10 @@ case class FilterEstimation(plan: Filter) extends Logging { } } - val newStatsLeft = colStatLeft.copy(distinctCount = newNdvLeft, min = newMinLeft, + val newStatsLeft = colStatLeft.copy(distinctCount = Some(newNdvLeft), min = newMinLeft, max = newMaxLeft) colStatsMap(attrLeft) = newStatsLeft - val newStatsRight = colStatRight.copy(distinctCount = newNdvRight, min = newMinRight, + val newStatsRight = colStatRight.copy(distinctCount = Some(newNdvRight), min = newMinRight, max = newMaxRight) colStatsMap(attrRight) = newStatsRight } @@ -849,17 +858,35 @@ case class ColumnStatsMap(originalMap: AttributeMap[ColumnStat]) { def contains(a: Attribute): Boolean = updatedMap.contains(a.exprId) || originalMap.contains(a) /** - * Gets column stat for the given attribute. Prefer the column stat in updatedMap than that in - * originalMap, because updatedMap has the latest (updated) column stats. + * Gets an Option of column stat for the given attribute. + * Prefer the column stat in updatedMap than that in originalMap, + * because updatedMap has the latest (updated) column stats. */ - def apply(a: Attribute): ColumnStat = { + def get(a: Attribute): Option[ColumnStat] = { if (updatedMap.contains(a.exprId)) { - updatedMap(a.exprId)._2 + updatedMap.get(a.exprId).map(_._2) } else { - originalMap(a) + originalMap.get(a) } } + def hasCountStats(a: Attribute): Boolean = + get(a).map(_.hasCountStats).getOrElse(false) + + def hasDistinctCount(a: Attribute): Boolean = + get(a).map(_.distinctCount.isDefined).getOrElse(false) + + def hasMinMaxStats(a: Attribute): Boolean = + get(a).map(_.hasCountStats).getOrElse(false) + + /** + * Gets column stat for the given attribute. Prefer the column stat in updatedMap than that in + * originalMap, because updatedMap has the latest (updated) column stats. + */ + def apply(a: Attribute): ColumnStat = { + get(a).get + } + /** Updates column stats in updatedMap. */ def update(a: Attribute, stats: ColumnStat): Unit = updatedMap.update(a.exprId, a -> stats) @@ -871,11 +898,14 @@ case class ColumnStatsMap(originalMap: AttributeMap[ColumnStat]) { : AttributeMap[ColumnStat] = { val newColumnStats = originalMap.map { case (attr, oriColStat) => val colStat = updatedMap.get(attr.exprId).map(_._2).getOrElse(oriColStat) - val newNdv = if (colStat.distinctCount > 1) { + val newNdv = if (colStat.distinctCount.isEmpty) { + // No NDV in the original stats. + None + } else if (colStat.distinctCount.get > 1) { // Update ndv based on the overall filter selectivity: scale down ndv if the number of rows // decreases; otherwise keep it unchanged. - EstimationUtils.updateNdv(oldNumRows = rowsBeforeFilter, - newNumRows = rowsAfterFilter, oldNdv = oriColStat.distinctCount) + Some(EstimationUtils.updateNdv(oldNumRows = rowsBeforeFilter, + newNumRows = rowsAfterFilter, oldNdv = oriColStat.distinctCount.get)) } else { // no need to scale down since it is already down to 1 (for skewed distribution case) colStat.distinctCount diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index f0294a4246703..2543e38a92c0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -85,7 +85,8 @@ case class JoinEstimation(join: Join) extends Logging { // 3. Update statistics based on the output of join val inputAttrStats = AttributeMap( leftStats.attributeStats.toSeq ++ rightStats.attributeStats.toSeq) - val attributesWithStat = join.output.filter(a => inputAttrStats.contains(a)) + val attributesWithStat = join.output.filter(a => + inputAttrStats.get(a).map(_.hasCountStats).getOrElse(false)) val (fromLeft, fromRight) = attributesWithStat.partition(join.left.outputSet.contains(_)) val outputStats: Seq[(Attribute, ColumnStat)] = if (outputRows == 0) { @@ -106,10 +107,10 @@ case class JoinEstimation(join: Join) extends Logging { case FullOuter => fromLeft.map { a => val oriColStat = inputAttrStats(a) - (a, oriColStat.copy(nullCount = oriColStat.nullCount + rightRows)) + (a, oriColStat.copy(nullCount = Some(oriColStat.nullCount.get + rightRows))) } ++ fromRight.map { a => val oriColStat = inputAttrStats(a) - (a, oriColStat.copy(nullCount = oriColStat.nullCount + leftRows)) + (a, oriColStat.copy(nullCount = Some(oriColStat.nullCount.get + leftRows))) } case _ => assert(joinType == Inner || joinType == Cross) @@ -219,19 +220,27 @@ case class JoinEstimation(join: Join) extends Logging { private def computeByNdv( leftKey: AttributeReference, rightKey: AttributeReference, - newMin: Option[Any], - newMax: Option[Any]): (BigInt, ColumnStat) = { + min: Option[Any], + max: Option[Any]): (BigInt, ColumnStat) = { val leftKeyStat = leftStats.attributeStats(leftKey) val rightKeyStat = rightStats.attributeStats(rightKey) - val maxNdv = leftKeyStat.distinctCount.max(rightKeyStat.distinctCount) + val maxNdv = leftKeyStat.distinctCount.get.max(rightKeyStat.distinctCount.get) // Compute cardinality by the basic formula. val card = BigDecimal(leftStats.rowCount.get * rightStats.rowCount.get) / BigDecimal(maxNdv) // Get the intersected column stat. - val newNdv = leftKeyStat.distinctCount.min(rightKeyStat.distinctCount) - val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen) - val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2 - val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) + val newNdv = Some(leftKeyStat.distinctCount.get.min(rightKeyStat.distinctCount.get)) + val newMaxLen = if (leftKeyStat.maxLen.isDefined && rightKeyStat.maxLen.isDefined) { + Some(math.min(leftKeyStat.maxLen.get, rightKeyStat.maxLen.get)) + } else { + None + } + val newAvgLen = if (leftKeyStat.avgLen.isDefined && rightKeyStat.avgLen.isDefined) { + Some((leftKeyStat.avgLen.get + rightKeyStat.avgLen.get) / 2) + } else { + None + } + val newStats = ColumnStat(newNdv, min, max, Some(0), newAvgLen, newMaxLen) (ceil(card), newStats) } @@ -267,9 +276,17 @@ case class JoinEstimation(join: Join) extends Logging { val leftKeyStat = leftStats.attributeStats(leftKey) val rightKeyStat = rightStats.attributeStats(rightKey) - val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen) - val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2 - val newStats = ColumnStat(ceil(totalNdv), newMin, newMax, 0, newAvgLen, newMaxLen) + val newMaxLen = if (leftKeyStat.maxLen.isDefined && rightKeyStat.maxLen.isDefined) { + Some(math.min(leftKeyStat.maxLen.get, rightKeyStat.maxLen.get)) + } else { + None + } + val newAvgLen = if (leftKeyStat.avgLen.isDefined && rightKeyStat.avgLen.isDefined) { + Some((leftKeyStat.avgLen.get + rightKeyStat.avgLen.get) / 2) + } else { + None + } + val newStats = ColumnStat(Some(ceil(totalNdv)), newMin, newMax, Some(0), newAvgLen, newMaxLen) (ceil(card), newStats) } @@ -292,10 +309,14 @@ case class JoinEstimation(join: Join) extends Logging { } else { val oldColStat = oldAttrStats(a) val oldNdv = oldColStat.distinctCount - val newNdv = if (join.left.outputSet.contains(a)) { - updateNdv(oldNumRows = leftRows, newNumRows = outputRows, oldNdv = oldNdv) + val newNdv = if (oldNdv.isDefined) { + Some(if (join.left.outputSet.contains(a)) { + updateNdv(oldNumRows = leftRows, newNumRows = outputRows, oldNdv = oldNdv.get) + } else { + updateNdv(oldNumRows = rightRows, newNumRows = outputRows, oldNdv = oldNdv.get) + }) } else { - updateNdv(oldNumRows = rightRows, newNumRows = outputRows, oldNdv = oldNdv) + None } val newColStat = oldColStat.copy(distinctCount = newNdv) // TODO: support nullCount updates for specific outer joins @@ -313,7 +334,7 @@ case class JoinEstimation(join: Join) extends Logging { // Note: join keys from EqualNullSafe also fall into this case (Coalesce), consider to // support it in the future by using `nullCount` in column stats. case (lk: AttributeReference, rk: AttributeReference) - if columnStatsExist((leftStats, lk), (rightStats, rk)) => (lk, rk) + if columnStatsWithCountsExist((leftStats, lk), (rightStats, rk)) => (lk, rk) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index 2fb587d50a4cb..565b0a10154a8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -62,24 +62,15 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { } } - /** Set up tables and columns for testing */ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - attr("t1.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t1.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t2.k-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t3.v-1-100") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t4.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t4.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t5.k-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t5.v-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4) + attr("t1.k-1-2") -> rangeColumnStat(2, 0), + attr("t1.v-1-10") -> rangeColumnStat(10, 0), + attr("t2.k-1-5") -> rangeColumnStat(5, 0), + attr("t3.v-1-100") -> rangeColumnStat(100, 0), + attr("t4.k-1-2") -> rangeColumnStat(2, 0), + attr("t4.v-1-10") -> rangeColumnStat(10, 0), + attr("t5.k-1-5") -> rangeColumnStat(5, 0), + attr("t5.v-1-5") -> rangeColumnStat(5, 0) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala index ada6e2a43ea0f..d4d23ad69b2c2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala @@ -68,88 +68,56 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( // F1 (fact table) - attr("f1_fk1") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_fk2") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_fk3") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_c1") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_c2") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk1") -> rangeColumnStat(100, 0), + attr("f1_fk2") -> rangeColumnStat(100, 0), + attr("f1_fk3") -> rangeColumnStat(100, 0), + attr("f1_c1") -> rangeColumnStat(100, 0), + attr("f1_c2") -> rangeColumnStat(100, 0), // D1 (dimension) - attr("d1_pk") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c2") -> ColumnStat(distinctCount = 50, min = Some(1), max = Some(50), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c3") -> ColumnStat(distinctCount = 50, min = Some(1), max = Some(50), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_pk") -> rangeColumnStat(100, 0), + attr("d1_c2") -> rangeColumnStat(50, 0), + attr("d1_c3") -> rangeColumnStat(50, 0), // D2 (dimension) - attr("d2_pk") -> ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d2_c2") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d2_c3") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_pk") -> rangeColumnStat(20, 0), + attr("d2_c2") -> rangeColumnStat(10, 0), + attr("d2_c3") -> rangeColumnStat(10, 0), // D3 (dimension) - attr("d3_pk") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_pk") -> rangeColumnStat(10, 0), + attr("d3_c2") -> rangeColumnStat(5, 0), + attr("d3_c3") -> rangeColumnStat(5, 0), // T1 (regular table i.e. outside star) - attr("t1_c1") -> ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t1_c2") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t1_c3") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t1_c1") -> rangeColumnStat(20, 1), + attr("t1_c2") -> rangeColumnStat(10, 1), + attr("t1_c3") -> rangeColumnStat(10, 1), // T2 (regular table) - attr("t2_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t2_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t2_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t2_c1") -> rangeColumnStat(5, 1), + attr("t2_c2") -> rangeColumnStat(5, 1), + attr("t2_c3") -> rangeColumnStat(5, 1), // T3 (regular table) - attr("t3_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t3_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t3_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t3_c1") -> rangeColumnStat(5, 1), + attr("t3_c2") -> rangeColumnStat(5, 1), + attr("t3_c3") -> rangeColumnStat(5, 1), // T4 (regular table) - attr("t4_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t4_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t4_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t4_c1") -> rangeColumnStat(5, 1), + attr("t4_c2") -> rangeColumnStat(5, 1), + attr("t4_c3") -> rangeColumnStat(5, 1), // T5 (regular table) - attr("t5_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t5_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t5_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t5_c1") -> rangeColumnStat(5, 1), + attr("t5_c2") -> rangeColumnStat(5, 1), + attr("t5_c3") -> rangeColumnStat(5, 1), // T6 (regular table) - attr("t6_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t6_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t6_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4) + attr("t6_c1") -> rangeColumnStat(5, 1), + attr("t6_c2") -> rangeColumnStat(5, 1), + attr("t6_c3") -> rangeColumnStat(5, 1) )) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala index 777c5637201ed..4e0883e91e84a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala @@ -70,59 +70,40 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { // Tables' cardinality: f1 > d3 > d1 > d2 > s3 private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( // F1 - attr("f1_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk1") -> rangeColumnStat(3, 0), + attr("f1_fk2") -> rangeColumnStat(3, 0), + attr("f1_fk3") -> rangeColumnStat(4, 0), + attr("f1_c4") -> rangeColumnStat(4, 0), // D1 - attr("d1_pk1") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_pk1") -> rangeColumnStat(4, 0), + attr("d1_c2") -> rangeColumnStat(3, 0), + attr("d1_c3") -> rangeColumnStat(4, 0), + attr("d1_c4") -> ColumnStat(distinctCount = Some(2), min = Some("2"), max = Some("3"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), // D2 - attr("d2_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("d2_pk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d2_c3") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d2_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_c2") -> ColumnStat(distinctCount = Some(3), min = Some("1"), max = Some("3"), + nullCount = Some(1), avgLen = Some(4), maxLen = Some(4)), + attr("d2_pk1") -> rangeColumnStat(3, 0), + attr("d2_c3") -> rangeColumnStat(3, 0), + attr("d2_c4") -> ColumnStat(distinctCount = Some(2), min = Some("3"), max = Some("4"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), // D3 - attr("d3_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_pk1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_fk1") -> rangeColumnStat(3, 0), + attr("d3_c2") -> rangeColumnStat(3, 0), + attr("d3_pk1") -> rangeColumnStat(5, 0), + attr("d3_c4") -> ColumnStat(distinctCount = Some(2), min = Some("2"), max = Some("3"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), // S3 - attr("s3_pk1") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("s3_c2") -> ColumnStat(distinctCount = 1, min = Some(3), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("s3_c3") -> ColumnStat(distinctCount = 1, min = Some(3), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("s3_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("s3_pk1") -> rangeColumnStat(2, 0), + attr("s3_c2") -> rangeColumnStat(1, 0), + attr("s3_c3") -> rangeColumnStat(1, 0), + attr("s3_c4") -> ColumnStat(distinctCount = Some(2), min = Some("3"), max = Some("4"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), // F11 - attr("f11_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f11_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f11_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f11_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4) + attr("f11_fk1") -> rangeColumnStat(3, 0), + attr("f11_fk2") -> rangeColumnStat(3, 0), + attr("f11_fk3") -> rangeColumnStat(4, 0), + attr("f11_c4") -> rangeColumnStat(4, 0) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala index 23f95a6cc2ac2..8213d568fe85e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala @@ -29,16 +29,16 @@ class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest { /** Columns for testing */ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - attr("key11") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key12") -> ColumnStat(distinctCount = 4, min = Some(10), max = Some(40), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key21") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key22") -> ColumnStat(distinctCount = 2, min = Some(10), max = Some(20), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key31") -> ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 0, - avgLen = 4, maxLen = 4) + attr("key11") -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key12") -> ColumnStat(distinctCount = Some(4), min = Some(10), max = Some(40), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key21") -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key22") -> ColumnStat(distinctCount = Some(2), min = Some(10), max = Some(20), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key31") -> ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) @@ -63,8 +63,8 @@ class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest { tableRowCount = 6, groupByColumns = Seq("key21", "key22"), // Row count = product of ndv - expectedOutputRowCount = nameToColInfo("key21")._2.distinctCount * nameToColInfo("key22")._2 - .distinctCount) + expectedOutputRowCount = nameToColInfo("key21")._2.distinctCount.get * + nameToColInfo("key22")._2.distinctCount.get) } test("empty group-by column") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 7d532ff343178..953094cb0dd52 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.types.IntegerType class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { val attribute = attr("key") - val colStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStat = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val plan = StatsTestPlan( outputList = Seq(attribute), @@ -116,13 +116,17 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { sizeInBytes = 40, rowCount = Some(10), attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4)))) + AttributeReference("c1", IntegerType)() -> ColumnStat(distinctCount = Some(10), + min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))))) val expectedCboStats = Statistics( sizeInBytes = 4, rowCount = Some(1), attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4)))) + AttributeReference("c1", IntegerType)() -> ColumnStat(distinctCount = Some(10), + min = Some(5), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))))) val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats) checkStats( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 2b1fe987a7960..43440d51dede6 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -37,59 +37,61 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 val attrInt = AttributeReference("cint", IntegerType)() - val colStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cbool has only 2 distinct values val attrBool = AttributeReference("cbool", BooleanType)() - val colStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1) + val colStatBool = ColumnStat(distinctCount = Some(2), min = Some(false), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)) // column cdate has 10 values from 2017-01-01 through 2017-01-10. val dMin = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01")) val dMax = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-10")) val attrDate = AttributeReference("cdate", DateType)() - val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatDate = ColumnStat(distinctCount = Some(10), + min = Some(dMin), max = Some(dMax), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. val decMin = Decimal("0.200000000000000000") val decMax = Decimal("0.800000000000000000") val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() - val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), - nullCount = 0, avgLen = 8, maxLen = 8) + val colStatDecimal = ColumnStat(distinctCount = Some(4), + min = Some(decMin), max = Some(decMax), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) // column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 val attrDouble = AttributeReference("cdouble", DoubleType)() - val colStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), - nullCount = 0, avgLen = 8, maxLen = 8) + val colStatDouble = ColumnStat(distinctCount = Some(10), min = Some(1.0), max = Some(10.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) // column cstring has 10 String values: // "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9" val attrString = AttributeReference("cstring", StringType)() - val colStatString = ColumnStat(distinctCount = 10, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2) + val colStatString = ColumnStat(distinctCount = Some(10), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)) // column cint2 has values: 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 // Hence, distinctCount:10, min:7, max:16, nullCount:0, avgLen:4, maxLen:4 // This column is created to test "cint < cint2 val attrInt2 = AttributeReference("cint2", IntegerType)() - val colStatInt2 = ColumnStat(distinctCount = 10, min = Some(7), max = Some(16), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt2 = ColumnStat(distinctCount = Some(10), min = Some(7), max = Some(16), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cint3 has values: 30, 31, 32, 33, 34, 35, 36, 37, 38, 39 // Hence, distinctCount:10, min:30, max:39, nullCount:0, avgLen:4, maxLen:4 // This column is created to test "cint = cint3 without overlap at all. val attrInt3 = AttributeReference("cint3", IntegerType)() - val colStatInt3 = ColumnStat(distinctCount = 10, min = Some(30), max = Some(39), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt3 = ColumnStat(distinctCount = Some(10), min = Some(30), max = Some(39), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cint4 has values in the range from 1 to 10 // distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 // This column is created to test complete overlap val attrInt4 = AttributeReference("cint4", IntegerType)() - val colStatInt4 = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt4 = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cintHgm has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 with histogram. // Note that cintHgm has an even distribution with histogram information built. @@ -98,8 +100,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val hgmInt = Histogram(2.0, Array(HistogramBin(1.0, 2.0, 2), HistogramBin(2.0, 4.0, 2), HistogramBin(4.0, 6.0, 2), HistogramBin(6.0, 8.0, 2), HistogramBin(8.0, 10.0, 2))) - val colStatIntHgm = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt)) + val colStatIntHgm = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt)) // column cintSkewHgm has values: 1, 4, 4, 5, 5, 5, 5, 6, 6, 10 with histogram. // Note that cintSkewHgm has a skewed distribution with histogram information built. @@ -108,8 +110,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val hgmIntSkew = Histogram(2.0, Array(HistogramBin(1.0, 4.0, 2), HistogramBin(4.0, 5.0, 2), HistogramBin(5.0, 5.0, 1), HistogramBin(5.0, 6.0, 2), HistogramBin(6.0, 10.0, 2))) - val colStatIntSkewHgm = ColumnStat(distinctCount = 5, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew)) + val colStatIntSkewHgm = ColumnStat(distinctCount = Some(5), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew)) val attributeMap = AttributeMap(Seq( attrInt -> colStatInt, @@ -172,7 +174,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType)) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 3)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(3))), expectedRowCount = 3) } @@ -180,7 +182,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 8)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(8))), expectedRowCount = 8) } @@ -196,23 +198,23 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrInt, Literal(3)), Not(Literal(null, IntegerType)))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 8)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(8))), expectedRowCount = 8) } test("cint = 2") { validateEstimatedStats( Filter(EqualTo(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(1), min = Some(2), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 1) } test("cint <=> 2") { validateEstimatedStats( Filter(EqualNullSafe(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(1), min = Some(2), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 1) } @@ -227,8 +229,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint < 3") { validateEstimatedStats( Filter(LessThan(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } @@ -243,16 +245,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint <= 3") { validateEstimatedStats( Filter(LessThanOrEqual(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } test("cint > 6") { validateEstimatedStats( Filter(GreaterThan(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(5), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 5) } @@ -267,8 +269,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint >= 6") { validateEstimatedStats( Filter(GreaterThanOrEqual(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(5), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 5) } @@ -282,8 +284,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint IS NOT NULL") { validateEstimatedStats( Filter(IsNotNull(attrInt), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 10) } @@ -301,8 +303,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(3), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -310,7 +312,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(EqualTo(attrInt, Literal(3)), EqualTo(attrInt, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 2)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(2))), expectedRowCount = 2) } @@ -318,7 +320,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6)))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 6)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(6))), expectedRowCount = 6) } @@ -326,7 +328,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(Or(LessThanOrEqual(attrInt, Literal(3)), GreaterThan(attrInt, Literal(6)))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 5)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(5))), expectedRowCount = 5) } @@ -342,47 +344,47 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(Or(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8")))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 9), - attrString -> colStatString.copy(distinctCount = 9)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(9)), + attrString -> colStatString.copy(distinctCount = Some(9))), expectedRowCount = 9) } test("cint IN (3, 4, 5)") { validateEstimatedStats( Filter(InSet(attrInt, Set(3, 4, 5)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(3), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 7)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(7))), expectedRowCount = 7) } test("cbool IN (true)") { validateEstimatedStats( Filter(InSet(attrBool, Set(true)), childStatsTestPlan(Seq(attrBool), 10L)), - Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1)), + Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(true), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))), expectedRowCount = 5) } test("cbool = true") { validateEstimatedStats( Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)), - Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1)), + Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(true), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))), expectedRowCount = 5) } test("cbool > false") { validateEstimatedStats( Filter(GreaterThan(attrBool, Literal(false)), childStatsTestPlan(Seq(attrBool), 10L)), - Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(false), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1)), + Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(false), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))), expectedRowCount = 5) } @@ -391,18 +393,21 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(EqualTo(attrDate, Literal(d20170102, DateType)), childStatsTestPlan(Seq(attrDate), 10L)), - Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrDate -> ColumnStat(distinctCount = Some(1), + min = Some(d20170102), max = Some(d20170102), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 1) } test("cdate < cast('2017-01-03' AS DATE)") { + val d20170101 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01")) val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03")) validateEstimatedStats( Filter(LessThan(attrDate, Literal(d20170103, DateType)), childStatsTestPlan(Seq(attrDate), 10L)), - Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(dMin), max = Some(d20170103), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrDate -> ColumnStat(distinctCount = Some(3), + min = Some(d20170101), max = Some(d20170103), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } @@ -414,8 +419,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(In(attrDate, Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), Literal(d20170105, DateType))), childStatsTestPlan(Seq(attrDate), 10L)), - Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrDate -> ColumnStat(distinctCount = Some(3), + min = Some(d20170103), max = Some(d20170105), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } @@ -424,42 +430,45 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(EqualTo(attrDecimal, Literal(dec_0_40)), childStatsTestPlan(Seq(attrDecimal), 4L)), - Seq(attrDecimal -> ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), - nullCount = 0, avgLen = 8, maxLen = 8)), + Seq(attrDecimal -> ColumnStat(distinctCount = Some(1), + min = Some(dec_0_40), max = Some(dec_0_40), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))), expectedRowCount = 1) } test("cdecimal < 0.60 ") { + val dec_0_20 = Decimal("0.200000000000000000") val dec_0_60 = Decimal("0.600000000000000000") validateEstimatedStats( Filter(LessThan(attrDecimal, Literal(dec_0_60)), childStatsTestPlan(Seq(attrDecimal), 4L)), - Seq(attrDecimal -> ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), - nullCount = 0, avgLen = 8, maxLen = 8)), + Seq(attrDecimal -> ColumnStat(distinctCount = Some(3), + min = Some(dec_0_20), max = Some(dec_0_60), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))), expectedRowCount = 3) } test("cdouble < 3.0") { validateEstimatedStats( Filter(LessThan(attrDouble, Literal(3.0)), childStatsTestPlan(Seq(attrDouble), 10L)), - Seq(attrDouble -> ColumnStat(distinctCount = 3, min = Some(1.0), max = Some(3.0), - nullCount = 0, avgLen = 8, maxLen = 8)), + Seq(attrDouble -> ColumnStat(distinctCount = Some(3), min = Some(1.0), max = Some(3.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))), expectedRowCount = 3) } test("cstring = 'A2'") { validateEstimatedStats( Filter(EqualTo(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)), - Seq(attrString -> ColumnStat(distinctCount = 1, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2)), + Seq(attrString -> ColumnStat(distinctCount = Some(1), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))), expectedRowCount = 1) } test("cstring < 'A2' - unsupported condition") { validateEstimatedStats( Filter(LessThan(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)), - Seq(attrString -> ColumnStat(distinctCount = 10, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2)), + Seq(attrString -> ColumnStat(distinctCount = Some(10), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))), expectedRowCount = 10) } @@ -468,8 +477,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // valid values in IN clause is greater than the number of distinct values for a given column. // For example, column has only 2 distinct values 1 and 6. // The predicate is: column IN (1, 2, 3, 4, 5). - val cornerChildColStatInt = ColumnStat(distinctCount = 2, min = Some(1), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4) + val cornerChildColStatInt = ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val cornerChildStatsTestplan = StatsTestPlan( outputList = Seq(attrInt), rowCount = 2L, @@ -477,16 +487,17 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) validateEstimatedStats( Filter(InSet(attrInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), - Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 2) } // This is a limitation test. We should remove it after the limitation is removed. test("don't estimate IsNull or IsNotNull if the child is a non-leaf node") { val attrIntLargerRange = AttributeReference("c1", IntegerType)() - val colStatIntLargerRange = ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), - nullCount = 10, avgLen = 4, maxLen = 4) + val colStatIntLargerRange = ColumnStat(distinctCount = Some(20), + min = Some(1), max = Some(20), + nullCount = Some(10), avgLen = Some(4), maxLen = Some(4)) val smallerTable = childStatsTestPlan(Seq(attrInt), 10L) val largerTable = StatsTestPlan( outputList = Seq(attrIntLargerRange), @@ -508,10 +519,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(EqualTo(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt2 -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -519,10 +530,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(GreaterThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt2 -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -530,10 +541,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(LessThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(16), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt2 -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(16), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -541,10 +552,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // complete overlap case validateEstimatedStats( Filter(EqualTo(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt4 -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt4 -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 10) } @@ -552,10 +563,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(LessThan(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt4 -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt4 -> ColumnStat(distinctCount = Some(4), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -571,10 +582,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // all table records qualify. validateEstimatedStats( Filter(LessThan(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt3 -> ColumnStat(distinctCount = 10, min = Some(30), max = Some(39), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt3 -> ColumnStat(distinctCount = Some(10), min = Some(30), max = Some(39), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 10) } @@ -592,11 +603,11 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt, attrInt4, attrString), 10L)), Seq( - attrInt -> ColumnStat(distinctCount = 5, min = Some(3), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt4 -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4), - attrString -> colStatString.copy(distinctCount = 5)), + attrInt -> ColumnStat(distinctCount = Some(5), min = Some(3), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt4 -> ColumnStat(distinctCount = Some(5), min = Some(1), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrString -> colStatString.copy(distinctCount = Some(5))), expectedRowCount = 5) } @@ -606,15 +617,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrIntHgm, Literal(3)), Literal(null, IntegerType))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = 7)), + Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = Some(7))), expectedRowCount = 7) } test("cintHgm = 5") { validateEstimatedStats( Filter(EqualTo(attrIntHgm, Literal(5)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 1, min = Some(5), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(1), min = Some(5), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 1) } @@ -629,8 +640,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintHgm < 3") { validateEstimatedStats( Filter(LessThan(attrIntHgm, Literal(3)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 3) } @@ -645,16 +656,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintHgm <= 3") { validateEstimatedStats( Filter(LessThanOrEqual(attrIntHgm, Literal(3)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 3) } test("cintHgm > 6") { validateEstimatedStats( Filter(GreaterThan(attrIntHgm, Literal(6)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(4), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 4) } @@ -669,8 +680,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintHgm >= 6") { validateEstimatedStats( Filter(GreaterThanOrEqual(attrIntHgm, Literal(6)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(5), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 5) } @@ -679,8 +690,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Literal(3)), LessThanOrEqual(attrIntHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(4), min = Some(3), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 4) } @@ -688,7 +699,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(EqualTo(attrIntHgm, Literal(3)), EqualTo(attrIntHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = 3)), + Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = Some(3))), expectedRowCount = 3) } @@ -698,15 +709,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrIntSkewHgm, Literal(3)), Literal(null, IntegerType))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = 5)), + Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = Some(5))), expectedRowCount = 9) } test("cintSkewHgm = 5") { validateEstimatedStats( Filter(EqualTo(attrIntSkewHgm, Literal(5)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(5), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(5), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 4) } @@ -721,8 +732,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintSkewHgm < 3") { validateEstimatedStats( Filter(LessThan(attrIntSkewHgm, Literal(3)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 2) } @@ -738,16 +749,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(LessThanOrEqual(attrIntSkewHgm, Literal(3)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 2) } test("cintSkewHgm > 6") { validateEstimatedStats( Filter(GreaterThan(attrIntSkewHgm, Literal(6)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 2) } @@ -764,8 +775,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(GreaterThanOrEqual(attrIntSkewHgm, Literal(6)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 2, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(2), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 3) } @@ -774,8 +785,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Literal(3)), LessThanOrEqual(attrIntSkewHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(4), min = Some(3), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 8) } @@ -783,7 +794,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(EqualTo(attrIntSkewHgm, Literal(3)), EqualTo(attrIntSkewHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = 2)), + Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = Some(2))), expectedRowCount = 3) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index 26139d85d25fb..12c0a7be21292 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -33,16 +33,16 @@ class JoinEstimationSuite extends StatsEstimationTestBase { /** Set up tables and its columns for testing */ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - attr("key-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-5-9") -> ColumnStat(distinctCount = 5, min = Some(5), max = Some(9), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-2-4") -> ColumnStat(distinctCount = 3, min = Some(2), max = Some(4), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-2-3") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + attr("key-1-5") -> ColumnStat(distinctCount = Some(5), min = Some(1), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-5-9") -> ColumnStat(distinctCount = Some(5), min = Some(5), max = Some(9), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-1-2") -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-2-4") -> ColumnStat(distinctCount = Some(3), min = Some(2), max = Some(4), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-2-3") -> ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) @@ -70,8 +70,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { private def estimateByHistogram( leftHistogram: Histogram, rightHistogram: Histogram, - expectedMin: Double, - expectedMax: Double, + expectedMin: Any, + expectedMax: Any, expectedNdv: Long, expectedRows: Long): Unit = { val col1 = attr("key1") @@ -86,9 +86,11 @@ class JoinEstimationSuite extends StatsEstimationTestBase { rowCount = Some(expectedRows), attributeStats = AttributeMap(Seq( col1 -> c1.stats.attributeStats(col1).copy( - distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax)), + distinctCount = Some(expectedNdv), + min = Some(expectedMin), max = Some(expectedMax)), col2 -> c2.stats.attributeStats(col2).copy( - distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax)))) + distinctCount = Some(expectedNdv), + min = Some(expectedMin), max = Some(expectedMax)))) ) // Join order should not affect estimation result. @@ -100,9 +102,9 @@ class JoinEstimationSuite extends StatsEstimationTestBase { private def generateJoinChild( col: Attribute, histogram: Histogram, - expectedMin: Double, - expectedMax: Double): LogicalPlan = { - val colStat = inferColumnStat(histogram) + expectedMin: Any, + expectedMax: Any): LogicalPlan = { + val colStat = inferColumnStat(histogram, expectedMin, expectedMax) StatsTestPlan( outputList = Seq(col), rowCount = (histogram.height * histogram.bins.length).toLong, @@ -110,7 +112,11 @@ class JoinEstimationSuite extends StatsEstimationTestBase { } /** Column statistics should be consistent with histograms in tests. */ - private def inferColumnStat(histogram: Histogram): ColumnStat = { + private def inferColumnStat( + histogram: Histogram, + expectedMin: Any, + expectedMax: Any): ColumnStat = { + var ndv = 0L for (i <- histogram.bins.indices) { val bin = histogram.bins(i) @@ -118,8 +124,9 @@ class JoinEstimationSuite extends StatsEstimationTestBase { ndv += bin.ndv } } - ColumnStat(distinctCount = ndv, min = Some(histogram.bins.head.lo), - max = Some(histogram.bins.last.hi), nullCount = 0, avgLen = 4, maxLen = 4, + ColumnStat(distinctCount = Some(ndv), + min = Some(expectedMin), max = Some(expectedMax), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(histogram)) } @@ -343,10 +350,10 @@ class JoinEstimationSuite extends StatsEstimationTestBase { rowCount = Some(5 + 3), attributeStats = AttributeMap( // Update null count in column stats. - Seq(nameToAttr("key-1-5") -> columnInfo(nameToAttr("key-1-5")).copy(nullCount = 3), - nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = 3), - nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = 5), - nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = 5)))) + Seq(nameToAttr("key-1-5") -> columnInfo(nameToAttr("key-1-5")).copy(nullCount = Some(3)), + nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = Some(3)), + nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = Some(5)), + nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = Some(5))))) assert(join.stats == expectedStats) } @@ -356,11 +363,11 @@ class JoinEstimationSuite extends StatsEstimationTestBase { val join = Join(table1, table2, Inner, Some(EqualTo(nameToAttr("key-1-5"), nameToAttr("key-1-2")))) // Update column stats for equi-join keys (key-1-5 and key-1-2). - val joinedColStat = ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // Update column stat for other column if #outputRow / #sideRow < 1 (key-5-9), or keep it // unchanged (key-2-4). - val colStatForkey59 = nameToColInfo("key-5-9")._2.copy(distinctCount = 5 * 3 / 5) + val colStatForkey59 = nameToColInfo("key-5-9")._2.copy(distinctCount = Some(5 * 3 / 5)) val expectedStats = Statistics( sizeInBytes = 3 * (8 + 4 * 4), @@ -379,10 +386,10 @@ class JoinEstimationSuite extends StatsEstimationTestBase { EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))))) // Update column stats for join keys. - val joinedColStat1 = ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4) - val joinedColStat2 = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat1 = ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) + val joinedColStat2 = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val expectedStats = Statistics( sizeInBytes = 2 * (8 + 4 * 4), @@ -398,8 +405,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table3, table2, LeftOuter, Some(EqualTo(nameToAttr("key-2-3"), nameToAttr("key-2-4")))) - val joinedColStat = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val expectedStats = Statistics( sizeInBytes = 2 * (8 + 4 * 4), @@ -416,8 +423,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table2, table3, RightOuter, Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))) - val joinedColStat = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val expectedStats = Statistics( sizeInBytes = 2 * (8 + 4 * 4), @@ -466,30 +473,40 @@ class JoinEstimationSuite extends StatsEstimationTestBase { val date = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08")) val timestamp = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01")) mutable.LinkedHashMap[Attribute, ColumnStat]( - AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 1, - min = Some(false), max = Some(false), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 1, - min = Some(1.toByte), max = Some(1.toByte), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 1, - min = Some(1.toShort), max = Some(1.toShort), nullCount = 0, avgLen = 2, maxLen = 2), - AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 1, - min = Some(1), max = Some(1), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 1, - min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 1, - min = Some(1.0f), max = Some(1.0f), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 1, - min = Some(dec), max = Some(dec), nullCount = 0, avgLen = 16, maxLen = 16), - AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 1, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = 1, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = 1, - min = Some(date), max = Some(date), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 1, - min = Some(timestamp), max = Some(timestamp), nullCount = 0, avgLen = 8, maxLen = 8) + AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = Some(1), + min = Some(false), max = Some(false), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.toByte), max = Some(1.toByte), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.toShort), max = Some(1.toShort), + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)), + AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1), max = Some(1), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1L), max = Some(1L), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.0), max = Some(1.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.0f), max = Some(1.0f), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat( + distinctCount = Some(1), min = Some(dec), max = Some(dec), + nullCount = Some(0), avgLen = Some(16), maxLen = Some(16)), + AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = Some(1), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = Some(1), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = Some(1), + min = Some(date), max = Some(date), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = Some(1), + min = Some(timestamp), max = Some(timestamp), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) ) } @@ -520,7 +537,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { test("join with null column") { val (nullColumn, nullColStat) = (attr("cnull"), - ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 1, avgLen = 4, maxLen = 4)) + ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(1), avgLen = Some(4), maxLen = Some(4))) val nullTable = StatsTestPlan( outputList = Seq(nullColumn), rowCount = 1, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala index cda54fa9d64f4..dcb37017329fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala @@ -28,10 +28,10 @@ import org.apache.spark.sql.types._ class ProjectEstimationSuite extends StatsEstimationTestBase { test("project with alias") { - val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = 2, min = Some(1), - max = Some(2), nullCount = 0, avgLen = 4, maxLen = 4)) - val (ar2, colStat2) = (attr("key2"), ColumnStat(distinctCount = 1, min = Some(10), - max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4)) + val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = Some(2), min = Some(1), + max = Some(2), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))) + val (ar2, colStat2) = (attr("key2"), ColumnStat(distinctCount = Some(1), min = Some(10), + max = Some(10), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))) val child = StatsTestPlan( outputList = Seq(ar1, ar2), @@ -49,8 +49,8 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { } test("project on empty table") { - val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = 0, min = None, max = None, - nullCount = 0, avgLen = 4, maxLen = 4)) + val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))) val child = StatsTestPlan( outputList = Seq(ar1), rowCount = 0, @@ -71,30 +71,40 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { val t2 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-09 00:00:02")) val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 2, - min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 2, - min = Some(1.toByte), max = Some(2.toByte), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 2, - min = Some(1.toShort), max = Some(3.toShort), nullCount = 0, avgLen = 2, maxLen = 2), - AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 2, - min = Some(1), max = Some(4), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(5L), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 2, - min = Some(1.0), max = Some(6.0), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 2, - min = Some(1.0f), max = Some(7.0f), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 2, - min = Some(dec1), max = Some(dec2), nullCount = 0, avgLen = 16, maxLen = 16), - AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 2, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = 2, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = 2, - min = Some(d1), max = Some(d2), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 2, - min = Some(t1), max = Some(t2), nullCount = 0, avgLen = 8, maxLen = 8) + AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = Some(2), + min = Some(false), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)), + AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(4), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(5), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1.0), max = Some(6.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1.0), max = Some(7.0), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat( + distinctCount = Some(2), min = Some(dec1), max = Some(dec2), + nullCount = Some(0), avgLen = Some(16), maxLen = Some(16)), + AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = Some(2), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = Some(2), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = Some(2), + min = Some(d1), max = Some(d2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = Some(2), + min = Some(t1), max = Some(t2), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) )) val columnSizes: Map[Attribute, Long] = columnInfo.map(kv => (kv._1, getColSize(kv._1, kv._2))) val child = StatsTestPlan( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index 31dea2e3e7f1d..9dceca59f5b87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -42,8 +42,8 @@ trait StatsEstimationTestBase extends SparkFunSuite { def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match { // For UTF8String: base + offset + numBytes - case StringType => colStat.avgLen + 8 + 4 - case _ => colStat.avgLen + case StringType => colStat.avgLen.getOrElse(attribute.dataType.defaultSize.toLong) + 8 + 4 + case _ => colStat.avgLen.getOrElse(attribute.dataType.defaultSize) } def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)() @@ -54,6 +54,12 @@ trait StatsEstimationTestBase extends SparkFunSuite { val nameToAttr: Map[String, Attribute] = plan.output.map(a => (a.name, a)).toMap AttributeMap(colStats.map(kv => nameToAttr(kv._1) -> kv._2)) } + + /** Get a test ColumnStat with given distinctCount and nullCount */ + def rangeColumnStat(distinctCount: Int, nullCount: Int): ColumnStat = + ColumnStat(distinctCount = Some(distinctCount), + min = Some(1), max = Some(distinctCount), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 1122522ccb4cb..640e01336aa75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.command import scala.collection.mutable import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableType} +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, CatalogTableType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ /** @@ -64,12 +66,12 @@ case class AnalyzeColumnCommand( /** * Compute stats for the given columns. - * @return (row count, map from column name to ColumnStats) + * @return (row count, map from column name to CatalogColumnStats) */ private def computeColumnStats( sparkSession: SparkSession, tableIdent: TableIdentifier, - columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = { + columnNames: Seq[String]): (Long, Map[String, CatalogColumnStat]) = { val conf = sparkSession.sessionState.conf val relation = sparkSession.table(tableIdent).logicalPlan @@ -81,7 +83,7 @@ case class AnalyzeColumnCommand( // Make sure the column types are supported for stats gathering. attributesToAnalyze.foreach { attr => - if (!ColumnStat.supportsType(attr.dataType)) { + if (!supportsType(attr.dataType)) { throw new AnalysisException( s"Column ${attr.name} in table $tableIdent is of type ${attr.dataType}, " + "and Spark does not support statistics collection on this column type.") @@ -103,7 +105,7 @@ case class AnalyzeColumnCommand( // will be structs containing all column stats. // The layout of each struct follows the layout of the ColumnStats. val expressions = Count(Literal(1)).toAggregateExpression() +: - attributesToAnalyze.map(ColumnStat.statExprs(_, conf, attributePercentiles)) + attributesToAnalyze.map(statExprs(_, conf, attributePercentiles)) val namedExpressions = expressions.map(e => Alias(e, e.toString)()) val statsRow = new QueryExecution(sparkSession, Aggregate(Nil, namedExpressions, relation)) @@ -111,9 +113,9 @@ case class AnalyzeColumnCommand( val rowCount = statsRow.getLong(0) val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) => - // according to `ColumnStat.statExprs`, the stats struct always have 7 fields. - (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1, 7), attr, rowCount, - attributePercentiles.get(attr))) + // according to `statExprs`, the stats struct always have 7 fields. + (attr.name, rowToColumnStat(statsRow.getStruct(i + 1, 7), attr, rowCount, + attributePercentiles.get(attr)).toCatalogColumnStat(attr.name, attr.dataType)) }.toMap (rowCount, columnStats) } @@ -124,7 +126,7 @@ case class AnalyzeColumnCommand( sparkSession: SparkSession, relation: LogicalPlan): AttributeMap[ArrayData] = { val attrsToGenHistogram = if (conf.histogramEnabled) { - attributesToAnalyze.filter(a => ColumnStat.supportsHistogram(a.dataType)) + attributesToAnalyze.filter(a => supportsHistogram(a.dataType)) } else { Nil } @@ -154,4 +156,120 @@ case class AnalyzeColumnCommand( AttributeMap(attributePercentiles.toSeq) } + /** Returns true iff the we support gathering column statistics on column of the given type. */ + private def supportsType(dataType: DataType): Boolean = dataType match { + case _: IntegralType => true + case _: DecimalType => true + case DoubleType | FloatType => true + case BooleanType => true + case DateType => true + case TimestampType => true + case BinaryType | StringType => true + case _ => false + } + + /** Returns true iff the we support gathering histogram on column of the given type. */ + private def supportsHistogram(dataType: DataType): Boolean = dataType match { + case _: IntegralType => true + case _: DecimalType => true + case DoubleType | FloatType => true + case DateType => true + case TimestampType => true + case _ => false + } + + /** + * Constructs an expression to compute column statistics for a given column. + * + * The expression should create a single struct column with the following schema: + * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long, + * distinctCountsForIntervals: Array[Long] + * + * Together with [[rowToColumnStat]], this function is used to create [[ColumnStat]] and + * as a result should stay in sync with it. + */ + private def statExprs( + col: Attribute, + conf: SQLConf, + colPercentiles: AttributeMap[ArrayData]): CreateNamedStruct = { + def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr => + expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() } + }) + val one = Literal(1, LongType) + + // the approximate ndv (num distinct value) should never be larger than the number of rows + val numNonNulls = if (col.nullable) Count(col) else Count(one) + val ndv = Least(Seq(HyperLogLogPlusPlus(col, conf.ndvMaxError), numNonNulls)) + val numNulls = Subtract(Count(one), numNonNulls) + val defaultSize = Literal(col.dataType.defaultSize, LongType) + val nullArray = Literal(null, ArrayType(LongType)) + + def fixedLenTypeStruct: CreateNamedStruct = { + val genHistogram = + supportsHistogram(col.dataType) && colPercentiles.contains(col) + val intervalNdvsExpr = if (genHistogram) { + ApproxCountDistinctForIntervals(col, + Literal(colPercentiles(col), ArrayType(col.dataType)), conf.ndvMaxError) + } else { + nullArray + } + // For fixed width types, avg size should be the same as max size. + struct(ndv, Cast(Min(col), col.dataType), Cast(Max(col), col.dataType), numNulls, + defaultSize, defaultSize, intervalNdvsExpr) + } + + col.dataType match { + case _: IntegralType => fixedLenTypeStruct + case _: DecimalType => fixedLenTypeStruct + case DoubleType | FloatType => fixedLenTypeStruct + case BooleanType => fixedLenTypeStruct + case DateType => fixedLenTypeStruct + case TimestampType => fixedLenTypeStruct + case BinaryType | StringType => + // For string and binary type, we don't compute min, max or histogram + val nullLit = Literal(null, col.dataType) + struct( + ndv, nullLit, nullLit, numNulls, + // Set avg/max size to default size if all the values are null or there is no value. + Coalesce(Seq(Ceil(Average(Length(col))), defaultSize)), + Coalesce(Seq(Cast(Max(Length(col)), LongType), defaultSize)), + nullArray) + case _ => + throw new AnalysisException("Analyzing column statistics is not supported for column " + + s"${col.name} of data type: ${col.dataType}.") + } + } + + /** Convert a struct for column stats (defined in `statExprs`) into [[ColumnStat]]. */ + private def rowToColumnStat( + row: InternalRow, + attr: Attribute, + rowCount: Long, + percentiles: Option[ArrayData]): ColumnStat = { + // The first 6 fields are basic column stats, the 7th is ndvs for histogram bins. + val cs = ColumnStat( + distinctCount = Option(BigInt(row.getLong(0))), + // for string/binary min/max, get should return null + min = Option(row.get(1, attr.dataType)), + max = Option(row.get(2, attr.dataType)), + nullCount = Option(BigInt(row.getLong(3))), + avgLen = Option(row.getLong(4)), + maxLen = Option(row.getLong(5)) + ) + if (row.isNullAt(6) || cs.nullCount.isEmpty) { + cs + } else { + val ndvs = row.getArray(6).toLongArray() + assert(percentiles.get.numElements() == ndvs.length + 1) + val endpoints = percentiles.get.toArray[Any](attr.dataType).map(_.toString.toDouble) + // Construct equi-height histogram + val bins = ndvs.zipWithIndex.map { case (ndv, i) => + HistogramBin(endpoints(i), endpoints(i + 1), ndv) + } + val nonNullRows = rowCount - cs.nullCount.get + val histogram = Histogram(nonNullRows.toDouble / ndvs.length, bins) + cs.copy(histogram = Some(histogram)) + } + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index e400975f19708..44749190c79eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -695,10 +695,11 @@ case class DescribeColumnCommand( // Show column stats when EXTENDED or FORMATTED is specified. buffer += Row("min", cs.flatMap(_.min.map(_.toString)).getOrElse("NULL")) buffer += Row("max", cs.flatMap(_.max.map(_.toString)).getOrElse("NULL")) - buffer += Row("num_nulls", cs.map(_.nullCount.toString).getOrElse("NULL")) - buffer += Row("distinct_count", cs.map(_.distinctCount.toString).getOrElse("NULL")) - buffer += Row("avg_col_len", cs.map(_.avgLen.toString).getOrElse("NULL")) - buffer += Row("max_col_len", cs.map(_.maxLen.toString).getOrElse("NULL")) + buffer += Row("num_nulls", cs.flatMap(_.nullCount.map(_.toString)).getOrElse("NULL")) + buffer += Row("distinct_count", + cs.flatMap(_.distinctCount.map(_.toString)).getOrElse("NULL")) + buffer += Row("avg_col_len", cs.flatMap(_.avgLen.map(_.toString)).getOrElse("NULL")) + buffer += Row("max_col_len", cs.flatMap(_.maxLen.map(_.toString)).getOrElse("NULL")) val histDesc = for { c <- cs hist <- c.histogram diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index b11e798532056..ed4ea0231f1a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import scala.collection.mutable import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogColumnStat import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -95,7 +96,8 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared assert(fetchedStats2.get.sizeInBytes == 0) val expectedColStat = - "key" -> ColumnStat(0, None, None, 0, IntegerType.defaultSize, IntegerType.defaultSize) + "key" -> CatalogColumnStat(Some(0), None, None, Some(0), + Some(IntegerType.defaultSize), Some(IntegerType.defaultSize)) // There won't be histogram for empty column. Seq("true", "false").foreach { histogramEnabled => @@ -156,7 +158,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared Seq(stats, statsWithHgms).foreach { s => s.zip(df.schema).foreach { case ((k, v), field) => withClue(s"column $k with type ${field.dataType}") { - val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap(k, field.dataType)) + val roundtrip = CatalogColumnStat.fromMap("table_is_foo", field.name, v.toMap(k)) assert(roundtrip == Some(v)) } } @@ -187,7 +189,8 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared }.mkString(", ")) val expectedColStats = dataTypes.map { case (tpe, idx) => - (s"col$idx", ColumnStat(0, None, None, 1, tpe.defaultSize.toLong, tpe.defaultSize.toLong)) + (s"col$idx", CatalogColumnStat(Some(0), None, None, Some(1), + Some(tpe.defaultSize.toLong), Some(tpe.defaultSize.toLong))) } // There won't be histograms for null columns. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala index 65ccc1915882f..bf4abb6e625c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala @@ -24,8 +24,8 @@ import scala.collection.mutable import scala.util.Random import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, HiveTableRelation} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, HistogramBin, LogicalPlan} +import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, CatalogTable, HiveTableRelation} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, HistogramBin, HistogramSerializer, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} @@ -67,18 +67,21 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils /** A mapping from column to the stats collected. */ protected val stats = mutable.LinkedHashMap( - "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1), - "cbyte" -> ColumnStat(2, Some(1.toByte), Some(2.toByte), 1, 1, 1), - "cshort" -> ColumnStat(2, Some(1.toShort), Some(3.toShort), 1, 2, 2), - "cint" -> ColumnStat(2, Some(1), Some(4), 1, 4, 4), - "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8), - "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8), - "cfloat" -> ColumnStat(2, Some(1.0f), Some(7.0f), 1, 4, 4), - "cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16), - "cstring" -> ColumnStat(2, None, None, 1, 3, 3), - "cbinary" -> ColumnStat(2, None, None, 1, 3, 3), - "cdate" -> ColumnStat(2, Some(d1Internal), Some(d2Internal), 1, 4, 4), - "ctimestamp" -> ColumnStat(2, Some(t1Internal), Some(t2Internal), 1, 8, 8) + "cbool" -> CatalogColumnStat(Some(2), Some("false"), Some("true"), Some(1), Some(1), Some(1)), + "cbyte" -> CatalogColumnStat(Some(2), Some("1"), Some("2"), Some(1), Some(1), Some(1)), + "cshort" -> CatalogColumnStat(Some(2), Some("1"), Some("3"), Some(1), Some(2), Some(2)), + "cint" -> CatalogColumnStat(Some(2), Some("1"), Some("4"), Some(1), Some(4), Some(4)), + "clong" -> CatalogColumnStat(Some(2), Some("1"), Some("5"), Some(1), Some(8), Some(8)), + "cdouble" -> CatalogColumnStat(Some(2), Some("1.0"), Some("6.0"), Some(1), Some(8), Some(8)), + "cfloat" -> CatalogColumnStat(Some(2), Some("1.0"), Some("7.0"), Some(1), Some(4), Some(4)), + "cdecimal" -> CatalogColumnStat(Some(2), Some(dec1.toString), Some(dec2.toString), Some(1), + Some(16), Some(16)), + "cstring" -> CatalogColumnStat(Some(2), None, None, Some(1), Some(3), Some(3)), + "cbinary" -> CatalogColumnStat(Some(2), None, None, Some(1), Some(3), Some(3)), + "cdate" -> CatalogColumnStat(Some(2), Some(d1.toString), Some(d2.toString), Some(1), Some(4), + Some(4)), + "ctimestamp" -> CatalogColumnStat(Some(2), Some(t1.toString), Some(t2.toString), Some(1), + Some(8), Some(8)) ) /** @@ -110,6 +113,110 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils colStats } + val expectedSerializedColStats = Map( + "spark.sql.statistics.colStats.cbinary.avgLen" -> "3", + "spark.sql.statistics.colStats.cbinary.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbinary.maxLen" -> "3", + "spark.sql.statistics.colStats.cbinary.nullCount" -> "1", + "spark.sql.statistics.colStats.cbinary.version" -> "1", + "spark.sql.statistics.colStats.cbool.avgLen" -> "1", + "spark.sql.statistics.colStats.cbool.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbool.max" -> "true", + "spark.sql.statistics.colStats.cbool.maxLen" -> "1", + "spark.sql.statistics.colStats.cbool.min" -> "false", + "spark.sql.statistics.colStats.cbool.nullCount" -> "1", + "spark.sql.statistics.colStats.cbool.version" -> "1", + "spark.sql.statistics.colStats.cbyte.avgLen" -> "1", + "spark.sql.statistics.colStats.cbyte.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbyte.max" -> "2", + "spark.sql.statistics.colStats.cbyte.maxLen" -> "1", + "spark.sql.statistics.colStats.cbyte.min" -> "1", + "spark.sql.statistics.colStats.cbyte.nullCount" -> "1", + "spark.sql.statistics.colStats.cbyte.version" -> "1", + "spark.sql.statistics.colStats.cdate.avgLen" -> "4", + "spark.sql.statistics.colStats.cdate.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdate.max" -> "2016-05-09", + "spark.sql.statistics.colStats.cdate.maxLen" -> "4", + "spark.sql.statistics.colStats.cdate.min" -> "2016-05-08", + "spark.sql.statistics.colStats.cdate.nullCount" -> "1", + "spark.sql.statistics.colStats.cdate.version" -> "1", + "spark.sql.statistics.colStats.cdecimal.avgLen" -> "16", + "spark.sql.statistics.colStats.cdecimal.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdecimal.max" -> "8.000000000000000000", + "spark.sql.statistics.colStats.cdecimal.maxLen" -> "16", + "spark.sql.statistics.colStats.cdecimal.min" -> "1.000000000000000000", + "spark.sql.statistics.colStats.cdecimal.nullCount" -> "1", + "spark.sql.statistics.colStats.cdecimal.version" -> "1", + "spark.sql.statistics.colStats.cdouble.avgLen" -> "8", + "spark.sql.statistics.colStats.cdouble.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdouble.max" -> "6.0", + "spark.sql.statistics.colStats.cdouble.maxLen" -> "8", + "spark.sql.statistics.colStats.cdouble.min" -> "1.0", + "spark.sql.statistics.colStats.cdouble.nullCount" -> "1", + "spark.sql.statistics.colStats.cdouble.version" -> "1", + "spark.sql.statistics.colStats.cfloat.avgLen" -> "4", + "spark.sql.statistics.colStats.cfloat.distinctCount" -> "2", + "spark.sql.statistics.colStats.cfloat.max" -> "7.0", + "spark.sql.statistics.colStats.cfloat.maxLen" -> "4", + "spark.sql.statistics.colStats.cfloat.min" -> "1.0", + "spark.sql.statistics.colStats.cfloat.nullCount" -> "1", + "spark.sql.statistics.colStats.cfloat.version" -> "1", + "spark.sql.statistics.colStats.cint.avgLen" -> "4", + "spark.sql.statistics.colStats.cint.distinctCount" -> "2", + "spark.sql.statistics.colStats.cint.max" -> "4", + "spark.sql.statistics.colStats.cint.maxLen" -> "4", + "spark.sql.statistics.colStats.cint.min" -> "1", + "spark.sql.statistics.colStats.cint.nullCount" -> "1", + "spark.sql.statistics.colStats.cint.version" -> "1", + "spark.sql.statistics.colStats.clong.avgLen" -> "8", + "spark.sql.statistics.colStats.clong.distinctCount" -> "2", + "spark.sql.statistics.colStats.clong.max" -> "5", + "spark.sql.statistics.colStats.clong.maxLen" -> "8", + "spark.sql.statistics.colStats.clong.min" -> "1", + "spark.sql.statistics.colStats.clong.nullCount" -> "1", + "spark.sql.statistics.colStats.clong.version" -> "1", + "spark.sql.statistics.colStats.cshort.avgLen" -> "2", + "spark.sql.statistics.colStats.cshort.distinctCount" -> "2", + "spark.sql.statistics.colStats.cshort.max" -> "3", + "spark.sql.statistics.colStats.cshort.maxLen" -> "2", + "spark.sql.statistics.colStats.cshort.min" -> "1", + "spark.sql.statistics.colStats.cshort.nullCount" -> "1", + "spark.sql.statistics.colStats.cshort.version" -> "1", + "spark.sql.statistics.colStats.cstring.avgLen" -> "3", + "spark.sql.statistics.colStats.cstring.distinctCount" -> "2", + "spark.sql.statistics.colStats.cstring.maxLen" -> "3", + "spark.sql.statistics.colStats.cstring.nullCount" -> "1", + "spark.sql.statistics.colStats.cstring.version" -> "1", + "spark.sql.statistics.colStats.ctimestamp.avgLen" -> "8", + "spark.sql.statistics.colStats.ctimestamp.distinctCount" -> "2", + "spark.sql.statistics.colStats.ctimestamp.max" -> "2016-05-09 00:00:02.0", + "spark.sql.statistics.colStats.ctimestamp.maxLen" -> "8", + "spark.sql.statistics.colStats.ctimestamp.min" -> "2016-05-08 00:00:01.0", + "spark.sql.statistics.colStats.ctimestamp.nullCount" -> "1", + "spark.sql.statistics.colStats.ctimestamp.version" -> "1" + ) + + val expectedSerializedHistograms = Map( + "spark.sql.statistics.colStats.cbyte.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cbyte").histogram.get), + "spark.sql.statistics.colStats.cshort.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cshort").histogram.get), + "spark.sql.statistics.colStats.cint.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cint").histogram.get), + "spark.sql.statistics.colStats.clong.histogram" -> + HistogramSerializer.serialize(statsWithHgms("clong").histogram.get), + "spark.sql.statistics.colStats.cdouble.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdouble").histogram.get), + "spark.sql.statistics.colStats.cfloat.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cfloat").histogram.get), + "spark.sql.statistics.colStats.cdecimal.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdecimal").histogram.get), + "spark.sql.statistics.colStats.cdate.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdate").histogram.get), + "spark.sql.statistics.colStats.ctimestamp.histogram" -> + HistogramSerializer.serialize(statsWithHgms("ctimestamp").histogram.get) + ) + private val randomName = new Random(31) def getCatalogTable(tableName: String): CatalogTable = { @@ -151,7 +258,7 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils */ def checkColStats( df: DataFrame, - colStats: mutable.LinkedHashMap[String, ColumnStat]): Unit = { + colStats: mutable.LinkedHashMap[String, CatalogColumnStat]): Unit = { val tableName = "column_stats_test_" + randomName.nextInt(1000) withTable(tableName) { df.write.saveAsTable(tableName) @@ -161,14 +268,24 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils colStats.keys.mkString(", ")) // Validate statistics - val table = getCatalogTable(tableName) - assert(table.stats.isDefined) - assert(table.stats.get.colStats.size == colStats.size) - - colStats.foreach { case (k, v) => - withClue(s"column $k") { - assert(table.stats.get.colStats(k) == v) - } + validateColStats(tableName, colStats) + } + } + + /** + * Validate if the given catalog table has the provided statistics. + */ + def validateColStats( + tableName: String, + colStats: mutable.LinkedHashMap[String, CatalogColumnStat]): Unit = { + + val table = getCatalogTable(tableName) + assert(table.stats.isDefined) + assert(table.stats.get.colStats.size == colStats.size) + + colStats.foreach { case (k, v) => + withClue(s"column $k") { + assert(table.stats.get.colStats(k) == v) } } } @@ -215,12 +332,13 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils case catalogRel: HiveTableRelation => (catalogRel, catalogRel.tableMeta) case logicalRel: LogicalRelation => (logicalRel, logicalRel.catalogTable.get) }.head - val emptyColStat = ColumnStat(0, None, None, 0, 4, 4) + val emptyColStat = ColumnStat(Some(0), None, None, Some(0), Some(4), Some(4)) + val emptyCatalogColStat = CatalogColumnStat(Some(0), None, None, Some(0), Some(4), Some(4)) // Check catalog statistics assert(catalogTable.stats.isDefined) assert(catalogTable.stats.get.sizeInBytes == 0) assert(catalogTable.stats.get.rowCount == Some(0)) - assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat)) + assert(catalogTable.stats.get.colStats == Map("c1" -> emptyCatalogColStat)) // Check relation statistics withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 1ee1d57b8ebe1..28c340a176d91 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -663,14 +663,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat requireTableExists(db, table) val rawTable = getRawTable(db, table) - // For datasource tables and hive serde tables created by spark 2.1 or higher, - // the data schema is stored in the table properties. - val schema = restoreTableMetadata(rawTable).schema - // convert table statistics to properties so that we can persist them through hive client val statsProperties = if (stats.isDefined) { - statsToProperties(stats.get, schema) + statsToProperties(stats.get) } else { new mutable.HashMap[String, String]() } @@ -1028,9 +1024,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat currentFullPath } - private def statsToProperties( - stats: CatalogStatistics, - schema: StructType): Map[String, String] = { + private def statsToProperties(stats: CatalogStatistics): Map[String, String] = { val statsProperties = new mutable.HashMap[String, String]() statsProperties += STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString() @@ -1038,11 +1032,12 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() } - val colNameTypeMap: Map[String, DataType] = - schema.fields.map(f => (f.name, f.dataType)).toMap stats.colStats.foreach { case (colName, colStat) => - colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => - statsProperties += (columnStatKeyPropName(colName, k) -> v) + colStat.toMap(colName).foreach { case (k, v) => + // Fully qualified name used in table properties for a particular column stat. + // For example, for column "mycol", and "min" stat, this should return + // "spark.sql.statistics.colStats.mycol.min". + statsProperties += (STATISTICS_COL_STATS_PREFIX + k -> v) } } @@ -1058,23 +1053,20 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat if (statsProps.isEmpty) { None } else { + val colStats = new mutable.HashMap[String, CatalogColumnStat] + val colStatsProps = properties.filterKeys(_.startsWith(STATISTICS_COL_STATS_PREFIX)).map { + case (k, v) => k.drop(STATISTICS_COL_STATS_PREFIX.length) -> v + } - val colStats = new mutable.HashMap[String, ColumnStat] - - // For each column, recover its column stats. Note that this is currently a O(n^2) operation, - // but given the number of columns it usually not enormous, this is probably OK as a start. - // If we want to map this a linear operation, we'd need a stronger contract between the - // naming convention used for serialization. - schema.foreach { field => - if (statsProps.contains(columnStatKeyPropName(field.name, ColumnStat.KEY_VERSION))) { - // If "version" field is defined, then the column stat is defined. - val keyPrefix = columnStatKeyPropName(field.name, "") - val colStatMap = statsProps.filterKeys(_.startsWith(keyPrefix)).map { case (k, v) => - (k.drop(keyPrefix.length), v) - } - ColumnStat.fromMap(table, field, colStatMap).foreach { cs => - colStats += field.name -> cs - } + // Find all the column names by matching the KEY_VERSION properties for them. + colStatsProps.keys.filter { + k => k.endsWith(CatalogColumnStat.KEY_VERSION) + }.map { k => + k.dropRight(CatalogColumnStat.KEY_VERSION.length + 1) + }.foreach { fieldName => + // and for each, create a column stat. + CatalogColumnStat.fromMap(table, fieldName, colStatsProps).foreach { cs => + colStats += fieldName -> cs } } @@ -1093,14 +1085,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val rawTable = getRawTable(db, table) - // For datasource tables and hive serde tables created by spark 2.1 or higher, - // the data schema is stored in the table properties. - val schema = restoreTableMetadata(rawTable).schema - // convert partition statistics to properties so that we can persist them through hive api val withStatsProps = lowerCasedParts.map { p => if (p.stats.isDefined) { - val statsProperties = statsToProperties(p.stats.get, schema) + val statsProperties = statsToProperties(p.stats.get) p.copy(parameters = p.parameters ++ statsProperties) } else { p @@ -1310,15 +1298,6 @@ object HiveExternalCatalog { val EMPTY_DATA_SCHEMA = new StructType() .add("col", "array", nullable = true, comment = "from deserializer") - /** - * Returns the fully qualified name used in table properties for a particular column stat. - * For example, for column "mycol", and "min" stat, this should return - * "spark.sql.statistics.colStats.mycol.min". - */ - private def columnStatKeyPropName(columnName: String, statKey: String): String = { - STATISTICS_COL_STATS_PREFIX + columnName + "." + statKey - } - // A persisted data source table always store its schema in the catalog. private def getSchemaFromTableProperties(metadata: CatalogTable): StructType = { val errorMessage = "Could not read schema from the hive metastore because it is corrupted." diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 3af8af0814bb4..61cec82984795 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, HiveTableRelation} +import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, HiveTableRelation} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, HistogramBin, HistogramSerializer} import org.apache.spark.sql.catalyst.util.{DateTimeUtils, StringUtils} import org.apache.spark.sql.execution.command.DDLUtils @@ -177,8 +177,8 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val fetchedStats0 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)) assert(fetchedStats0.get.colStats == Map( - "a" -> ColumnStat(2, Some(1), Some(2), 0, 4, 4), - "b" -> ColumnStat(1, Some(1), Some(1), 0, 4, 4))) + "a" -> CatalogColumnStat(Some(2), Some("1"), Some("2"), Some(0), Some(4), Some(4)), + "b" -> CatalogColumnStat(Some(1), Some("1"), Some("1"), Some(0), Some(4), Some(4)))) } } @@ -208,8 +208,8 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val fetchedStats1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(1)).get assert(fetchedStats1.colStats == Map( - "C1" -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(1), nullCount = 0, - avgLen = 4, maxLen = 4))) + "C1" -> CatalogColumnStat(distinctCount = Some(1), min = Some("1"), max = Some("1"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)))) } } @@ -596,7 +596,8 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") val fetchedStats0 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) - assert(fetchedStats0.get.colStats == Map("c1" -> ColumnStat(0, None, None, 0, 4, 4))) + assert(fetchedStats0.get.colStats == + Map("c1" -> CatalogColumnStat(Some(0), None, None, Some(0), Some(4), Some(4)))) // Insert new data and analyze: have the latest column stats. sql(s"INSERT INTO TABLE $table SELECT 1, 'a', 10.0") @@ -604,18 +605,18 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val fetchedStats1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(1)).get assert(fetchedStats1.colStats == Map( - "c1" -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(1), nullCount = 0, - avgLen = 4, maxLen = 4))) + "c1" -> CatalogColumnStat(distinctCount = Some(1), min = Some("1"), max = Some("1"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)))) // Analyze another column: since the table is not changed, the precious column stats are kept. sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2") val fetchedStats2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(1)).get assert(fetchedStats2.colStats == Map( - "c1" -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(1), nullCount = 0, - avgLen = 4, maxLen = 4), - "c2" -> ColumnStat(distinctCount = 1, min = None, max = None, nullCount = 0, - avgLen = 1, maxLen = 1))) + "c1" -> CatalogColumnStat(distinctCount = Some(1), min = Some("1"), max = Some("1"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + "c2" -> CatalogColumnStat(distinctCount = Some(1), min = None, max = None, + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)))) // Insert new data and analyze: stale column stats are removed and newly collected column // stats are added. @@ -624,10 +625,10 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val fetchedStats3 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)).get assert(fetchedStats3.colStats == Map( - "c1" -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4), - "c3" -> ColumnStat(distinctCount = 2, min = Some(10.0), max = Some(20.0), nullCount = 0, - avgLen = 8, maxLen = 8))) + "c1" -> CatalogColumnStat(distinctCount = Some(2), min = Some("1"), max = Some("2"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + "c3" -> CatalogColumnStat(distinctCount = Some(2), min = Some("10.0"), max = Some("20.0"), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)))) } } @@ -999,115 +1000,11 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("verify serialized column stats after analyzing columns") { import testImplicits._ - val tableName = "column_stats_test2" + val tableName = "column_stats_test_ser" // (data.head.productArity - 1) because the last column does not support stats collection. assert(stats.size == data.head.productArity - 1) val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) - val expectedSerializedColStats = Map( - "spark.sql.statistics.colStats.cbinary.avgLen" -> "3", - "spark.sql.statistics.colStats.cbinary.distinctCount" -> "2", - "spark.sql.statistics.colStats.cbinary.maxLen" -> "3", - "spark.sql.statistics.colStats.cbinary.nullCount" -> "1", - "spark.sql.statistics.colStats.cbinary.version" -> "1", - "spark.sql.statistics.colStats.cbool.avgLen" -> "1", - "spark.sql.statistics.colStats.cbool.distinctCount" -> "2", - "spark.sql.statistics.colStats.cbool.max" -> "true", - "spark.sql.statistics.colStats.cbool.maxLen" -> "1", - "spark.sql.statistics.colStats.cbool.min" -> "false", - "spark.sql.statistics.colStats.cbool.nullCount" -> "1", - "spark.sql.statistics.colStats.cbool.version" -> "1", - "spark.sql.statistics.colStats.cbyte.avgLen" -> "1", - "spark.sql.statistics.colStats.cbyte.distinctCount" -> "2", - "spark.sql.statistics.colStats.cbyte.max" -> "2", - "spark.sql.statistics.colStats.cbyte.maxLen" -> "1", - "spark.sql.statistics.colStats.cbyte.min" -> "1", - "spark.sql.statistics.colStats.cbyte.nullCount" -> "1", - "spark.sql.statistics.colStats.cbyte.version" -> "1", - "spark.sql.statistics.colStats.cdate.avgLen" -> "4", - "spark.sql.statistics.colStats.cdate.distinctCount" -> "2", - "spark.sql.statistics.colStats.cdate.max" -> "2016-05-09", - "spark.sql.statistics.colStats.cdate.maxLen" -> "4", - "spark.sql.statistics.colStats.cdate.min" -> "2016-05-08", - "spark.sql.statistics.colStats.cdate.nullCount" -> "1", - "spark.sql.statistics.colStats.cdate.version" -> "1", - "spark.sql.statistics.colStats.cdecimal.avgLen" -> "16", - "spark.sql.statistics.colStats.cdecimal.distinctCount" -> "2", - "spark.sql.statistics.colStats.cdecimal.max" -> "8.000000000000000000", - "spark.sql.statistics.colStats.cdecimal.maxLen" -> "16", - "spark.sql.statistics.colStats.cdecimal.min" -> "1.000000000000000000", - "spark.sql.statistics.colStats.cdecimal.nullCount" -> "1", - "spark.sql.statistics.colStats.cdecimal.version" -> "1", - "spark.sql.statistics.colStats.cdouble.avgLen" -> "8", - "spark.sql.statistics.colStats.cdouble.distinctCount" -> "2", - "spark.sql.statistics.colStats.cdouble.max" -> "6.0", - "spark.sql.statistics.colStats.cdouble.maxLen" -> "8", - "spark.sql.statistics.colStats.cdouble.min" -> "1.0", - "spark.sql.statistics.colStats.cdouble.nullCount" -> "1", - "spark.sql.statistics.colStats.cdouble.version" -> "1", - "spark.sql.statistics.colStats.cfloat.avgLen" -> "4", - "spark.sql.statistics.colStats.cfloat.distinctCount" -> "2", - "spark.sql.statistics.colStats.cfloat.max" -> "7.0", - "spark.sql.statistics.colStats.cfloat.maxLen" -> "4", - "spark.sql.statistics.colStats.cfloat.min" -> "1.0", - "spark.sql.statistics.colStats.cfloat.nullCount" -> "1", - "spark.sql.statistics.colStats.cfloat.version" -> "1", - "spark.sql.statistics.colStats.cint.avgLen" -> "4", - "spark.sql.statistics.colStats.cint.distinctCount" -> "2", - "spark.sql.statistics.colStats.cint.max" -> "4", - "spark.sql.statistics.colStats.cint.maxLen" -> "4", - "spark.sql.statistics.colStats.cint.min" -> "1", - "spark.sql.statistics.colStats.cint.nullCount" -> "1", - "spark.sql.statistics.colStats.cint.version" -> "1", - "spark.sql.statistics.colStats.clong.avgLen" -> "8", - "spark.sql.statistics.colStats.clong.distinctCount" -> "2", - "spark.sql.statistics.colStats.clong.max" -> "5", - "spark.sql.statistics.colStats.clong.maxLen" -> "8", - "spark.sql.statistics.colStats.clong.min" -> "1", - "spark.sql.statistics.colStats.clong.nullCount" -> "1", - "spark.sql.statistics.colStats.clong.version" -> "1", - "spark.sql.statistics.colStats.cshort.avgLen" -> "2", - "spark.sql.statistics.colStats.cshort.distinctCount" -> "2", - "spark.sql.statistics.colStats.cshort.max" -> "3", - "spark.sql.statistics.colStats.cshort.maxLen" -> "2", - "spark.sql.statistics.colStats.cshort.min" -> "1", - "spark.sql.statistics.colStats.cshort.nullCount" -> "1", - "spark.sql.statistics.colStats.cshort.version" -> "1", - "spark.sql.statistics.colStats.cstring.avgLen" -> "3", - "spark.sql.statistics.colStats.cstring.distinctCount" -> "2", - "spark.sql.statistics.colStats.cstring.maxLen" -> "3", - "spark.sql.statistics.colStats.cstring.nullCount" -> "1", - "spark.sql.statistics.colStats.cstring.version" -> "1", - "spark.sql.statistics.colStats.ctimestamp.avgLen" -> "8", - "spark.sql.statistics.colStats.ctimestamp.distinctCount" -> "2", - "spark.sql.statistics.colStats.ctimestamp.max" -> "2016-05-09 00:00:02.0", - "spark.sql.statistics.colStats.ctimestamp.maxLen" -> "8", - "spark.sql.statistics.colStats.ctimestamp.min" -> "2016-05-08 00:00:01.0", - "spark.sql.statistics.colStats.ctimestamp.nullCount" -> "1", - "spark.sql.statistics.colStats.ctimestamp.version" -> "1" - ) - - val expectedSerializedHistograms = Map( - "spark.sql.statistics.colStats.cbyte.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cbyte").histogram.get), - "spark.sql.statistics.colStats.cshort.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cshort").histogram.get), - "spark.sql.statistics.colStats.cint.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cint").histogram.get), - "spark.sql.statistics.colStats.clong.histogram" -> - HistogramSerializer.serialize(statsWithHgms("clong").histogram.get), - "spark.sql.statistics.colStats.cdouble.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cdouble").histogram.get), - "spark.sql.statistics.colStats.cfloat.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cfloat").histogram.get), - "spark.sql.statistics.colStats.cdecimal.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cdecimal").histogram.get), - "spark.sql.statistics.colStats.cdate.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cdate").histogram.get), - "spark.sql.statistics.colStats.ctimestamp.histogram" -> - HistogramSerializer.serialize(statsWithHgms("ctimestamp").histogram.get) - ) - def checkColStatsProps(expected: Map[String, String]): Unit = { sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS " + stats.keys.mkString(", ")) val table = hiveClient.getTable("default", tableName) @@ -1129,6 +1026,29 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } + test("verify column stats can be deserialized from tblproperties") { + import testImplicits._ + + val tableName = "column_stats_test_de" + // (data.head.productArity - 1) because the last column does not support stats collection. + assert(stats.size == data.head.productArity - 1) + val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) + + withTable(tableName) { + df.write.saveAsTable(tableName) + + // Put in stats properties manually. + val table = getCatalogTable(tableName) + val newTable = table.copy( + properties = table.properties ++ + expectedSerializedColStats ++ expectedSerializedHistograms + + ("spark.sql.statistics.totalSize" -> "1") /* totalSize always required */) + hiveClient.alterTable(newTable) + + validateColStats(tableName, statsWithHgms) + } + } + test("serialization and deserialization of histograms to/from hive metastore") { import testImplicits._ From 649ed9c5732f85ef1306576fdd3a9278a2a6410c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 27 Feb 2018 08:18:41 -0600 Subject: [PATCH 397/774] [SPARK-23509][BUILD] Upgrade commons-net from 2.2 to 3.1 ## What changes were proposed in this pull request? This PR avoids version conflicts of `commons-net` by upgrading commons-net from 2.2 to 3.1. We are seeing the following message during the build using sbt. ``` [warn] Found version conflict(s) in library dependencies; some are suspected to be binary incompatible: ... [warn] * commons-net:commons-net:3.1 is selected over 2.2 [warn] +- org.apache.hadoop:hadoop-common:2.6.5 (depends on 3.1) [warn] +- org.apache.spark:spark-core_2.11:2.4.0-SNAPSHOT (depends on 2.2) [warn] ``` [Here](https://commons.apache.org/proper/commons-net/changes-report.html) is a release history. [Here](https://commons.apache.org/proper/commons-net/migration.html) is a migration guide from 2.x to 3.0. ## How was this patch tested? Existing tests Author: Kazuaki Ishizaki Closes #20672 from kiszk/SPARK-23509. --- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- pom.xml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index ed310507d14ed..c3d1dd444b506 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -48,7 +48,7 @@ commons-lang-2.6.jar commons-lang3-3.5.jar commons-logging-1.1.3.jar commons-math3-3.4.1.jar -commons-net-2.2.jar +commons-net-3.1.jar commons-pool-1.5.4.jar compress-lzf-1.0.3.jar core-1.1.2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 04dec04796af4..290867035f91d 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -48,7 +48,7 @@ commons-lang-2.6.jar commons-lang3-3.5.jar commons-logging-1.1.3.jar commons-math3-3.4.1.jar -commons-net-2.2.jar +commons-net-3.1.jar commons-pool-1.5.4.jar compress-lzf-1.0.3.jar core-1.1.2.jar diff --git a/pom.xml b/pom.xml index ac30107066389..b8396166f6b1b 100644 --- a/pom.xml +++ b/pom.xml @@ -579,7 +579,7 @@ commons-net commons-net - 2.2 + 3.1 io.netty From eac0b067222a3dfa52be20360a453cb7bd420bf2 Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Tue, 27 Feb 2018 08:21:11 -0600 Subject: [PATCH 398/774] [SPARK-17147][STREAMING][KAFKA] Allow non-consecutive offsets ## What changes were proposed in this pull request? Add a configuration spark.streaming.kafka.allowNonConsecutiveOffsets to allow streaming jobs to proceed on compacted topics (or other situations involving gaps between offsets in the log). ## How was this patch tested? Added new unit test justinrmiller has been testing this branch in production for a few weeks Author: cody koeninger Closes #20572 from koeninger/SPARK-17147. --- .../kafka010/CachedKafkaConsumer.scala | 55 +++- .../spark/streaming/kafka010/KafkaRDD.scala | 236 +++++++++++++----- .../streaming/kafka010/KafkaRDDSuite.scala | 106 ++++++++ .../streaming/kafka010/KafkaTestUtils.scala | 25 +- .../kafka010/mocks/MockScheduler.scala | 96 +++++++ .../streaming/kafka010/mocks/MockTime.scala | 51 ++++ 6 files changed, 487 insertions(+), 82 deletions(-) create mode 100644 external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala create mode 100644 external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala index fa3ea6131a507..aeb8c1dc342b3 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala @@ -22,10 +22,8 @@ import java.{ util => ju } import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord, KafkaConsumer } import org.apache.kafka.common.{ KafkaException, TopicPartition } -import org.apache.spark.SparkConf import org.apache.spark.internal.Logging - /** * Consumer of single topicpartition, intended for cached reuse. * Underlying consumer is not threadsafe, so neither is this, @@ -38,7 +36,7 @@ class CachedKafkaConsumer[K, V] private( val partition: Int, val kafkaParams: ju.Map[String, Object]) extends Logging { - assert(groupId == kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG), + require(groupId == kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG), "groupId used for cache key must match the groupId in kafkaParams") val topicPartition = new TopicPartition(topic, partition) @@ -53,7 +51,7 @@ class CachedKafkaConsumer[K, V] private( // TODO if the buffer was kept around as a random-access structure, // could possibly optimize re-calculating of an RDD in the same batch - protected var buffer = ju.Collections.emptyList[ConsumerRecord[K, V]]().iterator + protected var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, V]]() protected var nextOffset = -2L def close(): Unit = consumer.close() @@ -71,7 +69,7 @@ class CachedKafkaConsumer[K, V] private( } if (!buffer.hasNext()) { poll(timeout) } - assert(buffer.hasNext(), + require(buffer.hasNext(), s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") var record = buffer.next() @@ -79,17 +77,56 @@ class CachedKafkaConsumer[K, V] private( logInfo(s"Buffer miss for $groupId $topic $partition $offset") seek(offset) poll(timeout) - assert(buffer.hasNext(), + require(buffer.hasNext(), s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") record = buffer.next() - assert(record.offset == offset, - s"Got wrong record for $groupId $topic $partition even after seeking to offset $offset") + require(record.offset == offset, + s"Got wrong record for $groupId $topic $partition even after seeking to offset $offset " + + s"got offset ${record.offset} instead. If this is a compacted topic, consider enabling " + + "spark.streaming.kafka.allowNonConsecutiveOffsets" + ) } nextOffset = offset + 1 record } + /** + * Start a batch on a compacted topic + */ + def compactedStart(offset: Long, timeout: Long): Unit = { + logDebug(s"compacted start $groupId $topic $partition starting $offset") + // This seek may not be necessary, but it's hard to tell due to gaps in compacted topics + if (offset != nextOffset) { + logInfo(s"Initial fetch for compacted $groupId $topic $partition $offset") + seek(offset) + poll(timeout) + } + } + + /** + * Get the next record in the batch from a compacted topic. + * Assumes compactedStart has been called first, and ignores gaps. + */ + def compactedNext(timeout: Long): ConsumerRecord[K, V] = { + if (!buffer.hasNext()) { + poll(timeout) + } + require(buffer.hasNext(), + s"Failed to get records for compacted $groupId $topic $partition after polling for $timeout") + val record = buffer.next() + nextOffset = record.offset + 1 + record + } + + /** + * Rewind to previous record in the batch from a compacted topic. + * @throws NoSuchElementException if no previous element + */ + def compactedPrevious(): ConsumerRecord[K, V] = { + buffer.previous() + } + private def seek(offset: Long): Unit = { logDebug(s"Seeking to $topicPartition $offset") consumer.seek(topicPartition, offset) @@ -99,7 +136,7 @@ class CachedKafkaConsumer[K, V] private( val p = consumer.poll(timeout) val r = p.records(topicPartition) logDebug(s"Polled ${p.partitions()} ${r.size}") - buffer = r.iterator + buffer = r.listIterator } } diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index d9fc9cc206647..07239eda64d2e 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -55,12 +55,12 @@ private[spark] class KafkaRDD[K, V]( useConsumerCache: Boolean ) extends RDD[ConsumerRecord[K, V]](sc, Nil) with Logging with HasOffsetRanges { - assert("none" == + require("none" == kafkaParams.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG).asInstanceOf[String], ConsumerConfig.AUTO_OFFSET_RESET_CONFIG + " must be set to none for executor kafka params, else messages may not match offsetRange") - assert(false == + require(false == kafkaParams.get(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG).asInstanceOf[Boolean], ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG + " must be set to false for executor kafka params, else offsets may commit before processing") @@ -74,6 +74,8 @@ private[spark] class KafkaRDD[K, V]( conf.getInt("spark.streaming.kafka.consumer.cache.maxCapacity", 64) private val cacheLoadFactor = conf.getDouble("spark.streaming.kafka.consumer.cache.loadFactor", 0.75).toFloat + private val compacted = + conf.getBoolean("spark.streaming.kafka.allowNonConsecutiveOffsets", false) override def persist(newLevel: StorageLevel): this.type = { logError("Kafka ConsumerRecord is not serializable. " + @@ -87,48 +89,63 @@ private[spark] class KafkaRDD[K, V]( }.toArray } - override def count(): Long = offsetRanges.map(_.count).sum + override def count(): Long = + if (compacted) { + super.count() + } else { + offsetRanges.map(_.count).sum + } override def countApprox( timeout: Long, confidence: Double = 0.95 - ): PartialResult[BoundedDouble] = { - val c = count - new PartialResult(new BoundedDouble(c, 1.0, c, c), true) - } - - override def isEmpty(): Boolean = count == 0L - - override def take(num: Int): Array[ConsumerRecord[K, V]] = { - val nonEmptyPartitions = this.partitions - .map(_.asInstanceOf[KafkaRDDPartition]) - .filter(_.count > 0) + ): PartialResult[BoundedDouble] = + if (compacted) { + super.countApprox(timeout, confidence) + } else { + val c = count + new PartialResult(new BoundedDouble(c, 1.0, c, c), true) + } - if (num < 1 || nonEmptyPartitions.isEmpty) { - return new Array[ConsumerRecord[K, V]](0) + override def isEmpty(): Boolean = + if (compacted) { + super.isEmpty() + } else { + count == 0L } - // Determine in advance how many messages need to be taken from each partition - val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => - val remain = num - result.values.sum - if (remain > 0) { - val taken = Math.min(remain, part.count) - result + (part.index -> taken.toInt) + override def take(num: Int): Array[ConsumerRecord[K, V]] = + if (compacted) { + super.take(num) + } else if (num < 1) { + Array.empty[ConsumerRecord[K, V]] + } else { + val nonEmptyPartitions = this.partitions + .map(_.asInstanceOf[KafkaRDDPartition]) + .filter(_.count > 0) + + if (nonEmptyPartitions.isEmpty) { + Array.empty[ConsumerRecord[K, V]] } else { - result + // Determine in advance how many messages need to be taken from each partition + val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => + val remain = num - result.values.sum + if (remain > 0) { + val taken = Math.min(remain, part.count) + result + (part.index -> taken.toInt) + } else { + result + } + } + + context.runJob( + this, + (tc: TaskContext, it: Iterator[ConsumerRecord[K, V]]) => + it.take(parts(tc.partitionId)).toArray, parts.keys.toArray + ).flatten } } - val buf = new ArrayBuffer[ConsumerRecord[K, V]] - val res = context.runJob( - this, - (tc: TaskContext, it: Iterator[ConsumerRecord[K, V]]) => - it.take(parts(tc.partitionId)).toArray, parts.keys.toArray - ) - res.foreach(buf ++= _) - buf.toArray - } - private def executors(): Array[ExecutorCacheTaskLocation] = { val bm = sparkContext.env.blockManager bm.master.getPeers(bm.blockManagerId).toArray @@ -172,57 +189,138 @@ private[spark] class KafkaRDD[K, V]( override def compute(thePart: Partition, context: TaskContext): Iterator[ConsumerRecord[K, V]] = { val part = thePart.asInstanceOf[KafkaRDDPartition] - assert(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part)) + require(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part)) if (part.fromOffset == part.untilOffset) { logInfo(s"Beginning offset ${part.fromOffset} is the same as ending offset " + s"skipping ${part.topic} ${part.partition}") Iterator.empty } else { - new KafkaRDDIterator(part, context) + logInfo(s"Computing topic ${part.topic}, partition ${part.partition} " + + s"offsets ${part.fromOffset} -> ${part.untilOffset}") + if (compacted) { + new CompactedKafkaRDDIterator[K, V]( + part, + context, + kafkaParams, + useConsumerCache, + pollTimeout, + cacheInitialCapacity, + cacheMaxCapacity, + cacheLoadFactor + ) + } else { + new KafkaRDDIterator[K, V]( + part, + context, + kafkaParams, + useConsumerCache, + pollTimeout, + cacheInitialCapacity, + cacheMaxCapacity, + cacheLoadFactor + ) + } } } +} - /** - * An iterator that fetches messages directly from Kafka for the offsets in partition. - * Uses a cached consumer where possible to take advantage of prefetching - */ - private class KafkaRDDIterator( - part: KafkaRDDPartition, - context: TaskContext) extends Iterator[ConsumerRecord[K, V]] { - - logInfo(s"Computing topic ${part.topic}, partition ${part.partition} " + - s"offsets ${part.fromOffset} -> ${part.untilOffset}") - - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] +/** + * An iterator that fetches messages directly from Kafka for the offsets in partition. + * Uses a cached consumer where possible to take advantage of prefetching + */ +private class KafkaRDDIterator[K, V]( + part: KafkaRDDPartition, + context: TaskContext, + kafkaParams: ju.Map[String, Object], + useConsumerCache: Boolean, + pollTimeout: Long, + cacheInitialCapacity: Int, + cacheMaxCapacity: Int, + cacheLoadFactor: Float +) extends Iterator[ConsumerRecord[K, V]] { + + val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + + context.addTaskCompletionListener(_ => closeIfNeeded()) + + val consumer = if (useConsumerCache) { + CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) + if (context.attemptNumber >= 1) { + // just in case the prior attempt failures were cache related + CachedKafkaConsumer.remove(groupId, part.topic, part.partition) + } + CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, kafkaParams) + } else { + CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, kafkaParams) + } - context.addTaskCompletionListener{ context => closeIfNeeded() } + var requestOffset = part.fromOffset - val consumer = if (useConsumerCache) { - CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) - if (context.attemptNumber >= 1) { - // just in case the prior attempt failures were cache related - CachedKafkaConsumer.remove(groupId, part.topic, part.partition) - } - CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, kafkaParams) - } else { - CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, kafkaParams) + def closeIfNeeded(): Unit = { + if (!useConsumerCache && consumer != null) { + consumer.close() } + } - var requestOffset = part.fromOffset + override def hasNext(): Boolean = requestOffset < part.untilOffset - def closeIfNeeded(): Unit = { - if (!useConsumerCache && consumer != null) { - consumer.close - } + override def next(): ConsumerRecord[K, V] = { + if (!hasNext) { + throw new ju.NoSuchElementException("Can't call getNext() once untilOffset has been reached") } + val r = consumer.get(requestOffset, pollTimeout) + requestOffset += 1 + r + } +} - override def hasNext(): Boolean = requestOffset < part.untilOffset - - override def next(): ConsumerRecord[K, V] = { - assert(hasNext(), "Can't call getNext() once untilOffset has been reached") - val r = consumer.get(requestOffset, pollTimeout) - requestOffset += 1 - r +/** + * An iterator that fetches messages directly from Kafka for the offsets in partition. + * Uses a cached consumer where possible to take advantage of prefetching. + * Intended for compacted topics, or other cases when non-consecutive offsets are ok. + */ +private class CompactedKafkaRDDIterator[K, V]( + part: KafkaRDDPartition, + context: TaskContext, + kafkaParams: ju.Map[String, Object], + useConsumerCache: Boolean, + pollTimeout: Long, + cacheInitialCapacity: Int, + cacheMaxCapacity: Int, + cacheLoadFactor: Float + ) extends KafkaRDDIterator[K, V]( + part, + context, + kafkaParams, + useConsumerCache, + pollTimeout, + cacheInitialCapacity, + cacheMaxCapacity, + cacheLoadFactor + ) { + + consumer.compactedStart(part.fromOffset, pollTimeout) + + private var nextRecord = consumer.compactedNext(pollTimeout) + + private var okNext: Boolean = true + + override def hasNext(): Boolean = okNext + + override def next(): ConsumerRecord[K, V] = { + if (!hasNext) { + throw new ju.NoSuchElementException("Can't call getNext() once untilOffset has been reached") + } + val r = nextRecord + if (r.offset + 1 >= part.untilOffset) { + okNext = false + } else { + nextRecord = consumer.compactedNext(pollTimeout) + if (nextRecord.offset >= part.untilOffset) { + okNext = false + consumer.compactedPrevious() + } } + r } } diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala index be373af0599cc..271adea1df731 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala @@ -18,16 +18,22 @@ package org.apache.spark.streaming.kafka010 import java.{ util => ju } +import java.io.File import scala.collection.JavaConverters._ import scala.util.Random +import kafka.common.TopicAndPartition +import kafka.log._ +import kafka.message._ +import kafka.utils.Pool import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.serialization.StringDeserializer import org.scalatest.BeforeAndAfterAll import org.apache.spark._ import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.streaming.kafka010.mocks.MockTime class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { @@ -64,6 +70,41 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { private val preferredHosts = LocationStrategies.PreferConsistent + private def compactLogs(topic: String, partition: Int, messages: Array[(String, String)]) { + val mockTime = new MockTime() + // LogCleaner in 0.10 version of Kafka is still expecting the old TopicAndPartition api + val logs = new Pool[TopicAndPartition, Log]() + val logDir = kafkaTestUtils.brokerLogDir + val dir = new File(logDir, topic + "-" + partition) + dir.mkdirs() + val logProps = new ju.Properties() + logProps.put(LogConfig.CleanupPolicyProp, LogConfig.Compact) + logProps.put(LogConfig.MinCleanableDirtyRatioProp, java.lang.Float.valueOf(0.1f)) + val log = new Log( + dir, + LogConfig(logProps), + 0L, + mockTime.scheduler, + mockTime + ) + messages.foreach { case (k, v) => + val msg = new ByteBufferMessageSet( + NoCompressionCodec, + new Message(v.getBytes, k.getBytes, Message.NoTimestamp, Message.CurrentMagicValue)) + log.append(msg) + } + log.roll() + logs.put(TopicAndPartition(topic, partition), log) + + val cleaner = new LogCleaner(CleanerConfig(), logDirs = Array(dir), logs = logs) + cleaner.startup() + cleaner.awaitCleaned(topic, partition, log.activeSegment.baseOffset, 1000) + + cleaner.shutdown() + mockTime.scheduler.shutdown() + } + + test("basic usage") { val topic = s"topicbasic-${Random.nextInt}-${System.currentTimeMillis}" kafkaTestUtils.createTopic(topic) @@ -102,6 +143,71 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { } } + test("compacted topic") { + val compactConf = sparkConf.clone() + compactConf.set("spark.streaming.kafka.allowNonConsecutiveOffsets", "true") + sc.stop() + sc = new SparkContext(compactConf) + val topic = s"topiccompacted-${Random.nextInt}-${System.currentTimeMillis}" + + val messages = Array( + ("a", "1"), + ("a", "2"), + ("b", "1"), + ("c", "1"), + ("c", "2"), + ("b", "2"), + ("b", "3") + ) + val compactedMessages = Array( + ("a", "2"), + ("b", "3"), + ("c", "2") + ) + + compactLogs(topic, 0, messages) + + val props = new ju.Properties() + props.put("cleanup.policy", "compact") + props.put("flush.messages", "1") + props.put("segment.ms", "1") + props.put("segment.bytes", "256") + kafkaTestUtils.createTopic(topic, 1, props) + + + val kafkaParams = getKafkaParams() + + val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size)) + + val rdd = KafkaUtils.createRDD[String, String]( + sc, kafkaParams, offsetRanges, preferredHosts + ).map(m => m.key -> m.value) + + val received = rdd.collect.toSet + assert(received === compactedMessages.toSet) + + // size-related method optimizations return sane results + assert(rdd.count === compactedMessages.size) + assert(rdd.countApprox(0).getFinalValue.mean === compactedMessages.size) + assert(!rdd.isEmpty) + assert(rdd.take(1).size === 1) + assert(rdd.take(1).head === compactedMessages.head) + assert(rdd.take(messages.size + 10).size === compactedMessages.size) + + val emptyRdd = KafkaUtils.createRDD[String, String]( + sc, kafkaParams, Array(OffsetRange(topic, 0, 0, 0)), preferredHosts) + + assert(emptyRdd.isEmpty) + + // invalid offset ranges throw exceptions + val badRanges = Array(OffsetRange(topic, 0, 0, messages.size + 1)) + intercept[SparkException] { + val result = KafkaUtils.createRDD[String, String](sc, kafkaParams, badRanges, preferredHosts) + .map(_.value) + .collect() + } + } + test("iterator boundary conditions") { // the idea is to find e.g. off-by-one errors between what kafka has available and the rdd val topic = s"topicboundary-${Random.nextInt}-${System.currentTimeMillis}" diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala index 6c7024ea4b5a5..70b579d96d692 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala @@ -162,17 +162,22 @@ private[kafka010] class KafkaTestUtils extends Logging { } /** Create a Kafka topic and wait until it is propagated to the whole cluster */ - def createTopic(topic: String, partitions: Int): Unit = { - AdminUtils.createTopic(zkUtils, topic, partitions, 1) + def createTopic(topic: String, partitions: Int, config: Properties): Unit = { + AdminUtils.createTopic(zkUtils, topic, partitions, 1, config) // wait until metadata is propagated (0 until partitions).foreach { p => waitUntilMetadataIsPropagated(topic, p) } } + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ + def createTopic(topic: String, partitions: Int): Unit = { + createTopic(topic, partitions, new Properties()) + } + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ def createTopic(topic: String): Unit = { - createTopic(topic, 1) + createTopic(topic, 1, new Properties()) } /** Java-friendly function for sending messages to the Kafka broker */ @@ -196,12 +201,24 @@ private[kafka010] class KafkaTestUtils extends Logging { producer = null } + /** Send the array of (key, value) messages to the Kafka broker */ + def sendMessages(topic: String, messages: Array[(String, String)]): Unit = { + producer = new KafkaProducer[String, String](producerConfiguration) + messages.foreach { message => + producer.send(new ProducerRecord[String, String](topic, message._1, message._2)) + } + producer.close() + producer = null + } + + val brokerLogDir = Utils.createTempDir().getAbsolutePath + private def brokerConfiguration: Properties = { val props = new Properties() props.put("broker.id", "0") props.put("host.name", "localhost") props.put("port", brokerPort.toString) - props.put("log.dir", Utils.createTempDir().getAbsolutePath) + props.put("log.dir", brokerLogDir) props.put("zookeeper.connect", zkAddress) props.put("log.flush.interval.messages", "1") props.put("replica.socket.timeout.ms", "1500") diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala new file mode 100644 index 0000000000000..928e1a6ef54b9 --- /dev/null +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010.mocks + +import java.util.concurrent.TimeUnit + +import scala.collection.mutable.PriorityQueue + +import kafka.utils.{Scheduler, Time} + +/** + * A mock scheduler that executes tasks synchronously using a mock time instance. + * Tasks are executed synchronously when the time is advanced. + * This class is meant to be used in conjunction with MockTime. + * + * Example usage + * + * val time = new MockTime + * time.scheduler.schedule("a task", println("hello world: " + time.milliseconds), delay = 1000) + * time.sleep(1001) // this should cause our scheduled task to fire + * + * + * Incrementing the time to the exact next execution time of a task will result in that task + * executing (it as if execution itself takes no time). + */ +private[kafka010] class MockScheduler(val time: Time) extends Scheduler { + + /* a priority queue of tasks ordered by next execution time */ + var tasks = new PriorityQueue[MockTask]() + + def isStarted: Boolean = true + + def startup(): Unit = {} + + def shutdown(): Unit = synchronized { + tasks.foreach(_.fun()) + tasks.clear() + } + + /** + * Check for any tasks that need to execute. Since this is a mock scheduler this check only occurs + * when this method is called and the execution happens synchronously in the calling thread. + * If you are using the scheduler associated with a MockTime instance this call + * will be triggered automatically. + */ + def tick(): Unit = synchronized { + val now = time.milliseconds + while(!tasks.isEmpty && tasks.head.nextExecution <= now) { + /* pop and execute the task with the lowest next execution time */ + val curr = tasks.dequeue + curr.fun() + /* if the task is periodic, reschedule it and re-enqueue */ + if(curr.periodic) { + curr.nextExecution += curr.period + this.tasks += curr + } + } + } + + def schedule( + name: String, + fun: () => Unit, + delay: Long = 0, + period: Long = -1, + unit: TimeUnit = TimeUnit.MILLISECONDS): Unit = synchronized { + tasks += MockTask(name, fun, time.milliseconds + delay, period = period) + tick() + } + +} + +case class MockTask( + val name: String, + val fun: () => Unit, + var nextExecution: Long, + val period: Long) extends Ordered[MockTask] { + def periodic: Boolean = period >= 0 + def compare(t: MockTask): Int = { + java.lang.Long.compare(t.nextExecution, nextExecution) + } +} diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala new file mode 100644 index 0000000000000..a68f94db1f689 --- /dev/null +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010.mocks + +import java.util.concurrent._ + +import kafka.utils.Time + +/** + * A class used for unit testing things which depend on the Time interface. + * + * This class never manually advances the clock, it only does so when you call + * sleep(ms) + * + * It also comes with an associated scheduler instance for managing background tasks in + * a deterministic way. + */ +private[kafka010] class MockTime(@volatile private var currentMs: Long) extends Time { + + val scheduler = new MockScheduler(this) + + def this() = this(System.currentTimeMillis) + + def milliseconds: Long = currentMs + + def nanoseconds: Long = + TimeUnit.NANOSECONDS.convert(currentMs, TimeUnit.MILLISECONDS) + + def sleep(ms: Long) { + this.currentMs += ms + scheduler.tick() + } + + override def toString(): String = s"MockTime($milliseconds)" + +} From 414ee867ba0835b97aae2e8d4e489e1879c251dd Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 27 Feb 2018 08:44:25 -0800 Subject: [PATCH 399/774] [SPARK-23523][SQL] Fix the incorrect result caused by the rule OptimizeMetadataOnlyQuery ## What changes were proposed in this pull request? ```Scala val tablePath = new File(s"${path.getCanonicalPath}/cOl3=c/cOl1=a/cOl5=e") Seq(("a", "b", "c", "d", "e")).toDF("cOl1", "cOl2", "cOl3", "cOl4", "cOl5") .write.json(tablePath.getCanonicalPath) val df = spark.read.json(path.getCanonicalPath).select("CoL1", "CoL5", "CoL3").distinct() df.show() ``` It generates a wrong result. ``` [c,e,a] ``` We have a bug in the rule `OptimizeMetadataOnlyQuery `. We should respect the attribute order in the original leaf node. This PR is to fix it. ## How was this patch tested? Added a test case Author: gatorsmile Closes #20684 from gatorsmile/optimizeMetadataOnly. --- .../plans/logical/LocalRelation.scala | 9 ++++---- .../execution/OptimizeMetadataOnlyQuery.scala | 12 ++++++++-- .../datasources/HadoopFsRelation.scala | 3 +++ .../OptimizeMetadataOnlyQuerySuite.scala | 22 +++++++++++++++++++ 4 files changed, 40 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index d73d7e73f28d5..b05508db786ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -43,10 +43,11 @@ object LocalRelation { } } -case class LocalRelation(output: Seq[Attribute], - data: Seq[InternalRow] = Nil, - // Indicates whether this relation has data from a streaming source. - override val isStreaming: Boolean = false) +case class LocalRelation( + output: Seq[Attribute], + data: Seq[InternalRow] = Nil, + // Indicates whether this relation has data from a streaming source. + override val isStreaming: Boolean = false) extends LeafNode with analysis.MultiInstanceRelation { // A local relation must have resolved output. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index 18f6f697bc857..0613d9053f826 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution +import java.util.Locale + +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.{HiveTableRelation, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ @@ -80,8 +83,13 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic private def getPartitionAttrs( partitionColumnNames: Seq[String], relation: LogicalPlan): Seq[Attribute] = { - val partColumns = partitionColumnNames.map(_.toLowerCase).toSet - relation.output.filter(a => partColumns.contains(a.name.toLowerCase)) + val attrMap = relation.output.map(_.name.toLowerCase(Locale.ROOT)).zip(relation.output).toMap + partitionColumnNames.map { colName => + attrMap.getOrElse(colName.toLowerCase(Locale.ROOT), + throw new AnalysisException(s"Unable to find the column `$colName` " + + s"given [${relation.output.map(_.name).mkString(", ")}]") + ) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index 6b34638529770..ac574b07ec497 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -67,6 +67,9 @@ case class HadoopFsRelation( } } + // When data schema and partition schema have the overlapped columns, the output + // schema respects the order of data schema for the overlapped columns, but respect + // the data types of partition schema val schema: StructType = { StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++ partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala index 78c1e5dae566d..a543eb8351656 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.execution +import java.io.File + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_METADATA_ONLY import org.apache.spark.sql.test.SharedSQLContext class OptimizeMetadataOnlyQuerySuite extends QueryTest with SharedSQLContext { @@ -125,4 +128,23 @@ class OptimizeMetadataOnlyQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT COUNT(DISTINCT p) FROM t_1000").collect() } } + + test("Incorrect result caused by the rule OptimizeMetadataOnlyQuery") { + withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") { + withTempPath { path => + val tablePath = new File(s"${path.getCanonicalPath}/cOl3=c/cOl1=a/cOl5=e") + Seq(("a", "b", "c", "d", "e")).toDF("cOl1", "cOl2", "cOl3", "cOl4", "cOl5") + .write.json(tablePath.getCanonicalPath) + + val df = spark.read.json(path.getCanonicalPath).select("CoL1", "CoL5", "CoL3").distinct() + checkAnswer(df, Row("a", "e", "c")) + + val localRelation = df.queryExecution.optimizedPlan.collectFirst { + case l: LocalRelation => l + } + assert(localRelation.nonEmpty, "expect to see a LocalRelation") + assert(localRelation.get.output.map(_.name) == Seq("cOl3", "cOl1", "cOl5")) + } + } + } } From ecb8b383af1cf1b67f3111c148229e00c9c17c40 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 27 Feb 2018 11:12:32 -0800 Subject: [PATCH 400/774] [SPARK-23365][CORE] Do not adjust num executors when killing idle executors. The ExecutorAllocationManager should not adjust the target number of executors when killing idle executors, as it has already adjusted the target number down based on the task backlog. The name `replace` was misleading with DynamicAllocation on, as the target number of executors is changed outside of the call to `killExecutors`, so I adjusted that name. Also separated out the logic of `countFailures` as you don't always want that tied to `replace`. While I was there I made two changes that weren't directly related to this: 1) Fixed `countFailures` in a couple cases where it was getting an incorrect value since it used to be tied to `replace`, eg. when killing executors on a blacklisted node. 2) hard error if you call `sc.killExecutors` with dynamic allocation on, since that's another way the ExecutorAllocationManager and the CoarseGrainedSchedulerBackend would get out of sync. Added a unit test case which verifies that the calls to ExecutorAllocationClient do not adjust the number of executors. Author: Imran Rashid Closes #20604 from squito/SPARK-23365. --- .../spark/ExecutorAllocationClient.scala | 15 +++-- .../spark/ExecutorAllocationManager.scala | 20 ++++-- .../scala/org/apache/spark/SparkContext.scala | 13 +++- .../spark/scheduler/BlacklistTracker.scala | 3 +- .../CoarseGrainedSchedulerBackend.scala | 22 ++++--- .../ExecutorAllocationManagerSuite.scala | 66 ++++++++++++++++++- .../StandaloneDynamicAllocationSuite.scala | 3 +- .../scheduler/BlacklistTrackerSuite.scala | 14 ++-- 8 files changed, 121 insertions(+), 35 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index 9112d93a86b2a..63d87b4cd385c 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -55,18 +55,18 @@ private[spark] trait ExecutorAllocationClient { /** * Request that the cluster manager kill the specified executors. * - * When asking the executor to be replaced, the executor loss is considered a failure, and - * killed tasks that are running on the executor will count towards the failure limits. If no - * replacement is being requested, then the tasks will not count towards the limit. - * * @param executorIds identifiers of executors to kill - * @param replace whether to replace the killed executors with new ones, default false + * @param adjustTargetNumExecutors whether the target number of executors will be adjusted down + * after these executors have been killed + * @param countFailures if there are tasks running on the executors when they are killed, whether + * to count those failures toward task failure limits * @param force whether to force kill busy executors, default false * @return the ids of the executors acknowledged by the cluster manager to be removed. */ def killExecutors( executorIds: Seq[String], - replace: Boolean = false, + adjustTargetNumExecutors: Boolean, + countFailures: Boolean, force: Boolean = false): Seq[String] /** @@ -81,7 +81,8 @@ private[spark] trait ExecutorAllocationClient { * @return whether the request is acknowledged by the cluster manager. */ def killExecutor(executorId: String): Boolean = { - val killedExecutors = killExecutors(Seq(executorId)) + val killedExecutors = killExecutors(Seq(executorId), adjustTargetNumExecutors = true, + countFailures = false) killedExecutors.nonEmpty && killedExecutors(0).equals(executorId) } } diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 6c59038f2a6c1..189d91333c045 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{DYN_ALLOCATION_MAX_EXECUTORS, DYN_ALLOCATION_MIN_EXECUTORS} import org.apache.spark.metrics.source.Source import org.apache.spark.scheduler._ +import org.apache.spark.storage.BlockManagerMaster import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} /** @@ -81,7 +82,8 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} private[spark] class ExecutorAllocationManager( client: ExecutorAllocationClient, listenerBus: LiveListenerBus, - conf: SparkConf) + conf: SparkConf, + blockManagerMaster: BlockManagerMaster) extends Logging { allocationManager => @@ -151,7 +153,7 @@ private[spark] class ExecutorAllocationManager( private var clock: Clock = new SystemClock() // Listener for Spark events that impact the allocation policy - private val listener = new ExecutorAllocationListener + val listener = new ExecutorAllocationListener // Executor that handles the scheduling task. private val executor = @@ -334,6 +336,11 @@ private[spark] class ExecutorAllocationManager( // If the new target has not changed, avoid sending a message to the cluster manager if (numExecutorsTarget < oldNumExecutorsTarget) { + // We lower the target number of executors but don't actively kill any yet. Killing is + // controlled separately by an idle timeout. It's still helpful to reduce the target number + // in case an executor just happens to get lost (eg., bad hardware, or the cluster manager + // preempts it) -- in that case, there is no point in trying to immediately get a new + // executor, since we wouldn't even use it yet. client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) logDebug(s"Lowering target number of executors to $numExecutorsTarget (previously " + s"$oldNumExecutorsTarget) because not all requested executors are actually needed") @@ -455,7 +462,10 @@ private[spark] class ExecutorAllocationManager( val executorsRemoved = if (testing) { executorIdsToBeRemoved } else { - client.killExecutors(executorIdsToBeRemoved) + // We don't want to change our target number of executors, because we already did that + // when the task backlog decreased. + client.killExecutors(executorIdsToBeRemoved, adjustTargetNumExecutors = false, + countFailures = false, force = false) } // [SPARK-21834] killExecutors api reduces the target number of executors. // So we need to update the target with desired value. @@ -575,7 +585,7 @@ private[spark] class ExecutorAllocationManager( // Note that it is not necessary to query the executors since all the cached // blocks we are concerned with are reported to the driver. Note that this // does not include broadcast blocks. - val hasCachedBlocks = SparkEnv.get.blockManager.master.hasCachedBlocks(executorId) + val hasCachedBlocks = blockManagerMaster.hasCachedBlocks(executorId) val now = clock.getTimeMillis() val timeout = { if (hasCachedBlocks) { @@ -610,7 +620,7 @@ private[spark] class ExecutorAllocationManager( * This class is intentionally conservative in its assumptions about the relative ordering * and consistency of events returned by the listener. */ - private class ExecutorAllocationListener extends SparkListener { + private[spark] class ExecutorAllocationListener extends SparkListener { private val stageIdToNumTasks = new mutable.HashMap[Int, Int] // Number of running tasks per stage including speculative tasks. diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index dc531e3337014..5e8595603cc90 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -534,7 +534,8 @@ class SparkContext(config: SparkConf) extends Logging { schedulerBackend match { case b: ExecutorAllocationClient => Some(new ExecutorAllocationManager( - schedulerBackend.asInstanceOf[ExecutorAllocationClient], listenerBus, _conf)) + schedulerBackend.asInstanceOf[ExecutorAllocationClient], listenerBus, _conf, + _env.blockManager.master)) case _ => None } @@ -1633,6 +1634,8 @@ class SparkContext(config: SparkConf) extends Logging { * :: DeveloperApi :: * Request that the cluster manager kill the specified executors. * + * This is not supported when dynamic allocation is turned on. + * * @note This is an indication to the cluster manager that the application wishes to adjust * its resource usage downwards. If the application wishes to replace the executors it kills * through this method with new ones, it should follow up explicitly with a call to @@ -1644,7 +1647,10 @@ class SparkContext(config: SparkConf) extends Logging { def killExecutors(executorIds: Seq[String]): Boolean = { schedulerBackend match { case b: ExecutorAllocationClient => - b.killExecutors(executorIds, replace = false, force = true).nonEmpty + require(executorAllocationManager.isEmpty, + "killExecutors() unsupported with Dynamic Allocation turned on") + b.killExecutors(executorIds, adjustTargetNumExecutors = true, countFailures = false, + force = true).nonEmpty case _ => logWarning("Killing executors is not supported by current scheduler.") false @@ -1682,7 +1688,8 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] def killAndReplaceExecutor(executorId: String): Boolean = { schedulerBackend match { case b: ExecutorAllocationClient => - b.killExecutors(Seq(executorId), replace = true, force = true).nonEmpty + b.killExecutors(Seq(executorId), adjustTargetNumExecutors = false, countFailures = true, + force = true).nonEmpty case _ => logWarning("Killing executors is not supported by current scheduler.") false diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index cd8e61d6d0208..952598f6de19d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -152,7 +152,8 @@ private[scheduler] class BlacklistTracker ( case Some(a) => logInfo(s"Killing blacklisted executor id $exec " + s"since ${config.BLACKLIST_KILL_ENABLED.key} is set.") - a.killExecutors(Seq(exec), true, true) + a.killExecutors(Seq(exec), adjustTargetNumExecutors = false, countFailures = false, + force = true) case None => logWarning(s"Not attempting to kill blacklisted executor id $exec " + s"since allocation client is not defined.") diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 4d75063fbf1c5..5627a557a12f3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -147,7 +147,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case KillExecutorsOnHost(host) => scheduler.getExecutorsAliveOnHost(host).foreach { exec => - killExecutors(exec.toSeq, replace = true, force = true) + killExecutors(exec.toSeq, adjustTargetNumExecutors = false, countFailures = false, + force = true) } case UpdateDelegationTokens(newDelegationTokens) => @@ -584,18 +585,18 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp /** * Request that the cluster manager kill the specified executors. * - * When asking the executor to be replaced, the executor loss is considered a failure, and - * killed tasks that are running on the executor will count towards the failure limits. If no - * replacement is being requested, then the tasks will not count towards the limit. - * * @param executorIds identifiers of executors to kill - * @param replace whether to replace the killed executors with new ones, default false + * @param adjustTargetNumExecutors whether the target number of executors be adjusted down + * after these executors have been killed + * @param countFailures if there are tasks running on the executors when they are killed, whether + * those failures be counted to task failure limits? * @param force whether to force kill busy executors, default false * @return the ids of the executors acknowledged by the cluster manager to be removed. */ final override def killExecutors( executorIds: Seq[String], - replace: Boolean, + adjustTargetNumExecutors: Boolean, + countFailures: Boolean, force: Boolean): Seq[String] = { logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}") @@ -610,7 +611,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val executorsToKill = knownExecutors .filter { id => !executorsPendingToRemove.contains(id) } .filter { id => force || !scheduler.isExecutorBusy(id) } - executorsToKill.foreach { id => executorsPendingToRemove(id) = !replace } + executorsToKill.foreach { id => executorsPendingToRemove(id) = !countFailures } logInfo(s"Actual list of executor(s) to be killed is ${executorsToKill.mkString(", ")}") @@ -618,12 +619,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // with the cluster manager to avoid allocating new ones. When computing the new target, // take into account executors that are pending to be added or removed. val adjustTotalExecutors = - if (!replace) { + if (adjustTargetNumExecutors) { requestedTotalExecutors = math.max(requestedTotalExecutors - executorsToKill.size, 0) if (requestedTotalExecutors != (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) { logDebug( - s"""killExecutors($executorIds, $replace, $force): Executor counts do not match: + s"""killExecutors($executorIds, $adjustTargetNumExecutors, $countFailures, $force): + |Executor counts do not match: |requestedTotalExecutors = $requestedTotalExecutors |numExistingExecutors = $numExistingExecutors |numPendingExecutors = $numPendingExecutors diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index a0cae5a9e011c..9807d1269e3d4 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark import scala.collection.mutable +import org.mockito.Matchers.{any, eq => meq} +import org.mockito.Mockito.{mock, never, verify, when} import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics @@ -26,6 +28,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.ExternalClusterManager import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.local.LocalSchedulerBackend +import org.apache.spark.storage.BlockManagerMaster import org.apache.spark.util.ManualClock /** @@ -1050,6 +1053,66 @@ class ExecutorAllocationManagerSuite assert(removeTimes(manager) === Map.empty) } + test("SPARK-23365 Don't update target num executors when killing idle executors") { + val minExecutors = 1 + val initialExecutors = 1 + val maxExecutors = 2 + val conf = new SparkConf() + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.shuffle.service.enabled", "true") + .set("spark.dynamicAllocation.minExecutors", minExecutors.toString) + .set("spark.dynamicAllocation.maxExecutors", maxExecutors.toString) + .set("spark.dynamicAllocation.initialExecutors", initialExecutors.toString) + .set("spark.dynamicAllocation.schedulerBacklogTimeout", "1000ms") + .set("spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", "1000ms") + .set("spark.dynamicAllocation.executorIdleTimeout", s"3000ms") + val mockAllocationClient = mock(classOf[ExecutorAllocationClient]) + val mockBMM = mock(classOf[BlockManagerMaster]) + val manager = new ExecutorAllocationManager( + mockAllocationClient, mock(classOf[LiveListenerBus]), conf, mockBMM) + val clock = new ManualClock() + manager.setClock(clock) + + when(mockAllocationClient.requestTotalExecutors(meq(2), any(), any())).thenReturn(true) + // test setup -- job with 2 tasks, scale up to two executors + assert(numExecutorsTarget(manager) === 1) + manager.listener.onExecutorAdded(SparkListenerExecutorAdded( + clock.getTimeMillis(), "executor-1", new ExecutorInfo("host1", 1, Map.empty))) + manager.listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 2))) + clock.advance(1000) + manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.getTimeMillis()) + assert(numExecutorsTarget(manager) === 2) + val taskInfo0 = createTaskInfo(0, 0, "executor-1") + manager.listener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo0)) + manager.listener.onExecutorAdded(SparkListenerExecutorAdded( + clock.getTimeMillis(), "executor-2", new ExecutorInfo("host1", 1, Map.empty))) + val taskInfo1 = createTaskInfo(1, 1, "executor-2") + manager.listener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo1)) + assert(numExecutorsTarget(manager) === 2) + + // have one task finish -- we should adjust the target number of executors down + // but we should *not* kill any executors yet + manager.listener.onTaskEnd(SparkListenerTaskEnd(0, 0, null, Success, taskInfo0, null)) + assert(maxNumExecutorsNeeded(manager) === 1) + assert(numExecutorsTarget(manager) === 2) + clock.advance(1000) + manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.getTimeMillis()) + assert(numExecutorsTarget(manager) === 1) + verify(mockAllocationClient, never).killExecutors(any(), any(), any(), any()) + + // now we cross the idle timeout for executor-1, so we kill it. the really important + // thing here is that we do *not* ask the executor allocation client to adjust the target + // number of executors down + when(mockAllocationClient.killExecutors(Seq("executor-1"), false, false, false)) + .thenReturn(Seq("executor-1")) + clock.advance(3000) + schedule(manager) + assert(maxNumExecutorsNeeded(manager) === 1) + assert(numExecutorsTarget(manager) === 1) + // here's the important verify -- we did kill the executors, but did not adjust the target count + verify(mockAllocationClient).killExecutors(Seq("executor-1"), false, false, false) + } + private def createSparkContext( minExecutors: Int = 1, maxExecutors: Int = 5, @@ -1268,7 +1331,8 @@ private class DummyLocalSchedulerBackend (sc: SparkContext, sb: SchedulerBackend override def killExecutors( executorIds: Seq[String], - replace: Boolean, + adjustTargetNumExecutors: Boolean, + countFailures: Boolean, force: Boolean): Seq[String] = executorIds override def start(): Unit = sb.start() diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index c21ee7d26f8ca..27cc47496c805 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -573,7 +573,8 @@ class StandaloneDynamicAllocationSuite syncExecutors(sc) sc.schedulerBackend match { case b: CoarseGrainedSchedulerBackend => - b.killExecutors(Seq(executorId), replace = false, force) + b.killExecutors(Seq(executorId), adjustTargetNumExecutors = true, countFailures = false, + force) case _ => fail("expected coarse grained scheduler") } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index afebcdd7b9e31..06d7afaaff55c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -479,7 +479,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M test("blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") { val allocationClientMock = mock[ExecutorAllocationClient] - when(allocationClientMock.killExecutors(any(), any(), any())).thenReturn(Seq("called")) + when(allocationClientMock.killExecutors(any(), any(), any(), any())).thenReturn(Seq("called")) when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] { // To avoid a race between blacklisting and killing, it is important that the nodeBlacklist // is updated before we ask the executor allocation client to kill all the executors @@ -517,7 +517,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist1.execToFailures) - verify(allocationClientMock, never).killExecutors(any(), any(), any()) + verify(allocationClientMock, never).killExecutors(any(), any(), any(), any()) verify(allocationClientMock, never).killExecutorsOnHost(any()) // Enable auto-kill. Blacklist an executor and make sure killExecutors is called. @@ -533,7 +533,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist2.execToFailures) - verify(allocationClientMock).killExecutors(Seq("1"), true, true) + verify(allocationClientMock).killExecutors(Seq("1"), false, false, true) val taskSetBlacklist3 = createTaskSetBlacklist(stageId = 1) // Fail 4 tasks in one task set on executor 2, so that executor gets blacklisted for the whole @@ -545,13 +545,13 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist3.execToFailures) - verify(allocationClientMock).killExecutors(Seq("2"), true, true) + verify(allocationClientMock).killExecutors(Seq("2"), false, false, true) verify(allocationClientMock).killExecutorsOnHost("hostA") } test("fetch failure blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") { val allocationClientMock = mock[ExecutorAllocationClient] - when(allocationClientMock.killExecutors(any(), any(), any())).thenReturn(Seq("called")) + when(allocationClientMock.killExecutors(any(), any(), any(), any())).thenReturn(Seq("called")) when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] { // To avoid a race between blacklisting and killing, it is important that the nodeBlacklist // is updated before we ask the executor allocation client to kill all the executors @@ -571,7 +571,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M conf.set(config.BLACKLIST_KILL_ENABLED, false) blacklist.updateBlacklistForFetchFailure("hostA", exec = "1") - verify(allocationClientMock, never).killExecutors(any(), any(), any()) + verify(allocationClientMock, never).killExecutors(any(), any(), any(), any()) verify(allocationClientMock, never).killExecutorsOnHost(any()) // Enable auto-kill. Blacklist an executor and make sure killExecutors is called. @@ -580,7 +580,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M clock.advance(1000) blacklist.updateBlacklistForFetchFailure("hostA", exec = "1") - verify(allocationClientMock).killExecutors(Seq("1"), true, true) + verify(allocationClientMock).killExecutors(Seq("1"), false, false, true) verify(allocationClientMock, never).killExecutorsOnHost(any()) assert(blacklist.executorIdToBlacklistStatus.contains("1")) From 598446b74b61fee272d3aee3a2e9a3fc90a70d6a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 27 Feb 2018 11:33:10 -0800 Subject: [PATCH 401/774] [SPARK-23501][UI] Refactor AllStagesPage in order to avoid redundant code As suggested in #20651, the code is very redundant in `AllStagesPage` and modifying it is a copy-and-paste work. We should avoid such a pattern, which is error prone, and have a cleaner solution which avoids code redundancy. existing UTs Author: Marco Gaido Closes #20663 from mgaido91/SPARK-23475_followup. --- .../apache/spark/ui/jobs/AllStagesPage.scala | 261 +++++++----------- 1 file changed, 102 insertions(+), 159 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index 38450b9126ff0..4658aa1cea3f1 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -19,46 +19,20 @@ package org.apache.spark.ui.jobs import javax.servlet.http.HttpServletRequest -import scala.xml.{Node, NodeSeq} +import scala.xml.{Attribute, Elem, Node, NodeSeq, Null, Text} import org.apache.spark.scheduler.Schedulable -import org.apache.spark.status.PoolData -import org.apache.spark.status.api.v1._ +import org.apache.spark.status.{AppSummary, PoolData} +import org.apache.spark.status.api.v1.{StageData, StageStatus} import org.apache.spark.ui.{UIUtils, WebUIPage} /** Page showing list of all ongoing and recently finished stages and pools */ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { private val sc = parent.sc + private val subPath = "stages" private def isFairScheduler = parent.isFairScheduler def render(request: HttpServletRequest): Seq[Node] = { - val allStages = parent.store.stageList(null) - - val activeStages = allStages.filter(_.status == StageStatus.ACTIVE) - val pendingStages = allStages.filter(_.status == StageStatus.PENDING) - val skippedStages = allStages.filter(_.status == StageStatus.SKIPPED) - val completedStages = allStages.filter(_.status == StageStatus.COMPLETE) - val failedStages = allStages.filter(_.status == StageStatus.FAILED).reverse - - val numFailedStages = failedStages.size - val subPath = "stages" - - val activeStagesTable = - new StageTableBase(parent.store, request, activeStages, "active", "activeStage", - parent.basePath, subPath, parent.isFairScheduler, parent.killEnabled, false) - val pendingStagesTable = - new StageTableBase(parent.store, request, pendingStages, "pending", "pendingStage", - parent.basePath, subPath, parent.isFairScheduler, false, false) - val completedStagesTable = - new StageTableBase(parent.store, request, completedStages, "completed", "completedStage", - parent.basePath, subPath, parent.isFairScheduler, false, false) - val skippedStagesTable = - new StageTableBase(parent.store, request, skippedStages, "skipped", "skippedStage", - parent.basePath, subPath, parent.isFairScheduler, false, false) - val failedStagesTable = - new StageTableBase(parent.store, request, failedStages, "failed", "failedStage", - parent.basePath, subPath, parent.isFairScheduler, false, true) - // For now, pool information is only accessible in live UIs val pools = sc.map(_.getAllPools).getOrElse(Seq.empty[Schedulable]).map { pool => val uiPool = parent.store.asOption(parent.store.pool(pool.name)).getOrElse( @@ -67,152 +41,121 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { }.toMap val poolTable = new PoolTable(pools, parent) - val shouldShowActiveStages = activeStages.nonEmpty - val shouldShowPendingStages = pendingStages.nonEmpty - val shouldShowCompletedStages = completedStages.nonEmpty - val shouldShowSkippedStages = skippedStages.nonEmpty - val shouldShowFailedStages = failedStages.nonEmpty + val allStatuses = Seq(StageStatus.ACTIVE, StageStatus.PENDING, StageStatus.COMPLETE, + StageStatus.SKIPPED, StageStatus.FAILED) + val allStages = parent.store.stageList(null) val appSummary = parent.store.appSummary() - val completedStageNumStr = if (appSummary.numCompletedStages == completedStages.size) { - s"${appSummary.numCompletedStages}" - } else { - s"${appSummary.numCompletedStages}, only showing ${completedStages.size}" - } + + val (summaries, tables) = allStatuses.map( + summaryAndTableForStatus(allStages, appSummary, _, request)).unzip val summary: NodeSeq =
      - { - if (shouldShowActiveStages) { -
    • - Active Stages: - {activeStages.size} -
    • - } - } - { - if (shouldShowPendingStages) { -
    • - Pending Stages: - {pendingStages.size} -
    • - } - } - { - if (shouldShowCompletedStages) { -
    • - Completed Stages: - {completedStageNumStr} -
    • - } - } - { - if (shouldShowSkippedStages) { -
    • - Skipped Stages: - {skippedStages.size} -
    • - } - } - { - if (shouldShowFailedStages) { -
    • - Failed Stages: - {numFailedStages} -
    • - } - } + {summaries.flatten}
    - var content = summary ++ - { - if (sc.isDefined && isFairScheduler) { - -

    - - Fair Scheduler Pools ({pools.size}) -

    -
    ++ -
    - {poolTable.toNodeSeq} -
    - } else { - Seq.empty[Node] - } - } - if (shouldShowActiveStages) { - content ++= - -

    - - Active Stages ({activeStages.size}) -

    -
    ++ -
    - {activeStagesTable.toNodeSeq} -
    - } - if (shouldShowPendingStages) { - content ++= - + val poolsDescription = if (sc.isDefined && isFairScheduler) { +

    - Pending Stages ({pendingStages.size}) + Fair Scheduler Pools ({pools.size})

    ++ -
    - {pendingStagesTable.toNodeSeq} +
    + {poolTable.toNodeSeq}
    + } else { + Seq.empty[Node] + } + + val content = summary ++ poolsDescription ++ tables.flatten.flatten + + UIUtils.headerSparkPage("Stages for All Jobs", content, parent) + } + + private def summaryAndTableForStatus( + allStages: Seq[StageData], + appSummary: AppSummary, + status: StageStatus, + request: HttpServletRequest): (Option[Elem], Option[NodeSeq]) = { + val stages = if (status == StageStatus.FAILED) { + allStages.filter(_.status == status).reverse + } else { + allStages.filter(_.status == status) } - if (shouldShowCompletedStages) { - content ++= - -

    - - Completed Stages ({completedStageNumStr}) -

    -
    ++ -
    - {completedStagesTable.toNodeSeq} -
    + + if (stages.isEmpty) { + (None, None) + } else { + val killEnabled = status == StageStatus.ACTIVE && parent.killEnabled + val isFailedStage = status == StageStatus.FAILED + + val stagesTable = + new StageTableBase(parent.store, request, stages, statusName(status), stageTag(status), + parent.basePath, subPath, parent.isFairScheduler, killEnabled, isFailedStage) + val stagesSize = stages.size + (Some(summary(appSummary, status, stagesSize)), + Some(table(appSummary, status, stagesTable, stagesSize))) } - if (shouldShowSkippedStages) { - content ++= - -

    - - Skipped Stages ({skippedStages.size}) -

    -
    ++ -
    - {skippedStagesTable.toNodeSeq} -
    + } + + private def statusName(status: StageStatus): String = status match { + case StageStatus.ACTIVE => "active" + case StageStatus.COMPLETE => "completed" + case StageStatus.FAILED => "failed" + case StageStatus.PENDING => "pending" + case StageStatus.SKIPPED => "skipped" + } + + private def stageTag(status: StageStatus): String = s"${statusName(status)}Stage" + + private def headerDescription(status: StageStatus): String = statusName(status).capitalize + + private def summaryContent(appSummary: AppSummary, status: StageStatus, size: Int): String = { + if (status == StageStatus.COMPLETE && appSummary.numCompletedStages != size) { + s"${appSummary.numCompletedStages}, only showing $size" + } else { + s"$size" } - if (shouldShowFailedStages) { - content ++= - -

    - - Failed Stages ({numFailedStages}) -

    -
    ++ -
    - {failedStagesTable.toNodeSeq} -
    + } + + private def summary(appSummary: AppSummary, status: StageStatus, size: Int): Elem = { + val summary = +
  • + + {headerDescription(status)} Stages: + + {summaryContent(appSummary, status, size)} +
  • + + if (status == StageStatus.COMPLETE) { + summary % Attribute(None, "id", Text("completed-summary"), Null) + } else { + summary } - UIUtils.headerSparkPage("Stages for All Jobs", content, parent) + } + + private def table( + appSummary: AppSummary, + status: StageStatus, + stagesTable: StageTableBase, + size: Int): NodeSeq = { + val classSuffix = s"${statusName(status).capitalize}Stages" + +

    + + {headerDescription(status)} Stages ({summaryContent(appSummary, status, size)}) +

    +
    ++ +
    + {stagesTable.toNodeSeq} +
    } } From 23ac3aaba4a33bc3d31d01f21e93c4681ef6de03 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Wed, 28 Feb 2018 09:25:02 +0900 Subject: [PATCH 402/774] [SPARK-23417][PYTHON] Fix the build instructions supplied by exception messages in python streaming tests ## What changes were proposed in this pull request? Fix the build instructions supplied by exception messages in python streaming tests. I also added -DskipTests to the maven instructions to avoid the 170 minutes of scala tests that occurs each time one wants to add a jar to the assembly directory. ## How was this patch tested? - clone branch - run build/sbt package - run python/run-tests --modules "pyspark-streaming" , expect error message - follow instructions in error message. i.e., run build/sbt assembly/package streaming-kafka-0-8-assembly/assembly - rerun python tests, expect error message - follow instructions in error message. i.e run build/sbt -Pflume assembly/package streaming-flume-assembly/assembly - rerun python tests, see success. - repeated all of the above for mvn version of the process. Author: Bruce Robbins Closes #20638 from bersprockets/SPARK-23417_propa. --- python/pyspark/streaming/tests.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 5b86c1cb2c390..71f8101e34c50 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1477,8 +1477,8 @@ def search_kafka_assembly_jar(): raise Exception( ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + "You need to build Spark with " - "'build/sbt assembly/package streaming-kafka-0-8-assembly/assembly' or " - "'build/mvn -Pkafka-0-8 package' before running this test.") + "'build/sbt -Pkafka-0-8 assembly/package streaming-kafka-0-8-assembly/assembly' or " + "'build/mvn -DskipTests -Pkafka-0-8 package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kafka assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) @@ -1494,8 +1494,8 @@ def search_flume_assembly_jar(): raise Exception( ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + "You need to build Spark with " - "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " - "'build/mvn -Pflume package' before running this test.") + "'build/sbt -Pflume assembly/package streaming-flume-assembly/assembly' or " + "'build/mvn -DskipTests -Pflume package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Flume assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) From b14993e1fcb68e1c946a671c6048605ab4afdf58 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 28 Feb 2018 11:00:54 +0900 Subject: [PATCH 403/774] [SPARK-23448][SQL] Clarify JSON and CSV parser behavior in document ## What changes were proposed in this pull request? Clarify JSON and CSV reader behavior in document. JSON doesn't support partial results for corrupted records. CSV only supports partial results for the records with more or less tokens. ## How was this patch tested? Pass existing tests. Author: Liang-Chi Hsieh Closes #20666 from viirya/SPARK-23448-2. --- python/pyspark/sql/readwriter.py | 30 ++++++++++--------- python/pyspark/sql/streaming.py | 30 ++++++++++--------- .../sql/catalyst/json/JacksonParser.scala | 3 ++ .../apache/spark/sql/DataFrameReader.scala | 22 +++++++------- .../datasources/csv/UnivocityParser.scala | 5 ++++ .../sql/streaming/DataStreamReader.scala | 22 +++++++------- 6 files changed, 64 insertions(+), 48 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 49af1bcee5ef8..9d05ac7cb39be 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -209,13 +209,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record, and puts the malformed string into a field configured by \ - ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ - a string type field named ``columnNameOfCorruptRecord`` in an user-defined \ - schema. If a schema does not have the field, it drops corrupt records during \ - parsing. When inferring a schema, it implicitly adds a \ - ``columnNameOfCorruptRecord`` field in an output schema. + * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + fields to ``null``. To keep corrupt records, an user can set a string type \ + field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ + schema does not have the field, it drops corrupt records during parsing. \ + When inferring a schema, it implicitly adds a ``columnNameOfCorruptRecord`` \ + field in an output schema. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. @@ -393,13 +393,15 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record, and puts the malformed string into a field configured by \ - ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ - a string type field named ``columnNameOfCorruptRecord`` in an \ - user-defined schema. If a schema does not have the field, it drops corrupt \ - records during parsing. When a length of parsed CSV tokens is shorter than \ - an expected length of a schema, it sets `null` for extra fields. + * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + fields to ``null``. To keep corrupt records, an user can set a string type \ + field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ + schema does not have the field, it drops corrupt records during parsing. \ + A record with less/more tokens than schema is not a corrupted record to CSV. \ + When it meets a record having fewer tokens than the length of the schema, \ + sets ``null`` to extra fields. When the record has more tokens than the \ + length of the schema, it drops extra tokens. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index e2a97acb5e2a7..cc622decfd682 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -442,13 +442,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record, and puts the malformed string into a field configured by \ - ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ - a string type field named ``columnNameOfCorruptRecord`` in an user-defined \ - schema. If a schema does not have the field, it drops corrupt records during \ - parsing. When inferring a schema, it implicitly adds a \ - ``columnNameOfCorruptRecord`` field in an output schema. + * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + fields to ``null``. To keep corrupt records, an user can set a string type \ + field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ + schema does not have the field, it drops corrupt records during parsing. \ + When inferring a schema, it implicitly adds a ``columnNameOfCorruptRecord`` \ + field in an output schema. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. @@ -621,13 +621,15 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record, and puts the malformed string into a field configured by \ - ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ - a string type field named ``columnNameOfCorruptRecord`` in an \ - user-defined schema. If a schema does not have the field, it drops corrupt \ - records during parsing. When a length of parsed CSV tokens is shorter than \ - an expected length of a schema, it sets `null` for extra fields. + * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + fields to ``null``. To keep corrupt records, an user can set a string type \ + field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ + schema does not have the field, it drops corrupt records during parsing. \ + A record with less/more tokens than schema is not a corrupted record to CSV. \ + When it meets a record having fewer tokens than the length of the schema, \ + sets ``null`` to extra fields. When the record has more tokens than the \ + length of the schema, it drops extra tokens. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index bd144c9575c72..7f6956994f31f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -357,6 +357,9 @@ class JacksonParser( } } catch { case e @ (_: RuntimeException | _: JsonProcessingException) => + // JSON parser currently doesn't support partial results for corrupted records. + // For such records, all fields other than the field configured by + // `columnNameOfCorruptRecord` are set to `null`. throw BadRecordException(() => recordLiteral(record), () => None, e) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 4274f120a375a..0139913aaa4e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -345,12 +345,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. *
      - *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep - * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` - * in an user-defined schema. If a schema does not have the field, it drops corrupt records - * during parsing. When inferring a schema, it implicitly adds a `columnNameOfCorruptRecord` - * field in an output schema.
    • + *
    • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a + * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To + * keep corrupt records, an user can set a string type field named + * `columnNameOfCorruptRecord` in an user-defined schema. If a schema does not have the + * field, it drops corrupt records during parsing. When inferring a schema, it implicitly + * adds a `columnNameOfCorruptRecord` field in an output schema.
    • *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • *
    @@ -550,12 +550,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. It supports the following case-insensitive modes. *
      - *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep + *
    • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a + * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To keep * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` * in an user-defined schema. If a schema does not have the field, it drops corrupt records - * during parsing. When a length of parsed CSV tokens is shorter than an expected length - * of a schema, it sets `null` for extra fields.
    • + * during parsing. A record with less/more tokens than schema is not a corrupted record to + * CSV. When it meets a record having fewer tokens than the length of the schema, sets + * `null` to extra fields. When the record has more tokens than the length of the schema, + * it drops extra tokens. *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • *
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 7d6d7e7eef926..3d6cc30f2ba83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -203,6 +203,8 @@ class UnivocityParser( case _: BadRecordException => None } } + // For records with less or more tokens than the schema, tries to return partial results + // if possible. throw BadRecordException( () => getCurrentInput, () => getPartialResult(), @@ -218,6 +220,9 @@ class UnivocityParser( row } catch { case NonFatal(e) => + // For corrupted records with the number of tokens same as the schema, + // CSV reader doesn't support partial results. All fields other than the field + // configured by `columnNameOfCorruptRecord` are set to `null`. throw BadRecordException(() => getCurrentInput, () => None, e) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index f23851655350a..61e22fac854f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -236,12 +236,12 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. *
      - *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep - * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` - * in an user-defined schema. If a schema does not have the field, it drops corrupt records - * during parsing. When inferring a schema, it implicitly adds a `columnNameOfCorruptRecord` - * field in an output schema.
    • + *
    • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a + * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To + * keep corrupt records, an user can set a string type field named + * `columnNameOfCorruptRecord` in an user-defined schema. If a schema does not have the + * field, it drops corrupt records during parsing. When inferring a schema, it implicitly + * adds a `columnNameOfCorruptRecord` field in an output schema.
    • *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • *
    @@ -316,12 +316,14 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. It supports the following case-insensitive modes. *
      - *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep + *
    • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a + * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To keep * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` * in an user-defined schema. If a schema does not have the field, it drops corrupt records - * during parsing. When a length of parsed CSV tokens is shorter than an expected length - * of a schema, it sets `null` for extra fields.
    • + * during parsing. A record with less/more tokens than schema is not a corrupted record to + * CSV. When it meets a record having fewer tokens than the length of the schema, sets + * `null` to extra fields. When the record has more tokens than the length of the schema, + * it drops extra tokens. *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • *
    From 6a8abe29ef3369b387d9bc2ee3459a6611246ab1 Mon Sep 17 00:00:00 2001 From: zhoukang Date: Wed, 28 Feb 2018 23:16:29 +0800 Subject: [PATCH 404/774] [SPARK-23508][CORE] Fix BlockmanagerId in case blockManagerIdCache cause oom MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … cause oom ## What changes were proposed in this pull request? blockManagerIdCache in BlockManagerId will not remove old values which may cause oom `val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]()` Since whenever we apply a new BlockManagerId, it will put into this map. This patch will use guava cahce for blockManagerIdCache instead. A heap dump show in [SPARK-23508](https://issues.apache.org/jira/browse/SPARK-23508) ## How was this patch tested? Exist tests. Author: zhoukang Closes #20667 from caneGuy/zhoukang/fix-history. --- .../org/apache/spark/storage/BlockManagerId.scala | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index 2c3da0ee85e06..d4a59c33b974c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -18,7 +18,8 @@ package org.apache.spark.storage import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} -import java.util.concurrent.ConcurrentHashMap + +import com.google.common.cache.{CacheBuilder, CacheLoader} import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi @@ -132,10 +133,17 @@ private[spark] object BlockManagerId { getCachedBlockManagerId(obj) } - val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() + /** + * The max cache size is hardcoded to 10000, since the size of a BlockManagerId + * object is about 48B, the total memory cost should be below 1MB which is feasible. + */ + val blockManagerIdCache = CacheBuilder.newBuilder() + .maximumSize(10000) + .build(new CacheLoader[BlockManagerId, BlockManagerId]() { + override def load(id: BlockManagerId) = id + }) def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { - blockManagerIdCache.putIfAbsent(id, id) blockManagerIdCache.get(id) } } From fab563b9bd1581112462c0fc0b299ad6510b6564 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 1 Mar 2018 00:44:13 +0900 Subject: [PATCH 405/774] [SPARK-23517][PYTHON] Make `pyspark.util._exception_message` produce the trace from Java side by Py4JJavaError ## What changes were proposed in this pull request? This PR proposes for `pyspark.util._exception_message` to produce the trace from Java side by `Py4JJavaError`. Currently, in Python 2, it uses `message` attribute which `Py4JJavaError` didn't happen to have: ```python >>> from pyspark.util import _exception_message >>> try: ... sc._jvm.java.lang.String(None) ... except Exception as e: ... pass ... >>> e.message '' ``` Seems we should use `str` instead for now: https://github.com/bartdag/py4j/blob/aa6c53b59027925a426eb09b58c453de02c21b7c/py4j-python/src/py4j/protocol.py#L412 but this doesn't address the problem with non-ascii string from Java side - `https://github.com/bartdag/py4j/issues/306` So, we could directly call `__str__()`: ```python >>> e.__str__() u'An error occurred while calling None.java.lang.String.\n: java.lang.NullPointerException\n\tat java.lang.String.(String.java:588)\n\tat sun.reflect.NativeConstructorAccessorImpl.newInstance0(Native Method)\n\tat sun.reflect.NativeConstructorAccessorImpl.newInstance(NativeConstructorAccessorImpl.java:62)\n\tat sun.reflect.DelegatingConstructorAccessorImpl.newInstance(DelegatingConstructorAccessorImpl.java:45)\n\tat java.lang.reflect.Constructor.newInstance(Constructor.java:422)\n\tat py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:247)\n\tat py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\n\tat py4j.Gateway.invoke(Gateway.java:238)\n\tat py4j.commands.ConstructorCommand.invokeConstructor(ConstructorCommand.java:80)\n\tat py4j.commands.ConstructorCommand.execute(ConstructorCommand.java:69)\n\tat py4j.GatewayConnection.run(GatewayConnection.java:214)\n\tat java.lang.Thread.run(Thread.java:745)\n' ``` which doesn't type coerce unicodes to `str` in Python 2. This can be actually a problem: ```python from pyspark.sql.functions import udf spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.range(1).select(udf(lambda x: [[]])()).toPandas() ``` **Before** ``` Traceback (most recent call last): File "", line 1, in File "/.../spark/python/pyspark/sql/dataframe.py", line 2009, in toPandas raise RuntimeError("%s\n%s" % (_exception_message(e), msg)) RuntimeError: Note: toPandas attempted Arrow optimization because 'spark.sql.execution.arrow.enabled' is set to true. Please set it to false to disable this. ``` **After** ``` Traceback (most recent call last): File "", line 1, in File "/.../spark/python/pyspark/sql/dataframe.py", line 2009, in toPandas raise RuntimeError("%s\n%s" % (_exception_message(e), msg)) RuntimeError: An error occurred while calling o47.collectAsArrowToPython. : org.apache.spark.SparkException: Job aborted due to stage failure: Task 7 in stage 0.0 failed 1 times, most recent failure: Lost task 7.0 in stage 0.0 (TID 7, localhost, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last): File "/.../spark/python/pyspark/worker.py", line 245, in main process() File "/.../spark/python/pyspark/worker.py", line 240, in process ... Note: toPandas attempted Arrow optimization because 'spark.sql.execution.arrow.enabled' is set to true. Please set it to false to disable this. ``` ## How was this patch tested? Manually tested and unit tests were added. Author: hyukjinkwon Closes #20680 from HyukjinKwon/SPARK-23517. --- python/pyspark/tests.py | 11 +++++++++++ python/pyspark/util.py | 7 +++++++ 2 files changed, 18 insertions(+) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 511585763cb01..9111dbbed5929 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2293,6 +2293,17 @@ def set(self, x=None, other=None, other_x=None): self.assertEqual(b._x, 2) +class UtilTests(PySparkTestCase): + def test_py4j_exception_message(self): + from pyspark.util import _exception_message + + with self.assertRaises(Py4JJavaError) as context: + # This attempts java.lang.String(null) which throws an NPE. + self.sc._jvm.java.lang.String(None) + + self.assertTrue('NullPointerException' in _exception_message(context.exception)) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index e5d332ce54429..ad4a0bc68ef41 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from py4j.protocol import Py4JJavaError __all__ = [] @@ -33,6 +34,12 @@ def _exception_message(excp): >>> msg == _exception_message(excp) True """ + if isinstance(excp, Py4JJavaError): + # 'Py4JJavaError' doesn't contain the stack trace available on the Java side in 'message' + # attribute in Python 2. We should call 'str' function on this exception in general but + # 'Py4JJavaError' has an issue about addressing non-ascii strings. So, here we work + # around by the direct call, '__str__()'. Please see SPARK-23517. + return excp.__str__() if hasattr(excp, "message"): return excp.message return str(excp) From 476a7f026bc45462067ebd39cd269147e84cd641 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Wed, 28 Feb 2018 08:44:53 -0800 Subject: [PATCH 406/774] [SPARK-23514] Use SessionState.newHadoopConf() to propage hadoop configs set in SQLConf. ## What changes were proposed in this pull request? A few places in `spark-sql` were using `sc.hadoopConfiguration` directly. They should be using `sessionState.newHadoopConf()` to blend in configs that were set through `SQLConf`. Also, for better UX, for these configs blended in from `SQLConf`, we should consider removing the `spark.hadoop` prefix, so that the settings are recognized whether or not they were specified by the user. ## How was this patch tested? Tested that AlterTableRecoverPartitions now correctly recognizes settings that are passed in to the FileSystem through SQLConf. Author: Juliusz Sompolski Closes #20679 from juliuszsompolski/SPARK-23514. --- .../scala/org/apache/spark/sql/execution/command/ddl.scala | 6 +++--- .../scala/org/apache/spark/sql/hive/test/TestHive.scala | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 0142f17ce62e2..964cbca049b27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -610,10 +610,10 @@ case class AlterTableRecoverPartitionsCommand( val root = new Path(table.location) logInfo(s"Recover all the partitions in $root") - val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) + val hadoopConf = spark.sessionState.newHadoopConf() + val fs = root.getFileSystem(hadoopConf) val threshold = spark.conf.get("spark.rdd.parallelListingThreshold", "10").toInt - val hadoopConf = spark.sparkContext.hadoopConfiguration val pathFilter = getPathFilter(hadoopConf) val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8) @@ -697,7 +697,7 @@ case class AlterTableRecoverPartitionsCommand( pathFilter: PathFilter, threshold: Int): GenMap[String, PartitionStatistics] = { if (partitionSpecsAndLocs.length > threshold) { - val hadoopConf = spark.sparkContext.hadoopConfiguration + val hadoopConf = spark.sessionState.newHadoopConf() val serializableConfiguration = new SerializableConfiguration(hadoopConf) val serializedPaths = partitionSpecsAndLocs.map(_._2.toString).toArray diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 19028939f3673..fcf2025d34432 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -518,8 +518,9 @@ private[hive] class TestHiveSparkSession( // an HDFS scratch dir: ${hive.exec.scratchdir}/ is created, with // ${hive.scratch.dir.permission}. To resolve the permission issue, the simplest way is to // delete it. Later, it will be re-created with the right permission. - val location = new Path(sc.hadoopConfiguration.get(ConfVars.SCRATCHDIR.varname)) - val fs = location.getFileSystem(sc.hadoopConfiguration) + val hadoopConf = sessionState.newHadoopConf() + val location = new Path(hadoopConf.get(ConfVars.SCRATCHDIR.varname)) + val fs = location.getFileSystem(hadoopConf) fs.delete(location, true) // Some tests corrupt this value on purpose, which breaks the RESET call below. From 25c2776dd9ae3f9792048c78be2cbd958fd99841 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Wed, 28 Feb 2018 12:16:26 -0800 Subject: [PATCH 407/774] [SPARK-23523][SQL][FOLLOWUP] Minor refactor of OptimizeMetadataOnlyQuery ## What changes were proposed in this pull request? Inside `OptimizeMetadataOnlyQuery.getPartitionAttrs`, avoid using `zip` to generate attribute map. Also include other minor update of comments and format. ## How was this patch tested? Existing test cases. Author: Xingbo Jiang Closes #20693 from jiangxb1987/SPARK-23523. --- .../spark/sql/execution/OptimizeMetadataOnlyQuery.scala | 2 +- .../spark/sql/execution/datasources/HadoopFsRelation.scala | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index 0613d9053f826..dc4aff9f12580 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -83,7 +83,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic private def getPartitionAttrs( partitionColumnNames: Seq[String], relation: LogicalPlan): Seq[Attribute] = { - val attrMap = relation.output.map(_.name.toLowerCase(Locale.ROOT)).zip(relation.output).toMap + val attrMap = relation.output.map(a => a.name.toLowerCase(Locale.ROOT) -> a).toMap partitionColumnNames.map { colName => attrMap.getOrElse(colName.toLowerCase(Locale.ROOT), throw new AnalysisException(s"Unable to find the column `$colName` " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index ac574b07ec497..b2f73b7f8d1fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -67,9 +67,9 @@ case class HadoopFsRelation( } } - // When data schema and partition schema have the overlapped columns, the output - // schema respects the order of data schema for the overlapped columns, but respect - // the data types of partition schema + // When data and partition schemas have overlapping columns, the output + // schema respects the order of the data schema for the overlapping columns, and it + // respects the data types of the partition schema. val schema: StructType = { StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++ partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f)))) From 22f3d3334c85c042c6e90f5a02f308d7cd1c1498 Mon Sep 17 00:00:00 2001 From: liuxian Date: Thu, 1 Mar 2018 14:28:28 +0800 Subject: [PATCH 408/774] [SPARK-23389][CORE] When the shuffle dependency specifies aggregation ,and `dependency.mapSideCombine =false`, we should be able to use serialized sorting. ## What changes were proposed in this pull request? When the shuffle dependency specifies aggregation ,and `dependency.mapSideCombine=false`, in the map side,there is no need for aggregation and sorting, so we should be able to use serialized sorting. ## How was this patch tested? Existing unit test Author: liuxian Closes #20576 from 10110346/mapsidecombine. --- .../scala/org/apache/spark/Dependency.scala | 3 +++ .../spark/shuffle/BlockStoreShuffleReader.scala | 1 - .../spark/shuffle/sort/SortShuffleManager.scala | 6 +++--- .../spark/shuffle/sort/SortShuffleWriter.scala | 2 -- .../shuffle/sort/SortShuffleManagerSuite.scala | 17 +++++++++-------- 5 files changed, 15 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index ca52ecafa2cc8..9ea6d2fa2fd95 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -76,6 +76,9 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val mapSideCombine: Boolean = false) extends Dependency[Product2[K, V]] { + if (mapSideCombine) { + require(aggregator.isDefined, "Map-side combine without Aggregator specified!") + } override def rdd: RDD[Product2[K, V]] = _rdd.asInstanceOf[RDD[Product2[K, V]]] private[spark] val keyClassName: String = reflect.classTag[K].runtimeClass.getName diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 0562d45ff57c5..edd69715c9602 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -90,7 +90,6 @@ private[spark] class BlockStoreShuffleReader[K, C]( dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else { - require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index bfb4dc698e325..d9fad64f34c7c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -188,9 +188,9 @@ private[spark] object SortShuffleManager extends Logging { log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " + s"${dependency.serializer.getClass.getName}, does not support object relocation") false - } else if (dependency.aggregator.isDefined) { - log.debug( - s"Can't use serialized shuffle for shuffle $shufId because an aggregator is defined") + } else if (dependency.mapSideCombine) { + log.debug(s"Can't use serialized shuffle for shuffle $shufId because we need to do " + + s"map-side aggregation") false } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) { log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " + diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 636b88e792bf3..274399b9cc1f3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -50,7 +50,6 @@ private[spark] class SortShuffleWriter[K, V, C]( /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { sorter = if (dep.mapSideCombine) { - require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") new ExternalSorter[K, V, C]( context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) } else { @@ -107,7 +106,6 @@ private[spark] object SortShuffleWriter { def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { // We cannot bypass sorting if we need to do map-side aggregation. if (dep.mapSideCombine) { - require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") false } else { val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala index 55cebe7c8b6a8..f29dac965c803 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala @@ -85,6 +85,14 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { mapSideCombine = false ))) + // We support serialized shuffle if we do not need to do map-side aggregation + assert(canUseSerializedShuffle(shuffleDep( + partitioner = new HashPartitioner(2), + serializer = kryo, + keyOrdering = None, + aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), + mapSideCombine = false + ))) } test("unsupported shuffle dependencies for serialized shuffle") { @@ -111,14 +119,7 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { mapSideCombine = false ))) - // We do not support shuffles that perform aggregation - assert(!canUseSerializedShuffle(shuffleDep( - partitioner = new HashPartitioner(2), - serializer = kryo, - keyOrdering = None, - aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), - mapSideCombine = false - ))) + // We do not support serialized shuffle if we need to do map-side aggregation assert(!canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, From ff1480189b827af0be38605d566a4ee71b4c36f6 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 1 Mar 2018 16:26:11 +0800 Subject: [PATCH 409/774] [SPARK-23510][SQL] Support Hive 2.2 and Hive 2.3 metastore ## What changes were proposed in this pull request? This is based on https://github.com/apache/spark/pull/20668 for supporting Hive 2.2 and Hive 2.3 metastore. When we merge the PR, we should give the major credit to wangyum ## How was this patch tested? Added the test cases Author: Yuming Wang Author: gatorsmile Closes #20671 from gatorsmile/pr-20668. --- .../org/apache/spark/sql/hive/HiveUtils.scala | 2 +- .../sql/hive/client/HiveClientImpl.scala | 3 +- .../spark/sql/hive/client/HiveShim.scala | 8 +-- .../hive/client/IsolatedClientLoader.scala | 2 + .../spark/sql/hive/client/package.scala | 10 +++- .../sql/hive/execution/SaveAsHiveFile.scala | 3 +- .../sql/hive/client/HiveClientVersions.scala | 3 +- .../sql/hive/client/HiveVersionSuite.scala | 2 +- .../spark/sql/hive/client/VersionsSuite.scala | 51 +++++++++++++++++-- 9 files changed, 72 insertions(+), 12 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index c448c5a9821be..10c9603745379 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -62,7 +62,7 @@ private[spark] object HiveUtils extends Logging { val HIVE_METASTORE_VERSION = buildConf("spark.sql.hive.metastore.version") .doc("Version of the Hive metastore. Available options are " + - s"0.12.0 through 2.1.1.") + s"0.12.0 through 2.3.2.") .stringConf .createWithDefault(builtinHiveVersion) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 146fa54a1bce4..da9fe2d3088b4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -25,7 +25,6 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.conf.HiveConf @@ -104,6 +103,8 @@ private[hive] class HiveClientImpl( case hive.v1_2 => new Shim_v1_2() case hive.v2_0 => new Shim_v2_0() case hive.v2_1 => new Shim_v2_1() + case hive.v2_2 => new Shim_v2_2() + case hive.v2_3 => new Shim_v2_3() } // Create an internal session state for this HiveClientImpl. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 1eac70dbf19cd..948ba542b5733 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -880,9 +880,7 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { } -private[client] class Shim_v1_0 extends Shim_v0_14 { - -} +private[client] class Shim_v1_0 extends Shim_v0_14 private[client] class Shim_v1_1 extends Shim_v1_0 { @@ -1146,3 +1144,7 @@ private[client] class Shim_v2_1 extends Shim_v2_0 { alterPartitionsMethod.invoke(hive, tableName, newParts, environmentContextInAlterTable) } } + +private[client] class Shim_v2_2 extends Shim_v2_1 + +private[client] class Shim_v2_3 extends Shim_v2_1 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index dac0e333b63bc..12975bc85b971 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -97,6 +97,8 @@ private[hive] object IsolatedClientLoader extends Logging { case "1.2" | "1.2.0" | "1.2.1" | "1.2.2" => hive.v1_2 case "2.0" | "2.0.0" | "2.0.1" => hive.v2_0 case "2.1" | "2.1.0" | "2.1.1" => hive.v2_1 + case "2.2" | "2.2.0" => hive.v2_2 + case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" => hive.v2_3 } private def downloadVersion( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index c14154a3b3c21..681ee9200f02b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -71,7 +71,15 @@ package object client { exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) - val allSupportedHiveVersions = Set(v12, v13, v14, v1_0, v1_1, v1_2, v2_0, v2_1) + case object v2_2 extends HiveVersion("2.2.0", + exclusions = Seq("org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm")) + + case object v2_3 extends HiveVersion("2.3.2", + exclusions = Seq("org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm")) + + val allSupportedHiveVersions = Set(v12, v13, v14, v1_0, v1_1, v1_2, v2_0, v2_1, v2_2, v2_3) } // scalastyle:on diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index e484356906e87..6a7b25b36d9a5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -114,7 +114,8 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { // staging directory under the table director for Hive prior to 1.1, the staging directory will // be removed by Hive when Hive is trying to empty the table directory. val hiveVersionsUsingOldExternalTempPath: Set[HiveVersion] = Set(v12, v13, v14, v1_0) - val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = Set(v1_1, v1_2, v2_0, v2_1) + val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = + Set(v1_1, v1_2, v2_0, v2_1, v2_2, v2_3) // Ensure all the supported versions are considered here. assert(hiveVersionsUsingNewExternalTempPath ++ hiveVersionsUsingOldExternalTempPath == diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala index 2e7dfde8b2fa5..30592a3f85428 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala @@ -22,5 +22,6 @@ import scala.collection.immutable.IndexedSeq import org.apache.spark.SparkFunSuite private[client] trait HiveClientVersions { - protected val versions = IndexedSeq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1") + protected val versions = + IndexedSeq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1", "2.2", "2.3") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala index a70fb6464cc1d..e5963d03f6b52 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala @@ -34,7 +34,7 @@ private[client] abstract class HiveVersionSuite(version: String) extends SparkFu // Hive changed the default of datanucleus.schema.autoCreateAll from true to false and // hive.metastore.schema.verification from false to true since 2.0 // For details, see the JIRA HIVE-6113 and HIVE-12463 - if (version == "2.0" || version == "2.1") { + if (version == "2.0" || version == "2.1" || version == "2.2" || version == "2.3") { hadoopConf.set("datanucleus.schema.autoCreateAll", "true") hadoopConf.set("hive.metastore.schema.verification", "false") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 72536b833481a..6176273c88db1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -21,6 +21,7 @@ import java.io.{ByteArrayOutputStream, File, PrintStream, PrintWriter} import java.net.URI import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.mapred.TextInputFormat @@ -110,7 +111,8 @@ class VersionsSuite extends SparkFunSuite with Logging { assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") } - private val versions = Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1") + private val versions = + Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1", "2.2", "2.3") private var client: HiveClient = null @@ -125,7 +127,7 @@ class VersionsSuite extends SparkFunSuite with Logging { // Hive changed the default of datanucleus.schema.autoCreateAll from true to false and // hive.metastore.schema.verification from false to true since 2.0 // For details, see the JIRA HIVE-6113 and HIVE-12463 - if (version == "2.0" || version == "2.1") { + if (version == "2.0" || version == "2.1" || version == "2.2" || version == "2.3") { hadoopConf.set("datanucleus.schema.autoCreateAll", "true") hadoopConf.set("hive.metastore.schema.verification", "false") } @@ -422,15 +424,18 @@ class VersionsSuite extends SparkFunSuite with Logging { test(s"$version: alterPartitions") { val spec = Map("key1" -> "1", "key2" -> "2") + val parameters = Map(StatsSetupConst.TOTAL_SIZE -> "0", StatsSetupConst.NUM_FILES -> "1") val newLocation = new URI(Utils.createTempDir().toURI.toString.stripSuffix("/")) val storage = storageFormat.copy( locationUri = Some(newLocation), // needed for 0.12 alter partitions serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) - val partition = CatalogTablePartition(spec, storage) + val partition = CatalogTablePartition(spec, storage, parameters) client.alterPartitions("default", "src_part", Seq(partition)) assert(client.getPartition("default", "src_part", spec) .storage.locationUri == Some(newLocation)) + assert(client.getPartition("default", "src_part", spec) + .parameters.get(StatsSetupConst.TOTAL_SIZE) == Some("0")) } test(s"$version: dropPartitions") { @@ -633,6 +638,46 @@ class VersionsSuite extends SparkFunSuite with Logging { } } + test(s"$version: CREATE Partitioned TABLE AS SELECT") { + withTable("tbl") { + versionSpark.sql( + """ + |CREATE TABLE tbl(c1 string) + |PARTITIONED BY (ds STRING) + """.stripMargin) + versionSpark.sql("INSERT OVERWRITE TABLE tbl partition (ds='2') SELECT '1'") + + assert(versionSpark.table("tbl").collect().toSeq == Seq(Row("1", "2"))) + val partMeta = versionSpark.sessionState.catalog.getPartition( + TableIdentifier("tbl"), spec = Map("ds" -> "2")).parameters + val totalSize = partMeta.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) + val numFiles = partMeta.get(StatsSetupConst.NUM_FILES).map(_.toLong) + // Except 0.12, all the following versions will fill the Hive-generated statistics + if (version == "0.12") { + assert(totalSize.isEmpty && numFiles.isEmpty) + } else { + assert(totalSize.nonEmpty && numFiles.nonEmpty) + } + + versionSpark.sql( + """ + |ALTER TABLE tbl PARTITION (ds='2') + |SET SERDEPROPERTIES ('newKey' = 'vvv') + """.stripMargin) + val newPartMeta = versionSpark.sessionState.catalog.getPartition( + TableIdentifier("tbl"), spec = Map("ds" -> "2")).parameters + + val newTotalSize = newPartMeta.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) + val newNumFiles = newPartMeta.get(StatsSetupConst.NUM_FILES).map(_.toLong) + // Except 0.12, all the following versions will fill the Hive-generated statistics + if (version == "0.12") { + assert(newTotalSize.isEmpty && newNumFiles.isEmpty) + } else { + assert(newTotalSize.nonEmpty && newNumFiles.nonEmpty) + } + } + } + test(s"$version: Delete the temporary staging directory and files after each insert") { withTempDir { tmpDir => withTable("tab") { From cdcccd7b41c43d79edff2fec7a84cd00e9524f75 Mon Sep 17 00:00:00 2001 From: KaiXinXiaoLei <584620569@qq.com> Date: Fri, 2 Mar 2018 00:09:44 +0800 Subject: [PATCH 410/774] [SPARK-23405] Generate additional constraints for Join's children ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) I run a sql: `select ls.cs_order_number from ls left semi join catalog_sales cs on ls.cs_order_number = cs.cs_order_number`, The `ls` table is a small table ,and the number is one. The `catalog_sales` table is a big table, and the number is 10 billion. The task will be hang up. And i find the many null values of `cs_order_number` in the `catalog_sales` table. I think the null value should be removed in the logical plan. >== Optimized Logical Plan == >Join LeftSemi, (cs_order_number#1 = cs_order_number#22) >:- Project cs_order_number#1 > : +- Filter isnotnull(cs_order_number#1) > : +- MetastoreRelation 100t, ls >+- Project cs_order_number#22 > +- MetastoreRelation 100t, catalog_sales Now, use this patch, the plan will be: >== Optimized Logical Plan == >Join LeftSemi, (cs_order_number#1 = cs_order_number#22) >:- Project cs_order_number#1 > : +- Filter isnotnull(cs_order_number#1) > : +- MetastoreRelation 100t, ls >+- Project cs_order_number#22 > : **+- Filter isnotnull(cs_order_number#22)** > :+- MetastoreRelation 100t, catalog_sales ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: KaiXinXiaoLei <584620569@qq.com> Author: hanghang <584620569@qq.com> Closes #20670 from KaiXinXiaoLei/Spark-23405. --- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../plans/logical/QueryPlanConstraints.scala | 27 ++++++++++--------- .../InferFiltersFromConstraintsSuite.scala | 12 +++++++++ 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a28b6a0feb8f9..91208479be03b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -661,7 +661,7 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe case join @ Join(left, right, joinType, conditionOpt) => // Only consider constraints that can be pushed down completely to either the left or the // right child - val constraints = join.constraints.filter { c => + val constraints = join.allConstraints.filter { c => c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet) } // Remove those constraints that are already enforced by either the left or the right child diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 5c7b8e5b97883..046848875548b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -23,25 +23,28 @@ import org.apache.spark.sql.catalyst.expressions._ trait QueryPlanConstraints { self: LogicalPlan => /** - * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For - * example, if this set contains the expression `a = 2` then that expression is guaranteed to - * evaluate to `true` for all rows produced. + * An [[ExpressionSet]] that contains an additional set of constraints, such as equality + * constraints and `isNotNull` constraints, etc. */ - lazy val constraints: ExpressionSet = { + lazy val allConstraints: ExpressionSet = { if (conf.constraintPropagationEnabled) { - ExpressionSet( - validConstraints - .union(inferAdditionalConstraints(validConstraints)) - .union(constructIsNotNullConstraints(validConstraints)) - .filter { c => - c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic - } - ) + ExpressionSet(validConstraints + .union(inferAdditionalConstraints(validConstraints)) + .union(constructIsNotNullConstraints(validConstraints))) } else { ExpressionSet(Set.empty) } } + /** + * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For + * example, if this set contains the expression `a = 2` then that expression is guaranteed to + * evaluate to `true` for all rows produced. + */ + lazy val constraints: ExpressionSet = ExpressionSet(allConstraints.filter { c => + c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic + }) + /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints * based on the given operator's constraint propagation logic. These constraints are then diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 178c4b8c270a0..f78c2356e35a5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -192,4 +192,16 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(Optimize.execute(original.analyze), correct.analyze) } + + test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, LeftSemi, condition).analyze + val left = x.where(IsNotNull('a)) + val right = y.where(IsNotNull('a)) + val correctAnswer = left.join(right, LeftSemi, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } } From 34811e0b908449fd59bca476604612b1d200778d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 1 Mar 2018 17:26:39 -0800 Subject: [PATCH 411/774] [SPARK-23551][BUILD] Exclude `hadoop-mapreduce-client-core` dependency from `orc-mapreduce` ## What changes were proposed in this pull request? This PR aims to prevent `orc-mapreduce` dependency from making IDEs and maven confused. **BEFORE** Please note that `2.6.4` at `Spark Project SQL`. ``` $ mvn dependency:tree -Phadoop-2.7 -Dincludes=org.apache.hadoop:hadoop-mapreduce-client-core ... [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project Catalyst 2.4.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ [INFO] [INFO] --- maven-dependency-plugin:3.0.2:tree (default-cli) spark-catalyst_2.11 --- [INFO] org.apache.spark:spark-catalyst_2.11:jar:2.4.0-SNAPSHOT [INFO] \- org.apache.spark:spark-core_2.11:jar:2.4.0-SNAPSHOT:compile [INFO] \- org.apache.hadoop:hadoop-client:jar:2.7.3:compile [INFO] \- org.apache.hadoop:hadoop-mapreduce-client-core:jar:2.7.3:compile [INFO] [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project SQL 2.4.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ [INFO] [INFO] --- maven-dependency-plugin:3.0.2:tree (default-cli) spark-sql_2.11 --- [INFO] org.apache.spark:spark-sql_2.11:jar:2.4.0-SNAPSHOT [INFO] \- org.apache.orc:orc-mapreduce:jar:nohive:1.4.3:compile [INFO] \- org.apache.hadoop:hadoop-mapreduce-client-core:jar:2.6.4:compile ``` **AFTER** ``` $ mvn dependency:tree -Phadoop-2.7 -Dincludes=org.apache.hadoop:hadoop-mapreduce-client-core ... [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project Catalyst 2.4.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ [INFO] [INFO] --- maven-dependency-plugin:3.0.2:tree (default-cli) spark-catalyst_2.11 --- [INFO] org.apache.spark:spark-catalyst_2.11:jar:2.4.0-SNAPSHOT [INFO] \- org.apache.spark:spark-core_2.11:jar:2.4.0-SNAPSHOT:compile [INFO] \- org.apache.hadoop:hadoop-client:jar:2.7.3:compile [INFO] \- org.apache.hadoop:hadoop-mapreduce-client-core:jar:2.7.3:compile [INFO] [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project SQL 2.4.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ [INFO] [INFO] --- maven-dependency-plugin:3.0.2:tree (default-cli) spark-sql_2.11 --- [INFO] org.apache.spark:spark-sql_2.11:jar:2.4.0-SNAPSHOT [INFO] \- org.apache.spark:spark-core_2.11:jar:2.4.0-SNAPSHOT:compile [INFO] \- org.apache.hadoop:hadoop-client:jar:2.7.3:compile [INFO] \- org.apache.hadoop:hadoop-mapreduce-client-core:jar:2.7.3:compile ``` ## How was this patch tested? 1. Pass the Jenkins with `dev/test-dependencies.sh` with the existing dependencies. 2. Manually do the following and see the change. ``` mvn dependency:tree -Phadoop-2.7 -Dincludes=org.apache.hadoop:hadoop-mapreduce-client-core ``` Author: Dongjoon Hyun Closes #20704 from dongjoon-hyun/SPARK-23551. --- pom.xml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pom.xml b/pom.xml index b8396166f6b1b..0a711f287a53f 100644 --- a/pom.xml +++ b/pom.xml @@ -1753,6 +1753,10 @@ org.apache.hadoop hadoop-common + + org.apache.hadoop + hadoop-mapreduce-client-core + org.apache.orc orc-core From 119f6a0e4729aa952e811d2047790a32ee90bf69 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 1 Mar 2018 21:04:01 -0800 Subject: [PATCH 412/774] [SPARK-22883][ML][TEST] Streaming tests for spark.ml.feature, from A to H ## What changes were proposed in this pull request? Adds structured streaming tests using testTransformer for these suites: * BinarizerSuite * BucketedRandomProjectionLSHSuite * BucketizerSuite * ChiSqSelectorSuite * CountVectorizerSuite * DCTSuite.scala * ElementwiseProductSuite * FeatureHasherSuite * HashingTFSuite ## How was this patch tested? It tests itself because it is a bunch of tests! Author: Joseph K. Bradley Closes #20111 from jkbradley/SPARK-22883-streaming-featureAM. --- .../spark/ml/feature/BinarizerSuite.scala | 8 ++-- .../BucketedRandomProjectionLSHSuite.scala | 26 ++++++++--- .../spark/ml/feature/BucketizerSuite.scala | 11 +++-- .../spark/ml/feature/ChiSqSelectorSuite.scala | 36 +++++++-------- .../ml/feature/CountVectorizerSuite.scala | 23 +++++----- .../apache/spark/ml/feature/DCTSuite.scala | 14 +++--- .../ml/feature/ElementwiseProductSuite.scala | 30 ++++++++++--- .../spark/ml/feature/FeatureHasherSuite.scala | 45 +++++++++---------- .../spark/ml/feature/HashingTFSuite.scala | 34 ++++++++------ 9 files changed, 126 insertions(+), 101 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 4455d35210878..05d4a6ee2dabf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -17,14 +17,12 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.{DataFrame, Row} -class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class BinarizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -47,7 +45,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau .setInputCol("feature") .setOutputCol("binarized_feature") - binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach { + testTransformer[(Double, Double)](dataFrame, binarizer, "binarized_feature", "expected") { case Row(x: Double, y: Double) => assert(x === y, "The feature value is not correct after binarization.") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala index 7175c721bff36..ed9a39d8d1512 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala @@ -20,16 +20,15 @@ package org.apache.spark.ml.feature import breeze.numerics.{cos, sin} import breeze.numerics.constants.Pi -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Dataset +import org.apache.spark.sql.{Dataset, Row} -class BucketedRandomProjectionLSHSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class BucketedRandomProjectionLSHSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ @transient var dataset: Dataset[_] = _ @@ -98,6 +97,21 @@ class BucketedRandomProjectionLSHSuite MLTestingUtils.checkCopyAndUids(brp, brpModel) } + test("BucketedRandomProjectionLSH: streaming transform") { + val brp = new BucketedRandomProjectionLSH() + .setNumHashTables(2) + .setInputCol("keys") + .setOutputCol("values") + .setBucketLength(1.0) + .setSeed(12345) + val brpModel = brp.fit(dataset) + + testTransformer[Tuple1[Vector]](dataset.toDF(), brpModel, "values") { + case Row(values: Seq[_]) => + assert(values.length === brp.getNumHashTables) + } + } + test("BucketedRandomProjectionLSH: test of LSH property") { // Project from 2 dimensional Euclidean Space to 1 dimensions val brp = new BucketedRandomProjectionLSH() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 41cf72fe3470a..9ea15e1918532 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -23,14 +23,13 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class BucketizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -50,7 +49,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCol("result") .setSplits(splits) - bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { + testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") @@ -84,7 +83,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCol("result") .setSplits(splits) - bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { + testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") @@ -103,7 +102,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setSplits(splits) bucketizer.setHandleInvalid("keep") - bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { + testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index c83909c4498f2..c843df9f33e3e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} -class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class ChiSqSelectorSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ @transient var dataset: Dataset[_] = _ @@ -119,32 +118,32 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext test("Test Chi-Square selector: numTopFeatures") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(1) - val model = ChiSqSelectorSuite.testSelector(selector, dataset) + val model = testSelector(selector, dataset) MLTestingUtils.checkCopyAndUids(selector, model) } test("Test Chi-Square selector: percentile") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.17) - ChiSqSelectorSuite.testSelector(selector, dataset) + testSelector(selector, dataset) } test("Test Chi-Square selector: fpr") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("fpr").setFpr(0.02) - ChiSqSelectorSuite.testSelector(selector, dataset) + testSelector(selector, dataset) } test("Test Chi-Square selector: fdr") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("fdr").setFdr(0.12) - ChiSqSelectorSuite.testSelector(selector, dataset) + testSelector(selector, dataset) } test("Test Chi-Square selector: fwe") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("fwe").setFwe(0.12) - ChiSqSelectorSuite.testSelector(selector, dataset) + testSelector(selector, dataset) } test("read/write") { @@ -163,18 +162,19 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext assert(expected.selectedFeatures === actual.selectedFeatures) } } -} -object ChiSqSelectorSuite { - - private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): ChiSqSelectorModel = { - val selectorModel = selector.fit(dataset) - selectorModel.transform(dataset).select("filtered", "topFeature").collect() - .foreach { case Row(vec1: Vector, vec2: Vector) => + private def testSelector(selector: ChiSqSelector, data: Dataset[_]): ChiSqSelectorModel = { + val selectorModel = selector.fit(data) + testTransformer[(Double, Vector, Vector)](data.toDF(), selectorModel, + "filtered", "topFeature") { + case Row(vec1: Vector, vec2: Vector) => assert(vec1 ~== vec2 absTol 1e-1) - } + } selectorModel } +} + +object ChiSqSelectorSuite { /** * Mapping from all Params to valid settings which differ from the defaults. diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index 1784c07ca23e3..61217669d9277 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -16,16 +16,13 @@ */ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row -class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class CountVectorizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -50,7 +47,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) .setInputCol("words") .setOutputCol("features") - cv.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -72,7 +69,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext MLTestingUtils.checkCopyAndUids(cv, cvm) assert(cvm.vocabulary.toSet === Set("a", "b", "c", "d", "e")) - cvm.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cvm, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -100,7 +97,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .fit(df) assert(cvModel2.vocabulary === Array("a", "b")) - cvModel2.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cvModel2, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -113,7 +110,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .fit(df) assert(cvModel3.vocabulary === Array("a", "b")) - cvModel3.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cvModel3, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -219,7 +216,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .setInputCol("words") .setOutputCol("features") .setMinTF(3) - cv.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -238,7 +235,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .setInputCol("words") .setOutputCol("features") .setMinTF(0.3) - cv.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -258,7 +255,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .setOutputCol("features") .setBinary(true) .fit(df) - cv.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -268,7 +265,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .setInputCol("words") .setOutputCol("features") .setBinary(true) - cv2.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv2, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala index 8dd3dd75e1be5..6734336aac39c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala @@ -21,16 +21,14 @@ import scala.beans.BeanInfo import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.Row @BeanInfo case class DCTTestData(vec: Vector, wantedVec: Vector) -class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class DCTSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -72,11 +70,9 @@ class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead .setOutputCol("resultVec") .setInverse(inverse) - transformer.transform(dataset) - .select("resultVec", "wantedVec") - .collect() - .foreach { case Row(resultVec: Vector, wantedVec: Vector) => - assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6) + testTransformer[(Vector, Vector)](dataset, transformer, "resultVec", "wantedVec") { + case Row(resultVec: Vector, wantedVec: Vector) => + assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala index a4cca27be7815..3a8d0762e2ab7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala @@ -17,13 +17,31 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.linalg.Vectors -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.sql.Row -class ElementwiseProductSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class ElementwiseProductSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ + + test("streaming transform") { + val scalingVec = Vectors.dense(0.1, 10.0) + val data = Seq( + (Vectors.dense(0.1, 1.0), Vectors.dense(0.01, 10.0)), + (Vectors.dense(0.0, -1.1), Vectors.dense(0.0, -11.0)) + ) + val df = spark.createDataFrame(data).toDF("features", "expected") + val ep = new ElementwiseProduct() + .setInputCol("features") + .setOutputCol("actual") + .setScalingVec(scalingVec) + testTransformer[(Vector, Vector)](df, ep, "actual", "expected") { + case Row(actual: Vector, expected: Vector) => + assert(actual ~== expected relTol 1e-14) + } + } test("read/write") { val ep = new ElementwiseProduct() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala index 7bc1825b69c43..d799ba6011fa8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala @@ -17,27 +17,24 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class FeatureHasherSuite extends SparkFunSuite - with MLlibTestSparkContext - with DefaultReadWriteTest { +class FeatureHasherSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ import FeatureHasherSuite.murmur3FeatureIdx - implicit private val vectorEncoder = ExpressionEncoder[Vector]() + implicit private val vectorEncoder: ExpressionEncoder[Vector] = ExpressionEncoder[Vector]() test("params") { ParamsSuite.checkParams(new FeatureHasher) @@ -52,31 +49,31 @@ class FeatureHasherSuite extends SparkFunSuite } test("feature hashing") { + val numFeatures = 100 + // Assume perfect hash on field names in computing expected results + def idx: Any => Int = murmur3FeatureIdx(numFeatures) + val df = Seq( - (2.0, true, "1", "foo"), - (3.0, false, "2", "bar") - ).toDF("real", "bool", "stringNum", "string") + (2.0, true, "1", "foo", + Vectors.sparse(numFeatures, Seq((idx("real"), 2.0), (idx("bool=true"), 1.0), + (idx("stringNum=1"), 1.0), (idx("string=foo"), 1.0)))), + (3.0, false, "2", "bar", + Vectors.sparse(numFeatures, Seq((idx("real"), 3.0), (idx("bool=false"), 1.0), + (idx("stringNum=2"), 1.0), (idx("string=bar"), 1.0)))) + ).toDF("real", "bool", "stringNum", "string", "expected") - val n = 100 val hasher = new FeatureHasher() .setInputCols("real", "bool", "stringNum", "string") .setOutputCol("features") - .setNumFeatures(n) + .setNumFeatures(numFeatures) val output = hasher.transform(df) val attrGroup = AttributeGroup.fromStructField(output.schema("features")) - assert(attrGroup.numAttributes === Some(n)) + assert(attrGroup.numAttributes === Some(numFeatures)) - val features = output.select("features").as[Vector].collect() - // Assume perfect hash on field names - def idx: Any => Int = murmur3FeatureIdx(n) - // check expected indices - val expected = Seq( - Vectors.sparse(n, Seq((idx("real"), 2.0), (idx("bool=true"), 1.0), - (idx("stringNum=1"), 1.0), (idx("string=foo"), 1.0))), - Vectors.sparse(n, Seq((idx("real"), 3.0), (idx("bool=false"), 1.0), - (idx("stringNum=2"), 1.0), (idx("string=bar"), 1.0))) - ) - assert(features.zip(expected).forall { case (e, a) => e ~== a absTol 1e-14 }) + testTransformer[(Double, Boolean, String, String, Vector)](df, hasher, "features", "expected") { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14 ) + } } test("setting explicit numerical columns to treat as categorical") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index a46272fdce1fb..c5183ecfef7d7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -17,17 +17,16 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{HashingTF => MLlibHashingTF} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row import org.apache.spark.util.Utils -class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class HashingTFSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ import HashingTFSuite.murmur3FeatureIdx @@ -37,21 +36,28 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } test("hashingTF") { - val df = Seq((0, "a a b b c d".split(" ").toSeq)).toDF("id", "words") - val n = 100 + val numFeatures = 100 + // Assume perfect hash when computing expected features. + def idx: Any => Int = murmur3FeatureIdx(numFeatures) + val data = Seq( + ("a a b b c d".split(" ").toSeq, + Vectors.sparse(numFeatures, + Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0)))) + ) + + val df = data.toDF("words", "expected") val hashingTF = new HashingTF() .setInputCol("words") .setOutputCol("features") - .setNumFeatures(n) + .setNumFeatures(numFeatures) val output = hashingTF.transform(df) val attrGroup = AttributeGroup.fromStructField(output.schema("features")) - require(attrGroup.numAttributes === Some(n)) - val features = output.select("features").first().getAs[Vector](0) - // Assume perfect hash on "a", "b", "c", and "d". - def idx: Any => Int = murmur3FeatureIdx(n) - val expected = Vectors.sparse(n, - Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0))) - assert(features ~== expected absTol 1e-14) + require(attrGroup.numAttributes === Some(numFeatures)) + + testTransformer[(Seq[String], Vector)](df, hashingTF, "features", "expected") { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } } test("applying binary term freqs") { From 0b6ceadeb563205cbd6bd03bc88e608086273b5b Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Fri, 2 Mar 2018 09:23:39 -0800 Subject: [PATCH 413/774] [SPARKR][DOC] fix link in vignettes ## What changes were proposed in this pull request? Fix doc link that was changed in 2.3 shivaram Author: Felix Cheung Closes #20711 from felixcheung/rvigmean. --- R/pkg/vignettes/sparkr-vignettes.Rmd | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index feca617c2554c..d4713de7806a1 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -46,7 +46,7 @@ Sys.setenv("_JAVA_OPTIONS" = paste("-XX:-UsePerfData", old_java_opt, sep = " ")) ## Overview -SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. With Spark `r packageVersion("SparkR")`, SparkR provides a distributed data frame implementation that supports data processing operations like selection, filtering, aggregation etc. and distributed machine learning using [MLlib](http://spark.apache.org/mllib/). +SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. With Spark `r packageVersion("SparkR")`, SparkR provides a distributed data frame implementation that supports data processing operations like selection, filtering, aggregation etc. and distributed machine learning using [MLlib](https://spark.apache.org/mllib/). ## Getting Started @@ -132,7 +132,7 @@ sparkR.session.stop() Different from many other R packages, to use SparkR, you need an additional installation of Apache Spark. The Spark installation will be used to run a backend process that will compile and execute SparkR programs. -After installing the SparkR package, you can call `sparkR.session` as explained in the previous section to start and it will check for the Spark installation. If you are working with SparkR from an interactive shell (eg. R, RStudio) then Spark is downloaded and cached automatically if it is not found. Alternatively, we provide an easy-to-use function `install.spark` for running this manually. If you don't have Spark installed on the computer, you may download it from [Apache Spark Website](http://spark.apache.org/downloads.html). +After installing the SparkR package, you can call `sparkR.session` as explained in the previous section to start and it will check for the Spark installation. If you are working with SparkR from an interactive shell (eg. R, RStudio) then Spark is downloaded and cached automatically if it is not found. Alternatively, we provide an easy-to-use function `install.spark` for running this manually. If you don't have Spark installed on the computer, you may download it from [Apache Spark Website](https://spark.apache.org/downloads.html). ```{r, eval=FALSE} install.spark() @@ -147,7 +147,7 @@ sparkR.session(sparkHome = "/HOME/spark") ### Spark Session {#SetupSparkSession} -In addition to `sparkHome`, many other options can be specified in `sparkR.session`. For a complete list, see [Starting up: SparkSession](http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession) and [SparkR API doc](http://spark.apache.org/docs/latest/api/R/sparkR.session.html). +In addition to `sparkHome`, many other options can be specified in `sparkR.session`. For a complete list, see [Starting up: SparkSession](https://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession) and [SparkR API doc](https://spark.apache.org/docs/latest/api/R/sparkR.session.html). In particular, the following Spark driver properties can be set in `sparkConfig`. @@ -169,7 +169,7 @@ sparkR.session(spark.sql.warehouse.dir = spark_warehouse_path) #### Cluster Mode -SparkR can connect to remote Spark clusters. [Cluster Mode Overview](http://spark.apache.org/docs/latest/cluster-overview.html) is a good introduction to different Spark cluster modes. +SparkR can connect to remote Spark clusters. [Cluster Mode Overview](https://spark.apache.org/docs/latest/cluster-overview.html) is a good introduction to different Spark cluster modes. When connecting SparkR to a remote Spark cluster, make sure that the Spark version and Hadoop version on the machine match the corresponding versions on the cluster. Current SparkR package is compatible with ```{r, echo=FALSE, tidy = TRUE} @@ -177,7 +177,7 @@ paste("Spark", packageVersion("SparkR")) ``` It should be used both on the local computer and on the remote cluster. -To connect, pass the URL of the master node to `sparkR.session`. A complete list can be seen in [Spark Master URLs](http://spark.apache.org/docs/latest/submitting-applications.html#master-urls). +To connect, pass the URL of the master node to `sparkR.session`. A complete list can be seen in [Spark Master URLs](https://spark.apache.org/docs/latest/submitting-applications.html#master-urls). For example, to connect to a local standalone Spark master, we can call ```{r, eval=FALSE} @@ -317,7 +317,7 @@ A common flow of grouping and aggregation is 2. Feed the `GroupedData` object to `agg` or `summarize` functions, with some provided aggregation functions to compute a number within each group. -A number of widely used functions are supported to aggregate data after grouping, including `avg`, `countDistinct`, `count`, `first`, `kurtosis`, `last`, `max`, `mean`, `min`, `sd`, `skewness`, `stddev_pop`, `stddev_samp`, `sumDistinct`, `sum`, `var_pop`, `var_samp`, `var`. See the [API doc for `mean`](http://spark.apache.org/docs/latest/api/R/mean.html) and other `agg_funcs` linked there. +A number of widely used functions are supported to aggregate data after grouping, including `avg`, `countDistinct`, `count`, `first`, `kurtosis`, `last`, `max`, `mean`, `min`, `sd`, `skewness`, `stddev_pop`, `stddev_samp`, `sumDistinct`, `sum`, `var_pop`, `var_samp`, `var`. See the [API doc for aggregate functions](https://spark.apache.org/docs/latest/api/R/column_aggregate_functions.html) linked there. For example we can compute a histogram of the number of cylinders in the `mtcars` dataset as shown below. @@ -935,7 +935,7 @@ perplexity #### Alternating Least Squares -`spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](http://dl.acm.org/citation.cfm?id=1608614). +`spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](https://dl.acm.org/citation.cfm?id=1608614). There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, and `nonnegative`. For a complete list, refer to the help file. @@ -1171,11 +1171,11 @@ env | map ## References -* [Spark Cluster Mode Overview](http://spark.apache.org/docs/latest/cluster-overview.html) +* [Spark Cluster Mode Overview](https://spark.apache.org/docs/latest/cluster-overview.html) -* [Submitting Spark Applications](http://spark.apache.org/docs/latest/submitting-applications.html) +* [Submitting Spark Applications](https://spark.apache.org/docs/latest/submitting-applications.html) -* [Machine Learning Library Guide (MLlib)](http://spark.apache.org/docs/latest/ml-guide.html) +* [Machine Learning Library Guide (MLlib)](https://spark.apache.org/docs/latest/ml-guide.html) * [SparkR: Scaling R Programs with Spark](https://people.csail.mit.edu/matei/papers/2016/sigmod_sparkr.pdf), Shivaram Venkataraman, Zongheng Yang, Davies Liu, Eric Liang, Hossein Falaki, Xiangrui Meng, Reynold Xin, Ali Ghodsi, Michael Franklin, Ion Stoica, and Matei Zaharia. SIGMOD 2016. June 2016. From 3a4d15e5d2b9ddbaeb2a6ab2d86d059ada6407b2 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Fri, 2 Mar 2018 10:38:50 -0800 Subject: [PATCH 414/774] [SPARK-23518][SQL] Avoid metastore access when the users only want to read and write data frames ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/18944 added one patch, which allowed a spark session to be created when the hive metastore server is down. However, it did not allow running any commands with the spark session. This brings troubles to the user who only wants to read / write data frames without metastore setup. ## How was this patch tested? Added some unit tests to read and write data frames based on the original HiveMetastoreLazyInitializationSuite. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu Closes #20681 from liufengdb/completely-lazy. --- R/pkg/tests/fulltests/test_sparkSQL.R | 2 ++ .../sql/catalyst/catalog/SessionCatalog.scala | 11 +++++++---- .../sql/internal/BaseSessionStateBuilder.scala | 4 ++-- .../HiveMetastoreLazyInitializationSuite.scala | 14 ++++++++++++++ .../apache/spark/sql/hive/HiveSessionCatalog.scala | 8 ++++---- .../spark/sql/hive/HiveSessionStateBuilder.scala | 10 +++++----- 6 files changed, 34 insertions(+), 15 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 5197838eaac66..bd0a0dcd0674c 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -67,6 +67,8 @@ sparkSession <- if (windows_with_hadoop()) { sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) } sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) +# materialize the catalog implementation +listTables() mockLines <- c("{\"name\":\"Michael\"}", "{\"name\":\"Andy\", \"age\":30}", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 4b119c75260a7..64e7ca11270b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -54,8 +54,8 @@ object SessionCatalog { * This class must be thread-safe. */ class SessionCatalog( - val externalCatalog: ExternalCatalog, - globalTempViewManager: GlobalTempViewManager, + externalCatalogBuilder: () => ExternalCatalog, + globalTempViewManagerBuilder: () => GlobalTempViewManager, functionRegistry: FunctionRegistry, conf: SQLConf, hadoopConf: Configuration, @@ -70,8 +70,8 @@ class SessionCatalog( functionRegistry: FunctionRegistry, conf: SQLConf) { this( - externalCatalog, - new GlobalTempViewManager("global_temp"), + () => externalCatalog, + () => new GlobalTempViewManager("global_temp"), functionRegistry, conf, new Configuration(), @@ -87,6 +87,9 @@ class SessionCatalog( new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) } + lazy val externalCatalog = externalCatalogBuilder() + lazy val globalTempViewManager = globalTempViewManagerBuilder() + /** List of temporary views, mapping from table name to their logical plan. */ @GuardedBy("this") protected val tempViews = new mutable.HashMap[String, LogicalPlan] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 007f8760edf82..3a0db7e16c23a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -130,8 +130,8 @@ abstract class BaseSessionStateBuilder( */ protected lazy val catalog: SessionCatalog = { val catalog = new SessionCatalog( - session.sharedState.externalCatalog, - session.sharedState.globalTempViewManager, + () => session.sharedState.externalCatalog, + () => session.sharedState.globalTempViewManager, functionRegistry, conf, SessionState.newHadoopConf(session.sparkContext.hadoopConfiguration, conf), diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala index 3f135cc864983..277df548aefd0 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala @@ -38,6 +38,20 @@ class HiveMetastoreLazyInitializationSuite extends SparkFunSuite { // We should be able to run Spark jobs without Hive client. assert(spark.sparkContext.range(0, 1).count() === 1) + // We should be able to use Spark SQL if no table references. + assert(spark.sql("select 1 + 1").count() === 1) + assert(spark.range(0, 1).count() === 1) + + // We should be able to use fs + val path = Utils.createTempDir() + path.delete() + try { + spark.range(0, 1).write.parquet(path.getAbsolutePath) + assert(spark.read.parquet(path.getAbsolutePath).count() === 1) + } finally { + Utils.deleteRecursively(path) + } + // Make sure that we are not using the local derby metastore. val exceptionString = Utils.exceptionString(intercept[AnalysisException] { spark.sql("show tables") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 1f11adbd4f62e..e5aff3b99d0b9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -39,8 +39,8 @@ import org.apache.spark.sql.types.{DecimalType, DoubleType} private[sql] class HiveSessionCatalog( - externalCatalog: HiveExternalCatalog, - globalTempViewManager: GlobalTempViewManager, + externalCatalogBuilder: () => HiveExternalCatalog, + globalTempViewManagerBuilder: () => GlobalTempViewManager, val metastoreCatalog: HiveMetastoreCatalog, functionRegistry: FunctionRegistry, conf: SQLConf, @@ -48,8 +48,8 @@ private[sql] class HiveSessionCatalog( parser: ParserInterface, functionResourceLoader: FunctionResourceLoader) extends SessionCatalog( - externalCatalog, - globalTempViewManager, + externalCatalogBuilder, + globalTempViewManagerBuilder, functionRegistry, conf, hadoopConf, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 12c74368dd184..40b9bb51ca9a0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -42,8 +42,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session * Create a Hive aware resource loader. */ override protected lazy val resourceLoader: HiveSessionResourceLoader = { - val client: HiveClient = externalCatalog.client - new HiveSessionResourceLoader(session, client) + new HiveSessionResourceLoader(session, () => externalCatalog.client) } /** @@ -51,8 +50,8 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session */ override protected lazy val catalog: HiveSessionCatalog = { val catalog = new HiveSessionCatalog( - externalCatalog, - session.sharedState.globalTempViewManager, + () => externalCatalog, + () => session.sharedState.globalTempViewManager, new HiveMetastoreCatalog(session), functionRegistry, conf, @@ -105,8 +104,9 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session class HiveSessionResourceLoader( session: SparkSession, - client: HiveClient) + clientBuilder: () => HiveClient) extends SessionResourceLoader(session) { + private lazy val client = clientBuilder() override def addJar(path: String): Unit = { client.addJar(path) super.addJar(path) From 707e6506d0dbdb598a6c99d666f3c66746113b67 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 2 Mar 2018 12:27:42 -0800 Subject: [PATCH 415/774] [SPARK-23097][SQL][SS] Migrate text socket source to V2 ## What changes were proposed in this pull request? This PR moves structured streaming text socket source to V2. Questions: do we need to remove old "socket" source? ## How was this patch tested? Unit test and manual verification. Author: jerryshao Closes #20382 from jerryshao/SPARK-23097. --- ...pache.spark.sql.sources.DataSourceRegister | 2 +- .../execution/datasources/DataSource.scala | 5 +- .../streaming/{ => sources}/socket.scala | 178 ++++++---- .../sql/streaming/DataStreamReader.scala | 21 +- .../streaming/TextSocketStreamSuite.scala | 231 ------------- .../sources/TextSocketStreamSuite.scala | 306 ++++++++++++++++++ 6 files changed, 434 insertions(+), 309 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/{ => sources}/socket.scala (51%) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 0259c774bbf4a..1fe9c093af99f 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,6 +5,6 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider -org.apache.spark.sql.execution.streaming.TextSocketSourceProvider org.apache.spark.sql.execution.streaming.RateSourceProvider +org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 6e1b5727e3fd5..35fcff69b14d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode @@ -563,6 +564,7 @@ object DataSource extends Logging { val libsvm = "org.apache.spark.ml.source.libsvm.LibSVMFileFormat" val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" val nativeOrc = classOf[OrcFileFormat].getCanonicalName + val socket = classOf[TextSocketSourceProvider].getCanonicalName Map( "org.apache.spark.sql.jdbc" -> jdbc, @@ -583,7 +585,8 @@ object DataSource extends Logging { "org.apache.spark.sql.execution.datasources.orc" -> nativeOrc, "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, "org.apache.spark.ml.source.libsvm" -> libsvm, - "com.databricks.spark.csv" -> csv + "com.databricks.spark.csv" -> csv, + "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala similarity index 51% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 0b22cbc46e6bf..5aae46b463398 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -15,27 +15,29 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming +package org.apache.spark.sql.execution.streaming.sources import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket import java.sql.Timestamp import java.text.SimpleDateFormat -import java.util.{Calendar, Locale} +import java.util.{Calendar, List => JList, Locale, Optional} import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.execution.streaming.LongOffset +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} -import org.apache.spark.unsafe.types.UTF8String - -object TextSocketSource { +object TextSocketMicroBatchReader { val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: StructField("timestamp", TimestampType) :: Nil) @@ -43,12 +45,17 @@ object TextSocketSource { } /** - * A source that reads text lines through a TCP socket, designed only for tutorials and debugging. - * This source will *not* work in production applications due to multiple reasons, including no - * support for fault recovery and keeping all of the text read in memory forever. + * A MicroBatchReader that reads text lines through a TCP socket, designed only for tutorials and + * debugging. This MicroBatchReader will *not* work in production applications due to multiple + * reasons, including no support for fault recovery. */ -class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlContext: SQLContext) - extends Source with Logging { +class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader with Logging { + + private var startOffset: Offset = _ + private var endOffset: Offset = _ + + private val host: String = options.get("host").get() + private val port: Int = options.get("port").get().toInt @GuardedBy("this") private var socket: Socket = null @@ -61,16 +68,21 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo * Stored in a ListBuffer to facilitate removing committed batches. */ @GuardedBy("this") - protected val batches = new ListBuffer[(String, Timestamp)] + private val batches = new ListBuffer[(String, Timestamp)] @GuardedBy("this") - protected var currentOffset: LongOffset = new LongOffset(-1) + private var currentOffset: LongOffset = LongOffset(-1L) @GuardedBy("this") - protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) + private var lastOffsetCommitted: LongOffset = LongOffset(-1L) initialize() + /** This method is only used for unit test */ + private[sources] def getCurrentOffset(): LongOffset = synchronized { + currentOffset.copy() + } + private def initialize(): Unit = synchronized { socket = new Socket(host, port) val reader = new BufferedReader(new InputStreamReader(socket.getInputStream)) @@ -86,12 +98,12 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo logWarning(s"Stream closed by $host:$port") return } - TextSocketSource.this.synchronized { + TextSocketMicroBatchReader.this.synchronized { val newData = (line, Timestamp.valueOf( - TextSocketSource.DATE_FORMAT.format(Calendar.getInstance().getTime())) - ) - currentOffset = currentOffset + 1 + TextSocketMicroBatchReader.DATE_FORMAT.format(Calendar.getInstance().getTime())) + ) + currentOffset += 1 batches.append(newData) } } @@ -103,23 +115,37 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo readThread.start() } - /** Returns the schema of the data from this source */ - override def schema: StructType = if (includeTimestamp) TextSocketSource.SCHEMA_TIMESTAMP - else TextSocketSource.SCHEMA_REGULAR + override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = synchronized { + startOffset = start.orElse(LongOffset(-1L)) + endOffset = end.orElse(currentOffset) + } + + override def getStartOffset(): Offset = { + Option(startOffset).getOrElse(throw new IllegalStateException("start offset not set")) + } + + override def getEndOffset(): Offset = { + Option(endOffset).getOrElse(throw new IllegalStateException("end offset not set")) + } + + override def deserializeOffset(json: String): Offset = { + LongOffset(json.toLong) + } - override def getOffset: Option[Offset] = synchronized { - if (currentOffset.offset == -1) { - None + override def readSchema(): StructType = { + if (options.getBoolean("includeTimestamp", false)) { + TextSocketMicroBatchReader.SCHEMA_TIMESTAMP } else { - Some(currentOffset) + TextSocketMicroBatchReader.SCHEMA_REGULAR } } - /** Returns the data that is between the offsets (`start`, `end`]. */ - override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized { - val startOrdinal = - start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1 - val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1 + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + assert(startOffset != null && endOffset != null, + "start offset and end offset should already be set before create read tasks.") + + val startOrdinal = LongOffset.convert(startOffset).get.offset.toInt + 1 + val endOrdinal = LongOffset.convert(endOffset).get.offset.toInt + 1 // Internal buffer only holds the batches after lastOffsetCommitted val rawList = synchronized { @@ -128,10 +154,34 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo batches.slice(sliceStart, sliceEnd) } - val rdd = sqlContext.sparkContext - .parallelize(rawList) - .map { case (v, ts) => InternalRow(UTF8String.fromString(v), ts.getTime) } - sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) + assert(SparkSession.getActiveSession.isDefined) + val spark = SparkSession.getActiveSession.get + val numPartitions = spark.sparkContext.defaultParallelism + + val slices = Array.fill(numPartitions)(new ListBuffer[(String, Timestamp)]) + rawList.zipWithIndex.foreach { case (r, idx) => + slices(idx % numPartitions).append(r) + } + + (0 until numPartitions).map { i => + val slice = slices(i) + new DataReaderFactory[Row] { + override def createDataReader(): DataReader[Row] = new DataReader[Row] { + private var currentIdx = -1 + + override def next(): Boolean = { + currentIdx += 1 + currentIdx < slice.size + } + + override def get(): Row = { + Row(slice(currentIdx)._1, slice(currentIdx)._2) + } + + override def close(): Unit = {} + } + } + }.toList.asJava } override def commit(end: Offset): Unit = synchronized { @@ -164,54 +214,40 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo } } - override def toString: String = s"TextSocketSource[host: $host, port: $port]" + override def toString: String = s"TextSocket[host: $host, port: $port]" } -class TextSocketSourceProvider extends StreamSourceProvider with DataSourceRegister with Logging { - private def parseIncludeTimestamp(params: Map[String, String]): Boolean = { - Try(params.getOrElse("includeTimestamp", "false").toBoolean) match { - case Success(bool) => bool - case Failure(_) => - throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"") - } - } +class TextSocketSourceProvider extends DataSourceV2 + with MicroBatchReadSupport with DataSourceRegister with Logging { - /** Returns the name and schema of the source that can be used to continually read data. */ - override def sourceSchema( - sqlContext: SQLContext, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): (String, StructType) = { + private def checkParameters(params: DataSourceOptions): Unit = { logWarning("The socket source should not be used for production applications! " + "It does not support recovery.") - if (!parameters.contains("host")) { + if (!params.get("host").isPresent) { throw new AnalysisException("Set a host to read from with option(\"host\", ...).") } - if (!parameters.contains("port")) { + if (!params.get("port").isPresent) { throw new AnalysisException("Set a port to read from with option(\"port\", ...).") } - if (schema.nonEmpty) { - throw new AnalysisException("The socket source does not support a user-specified schema.") + Try { + params.get("includeTimestamp").orElse("false").toBoolean + } match { + case Success(_) => + case Failure(_) => + throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"") } - - val sourceSchema = - if (parseIncludeTimestamp(parameters)) { - TextSocketSource.SCHEMA_TIMESTAMP - } else { - TextSocketSource.SCHEMA_REGULAR - } - ("textSocket", sourceSchema) } - override def createSource( - sqlContext: SQLContext, - metadataPath: String, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): Source = { - val host = parameters("host") - val port = parameters("port").toInt - new TextSocketSource(host, port, parseIncludeTimestamp(parameters), sqlContext) + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): MicroBatchReader = { + checkParameters(options) + if (schema.isPresent) { + throw new AnalysisException("The socket source does not support a user-specified schema.") + } + + new TextSocketMicroBatchReader(options) } /** String that represents the format that this data source provider uses. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 61e22fac854f9..c393dcdfdd7e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -172,15 +173,25 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } ds match { case s: MicroBatchReadSupport => - val tempReader = s.createMicroBatchReader( - Optional.ofNullable(userSpecifiedSchema.orNull), - Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, - options) + var tempReader: MicroBatchReader = null + val schema = try { + tempReader = s.createMicroBatchReader( + Optional.ofNullable(userSpecifiedSchema.orNull), + Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, + options) + tempReader.readSchema() + } finally { + // Stop tempReader to avoid side-effect thing + if (tempReader != null) { + tempReader.stop() + tempReader = null + } + } Dataset.ofRows( sparkSession, StreamingRelationV2( s, source, extraOptions.toMap, - tempReader.readSchema().toAttributes, v1Relation)(sparkSession)) + schema.toAttributes, v1Relation)(sparkSession)) case s: ContinuousReadSupport => val tempReader = s.createContinuousReader( Optional.ofNullable(userSpecifiedSchema.orNull), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala deleted file mode 100644 index ec11549073650..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.io.{IOException, OutputStreamWriter} -import java.net.ServerSocket -import java.sql.Timestamp -import java.util.concurrent.LinkedBlockingQueue - -import org.scalatest.BeforeAndAfterEach - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} - -class TextSocketStreamSuite extends StreamTest with SharedSQLContext with BeforeAndAfterEach { - import testImplicits._ - - override def afterEach() { - sqlContext.streams.active.foreach(_.stop()) - if (serverThread != null) { - serverThread.interrupt() - serverThread.join() - serverThread = null - } - if (source != null) { - source.stop() - source = null - } - } - - private var serverThread: ServerThread = null - private var source: Source = null - - test("basic usage") { - serverThread = new ServerThread() - serverThread.start() - - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString) - val schema = provider.sourceSchema(sqlContext, None, "", parameters)._2 - assert(schema === StructType(StructField("value", StringType) :: Nil)) - - source = provider.createSource(sqlContext, "", None, "", parameters) - - failAfter(streamingTimeout) { - serverThread.enqueue("hello") - while (source.getOffset.isEmpty) { - Thread.sleep(10) - } - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { - val offset1 = source.getOffset.get - val batch1 = source.getBatch(None, offset1) - assert(batch1.as[String].collect().toSeq === Seq("hello")) - - serverThread.enqueue("world") - while (source.getOffset.get === offset1) { - Thread.sleep(10) - } - val offset2 = source.getOffset.get - val batch2 = source.getBatch(Some(offset1), offset2) - assert(batch2.as[String].collect().toSeq === Seq("world")) - - val both = source.getBatch(None, offset2) - assert(both.as[String].collect().sorted.toSeq === Seq("hello", "world")) - } - - // Try stopping the source to make sure this does not block forever. - source.stop() - source = null - } - } - - test("timestamped usage") { - serverThread = new ServerThread() - serverThread.start() - - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString, - "includeTimestamp" -> "true") - val schema = provider.sourceSchema(sqlContext, None, "", parameters)._2 - assert(schema === StructType(StructField("value", StringType) :: - StructField("timestamp", TimestampType) :: Nil)) - - source = provider.createSource(sqlContext, "", None, "", parameters) - - failAfter(streamingTimeout) { - serverThread.enqueue("hello") - while (source.getOffset.isEmpty) { - Thread.sleep(10) - } - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { - val offset1 = source.getOffset.get - val batch1 = source.getBatch(None, offset1) - val batch1Seq = batch1.as[(String, Timestamp)].collect().toSeq - assert(batch1Seq.map(_._1) === Seq("hello")) - val batch1Stamp = batch1Seq(0)._2 - - serverThread.enqueue("world") - while (source.getOffset.get === offset1) { - Thread.sleep(10) - } - val offset2 = source.getOffset.get - val batch2 = source.getBatch(Some(offset1), offset2) - val batch2Seq = batch2.as[(String, Timestamp)].collect().toSeq - assert(batch2Seq.map(_._1) === Seq("world")) - val batch2Stamp = batch2Seq(0)._2 - assert(!batch2Stamp.before(batch1Stamp)) - } - - // Try stopping the source to make sure this does not block forever. - source.stop() - source = null - } - } - - test("params not given") { - val provider = new TextSocketSourceProvider - intercept[AnalysisException] { - provider.sourceSchema(sqlContext, None, "", Map()) - } - intercept[AnalysisException] { - provider.sourceSchema(sqlContext, None, "", Map("host" -> "localhost")) - } - intercept[AnalysisException] { - provider.sourceSchema(sqlContext, None, "", Map("port" -> "1234")) - } - } - - test("non-boolean includeTimestamp") { - val provider = new TextSocketSourceProvider - intercept[AnalysisException] { - provider.sourceSchema(sqlContext, None, "", Map("host" -> "localhost", - "port" -> "1234", "includeTimestamp" -> "fasle")) - } - } - - test("user-specified schema given") { - val provider = new TextSocketSourceProvider - val userSpecifiedSchema = StructType( - StructField("name", StringType) :: - StructField("area", StringType) :: Nil) - val exception = intercept[AnalysisException] { - provider.sourceSchema( - sqlContext, Some(userSpecifiedSchema), - "", - Map("host" -> "localhost", "port" -> "1234")) - } - assert(exception.getMessage.contains( - "socket source does not support a user-specified schema")) - } - - test("no server up") { - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> "0") - intercept[IOException] { - source = provider.createSource(sqlContext, "", None, "", parameters) - } - } - - test("input row metrics") { - serverThread = new ServerThread() - serverThread.start() - - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString) - source = provider.createSource(sqlContext, "", None, "", parameters) - - failAfter(streamingTimeout) { - serverThread.enqueue("hello") - while (source.getOffset.isEmpty) { - Thread.sleep(10) - } - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { - val batch = source.getBatch(None, source.getOffset.get).as[String] - batch.collect() - val numRowsMetric = - batch.queryExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows") - assert(numRowsMetric.nonEmpty) - assert(numRowsMetric.get.value === 1) - } - source.stop() - source = null - } - } - - private class ServerThread extends Thread with Logging { - private val serverSocket = new ServerSocket(0) - private val messageQueue = new LinkedBlockingQueue[String]() - - val port = serverSocket.getLocalPort - - override def run(): Unit = { - try { - val clientSocket = serverSocket.accept() - clientSocket.setTcpNoDelay(true) - val out = new OutputStreamWriter(clientSocket.getOutputStream) - while (true) { - val line = messageQueue.take() - out.write(line + "\n") - out.flush() - } - } catch { - case e: InterruptedException => - } finally { - serverSocket.close() - } - } - - def enqueue(line: String): Unit = { - messageQueue.put(line) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala new file mode 100644 index 0000000000000..a15a980bb92fd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io.IOException +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.nio.channels.ServerSocketChannel +import java.sql.Timestamp +import java.util.Optional +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.JavaConverters._ + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} + +class TextSocketStreamSuite extends StreamTest with SharedSQLContext with BeforeAndAfterEach { + + override def afterEach() { + sqlContext.streams.active.foreach(_.stop()) + if (serverThread != null) { + serverThread.interrupt() + serverThread.join() + serverThread = null + } + if (batchReader != null) { + batchReader.stop() + batchReader = null + } + } + + private var serverThread: ServerThread = null + private var batchReader: MicroBatchReader = null + + case class AddSocketData(data: String*) extends AddData { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + require( + query.nonEmpty, + "Cannot add data when there is no query for finding the active socket source") + + val sources = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: TextSocketMicroBatchReader, _) => source + } + if (sources.isEmpty) { + throw new Exception( + "Could not find socket source in the StreamExecution logical plan to add data to") + } else if (sources.size > 1) { + throw new Exception( + "Could not select the socket source in the StreamExecution logical plan as there" + + "are multiple socket sources:\n\t" + sources.mkString("\n\t")) + } + val socketSource = sources.head + + assert(serverThread != null && serverThread.port != 0) + val currOffset = socketSource.getCurrentOffset() + data.foreach(serverThread.enqueue) + + val newOffset = LongOffset(currOffset.offset + data.size) + (socketSource, newOffset) + } + + override def toString: String = s"AddSocketData(data = $data)" + } + + test("backward compatibility with old path") { + DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.TextSocketSourceProvider", + spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + assert(ds.isInstanceOf[TextSocketSourceProvider]) + case _ => + throw new IllegalStateException("Could not find socket source") + } + } + + test("basic usage") { + serverThread = new ServerThread() + serverThread.start() + + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val ref = spark + import ref.implicits._ + + val socket = spark + .readStream + .format("socket") + .options(Map("host" -> "localhost", "port" -> serverThread.port.toString)) + .load() + .as[String] + + assert(socket.schema === StructType(StructField("value", StringType) :: Nil)) + + testStream(socket)( + StartStream(), + AddSocketData("hello"), + CheckAnswer("hello"), + AddSocketData("world"), + CheckLastBatch("world"), + CheckAnswer("hello", "world"), + StopStream + ) + } + } + + test("timestamped usage") { + serverThread = new ServerThread() + serverThread.start() + + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val socket = spark + .readStream + .format("socket") + .options(Map( + "host" -> "localhost", + "port" -> serverThread.port.toString, + "includeTimestamp" -> "true")) + .load() + + assert(socket.schema === StructType(StructField("value", StringType) :: + StructField("timestamp", TimestampType) :: Nil)) + + var batch1Stamp: Timestamp = null + var batch2Stamp: Timestamp = null + + val curr = System.currentTimeMillis() + testStream(socket)( + StartStream(), + AddSocketData("hello"), + CheckAnswerRowsByFunc( + rows => { + assert(rows.size === 1) + assert(rows.head.getAs[String](0) === "hello") + batch1Stamp = rows.head.getAs[Timestamp](1) + Thread.sleep(10) + }, + true), + AddSocketData("world"), + CheckAnswerRowsByFunc( + rows => { + assert(rows.size === 1) + assert(rows.head.getAs[String](0) === "world") + batch2Stamp = rows.head.getAs[Timestamp](1) + }, + true), + StopStream + ) + + // Timestamp for rate stream is round to second which leads to milliseconds lost, that will + // make batch1stamp smaller than current timestamp if both of them are in the same second. + // Comparing by second to make sure the correct behavior. + assert(batch1Stamp.getTime >= curr / 1000 * 1000) + assert(!batch2Stamp.before(batch1Stamp)) + } + } + + test("params not given") { + val provider = new TextSocketSourceProvider + intercept[AnalysisException] { + provider.createMicroBatchReader(Optional.empty(), "", + new DataSourceOptions(Map.empty[String, String].asJava)) + } + intercept[AnalysisException] { + provider.createMicroBatchReader(Optional.empty(), "", + new DataSourceOptions(Map("host" -> "localhost").asJava)) + } + intercept[AnalysisException] { + provider.createMicroBatchReader(Optional.empty(), "", + new DataSourceOptions(Map("port" -> "1234").asJava)) + } + } + + test("non-boolean includeTimestamp") { + val provider = new TextSocketSourceProvider + val params = Map("host" -> "localhost", "port" -> "1234", "includeTimestamp" -> "fasle") + intercept[AnalysisException] { + val a = new DataSourceOptions(params.asJava) + provider.createMicroBatchReader(Optional.empty(), "", a) + } + } + + test("user-specified schema given") { + val provider = new TextSocketSourceProvider + val userSpecifiedSchema = StructType( + StructField("name", StringType) :: + StructField("area", StringType) :: Nil) + val params = Map("host" -> "localhost", "port" -> "1234") + val exception = intercept[AnalysisException] { + provider.createMicroBatchReader( + Optional.of(userSpecifiedSchema), "", new DataSourceOptions(params.asJava)) + } + assert(exception.getMessage.contains( + "socket source does not support a user-specified schema")) + } + + test("no server up") { + val provider = new TextSocketSourceProvider + val parameters = Map("host" -> "localhost", "port" -> "0") + intercept[IOException] { + batchReader = provider.createMicroBatchReader( + Optional.empty(), "", new DataSourceOptions(parameters.asJava)) + } + } + + test("input row metrics") { + serverThread = new ServerThread() + serverThread.start() + + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val ref = spark + import ref.implicits._ + + val socket = spark + .readStream + .format("socket") + .options(Map("host" -> "localhost", "port" -> serverThread.port.toString)) + .load() + .as[String] + + assert(socket.schema === StructType(StructField("value", StringType) :: Nil)) + + testStream(socket)( + StartStream(), + AddSocketData("hello"), + CheckAnswer("hello"), + AssertOnQuery { q => + val numRowMetric = + q.lastExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows") + numRowMetric.nonEmpty && numRowMetric.get.value == 1 + }, + StopStream + ) + } + } + + private class ServerThread extends Thread with Logging { + private val serverSocketChannel = ServerSocketChannel.open() + serverSocketChannel.bind(new InetSocketAddress(0)) + private val messageQueue = new LinkedBlockingQueue[String]() + + val port = serverSocketChannel.socket().getLocalPort + + override def run(): Unit = { + try { + while (true) { + val clientSocketChannel = serverSocketChannel.accept() + clientSocketChannel.configureBlocking(false) + clientSocketChannel.socket().setTcpNoDelay(true) + + // Check whether remote client is closed but still send data to this closed socket. + // This happens in DataStreamReader where a source will be created to get the schema. + var remoteIsClosed = false + var cnt = 0 + while (cnt < 3 && !remoteIsClosed) { + if (clientSocketChannel.read(ByteBuffer.allocate(1)) != -1) { + cnt += 1 + Thread.sleep(100) + } else { + remoteIsClosed = true + } + } + + if (remoteIsClosed) { + logInfo(s"remote client ${clientSocketChannel.socket()} is closed") + } else { + while (true) { + val line = messageQueue.take() + "\n" + clientSocketChannel.write(ByteBuffer.wrap(line.getBytes("UTF-8"))) + } + } + } + } catch { + case e: InterruptedException => + } finally { + serverSocketChannel.close() + } + } + + def enqueue(line: String): Unit = { + messageQueue.put(line) + } + } +} From 487377e693af65b2ff3d6b874ca7326c1ff0076c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 2 Mar 2018 14:30:37 -0800 Subject: [PATCH 416/774] [SPARK-23570][SQL] Add Spark 2.3.0 in HiveExternalCatalogVersionsSuite ## What changes were proposed in this pull request? Add Spark 2.3.0 in HiveExternalCatalogVersionsSuite since Spark 2.3.0 is released for ensuring backward compatibility. ## How was this patch tested? N/A Author: gatorsmile Closes #20720 from gatorsmile/add2.3. --- .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index c13a750dbb270..6ca58e68d31eb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -195,7 +195,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0", "2.2.1") + val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0", "2.2.1", "2.3.0") protected var spark: SparkSession = _ From 9e26473c0f29ee4281519104ac5e182a3bd4bf23 Mon Sep 17 00:00:00 2001 From: Alessandro Solimando <18898964+asolimando@users.noreply.github.com> Date: Fri, 2 Mar 2018 16:24:29 -0800 Subject: [PATCH 417/774] [SPARK-3159][ML] Add decision tree pruning ## What changes were proposed in this pull request? Added subtree pruning in the translation from LearningNode to Node: a learning node having a single prediction value for all the leaves in the subtree rooted at it is translated into a LeafNode, instead of a (redundant) InternalNode ## How was this patch tested? Added two unit tests under "mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala": - test("SPARK-3159 tree model redundancy - classification") - test("SPARK-3159 tree model redundancy - regression") 4 existing unit tests relying on the tree structure (existence of a specific redundant subtree) had to be adapted as the tested components in the output tree are now pruned (fixed by adding an extra _prune_ parameter which can be used to disable pruning for testing) Author: Alessandro Solimando <18898964+asolimando@users.noreply.github.com> Closes #20632 from asolimando/master. --- .../scala/org/apache/spark/ml/tree/Node.scala | 22 ++-- .../spark/ml/tree/impl/RandomForest.scala | 10 +- .../DecisionTreeClassifierSuite.scala | 38 ------- .../ml/tree/impl/RandomForestSuite.scala | 100 ++++++++++++++++-- .../spark/mllib/tree/DecisionTreeSuite.scala | 10 +- 5 files changed, 115 insertions(+), 65 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index 07e98a142b10e..d30be452a436e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -19,8 +19,7 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.linalg.Vector import org.apache.spark.mllib.tree.impurity.ImpurityCalculator -import org.apache.spark.mllib.tree.model.{ImpurityStats, - InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} +import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} /** * Decision tree node interface. @@ -266,15 +265,23 @@ private[tree] class LearningNode( var isLeaf: Boolean, var stats: ImpurityStats) extends Serializable { + def toNode: Node = toNode(prune = true) + /** * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. */ - def toNode: Node = { - if (leftChild.nonEmpty) { - assert(rightChild.nonEmpty && split.nonEmpty && stats != null, + def toNode(prune: Boolean = true): Node = { + + if (!leftChild.isEmpty || !rightChild.isEmpty) { + assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null, "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") - new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, - leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator) + (leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match { + case (l: LeafNode, r: LeafNode) if prune && l.prediction == r.prediction => + new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator) + case (l, r) => + new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, + l, r, split.get, stats.impurityCalculator) + } } else { if (stats.valid) { new LeafNode(stats.impurityCalculator.predict, stats.impurity, @@ -283,7 +290,6 @@ private[tree] class LearningNode( // Here we want to keep same behavior with the old mllib.DecisionTreeModel new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) } - } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index acfc6399c553b..8e514f11e78ea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -92,6 +92,7 @@ private[spark] object RandomForest extends Logging { featureSubsetStrategy: String, seed: Long, instr: Option[Instrumentation[_]], + prune: Boolean = true, // exposed for testing only, real trees are always pruned parentUID: Option[String] = None): Array[DecisionTreeModel] = { val timer = new TimeTracker() @@ -223,22 +224,23 @@ private[spark] object RandomForest extends Logging { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(uid, rootNode.toNode, numFeatures, + new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures, strategy.getNumClasses) } } else { topNodes.map { rootNode => - new DecisionTreeRegressionModel(uid, rootNode.toNode, numFeatures) + new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures) } } case None => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(rootNode.toNode, numFeatures, + new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures, strategy.getNumClasses) } } else { - topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode, numFeatures)) + topNodes.map(rootNode => + new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures)) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 98c879ece62d6..38b265d62611b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -280,44 +280,6 @@ class DecisionTreeClassifierSuite dt.fit(df) } - test("Use soft prediction for binary classification with ordered categorical features") { - // The following dataset is set up such that the best split is {1} vs. {0, 2}. - // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen. - val arr = Array( - LabeledPoint(0.0, Vectors.dense(0.0)), - LabeledPoint(0.0, Vectors.dense(0.0)), - LabeledPoint(0.0, Vectors.dense(0.0)), - LabeledPoint(1.0, Vectors.dense(0.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(2.0)), - LabeledPoint(0.0, Vectors.dense(2.0)), - LabeledPoint(0.0, Vectors.dense(2.0)), - LabeledPoint(1.0, Vectors.dense(2.0))) - val data = sc.parallelize(arr) - val df = TreeTests.setMetadata(data, Map(0 -> 3), 2) - - // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. - val dt = new DecisionTreeClassifier() - .setImpurity("gini") - .setMaxDepth(1) - .setMaxBins(3) - val model = dt.fit(df) - model.rootNode match { - case n: InternalNode => - n.split match { - case s: CategoricalSplit => - assert(s.leftCategories === Array(1.0)) - case other => - fail(s"All splits should be categorical, but got ${other.getClass.getName}: $other.") - } - case other => - fail(s"Root node should be an internal node, but got ${other.getClass.getName}: $other.") - } - } - test("Feature importance with toy data") { val dt = new DecisionTreeClassifier() .setImpurity("gini") diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index dbe2ea931fb9c..5f0d26eb5c058 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.tree.impl +import scala.annotation.tailrec import scala.collection.mutable import org.apache.spark.SparkFunSuite @@ -38,6 +39,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { import RandomForestSuite.mapToVec + private val seed = 42 + ///////////////////////////////////////////////////////////////////////////// // Tests for split calculation ///////////////////////////////////////////////////////////////////////////// @@ -320,10 +323,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.isLeaf === false) assert(topNode.stats === null) - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) + val nodesForGroup = Map(0 -> Array(topNode)) + val treeToNodeToIndexInfo = Map(0 -> Map( + topNode.id -> new RandomForest.NodeIndexInfo(0, None) + )) val nodeStack = new mutable.ArrayStack[(Int, LearningNode)] RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack) @@ -362,10 +365,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.isLeaf === false) assert(topNode.stats === null) - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) + val nodesForGroup = Map(0 -> Array(topNode)) + val treeToNodeToIndexInfo = Map(0 -> Map( + topNode.id -> new RandomForest.NodeIndexInfo(0, None) + )) val nodeStack = new mutable.ArrayStack[(Int, LearningNode)] RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack) @@ -407,7 +410,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3) val model = RandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 42, instr = None).head + seed = 42, instr = None, prune = false).head + model.rootNode match { case n: InternalNode => n.split match { case s: CategoricalSplit => @@ -631,13 +635,89 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) } + + /////////////////////////////////////////////////////////////////////////////// + // Tests for pruning of redundant subtrees (generated by a split improving the + // impurity measure, but always leading to the same prediction). + /////////////////////////////////////////////////////////////////////////////// + + test("SPARK-3159 tree model redundancy - classification") { + // The following dataset is set up such that splitting over feature_1 for points having + // feature_0 = 0 improves the impurity measure, despite the prediction will always be 0 + // in both branches. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)) + ) + val rdd = sc.parallelize(arr) + + val numClasses = 2 + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 4, + numClasses = numClasses, maxBins = 32) + + val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None).head + + val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None, prune = false).head + + assert(prunedTree.numNodes === 5) + assert(unprunedTree.numNodes === 7) + + assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size) + } + + test("SPARK-3159 tree model redundancy - regression") { + // The following dataset is set up such that splitting over feature_0 for points having + // feature_1 = 1 improves the impurity measure, despite the prediction will always be 0.5 + // in both branches. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.5, Vectors.dense(1.0, 1.0)) + ) + val rdd = sc.parallelize(arr) + + val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, maxDepth = 4, + numClasses = 0, maxBins = 32) + + val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None).head + + val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None, prune = false).head + + assert(prunedTree.numNodes === 3) + assert(unprunedTree.numNodes === 5) + assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size) + } } private object RandomForestSuite { - def mapToVec(map: Map[Int, Double]): Vector = { val size = (map.keys.toSeq :+ 0).max + 1 val (indices, values) = map.toSeq.sortBy(_._1).unzip Vectors.sparse(size, indices.toArray, values.toArray) } + + @tailrec + private def getSumLeafCounters(nodes: List[Node], acc: Long = 0): Long = { + if (nodes.isEmpty) { + acc + } + else { + nodes.head match { + case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild :: nodes.tail, acc) + case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.count) + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 441d0f7614bf6..bc59f3f4125fb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -363,10 +363,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { // if a split does not satisfy min instances per node requirements, // this split is invalid, even though the information gain of split is large. val arr = Array( - LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), - LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), - LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), - LabeledPoint(0.0, Vectors.dense(0.0, 0.0))) + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 0.0))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, @@ -541,7 +541,7 @@ object DecisionTreeSuite extends SparkFunSuite { Array[LabeledPoint] = { val arr = new Array[LabeledPoint](3000) for (i <- 0 until 3000) { - if (i < 1000) { + if (i < 1001) { arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0)) } else if (i < 2000) { arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0)) From dea381dfaa73e0cfb9a833b79c741b15ae274f64 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Sat, 3 Mar 2018 09:10:48 +0800 Subject: [PATCH 418/774] [SPARK-23514][FOLLOW-UP] Remove more places using sparkContext.hadoopConfiguration directly ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/20679 I missed a few places in SQL tests. For hygiene, they should also use the sessionState interface where possible. ## How was this patch tested? Modified existing tests. Author: Juliusz Sompolski Closes #20718 from juliuszsompolski/SPARK-23514-followup. --- .../scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala | 2 +- .../org/apache/spark/sql/execution/command/DDLSuite.scala | 2 +- .../spark/sql/execution/datasources/FileIndexSuite.scala | 2 +- .../execution/datasources/parquet/ParquetCommitterSuite.scala | 2 +- .../datasources/parquet/ParquetFileFormatSuite.scala | 4 ++-- .../datasources/parquet/ParquetInteroperabilitySuite.scala | 2 +- .../sql/execution/datasources/parquet/ParquetQuerySuite.scala | 2 +- .../org/apache/spark/sql/streaming/FileStreamSinkSuite.scala | 2 +- 8 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index b5d4c558f0d3e..73e3df3b6202e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -124,7 +124,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo Seq("1").toDF("a").write.format(format).save(new Path(basePath, "second").toString) val thirdPath = new Path(basePath, "third") - val fs = thirdPath.getFileSystem(spark.sparkContext.hadoopConfiguration) + val fs = thirdPath.getFileSystem(spark.sessionState.newHadoopConf()) Seq("2").toDF("a").write.format(format).save(thirdPath.toString) val files = fs.listStatus(thirdPath).filter(_.isFile).map(_.getPath) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index b800e6ff5b0ce..db9023b7ec8b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1052,7 +1052,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val part2 = Map("a" -> "2", "b" -> "6") val root = new Path(catalog.getTableMetadata(tableIdent).location) - val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) + val fs = root.getFileSystem(spark.sessionState.newHadoopConf()) // valid fs.mkdirs(new Path(new Path(root, "a=1"), "b=5")) fs.createNewFile(new Path(new Path(root, "a=1/b=5"), "a.csv")) // file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index b4616826e40b3..18bb4bfe661ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -59,7 +59,7 @@ class FileIndexSuite extends SharedSQLContext { require(!unqualifiedDirPath.toString.contains("file:")) require(!unqualifiedFilePath.toString.contains("file:")) - val fs = unqualifiedDirPath.getFileSystem(sparkContext.hadoopConfiguration) + val fs = unqualifiedDirPath.getFileSystem(spark.sessionState.newHadoopConf()) val qualifiedFilePath = fs.makeQualified(new Path(file.getCanonicalPath)) require(qualifiedFilePath.toString.startsWith("file:")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala index caa4f6d70c6a9..f3ecc5ced689f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala @@ -101,7 +101,7 @@ class ParquetCommitterSuite extends SparkFunSuite with SQLTestUtils if (check) { result = Some(MarkingFileOutput.checkMarker( destPath, - spark.sparkContext.hadoopConfiguration)) + spark.sessionState.newHadoopConf())) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala index ccb34355f1bac..3a0867fd2b78b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala @@ -29,7 +29,7 @@ class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLCo test("read parquet footers in parallel") { def testReadFooters(ignoreCorruptFiles: Boolean): Unit = { withTempDir { dir => - val fs = FileSystem.get(sparkContext.hadoopConfiguration) + val fs = FileSystem.get(spark.sessionState.newHadoopConf()) val basePath = dir.getCanonicalPath val path1 = new Path(basePath, "first") @@ -44,7 +44,7 @@ class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLCo Seq(fs.listStatus(path1), fs.listStatus(path2), fs.listStatus(path3)).flatten val footers = ParquetFileFormat.readParquetFootersInParallel( - sparkContext.hadoopConfiguration, fileStatuses, ignoreCorruptFiles) + spark.sessionState.newHadoopConf(), fileStatuses, ignoreCorruptFiles) assert(footers.size == 2) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala index e3edafa9c84e1..fbd83a0fa425a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -163,7 +163,7 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS // Just to be defensive in case anything ever changes in parquet, this test checks // the assumption on column stats, and also the end-to-end behavior. - val hadoopConf = sparkContext.hadoopConfiguration + val hadoopConf = spark.sessionState.newHadoopConf() val fs = FileSystem.get(hadoopConf) val parts = fs.listStatus(new Path(tableDir.getAbsolutePath), new PathFilter { override def accept(path: Path): Boolean = !path.getName.startsWith("_") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 55b0f729be8ce..e1f094d0a7af3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -819,7 +819,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val path = dir.getCanonicalPath spark.range(3).write.parquet(path) - val fs = FileSystem.get(sparkContext.hadoopConfiguration) + val fs = FileSystem.get(spark.sessionState.newHadoopConf()) val files = fs.listFiles(new Path(path), true) while (files.hasNext) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index ba48bc1ce0c4d..31e5527d7366a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -353,7 +353,7 @@ class FileStreamSinkSuite extends StreamTest { } test("FileStreamSink.ancestorIsMetadataDirectory()") { - val hadoopConf = spark.sparkContext.hadoopConfiguration + val hadoopConf = spark.sessionState.newHadoopConf() def assertAncestorIsMetadataDirectory(path: String): Unit = assert(FileStreamSink.ancestorIsMetadataDirectory(new Path(path), hadoopConf)) def assertAncestorIsNotMetadataDirectory(path: String): Unit = From 486f99eefead4e664a30a861eca65cab8568e70b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 2 Mar 2018 18:14:13 -0800 Subject: [PATCH 419/774] [SPARK-23541][SS] Allow Kafka source to read data with greater parallelism than the number of topic-partitions ## What changes were proposed in this pull request? Currently, when the Kafka source reads from Kafka, it generates as many tasks as the number of partitions in the topic(s) to be read. In some case, it may be beneficial to read the data with greater parallelism, that is, with more number partitions/tasks. That means, offset ranges must be divided up into smaller ranges such the number of records in partition ~= total records in batch / desired partitions. This would also balance out any data skews between topic-partitions. In this patch, I have added a new option called `minPartitions`, which allows the user to specify the desired level of parallelism. ## How was this patch tested? New tests in KafkaMicroBatchV2SourceSuite. Author: Tathagata Das Closes #20698 from tdas/SPARK-23541. --- .../sql/kafka010/KafkaMicroBatchReader.scala | 109 ++++++------- .../kafka010/KafkaOffsetRangeCalculator.scala | 105 +++++++++++++ .../sql/kafka010/KafkaSourceProvider.scala | 7 + .../apache/spark/sql/kafka010/package.scala | 24 +++ .../kafka010/KafkaMicroBatchSourceSuite.scala | 56 ++++++- .../KafkaOffsetRangeCalculatorSuite.scala | 147 ++++++++++++++++++ 6 files changed, 388 insertions(+), 60 deletions(-) create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index fb647ca7e70dd..8a5f3a249b11c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -24,7 +24,6 @@ import java.nio.charset.StandardCharsets import scala.collection.JavaConverters._ import org.apache.commons.io.IOUtils -import org.apache.kafka.common.TopicPartition import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging @@ -64,8 +63,6 @@ private[kafka010] class KafkaMicroBatchReader( failOnDataLoss: Boolean) extends MicroBatchReader with SupportsScanUnsafeRow with Logging { - type PartitionOffsetMap = Map[TopicPartition, Long] - private var startPartitionOffsets: PartitionOffsetMap = _ private var endPartitionOffsets: PartitionOffsetMap = _ @@ -76,6 +73,7 @@ private[kafka010] class KafkaMicroBatchReader( private val maxOffsetsPerTrigger = Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong) + private val rangeCalculator = KafkaOffsetRangeCalculator(options) /** * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only * called in StreamExecutionThread. Otherwise, interrupting a thread while running @@ -106,15 +104,15 @@ private[kafka010] class KafkaMicroBatchReader( override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { // Find the new partitions, and get their earliest offsets val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) - val newPartitionOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) - if (newPartitionOffsets.keySet != newPartitions) { + val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) + if (newPartitionInitialOffsets.keySet != newPartitions) { // We cannot get from offsets for some partitions. It means they got deleted. - val deletedPartitions = newPartitions.diff(newPartitionOffsets.keySet) + val deletedPartitions = newPartitions.diff(newPartitionInitialOffsets.keySet) reportDataLoss( s"Cannot find earliest offsets of ${deletedPartitions}. Some data may have been missed") } - logInfo(s"Partitions added: $newPartitionOffsets") - newPartitionOffsets.filter(_._2 != 0).foreach { case (p, o) => + logInfo(s"Partitions added: $newPartitionInitialOffsets") + newPartitionInitialOffsets.filter(_._2 != 0).foreach { case (p, o) => reportDataLoss( s"Added partition $p starts from $o instead of 0. Some data may have been missed") } @@ -125,46 +123,28 @@ private[kafka010] class KafkaMicroBatchReader( reportDataLoss(s"$deletedPartitions are gone. Some data may have been missed") } - // Use the until partitions to calculate offset ranges to ignore partitions that have + // Use the end partitions to calculate offset ranges to ignore partitions that have // been deleted val topicPartitions = endPartitionOffsets.keySet.filter { tp => // Ignore partitions that we don't know the from offsets. - newPartitionOffsets.contains(tp) || startPartitionOffsets.contains(tp) + newPartitionInitialOffsets.contains(tp) || startPartitionOffsets.contains(tp) }.toSeq logDebug("TopicPartitions: " + topicPartitions.mkString(", ")) - val sortedExecutors = getSortedExecutorList() - val numExecutors = sortedExecutors.length - logDebug("Sorted executors: " + sortedExecutors.mkString(", ")) - // Calculate offset ranges - val factories = topicPartitions.flatMap { tp => - val fromOffset = startPartitionOffsets.get(tp).getOrElse { - newPartitionOffsets.getOrElse( - tp, { - // This should not happen since newPartitionOffsets contains all partitions not in - // fromPartitionOffsets - throw new IllegalStateException(s"$tp doesn't have a from offset") - }) - } - val untilOffset = endPartitionOffsets(tp) - - if (untilOffset >= fromOffset) { - // This allows cached KafkaConsumers in the executors to be re-used to read the same - // partition in every batch. - val preferredLoc = if (numExecutors > 0) { - Some(sortedExecutors(Math.floorMod(tp.hashCode, numExecutors))) - } else None - val range = KafkaOffsetRange(tp, fromOffset, untilOffset) - Some( - new KafkaMicroBatchDataReaderFactory( - range, preferredLoc, executorKafkaParams, pollTimeoutMs, failOnDataLoss)) - } else { - reportDataLoss( - s"Partition $tp's offset was changed from " + - s"$fromOffset to $untilOffset, some data may have been missed") - None - } + val offsetRanges = rangeCalculator.getRanges( + fromOffsets = startPartitionOffsets ++ newPartitionInitialOffsets, + untilOffsets = endPartitionOffsets, + executorLocations = getSortedExecutorList()) + + // Reuse Kafka consumers only when all the offset ranges have distinct TopicPartitions, + // that is, concurrent tasks will not read the same TopicPartitions. + val reuseKafkaConsumer = offsetRanges.map(_.topicPartition).toSet.size == offsetRanges.size + + // Generate factories based on the offset ranges + val factories = offsetRanges.map { range => + new KafkaMicroBatchDataReaderFactory( + range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) } factories.map(_.asInstanceOf[DataReaderFactory[UnsafeRow]]).asJava } @@ -320,28 +300,39 @@ private[kafka010] class KafkaMicroBatchReader( } /** A [[DataReaderFactory]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] class KafkaMicroBatchDataReaderFactory( - range: KafkaOffsetRange, - preferredLoc: Option[String], +private[kafka010] case class KafkaMicroBatchDataReaderFactory( + offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends DataReaderFactory[UnsafeRow] { + failOnDataLoss: Boolean, + reuseKafkaConsumer: Boolean) extends DataReaderFactory[UnsafeRow] { - override def preferredLocations(): Array[String] = preferredLoc.toArray + override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray override def createDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader( - range, executorKafkaParams, pollTimeoutMs, failOnDataLoss) + offsetRange, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) } /** A [[DataReader]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] class KafkaMicroBatchDataReader( +private[kafka010] case class KafkaMicroBatchDataReader( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends DataReader[UnsafeRow] with Logging { + failOnDataLoss: Boolean, + reuseKafkaConsumer: Boolean) extends DataReader[UnsafeRow] with Logging { + + private val consumer = { + if (!reuseKafkaConsumer) { + // If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. We + // uses `assign` here, hence we don't need to worry about the "group.id" conflicts. + CachedKafkaConsumer.createUncached( + offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + } else { + CachedKafkaConsumer.getOrCreate( + offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + } + } - private val consumer = CachedKafkaConsumer.getOrCreate( - offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) private val rangeToRead = resolveRange(offsetRange) private val converter = new KafkaRecordToUnsafeRowConverter @@ -369,9 +360,14 @@ private[kafka010] class KafkaMicroBatchDataReader( } override def close(): Unit = { - // Indicate that we're no longer using this consumer - CachedKafkaConsumer.releaseKafkaConsumer( - offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + if (!reuseKafkaConsumer) { + // Don't forget to close non-reuse KafkaConsumers. You may take down your cluster! + consumer.close() + } else { + // Indicate that we're no longer using this consumer + CachedKafkaConsumer.releaseKafkaConsumer( + offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + } } private def resolveRange(range: KafkaOffsetRange): KafkaOffsetRange = { @@ -392,12 +388,9 @@ private[kafka010] class KafkaMicroBatchDataReader( } else { range.untilOffset } - KafkaOffsetRange(range.topicPartition, fromOffset, untilOffset) + KafkaOffsetRange(range.topicPartition, fromOffset, untilOffset, None) } else { range } } } - -private[kafka010] case class KafkaOffsetRange( - topicPartition: TopicPartition, fromOffset: Long, untilOffset: Long) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala new file mode 100644 index 0000000000000..6631ae84167c8 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.sql.sources.v2.DataSourceOptions + + +/** + * Class to calculate offset ranges to process based on the the from and until offsets, and + * the configured `minPartitions`. + */ +private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int]) { + require(minPartitions.isEmpty || minPartitions.get > 0) + + import KafkaOffsetRangeCalculator._ + /** + * Calculate the offset ranges that we are going to process this batch. If `minPartitions` + * is not set or is set less than or equal the number of `topicPartitions` that we're going to + * consume, then we fall back to a 1-1 mapping of Spark tasks to Kafka partitions. If + * `numPartitions` is set higher than the number of our `topicPartitions`, then we will split up + * the read tasks of the skewed partitions to multiple Spark tasks. + * The number of Spark tasks will be *approximately* `numPartitions`. It can be less or more + * depending on rounding errors or Kafka partitions that didn't receive any new data. + */ + def getRanges( + fromOffsets: PartitionOffsetMap, + untilOffsets: PartitionOffsetMap, + executorLocations: Seq[String] = Seq.empty): Seq[KafkaOffsetRange] = { + val partitionsToRead = untilOffsets.keySet.intersect(fromOffsets.keySet) + + val offsetRanges = partitionsToRead.toSeq.map { tp => + KafkaOffsetRange(tp, fromOffsets(tp), untilOffsets(tp), preferredLoc = None) + }.filter(_.size > 0) + + // If minPartitions not set or there are enough partitions to satisfy minPartitions + if (minPartitions.isEmpty || offsetRanges.size > minPartitions.get) { + // Assign preferred executor locations to each range such that the same topic-partition is + // preferentially read from the same executor and the KafkaConsumer can be reused. + offsetRanges.map { range => + range.copy(preferredLoc = getLocation(range.topicPartition, executorLocations)) + } + } else { + + // Splits offset ranges with relatively large amount of data to smaller ones. + val totalSize = offsetRanges.map(_.size).sum + val idealRangeSize = totalSize.toDouble / minPartitions.get + + offsetRanges.flatMap { range => + // Split the current range into subranges as close to the ideal range size + val numSplitsInRange = math.round(range.size.toDouble / idealRangeSize).toInt + + (0 until numSplitsInRange).map { i => + val splitStart = range.fromOffset + range.size * (i.toDouble / numSplitsInRange) + val splitEnd = range.fromOffset + range.size * ((i.toDouble + 1) / numSplitsInRange) + KafkaOffsetRange( + range.topicPartition, splitStart.toLong, splitEnd.toLong, preferredLoc = None) + } + } + } + } + + private def getLocation(tp: TopicPartition, executorLocations: Seq[String]): Option[String] = { + def floorMod(a: Long, b: Int): Int = ((a % b).toInt + b) % b + + val numExecutors = executorLocations.length + if (numExecutors > 0) { + // This allows cached KafkaConsumers in the executors to be re-used to read the same + // partition in every batch. + Some(executorLocations(floorMod(tp.hashCode, numExecutors))) + } else None + } +} + +private[kafka010] object KafkaOffsetRangeCalculator { + + def apply(options: DataSourceOptions): KafkaOffsetRangeCalculator = { + val optionalValue = Option(options.get("minPartitions").orElse(null)).map(_.toInt) + new KafkaOffsetRangeCalculator(optionalValue) + } +} + +private[kafka010] case class KafkaOffsetRange( + topicPartition: TopicPartition, + fromOffset: Long, + untilOffset: Long, + preferredLoc: Option[String]) { + lazy val size: Long = untilOffset - fromOffset +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 0aa64a6a9cf90..36b9f0466566b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -348,6 +348,12 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister throw new IllegalArgumentException("Unknown option") } + // Validate minPartitions value if present + if (caseInsensitiveParams.contains(MIN_PARTITIONS_OPTION_KEY)) { + val p = caseInsensitiveParams(MIN_PARTITIONS_OPTION_KEY).toInt + if (p <= 0) throw new IllegalArgumentException("minPartitions must be positive") + } + // Validate user-specified Kafka options if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) { @@ -455,6 +461,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { private[kafka010] val STARTING_OFFSETS_OPTION_KEY = "startingoffsets" private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets" private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" + private val MIN_PARTITIONS_OPTION_KEY = "minpartitions" val TOPIC_OPTION_KEY = "topic" diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala new file mode 100644 index 0000000000000..43acd6a8d9473 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.kafka.common.TopicPartition + +package object kafka010 { // scalastyle:ignore + // ^^ scalastyle:ignore is for ignoring warnings about digits in package name + type PartitionOffsetMap = Map[TopicPartition, Long] +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 89c9ef4cc73b5..f2b3ff7615e74 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.kafka010 import java.io._ import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.{Files, Paths} -import java.util.{Locale, Properties} +import java.util.{Locale, Optional, Properties} import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.io.Source import scala.util.Random @@ -34,15 +35,19 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext -import org.apache.spark.sql.{Dataset, ForeachWriter} +import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} +import org.apache.spark.sql.types.StructType abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { @@ -642,6 +647,53 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { } ) } + + testWithUninterruptibleThread("minPartitions is supported") { + import testImplicits._ + + val topic = newTopic() + val tp = new TopicPartition(topic, 0) + testUtils.createTopic(topic, partitions = 1) + + def test( + minPartitions: String, + numPartitionsGenerated: Int, + reusesConsumers: Boolean): Unit = { + + SparkSession.setActiveSession(spark) + withTempDir { dir => + val provider = new KafkaSourceProvider() + val options = Map( + "kafka.bootstrap.servers" -> testUtils.brokerAddress, + "subscribe" -> topic + ) ++ Option(minPartitions).map { p => "minPartitions" -> p} + val reader = provider.createMicroBatchReader( + Optional.empty[StructType], dir.getAbsolutePath, new DataSourceOptions(options.asJava)) + reader.setOffsetRange( + Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))), + Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L))) + ) + val factories = reader.createUnsafeRowReaderFactories().asScala + .map(_.asInstanceOf[KafkaMicroBatchDataReaderFactory]) + withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") { + assert(factories.size == numPartitionsGenerated) + factories.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) } + } + } + } + + // Test cases when minPartitions is used and not used + test(minPartitions = null, numPartitionsGenerated = 1, reusesConsumers = true) + test(minPartitions = "1", numPartitionsGenerated = 1, reusesConsumers = true) + test(minPartitions = "4", numPartitionsGenerated = 4, reusesConsumers = false) + + // Test illegal minPartitions values + intercept[IllegalArgumentException] { test(minPartitions = "a", 1, true) } + intercept[IllegalArgumentException] { test(minPartitions = "1.0", 1, true) } + intercept[IllegalArgumentException] { test(minPartitions = "0", 1, true) } + intercept[IllegalArgumentException] { test(minPartitions = "-1", 1, true) } + } + } abstract class KafkaSourceSuiteBase extends KafkaSourceTest { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala new file mode 100644 index 0000000000000..2ccf3e291bea7 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import scala.collection.JavaConverters._ + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.sources.v2.DataSourceOptions + +class KafkaOffsetRangeCalculatorSuite extends SparkFunSuite { + + def testWithMinPartitions(name: String, minPartition: Int) + (f: KafkaOffsetRangeCalculator => Unit): Unit = { + val options = new DataSourceOptions(Map("minPartitions" -> minPartition.toString).asJava) + test(s"with minPartition = $minPartition: $name") { + f(KafkaOffsetRangeCalculator(options)) + } + } + + + test("with no minPartition: N TopicPartitions to N offset ranges") { + val calc = KafkaOffsetRangeCalculator(DataSourceOptions.empty()) + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 2)) == + Seq(KafkaOffsetRange(tp1, 1, 2, None))) + + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 2, tp2 -> 1), Seq.empty) == + Seq(KafkaOffsetRange(tp1, 1, 2, None))) + + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1), + untilOffsets = Map(tp1 -> 2)) == + Seq(KafkaOffsetRange(tp1, 1, 2, None))) + + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1), + untilOffsets = Map(tp1 -> 2), + executorLocations = Seq("location")) == + Seq(KafkaOffsetRange(tp1, 1, 2, Some("location")))) + } + + test("with no minPartition: empty ranges ignored") { + val calc = KafkaOffsetRangeCalculator(DataSourceOptions.empty()) + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1), + untilOffsets = Map(tp1 -> 2, tp2 -> 1)) == + Seq(KafkaOffsetRange(tp1, 1, 2, None))) + } + + testWithMinPartitions("N TopicPartitions to N offset ranges", 3) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1, tp3 -> 1), + untilOffsets = Map(tp1 -> 2, tp2 -> 2, tp3 -> 2)) == + Seq( + KafkaOffsetRange(tp1, 1, 2, None), + KafkaOffsetRange(tp2, 1, 2, None), + KafkaOffsetRange(tp3, 1, 2, None))) + } + + testWithMinPartitions("1 TopicPartition to N offset ranges", 4) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 5)) == + Seq( + KafkaOffsetRange(tp1, 1, 2, None), + KafkaOffsetRange(tp1, 2, 3, None), + KafkaOffsetRange(tp1, 3, 4, None), + KafkaOffsetRange(tp1, 4, 5, None))) + + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 5), + executorLocations = Seq("location")) == + Seq( + KafkaOffsetRange(tp1, 1, 2, None), + KafkaOffsetRange(tp1, 2, 3, None), + KafkaOffsetRange(tp1, 3, 4, None), + KafkaOffsetRange(tp1, 4, 5, None))) // location pref not set when minPartition is set + } + + testWithMinPartitions("N skewed TopicPartitions to M offset ranges", 3) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1), + untilOffsets = Map(tp1 -> 5, tp2 -> 21)) == + Seq( + KafkaOffsetRange(tp1, 1, 5, None), + KafkaOffsetRange(tp2, 1, 7, None), + KafkaOffsetRange(tp2, 7, 14, None), + KafkaOffsetRange(tp2, 14, 21, None))) + } + + testWithMinPartitions("range inexact multiple of minPartitions", 3) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 11)) == + Seq( + KafkaOffsetRange(tp1, 1, 4, None), + KafkaOffsetRange(tp1, 4, 7, None), + KafkaOffsetRange(tp1, 7, 11, None))) + } + + testWithMinPartitions("empty ranges ignored", 3) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1, tp3 -> 1), + untilOffsets = Map(tp1 -> 5, tp2 -> 21, tp3 -> 1)) == + Seq( + KafkaOffsetRange(tp1, 1, 5, None), + KafkaOffsetRange(tp2, 1, 7, None), + KafkaOffsetRange(tp2, 7, 14, None), + KafkaOffsetRange(tp2, 14, 21, None))) + } + + private val tp1 = new TopicPartition("t1", 1) + private val tp2 = new TopicPartition("t2", 1) + private val tp3 = new TopicPartition("t3", 1) +} From a89cdf55fa76fa23a524f0443e323498c3cc8664 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 5 Mar 2018 07:32:24 +0900 Subject: [PATCH 420/774] [SQL][MINOR] XPathDouble prettyPrint should say 'double' not 'float' ## What changes were proposed in this pull request? It looks like this was incorrectly copied from `XPathFloat` in the class above. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Eric Liang Closes #20730 from ericl/fix-typo-xpath. --- .../org/apache/spark/sql/catalyst/expressions/xml/xpath.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index d0185562c9cfc..aacf1a44e2ad0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -160,7 +160,7 @@ case class XPathFloat(xml: Expression, path: Expression) extends XPathExtract { """) // scalastyle:on line.size.limit case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract { - override def prettyName: String = "xpath_float" + override def prettyName: String = "xpath_double" override def dataType: DataType = DoubleType override def nullSafeEval(xml: Any, path: Any): Any = { From 7965c91d8a67c213ca5eebda5e46e7c49a8ba121 Mon Sep 17 00:00:00 2001 From: "Michael (Stu) Stewart" Date: Mon, 5 Mar 2018 13:36:42 +0900 Subject: [PATCH 421/774] [SPARK-23569][PYTHON] Allow pandas_udf to work with python3 style type-annotated functions ## What changes were proposed in this pull request? Check python version to determine whether to use `inspect.getargspec` or `inspect.getfullargspec` before applying `pandas_udf` core logic to a function. The former is python2.7 (deprecated in python3) and the latter is python3.x. The latter correctly accounts for type annotations, which are syntax errors in python2.x. ## How was this patch tested? Locally, on python 2.7 and 3.6. Author: Michael (Stu) Stewart Closes #20728 from mstewart141/pandas_udf_fix. --- python/pyspark/sql/tests.py | 18 ++++++++++++++++++ python/pyspark/sql/udf.py | 9 ++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 19653072ea316..fa3b7203e10ac 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4381,6 +4381,24 @@ def test_timestamp_dst(self): result = df.withColumn('time', foo_udf(df.time)) self.assertEquals(df.collect(), result.collect()) + @unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.") + def test_type_annotation(self): + from pyspark.sql.functions import pandas_udf + # Regression test to check if type hints can be used. See SPARK-23569. + # Note that it throws an error during compilation in lower Python versions if 'exec' + # is not used. Also, note that we explicitly use another dictionary to avoid modifications + # in the current 'locals()'. + # + # Hyukjin: I think it's an ugly way to test issues about syntax specific in + # higher versions of Python, which we shouldn't encourage. This was the last resort + # I could come up with at that time. + _locals = {} + exec( + "import pandas as pd\ndef noop(col: pd.Series) -> pd.Series: return col", + _locals) + df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id')) + self.assertEqual(df.first()[0], 0) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index e5b35fc60e167..b9b490874f4fb 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -42,10 +42,17 @@ def _create_udf(f, returnType, evalType): PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): import inspect + import sys from pyspark.sql.utils import require_minimum_pyarrow_version require_minimum_pyarrow_version() - argspec = inspect.getargspec(f) + + if sys.version_info[0] < 3: + # `getargspec` is deprecated since python3.0 (incompatible with function annotations). + # See SPARK-23569. + argspec = inspect.getargspec(f) + else: + argspec = inspect.getfullargspec(f) if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \ argspec.varargs is None: From 269cd53590dd155aeb5269efc909a6e228f21e22 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 4 Mar 2018 21:22:30 -0800 Subject: [PATCH 422/774] [MINOR][DOCS] Fix a link in "Compatibility with Apache Hive" ## What changes were proposed in this pull request? This PR fixes a broken link as below: **Before:** 2018-03-05 12 23 58 **After:** 2018-03-05 12 23 20 Also see https://spark.apache.org/docs/2.3.0/sql-programming-guide.html#compatibility-with-apache-hive ## How was this patch tested? Manually tested. I checked the same instances in `docs` directory. Seems this is the only one. Author: hyukjinkwon Closes #20733 from HyukjinKwon/minor-link. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index c37c338a134f3..4d0f015f401bb 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -2223,7 +2223,7 @@ referencing a singleton. Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 2.1.1. Also see [Interacting with Different Versions of Hive Metastore] (#interacting-with-different-versions-of-hive-metastore)). +(from 0.12.0 to 2.1.1. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses From 2ce37b50fc01558f49ad22f89c8659f50544ffec Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 5 Mar 2018 11:39:01 +0100 Subject: [PATCH 423/774] [SPARK-23546][SQL] Refactor stateless methods/values in CodegenContext ## What changes were proposed in this pull request? A current `CodegenContext` class has immutable value or method without mutable state, too. This refactoring moves them to `CodeGenerator` object class which can be accessed from anywhere without an instantiated `CodegenContext` in the program. ## How was this patch tested? Existing tests Author: Kazuaki Ishizaki Closes #20700 from kiszk/SPARK-23546. --- .../catalyst/expressions/BoundAttribute.scala | 9 +- .../spark/sql/catalyst/expressions/Cast.scala | 35 +- .../sql/catalyst/expressions/Expression.scala | 16 +- .../MonotonicallyIncreasingID.scala | 8 +- .../sql/catalyst/expressions/ScalaUDF.scala | 7 +- .../expressions/SparkPartitionID.scala | 7 +- .../sql/catalyst/expressions/TimeWindow.scala | 4 +- .../sql/catalyst/expressions/arithmetic.scala | 51 +- .../expressions/bitwiseExpressions.scala | 2 +- .../expressions/codegen/CodeGenerator.scala | 458 +++++++++--------- .../expressions/codegen/CodegenFallback.scala | 7 +- .../codegen/GenerateMutableProjection.scala | 6 +- .../codegen/GenerateOrdering.scala | 4 +- .../codegen/GenerateSafeProjection.scala | 6 +- .../codegen/GenerateUnsafeProjection.scala | 11 +- .../expressions/collectionOperations.scala | 6 +- .../expressions/complexTypeCreator.scala | 4 +- .../expressions/complexTypeExtractors.scala | 15 +- .../expressions/conditionalExpressions.scala | 10 +- .../expressions/datetimeExpressions.scala | 18 +- .../spark/sql/catalyst/expressions/hash.scala | 25 +- .../catalyst/expressions/inputFileBlock.scala | 8 +- .../sql/catalyst/expressions/literals.scala | 8 +- .../expressions/mathExpressions.scala | 5 +- .../expressions/nullExpressions.scala | 22 +- .../expressions/objects/objects.scala | 99 ++-- .../sql/catalyst/expressions/predicates.scala | 14 +- .../expressions/randomExpressions.scala | 8 +- .../expressions/regexpExpressions.scala | 8 +- .../expressions/stringExpressions.scala | 39 +- .../expressions/CodeGenerationSuite.scala | 4 +- .../sql/execution/ColumnarBatchScan.scala | 13 +- .../spark/sql/execution/ExpandExec.scala | 5 +- .../spark/sql/execution/GenerateExec.scala | 8 +- .../apache/spark/sql/execution/SortExec.scala | 5 +- .../sql/execution/WholeStageCodegenExec.scala | 2 +- .../aggregate/HashAggregateExec.scala | 16 +- .../aggregate/HashMapGenerator.scala | 8 +- .../aggregate/RowBasedHashMapGenerator.scala | 8 +- .../VectorizedHashMapGenerator.scala | 11 +- .../execution/basicPhysicalOperators.scala | 10 +- .../columnar/GenerateColumnAccessor.scala | 2 +- .../joins/BroadcastHashJoinExec.scala | 5 +- .../execution/joins/SortMergeJoinExec.scala | 8 +- .../apache/spark/sql/execution/limit.scala | 7 +- 45 files changed, 535 insertions(+), 497 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 6a17a397b3ef2..89ffbb0016916 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types._ /** @@ -66,13 +66,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) ev.copy(code = oev.code) } else { assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.") - val javaType = ctx.javaType(dataType) - val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) + val javaType = CodeGenerator.javaType(dataType) + val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { ev.copy(code = s""" |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); - |$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); + |$javaType ${ev.value} = ${ev.isNull} ? + | ${CodeGenerator.defaultValue(dataType)} : ($value); """.stripMargin) } else { ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 79b051670e9e4..12330bfa55ab9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -669,7 +669,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = { s""" boolean $resultIsNull = $inputIsNull; - ${ctx.javaType(resultType)} $result = ${ctx.defaultValue(resultType)}; + ${CodeGenerator.javaType(resultType)} $result = ${CodeGenerator.defaultValue(resultType)}; if (!$inputIsNull) { ${cast(input, result, resultIsNull)} } @@ -685,7 +685,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val funcName = ctx.freshName("elementToString") val elementToStringFunc = ctx.addNewFunction(funcName, s""" - |private UTF8String $funcName(${ctx.javaType(et)} element) { + |private UTF8String $funcName(${CodeGenerator.javaType(et)} element) { | UTF8String elementStr = null; | ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)} | return elementStr; @@ -697,13 +697,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$buffer.append("["); |if ($array.numElements() > 0) { | if (!$array.isNullAt(0)) { - | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")})); + | $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, "0")})); | } | for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) { | $buffer.append(","); | if (!$array.isNullAt($loopIndex)) { | $buffer.append(" "); - | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)})); + | $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, loopIndex)})); | } | } |} @@ -723,7 +723,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val dataToStringCode = castToStringCode(dataType, ctx) ctx.addNewFunction(funcName, s""" - |private UTF8String $funcName(${ctx.javaType(dataType)} data) { + |private UTF8String $funcName(${CodeGenerator.javaType(dataType)} data) { | UTF8String dataStr = null; | ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)} | return dataStr; @@ -734,23 +734,26 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val keyToStringFunc = dataToStringFunc("keyToString", kt) val valueToStringFunc = dataToStringFunc("valueToString", vt) val loopIndex = ctx.freshName("loopIndex") + val getMapFirstKey = CodeGenerator.getValue(s"$map.keyArray()", kt, "0") + val getMapFirstValue = CodeGenerator.getValue(s"$map.valueArray()", vt, "0") + val getMapKeyArray = CodeGenerator.getValue(s"$map.keyArray()", kt, loopIndex) + val getMapValueArray = CodeGenerator.getValue(s"$map.valueArray()", vt, loopIndex) s""" |$buffer.append("["); |if ($map.numElements() > 0) { - | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, "0")})); + | $buffer.append($keyToStringFunc($getMapFirstKey)); | $buffer.append(" ->"); | if (!$map.valueArray().isNullAt(0)) { | $buffer.append(" "); - | $buffer.append($valueToStringFunc(${ctx.getValue(s"$map.valueArray()", vt, "0")})); + | $buffer.append($valueToStringFunc($getMapFirstValue)); | } | for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) { | $buffer.append(", "); - | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, loopIndex)})); + | $buffer.append($keyToStringFunc($getMapKeyArray)); | $buffer.append(" ->"); | if (!$map.valueArray().isNullAt($loopIndex)) { | $buffer.append(" "); - | $buffer.append($valueToStringFunc( - | ${ctx.getValue(s"$map.valueArray()", vt, loopIndex)})); + | $buffer.append($valueToStringFunc($getMapValueArray)); | } | } |} @@ -773,7 +776,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String | ${if (i != 0) s"""$buffer.append(" ");""" else ""} | | // Append $i field into the string buffer - | ${ctx.javaType(ft)} $field = ${ctx.getValue(row, ft, s"$i")}; + | ${CodeGenerator.javaType(ft)} $field = ${CodeGenerator.getValue(row, ft, s"$i")}; | UTF8String $fieldStr = null; | ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)} | $buffer.append($fieldStr); @@ -1202,8 +1205,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $values[$j] = null; } else { boolean $fromElementNull = false; - ${ctx.javaType(fromType)} $fromElementPrim = - ${ctx.getValue(c, fromType, j)}; + ${CodeGenerator.javaType(fromType)} $fromElementPrim = + ${CodeGenerator.getValue(c, fromType, j)}; ${castCode(ctx, fromElementPrim, fromElementNull, toElementPrim, toElementNull, toType, elementCast)} if ($toElementNull) { @@ -1259,20 +1262,20 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val fromFieldNull = ctx.freshName("ffn") val toFieldPrim = ctx.freshName("tfp") val toFieldNull = ctx.freshName("tfn") - val fromType = ctx.javaType(from.fields(i).dataType) + val fromType = CodeGenerator.javaType(from.fields(i).dataType) s""" boolean $fromFieldNull = $tmpInput.isNullAt($i); if ($fromFieldNull) { $tmpResult.setNullAt($i); } else { $fromType $fromFieldPrim = - ${ctx.getValue(tmpInput, from.fields(i).dataType, i.toString)}; + ${CodeGenerator.getValue(tmpInput, from.fields(i).dataType, i.toString)}; ${castCode(ctx, fromFieldPrim, fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)} if ($toFieldNull) { $tmpResult.setNullAt($i); } else { - ${ctx.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)}; + ${CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)}; } } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4568714933095..ed90b185865a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -119,7 +119,7 @@ abstract class Expression extends TreeNode[Expression] { // TODO: support whole stage codegen too if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { - val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull") + val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull eval.isNull = globalIsNull s"$globalIsNull = $localIsNull;" @@ -127,7 +127,7 @@ abstract class Expression extends TreeNode[Expression] { "" } - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val newValue = ctx.freshName("value") val funcName = ctx.freshName(nodeName) @@ -411,14 +411,14 @@ abstract class UnaryExpression extends Expression { ev.copy(code = s""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { ev.copy(code = s""" boolean ${ev.isNull} = false; ${childGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = "false") } } @@ -510,7 +510,7 @@ abstract class BinaryExpression extends Expression { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { @@ -518,7 +518,7 @@ abstract class BinaryExpression extends Expression { boolean ${ev.isNull} = false; ${leftGen.code} ${rightGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = "false") } } @@ -654,7 +654,7 @@ abstract class TernaryExpression extends Expression { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval""") } else { ev.copy(code = s""" @@ -662,7 +662,7 @@ abstract class TernaryExpression extends Expression { ${leftGen.code} ${midGen.code} ${rightGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 11fb579dfa88c..4523079060896 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types.{DataType, LongType} /** @@ -65,14 +65,14 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count") + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count") val partitionMaskTerm = "partitionMask" - ctx.addImmutableStateIfNotExists(ctx.JAVA_LONG, partitionMaskTerm) + ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_LONG, partitionMaskTerm) ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; + final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; $countTerm++;""", isNull = "false") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 989c02305620a..e869258469a97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -1018,11 +1018,12 @@ case class ScalaUDF( val udf = ctx.addReferenceObj("udf", function, s"scala.Function${children.length}") val getFuncResult = s"$udf.apply(${funcArgs.mkString(", ")})" val resultConverter = s"$convertersTerm[${children.length}]" + val boxedType = CodeGenerator.boxedType(dataType) val callFunc = s""" - |${ctx.boxedType(dataType)} $resultTerm = null; + |$boxedType $resultTerm = null; |try { - | $resultTerm = (${ctx.boxedType(dataType)})$resultConverter.apply($getFuncResult); + | $resultTerm = ($boxedType)$resultConverter.apply($getFuncResult); |} catch (Exception e) { | throw new org.apache.spark.SparkException($errorMsgTerm, e); |} @@ -1035,7 +1036,7 @@ case class ScalaUDF( |$callFunc | |boolean ${ev.isNull} = $resultTerm == null; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |if (!${ev.isNull}) { | ${ev.value} = $resultTerm; |} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index a160b9b275290..cc6a769d032d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -44,8 +44,9 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val idTerm = "partitionId" - ctx.addImmutableStateIfNotExists(ctx.JAVA_INT, idTerm) + ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false") + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", + isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 9a9f579b37f58..6c4a3601c1730 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -22,7 +22,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -165,7 +165,7 @@ case class PreciseTimestampConversion( val eval = child.genCode(ctx) ev.copy(code = eval.code + s"""boolean ${ev.isNull} = ${eval.isNull}; - |${ctx.javaType(dataType)} ${ev.value} = ${eval.value}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value}; """.stripMargin) } override def nullSafeEval(input: Any): Any = input diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 8bb14598a6d7b..508bdd5050b54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -49,8 +49,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression // codegen would fail to compile if we just write (-($c)) // for example, we could not write --9223372036854775808L in code s""" - ${ctx.javaType(dt)} $originValue = (${ctx.javaType(dt)})($eval); - ${ev.value} = (${ctx.javaType(dt)})(-($originValue)); + ${CodeGenerator.javaType(dt)} $originValue = (${CodeGenerator.javaType(dt)})($eval); + ${ev.value} = (${CodeGenerator.javaType(dt)})(-($originValue)); """}) case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") } @@ -107,7 +107,7 @@ case class Abs(child: Expression) case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.abs()") case dt: NumericType => - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(java.lang.Math.abs($c))") + defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dt)})(java.lang.Math.abs($c))") } protected override def nullSafeEval(input: Any): Any = numeric.abs(input) @@ -129,7 +129,7 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") case _ => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") } @@ -167,7 +167,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)") case ByteType | ShortType => defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") case CalendarIntervalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)") case _ => @@ -203,7 +203,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)") case ByteType | ShortType => defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") case CalendarIntervalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)") case _ => @@ -278,7 +278,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } else { s"${eval2.value} == 0" } - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val divide = if (dataType.isInstanceOf[DecimalType]) { s"${eval1.value}.$decimalMethod(${eval2.value})" } else { @@ -288,7 +288,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if ($isZero) { ${ev.isNull} = true; } else { @@ -299,7 +299,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { @@ -365,7 +365,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } else { s"${eval2.value} == 0" } - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val remainder = if (dataType.isInstanceOf[DecimalType]) { s"${eval1.value}.$decimalMethod(${eval2.value})" } else { @@ -375,7 +375,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if ($isZero) { ${ev.isNull} = true; } else { @@ -386,7 +386,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { @@ -454,13 +454,13 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { s"${eval2.value} == 0" } val remainder = ctx.freshName("remainder") - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val result = dataType match { case DecimalType.Fixed(_, _) => val decimalAdd = "$plus" s""" - ${ctx.javaType(dataType)} $remainder = ${eval1.value}.remainder(${eval2.value}); + $javaType $remainder = ${eval1.value}.remainder(${eval2.value}); if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) { ${ev.value}=($remainder.$decimalAdd(${eval2.value})).remainder(${eval2.value}); } else { @@ -470,17 +470,16 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => s""" - ${ctx.javaType(dataType)} $remainder = - (${ctx.javaType(dataType)})(${eval1.value} % ${eval2.value}); + $javaType $remainder = ($javaType)(${eval1.value} % ${eval2.value}); if ($remainder < 0) { - ${ev.value}=(${ctx.javaType(dataType)})(($remainder + ${eval2.value}) % ${eval2.value}); + ${ev.value}=($javaType)(($remainder + ${eval2.value}) % ${eval2.value}); } else { ${ev.value}=$remainder; } """ case _ => s""" - ${ctx.javaType(dataType)} $remainder = ${eval1.value} % ${eval2.value}; + $javaType $remainder = ${eval1.value} % ${eval2.value}; if ($remainder < 0) { ${ev.value}=($remainder + ${eval2.value}) % ${eval2.value}; } else { @@ -493,7 +492,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if ($isZero) { ${ev.isNull} = true; } else { @@ -504,7 +503,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { @@ -602,7 +601,7 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -614,7 +613,7 @@ case class Least(children: Seq[Expression]) extends Expression { """.stripMargin ) - val resultType = ctx.javaType(dataType) + val resultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "least", @@ -629,7 +628,7 @@ case class Least(children: Seq[Expression]) extends Expression { ev.copy(code = s""" |${ev.isNull} = true; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes """.stripMargin) } @@ -681,7 +680,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -693,7 +692,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { """.stripMargin ) - val resultType = ctx.javaType(dataType) + val resultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "greatest", @@ -708,7 +707,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { ev.copy(code = s""" |${ev.isNull} = true; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 173481f06a716..cc24e397cc14a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -147,7 +147,7 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)") + defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dataType)}) ~($c)") } protected override def nullSafeEval(input: Any): Any = not(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 60a6f50472504..793824b0b0a2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -59,6 +59,11 @@ import org.apache.spark.util.{ParentClassLoader, Utils} case class ExprCode(var code: String, var isNull: String, var value: String) object ExprCode { + def forNullValue(dataType: DataType): ExprCode = { + val defaultValueLiteral = CodeGenerator.defaultValue(dataType, typedNull = true) + ExprCode(code = "", isNull = "true", value = defaultValueLiteral) + } + def forNonNullValue(value: String): ExprCode = { ExprCode(code = "", isNull = "false", value = value) } @@ -105,6 +110,8 @@ private[codegen] case class NewFunctionSpec( */ class CodegenContext { + import CodeGenerator._ + /** * Holding a list of objects that could be used passed into generated class. */ @@ -196,11 +203,11 @@ class CodegenContext { /** * Returns the reference of next available slot in current compacted array. The size of each - * compacted array is controlled by the constant `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. + * compacted array is controlled by the constant `MUTABLESTATEARRAY_SIZE_LIMIT`. * Once reaching the threshold, new compacted array is created. */ def getNextSlot(): String = { - if (currentIndex < CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT) { + if (currentIndex < MUTABLESTATEARRAY_SIZE_LIMIT) { val res = s"${arrayNames.last}[$currentIndex]" currentIndex += 1 res @@ -247,10 +254,10 @@ class CodegenContext { * are satisfied: * 1. forceInline is true * 2. its type is primitive type and the total number of the inlined mutable variables - * is less than `CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD` + * is less than `OUTER_CLASS_VARIABLES_THRESHOLD` * 3. its type is multi-dimensional array * When a variable is compacted into an array, the max size of the array for compaction - * is given by `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. + * is given by `MUTABLESTATEARRAY_SIZE_LIMIT`. */ def addMutableState( javaType: String, @@ -261,7 +268,7 @@ class CodegenContext { // want to put a primitive type variable at outerClass for performance val canInlinePrimitive = isPrimitiveType(javaType) && - (inlinedMutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) + (inlinedMutableStates.length < OUTER_CLASS_VARIABLES_THRESHOLD) if (forceInline || canInlinePrimitive || javaType.contains("[][]")) { val varName = if (useFreshName) freshName(variableName) else variableName val initCode = initFunc(varName) @@ -339,7 +346,7 @@ class CodegenContext { val length = if (index + 1 == numArrays) { mutableStateArrays.getCurrentIndex } else { - CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + MUTABLESTATEARRAY_SIZE_LIMIT } if (javaType.contains("[]")) { // initializer had an one-dimensional array variable @@ -468,7 +475,7 @@ class CodegenContext { inlineToOuterClass: Boolean): NewFunctionSpec = { val (className, classInstance) = if (inlineToOuterClass) { outerClassName -> "" - } else if (currClassSize > CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD) { + } else if (currClassSize > GENERATED_CLASS_SIZE_THRESHOLD) { val className = freshName("NestedClass") val classInstance = freshName("nestedClassInstance") @@ -537,14 +544,6 @@ class CodegenContext { extraClasses.append(code) } - final val JAVA_BOOLEAN = "boolean" - final val JAVA_BYTE = "byte" - final val JAVA_SHORT = "short" - final val JAVA_INT = "int" - final val JAVA_LONG = "long" - final val JAVA_FLOAT = "float" - final val JAVA_DOUBLE = "double" - /** * The map from a variable name to it's next ID. */ @@ -580,196 +579,6 @@ class CodegenContext { } } - /** - * Returns the specialized code to access a value from `inputRow` at `ordinal`. - */ - def getValue(input: String, dataType: DataType, ordinal: String): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)" - case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})" - case StringType => s"$input.getUTF8String($ordinal)" - case BinaryType => s"$input.getBinary($ordinal)" - case CalendarIntervalType => s"$input.getInterval($ordinal)" - case t: StructType => s"$input.getStruct($ordinal, ${t.size})" - case _: ArrayType => s"$input.getArray($ordinal)" - case _: MapType => s"$input.getMap($ordinal)" - case NullType => "null" - case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal) - case _ => s"($jt)$input.get($ordinal, null)" - } - } - - /** - * Returns the code to update a column in Row for a given DataType. - */ - def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" - case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" - case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) - // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy - // it to avoid keeping a "pointer" to a memory region which may get updated afterwards. - case StringType | _: StructType | _: ArrayType | _: MapType => - s"$row.update($ordinal, $value.copy())" - case _ => s"$row.update($ordinal, $value)" - } - } - - /** - * Update a column in MutableRow from ExprCode. - * - * @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise - */ - def updateColumn( - row: String, - dataType: DataType, - ordinal: Int, - ev: ExprCode, - nullable: Boolean, - isVectorized: Boolean = false): String = { - if (nullable) { - // Can't call setNullAt on DecimalType, because we need to keep the offset - if (!isVectorized && dataType.isInstanceOf[DecimalType]) { - s""" - if (!${ev.isNull}) { - ${setColumn(row, dataType, ordinal, ev.value)}; - } else { - ${setColumn(row, dataType, ordinal, "null")}; - } - """ - } else { - s""" - if (!${ev.isNull}) { - ${setColumn(row, dataType, ordinal, ev.value)}; - } else { - $row.setNullAt($ordinal); - } - """ - } - } else { - s"""${setColumn(row, dataType, ordinal, ev.value)};""" - } - } - - /** - * Returns the specialized code to set a given value in a column vector for a given `DataType`. - */ - def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => - s"$vector.put${primitiveTypeName(jt)}($rowId, $value);" - case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});" - case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());" - case _ => - throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") - } - } - - /** - * Returns the specialized code to set a given value in a column vector for a given `DataType` - * that could potentially be nullable. - */ - def updateColumn( - vector: String, - rowId: String, - dataType: DataType, - ev: ExprCode, - nullable: Boolean): String = { - if (nullable) { - s""" - if (!${ev.isNull}) { - ${setValue(vector, rowId, dataType, ev.value)} - } else { - $vector.putNull($rowId); - } - """ - } else { - s"""${setValue(vector, rowId, dataType, ev.value)};""" - } - } - - /** - * Returns the specialized code to access a value from a column vector for a given `DataType`. - */ - def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = { - if (dataType.isInstanceOf[StructType]) { - // `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an - // `ordinal` parameter. - s"$vector.getStruct($rowId)" - } else { - getValue(vector, dataType, rowId) - } - } - - /** - * Returns the name used in accessor and setter for a Java primitive type. - */ - def primitiveTypeName(jt: String): String = jt match { - case JAVA_INT => "Int" - case _ => boxedType(jt) - } - - def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt)) - - /** - * Returns the Java type for a DataType. - */ - def javaType(dt: DataType): String = dt match { - case BooleanType => JAVA_BOOLEAN - case ByteType => JAVA_BYTE - case ShortType => JAVA_SHORT - case IntegerType | DateType => JAVA_INT - case LongType | TimestampType => JAVA_LONG - case FloatType => JAVA_FLOAT - case DoubleType => JAVA_DOUBLE - case dt: DecimalType => "Decimal" - case BinaryType => "byte[]" - case StringType => "UTF8String" - case CalendarIntervalType => "CalendarInterval" - case _: StructType => "InternalRow" - case _: ArrayType => "ArrayData" - case _: MapType => "MapData" - case udt: UserDefinedType[_] => javaType(udt.sqlType) - case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" - case ObjectType(cls) => cls.getName - case _ => "Object" - } - - /** - * Returns the boxed type in Java. - */ - def boxedType(jt: String): String = jt match { - case JAVA_BOOLEAN => "Boolean" - case JAVA_BYTE => "Byte" - case JAVA_SHORT => "Short" - case JAVA_INT => "Integer" - case JAVA_LONG => "Long" - case JAVA_FLOAT => "Float" - case JAVA_DOUBLE => "Double" - case other => other - } - - def boxedType(dt: DataType): String = boxedType(javaType(dt)) - - /** - * Returns the representation of default value for a given Java Type. - */ - def defaultValue(jt: String): String = jt match { - case JAVA_BOOLEAN => "false" - case JAVA_BYTE => "(byte)-1" - case JAVA_SHORT => "(short)-1" - case JAVA_INT => "-1" - case JAVA_LONG => "-1L" - case JAVA_FLOAT => "-1.0f" - case JAVA_DOUBLE => "-1.0" - case _ => "null" - } - - def defaultValue(dt: DataType): String = defaultValue(javaType(dt)) - /** * Generates code for equal expression in Java. */ @@ -812,6 +621,7 @@ class CodegenContext { val isNullB = freshName("isNullB") val compareFunc = freshName("compareArray") val minLength = freshName("minLength") + val jt = javaType(elementType) val funcCode: String = s""" public int $compareFunc(ArrayData a, ArrayData b) { @@ -833,8 +643,8 @@ class CodegenContext { } else if ($isNullB) { return 1; } else { - ${javaType(elementType)} $elementA = ${getValue("a", elementType, "i")}; - ${javaType(elementType)} $elementB = ${getValue("b", elementType, "i")}; + $jt $elementA = ${getValue("a", elementType, "i")}; + $jt $elementB = ${getValue("b", elementType, "i")}; int comp = ${genComp(elementType, elementA, elementB)}; if (comp != 0) { return comp; @@ -906,19 +716,6 @@ class CodegenContext { } } - /** - * List of java data types that have special accessors and setters in [[InternalRow]]. - */ - val primitiveTypes = - Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE) - - /** - * Returns true if the Java type has a special accessor and setter in [[InternalRow]]. - */ - def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt) - - def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) - /** * Splits the generated code of expressions into multiple functions, because function has * 64kb code size limit in JVM. If the class to which the function would be inlined would grow @@ -1089,7 +886,7 @@ class CodegenContext { // for performance reasons, the functions are prepended, instead of appended, // thus here they are in reversed order val orderedFunctions = innerClassFunctions.reverse - if (orderedFunctions.size > CodeGenerator.MERGE_SPLIT_METHODS_THRESHOLD) { + if (orderedFunctions.size > MERGE_SPLIT_METHODS_THRESHOLD) { // Adding a new function to each inner class which contains the invocation of all the // ones which have been added to that inner class. For example, // private class NestedClass { @@ -1289,7 +1086,7 @@ class CodegenContext { * length less than a pre-defined constant. */ def isValidParamLength(paramLength: Int): Boolean = { - paramLength <= CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH + paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH } } @@ -1524,4 +1321,221 @@ object CodeGenerator extends Logging { result } }) + + /** + * Name of Java primitive data type + */ + final val JAVA_BOOLEAN = "boolean" + final val JAVA_BYTE = "byte" + final val JAVA_SHORT = "short" + final val JAVA_INT = "int" + final val JAVA_LONG = "long" + final val JAVA_FLOAT = "float" + final val JAVA_DOUBLE = "double" + + /** + * List of java primitive data types + */ + val primitiveTypes = + Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE) + + /** + * Returns true if a Java type is Java primitive primitive type + */ + def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt) + + def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) + + /** + * Returns the specialized code to access a value from `inputRow` at `ordinal`. + */ + def getValue(input: String, dataType: DataType, ordinal: String): String = { + val jt = javaType(dataType) + dataType match { + case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)" + case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})" + case StringType => s"$input.getUTF8String($ordinal)" + case BinaryType => s"$input.getBinary($ordinal)" + case CalendarIntervalType => s"$input.getInterval($ordinal)" + case t: StructType => s"$input.getStruct($ordinal, ${t.size})" + case _: ArrayType => s"$input.getArray($ordinal)" + case _: MapType => s"$input.getMap($ordinal)" + case NullType => "null" + case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal) + case _ => s"($jt)$input.get($ordinal, null)" + } + } + + /** + * Returns the code to update a column in Row for a given DataType. + */ + def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { + val jt = javaType(dataType) + dataType match { + case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" + case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" + case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) + // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy + // it to avoid keeping a "pointer" to a memory region which may get updated afterwards. + case StringType | _: StructType | _: ArrayType | _: MapType => + s"$row.update($ordinal, $value.copy())" + case _ => s"$row.update($ordinal, $value)" + } + } + + /** + * Update a column in MutableRow from ExprCode. + * + * @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise + */ + def updateColumn( + row: String, + dataType: DataType, + ordinal: Int, + ev: ExprCode, + nullable: Boolean, + isVectorized: Boolean = false): String = { + if (nullable) { + // Can't call setNullAt on DecimalType, because we need to keep the offset + if (!isVectorized && dataType.isInstanceOf[DecimalType]) { + s""" + |if (!${ev.isNull}) { + | ${setColumn(row, dataType, ordinal, ev.value)}; + |} else { + | ${setColumn(row, dataType, ordinal, "null")}; + |} + """.stripMargin + } else { + s""" + |if (!${ev.isNull}) { + | ${setColumn(row, dataType, ordinal, ev.value)}; + |} else { + | $row.setNullAt($ordinal); + |} + """.stripMargin + } + } else { + s"""${setColumn(row, dataType, ordinal, ev.value)};""" + } + } + + /** + * Returns the specialized code to set a given value in a column vector for a given `DataType`. + */ + def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = { + val jt = javaType(dataType) + dataType match { + case _ if isPrimitiveType(jt) => + s"$vector.put${primitiveTypeName(jt)}($rowId, $value);" + case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});" + case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());" + case _ => + throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") + } + } + + /** + * Returns the specialized code to set a given value in a column vector for a given `DataType` + * that could potentially be nullable. + */ + def updateColumn( + vector: String, + rowId: String, + dataType: DataType, + ev: ExprCode, + nullable: Boolean): String = { + if (nullable) { + s""" + |if (!${ev.isNull}) { + | ${setValue(vector, rowId, dataType, ev.value)} + |} else { + | $vector.putNull($rowId); + |} + """.stripMargin + } else { + s"""${setValue(vector, rowId, dataType, ev.value)};""" + } + } + + /** + * Returns the specialized code to access a value from a column vector for a given `DataType`. + */ + def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = { + if (dataType.isInstanceOf[StructType]) { + // `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an + // `ordinal` parameter. + s"$vector.getStruct($rowId)" + } else { + getValue(vector, dataType, rowId) + } + } + + /** + * Returns the name used in accessor and setter for a Java primitive type. + */ + def primitiveTypeName(jt: String): String = jt match { + case JAVA_INT => "Int" + case _ => boxedType(jt) + } + + def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt)) + + /** + * Returns the Java type for a DataType. + */ + def javaType(dt: DataType): String = dt match { + case BooleanType => JAVA_BOOLEAN + case ByteType => JAVA_BYTE + case ShortType => JAVA_SHORT + case IntegerType | DateType => JAVA_INT + case LongType | TimestampType => JAVA_LONG + case FloatType => JAVA_FLOAT + case DoubleType => JAVA_DOUBLE + case _: DecimalType => "Decimal" + case BinaryType => "byte[]" + case StringType => "UTF8String" + case CalendarIntervalType => "CalendarInterval" + case _: StructType => "InternalRow" + case _: ArrayType => "ArrayData" + case _: MapType => "MapData" + case udt: UserDefinedType[_] => javaType(udt.sqlType) + case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" + case ObjectType(cls) => cls.getName + case _ => "Object" + } + + /** + * Returns the boxed type in Java. + */ + def boxedType(jt: String): String = jt match { + case JAVA_BOOLEAN => "Boolean" + case JAVA_BYTE => "Byte" + case JAVA_SHORT => "Short" + case JAVA_INT => "Integer" + case JAVA_LONG => "Long" + case JAVA_FLOAT => "Float" + case JAVA_DOUBLE => "Double" + case other => other + } + + def boxedType(dt: DataType): String = boxedType(javaType(dt)) + + /** + * Returns the representation of default value for a given Java Type. + * @param jt the string name of the Java type + * @param typedNull if true, for null literals, return a typed (with a cast) version + */ + def defaultValue(jt: String, typedNull: Boolean): String = jt match { + case JAVA_BOOLEAN => "false" + case JAVA_BYTE => "(byte)-1" + case JAVA_SHORT => "(short)-1" + case JAVA_INT => "-1" + case JAVA_LONG => "-1L" + case JAVA_FLOAT => "-1.0f" + case JAVA_DOUBLE => "-1.0" + case _ => if (typedNull) s"(($jt)null)" else "null" + } + + def defaultValue(dt: DataType, typedNull: Boolean = false): String = + defaultValue(javaType(dt), typedNull) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 0322d1dd6a9ff..e12420bb5dfdd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -44,20 +44,21 @@ trait CodegenFallback extends Expression { } val objectTerm = ctx.freshName("obj") val placeHolder = ctx.registerComment(this.toString) + val javaType = CodeGenerator.javaType(this.dataType) if (nullable) { ev.copy(code = s""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); boolean ${ev.isNull} = $objectTerm == null; - ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(this.dataType)}; if (!${ev.isNull}) { - ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; + ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; }""") } else { ev.copy(code = s""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); - ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; + $javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; """, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index b53c0087e7e2d..d35fd8ecb4d63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -62,9 +62,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map { case (ev, i) => val e = expressions(i) - val value = ctx.addMutableState(ctx.javaType(e.dataType), "value") + val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value") if (e.nullable) { - val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "isNull") + val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "isNull") (s""" |${ev.code} |$isNull = ${ev.isNull}; @@ -84,7 +84,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val updates = validExpr.zip(projectionCodes).map { case (e, (_, isNull, value, i)) => val ev = ExprCode("", isNull, value) - ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) + CodeGenerator.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 4a459571ed634..9a51be6ed5aeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -89,7 +89,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR s""" ${ctx.INPUT_ROW} = a; boolean $isNullA; - ${ctx.javaType(order.child.dataType)} $primitiveA; + ${CodeGenerator.javaType(order.child.dataType)} $primitiveA; { ${eval.code} $isNullA = ${eval.isNull}; @@ -97,7 +97,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR } ${ctx.INPUT_ROW} = b; boolean $isNullB; - ${ctx.javaType(order.child.dataType)} $primitiveB; + ${CodeGenerator.javaType(order.child.dataType)} $primitiveB; { ${eval.code} $isNullB = ${eval.isNull}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 3dcbb518ba42a..f92f70ee71fef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -53,7 +53,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => - val converter = convertToSafe(ctx, ctx.getValue(tmpInput, dt, i.toString), dt) + val converter = convertToSafe(ctx, CodeGenerator.getValue(tmpInput, dt, i.toString), dt) s""" if (!$tmpInput.isNullAt($i)) { ${converter.code} @@ -90,7 +90,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val arrayClass = classOf[GenericArrayData].getName val elementConverter = convertToSafe( - ctx, ctx.getValue(tmpInput, elementType, index), elementType) + ctx, CodeGenerator.getValue(tmpInput, elementType, index), elementType) val code = s""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); @@ -153,7 +153,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] mutableRow.setNullAt($i); } else { ${converter.code} - ${ctx.setColumn("mutableRow", e.dataType, i, converter.value)}; + ${CodeGenerator.setColumn("mutableRow", e.dataType, i, converter.value)}; } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 36ffa8dcdd2b6..22717f5954a45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -52,7 +52,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode("", s"$tmpInput.isNullAt($i)", ctx.getValue(tmpInput, dt, i.toString)) + ExprCode("", s"$tmpInput.isNullAt($i)", CodeGenerator.getValue(tmpInput, dt, i.toString)) } s""" @@ -195,16 +195,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case other => other } - val jt = ctx.javaType(et) + val jt = CodeGenerator.javaType(et) val elementOrOffsetSize = et match { case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8 - case _ if ctx.isPrimitiveType(jt) => et.defaultSize + case _ if CodeGenerator.isPrimitiveType(jt) => et.defaultSize case _ => 8 // we need 8 bytes to store offset and length } val tmpCursor = ctx.freshName("tmpCursor") - val element = ctx.getValue(tmpInput, et, index) + val element = CodeGenerator.getValue(tmpInput, et, index) val writeElement = et match { case t: StructType => s""" @@ -235,7 +235,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => s"$arrayWriter.write($index, $element);" } - val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else "" + val primitiveTypeName = + if (CodeGenerator.isPrimitiveType(jt)) CodeGenerator.primitiveTypeName(et) else "" s""" final ArrayData $tmpInput = $input; if ($tmpInput instanceof UnsafeArrayData) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4270b987d6de0..beb84694c44e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -20,7 +20,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -54,7 +54,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType ev.copy(code = s""" boolean ${ev.isNull} = false; ${childGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : (${childGen.value}).numElements();""", isNull = "false") } } @@ -270,7 +270,7 @@ case class ArrayContains(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (arr, value) => { val i = ctx.freshName("i") - val getValue = ctx.getValue(arr, right.dataType, i) + val getValue = CodeGenerator.getValue(arr, right.dataType, i) s""" for (int $i = 0; $i < $arr.numElements(); $i ++) { if ($arr.isNullAt($i)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 047b80ac5289c..85facdad43db7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -90,7 +90,7 @@ private [sql] object GenArrayData { val arrayDataName = ctx.freshName("arrayData") val numElements = elementsCode.length - if (!ctx.isPrimitiveType(elementType)) { + if (!CodeGenerator.isPrimitiveType(elementType)) { val arrayName = ctx.freshName("arrayObject") val genericArrayClass = classOf[GenericArrayData].getName @@ -124,7 +124,7 @@ private [sql] object GenArrayData { ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) val baseOffset = Platform.BYTE_ARRAY_OFFSET - val primitiveValueTypeName = ctx.primitiveTypeName(elementType) + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (!isMapKey) { s"$arrayDataName.setNullAt($i);" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 7e53ca3908905..6cdad19168dce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -129,12 +129,12 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] if ($eval.isNullAt($ordinal)) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + ${ev.value} = ${CodeGenerator.getValue(eval, dataType, ordinal.toString)}; } """ } else { s""" - ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + ${ev.value} = ${CodeGenerator.getValue(eval, dataType, ordinal.toString)}; """ } }) @@ -205,7 +205,7 @@ case class GetArrayStructFields( } else { final InternalRow $row = $eval.getStruct($j, $numFields); $nullSafeEval { - $values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)}; + $values[$j] = ${CodeGenerator.getValue(row, field.dataType, ordinal.toString)}; } } } @@ -260,7 +260,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) if ($index >= $eval1.numElements() || $index < 0$nullCheck) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(eval1, dataType, index)}; + ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; } """ }) @@ -327,6 +327,7 @@ case class GetMapValue(child: Expression, key: Expression) } else { "" } + val keyJavaType = CodeGenerator.javaType(keyType) nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" final int $length = $eval1.numElements(); @@ -336,7 +337,7 @@ case class GetMapValue(child: Expression, key: Expression) int $index = 0; boolean $found = false; while ($index < $length && !$found) { - final ${ctx.javaType(keyType)} $key = ${ctx.getValue(keys, keyType, index)}; + final $keyJavaType $key = ${CodeGenerator.getValue(keys, keyType, index)}; if (${ctx.genEqual(keyType, key, eval2)}) { $found = true; } else { @@ -347,7 +348,7 @@ case class GetMapValue(child: Expression, key: Expression) if (!$found$nullCheck) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(values, dataType, index)}; + ${ev.value} = ${CodeGenerator.getValue(values, dataType, index)}; } """ }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index b444c3a7be92a..f4e9619bac59d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -69,7 +69,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi s""" |${condEval.code} |boolean ${ev.isNull} = false; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |if (!${condEval.isNull} && ${condEval.value}) { | ${trueEval.code} | ${ev.isNull} = ${trueEval.isNull}; @@ -191,7 +191,7 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - ev.value = ctx.addMutableState(ctx.javaType(dataType), ev.value) + ev.value = ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value) // these blocks are meant to be inside a // do { @@ -244,10 +244,10 @@ case class CaseWhen( val codes = ctx.splitExpressionsWithCurrentInputs( expressions = allConditions, funcName = "caseWhen", - returnType = ctx.JAVA_BYTE, + returnType = CodeGenerator.JAVA_BYTE, makeSplitFunction = func => s""" - |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; + |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $func |} while (false); @@ -264,7 +264,7 @@ case class CaseWhen( ev.copy(code = s""" - |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; + |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $codes |} while (false); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 424871f2047e9..1ae4e5a2f716b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -673,18 +673,19 @@ abstract class UnixTime } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val javaType = CodeGenerator.javaType(dataType) left.dataType match { case StringType if right.foldable => val df = classOf[DateFormat].getName if (formatter == null) { - ExprCode("", "true", ctx.defaultValue(dataType)) + ExprCode.forNullValue(dataType) } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val eval1 = left.genCode(ctx) ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { try { ${ev.value} = $formatterName.parse(${eval1.value}.toString()).getTime() / 1000L; @@ -713,7 +714,7 @@ abstract class UnixTime ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = ${eval1.value} / 1000000L; }""") @@ -724,7 +725,7 @@ abstract class UnixTime ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $dtu.daysToMillis(${eval1.value}, $tz) / 1000L; }""") @@ -819,7 +820,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ ev.copy(code = s""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { try { ${ev.value} = UTF8String.fromString($formatterName.format( @@ -1344,18 +1345,19 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes { : ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val javaType = CodeGenerator.javaType(dataType) if (format.foldable) { if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { val t = instant.genCode(ctx) val truncFuncStr = truncFunc(t.value, truncLevel.toString) ev.copy(code = s""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $dtu.$truncFuncStr; }""") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 055ebf6c0da54..b702422ed7a1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -278,7 +278,7 @@ abstract class HashExpression[E] extends Expression { } } - val hashResultType = ctx.javaType(dataType) + val hashResultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = childrenHash, funcName = "computeHash", @@ -307,9 +307,10 @@ abstract class HashExpression[E] extends Expression { ctx: CodegenContext): String = { val element = ctx.freshName("element") + val jt = CodeGenerator.javaType(elementType) ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") { s""" - final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)}; + final $jt $element = ${CodeGenerator.getValue(input, elementType, index)}; ${computeHash(element, elementType, result, ctx)} """ } @@ -407,7 +408,7 @@ abstract class HashExpression[E] extends Expression { val fieldsHash = fields.zipWithIndex.map { case (field, index) => nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) } - val hashResultType = ctx.javaType(dataType) + val hashResultType = CodeGenerator.javaType(dataType) ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", @@ -651,11 +652,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = childrenHash, funcName = "computeHash", - extraArguments = Seq(ctx.JAVA_INT -> ev.value), - returnType = ctx.JAVA_INT, + extraArguments = Seq(CodeGenerator.JAVA_INT -> ev.value), + returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" - |${ctx.JAVA_INT} $childHash = 0; + |${CodeGenerator.JAVA_INT} $childHash = 0; |$body |return ${ev.value}; """.stripMargin, @@ -664,8 +665,8 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { ev.copy(code = s""" - |${ctx.JAVA_INT} ${ev.value} = $seed; - |${ctx.JAVA_INT} $childHash = 0; + |${CodeGenerator.JAVA_INT} ${ev.value} = $seed; + |${CodeGenerator.JAVA_INT} $childHash = 0; |$codes """.stripMargin) } @@ -780,14 +781,14 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """.stripMargin } - s"${ctx.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions( + s"${CodeGenerator.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", - arguments = Seq("InternalRow" -> input, ctx.JAVA_INT -> result), - returnType = ctx.JAVA_INT, + arguments = Seq("InternalRow" -> input, CodeGenerator.JAVA_INT -> result), + returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" - |${ctx.JAVA_INT} $childResult = 0; + |${CodeGenerator.JAVA_INT} $childResult = 0; |$body |return $result; """.stripMargin, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 7a8edabed1757..07785e7448586 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -42,7 +42,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + s"$className.getInputFilePath();", isNull = "false") } } @@ -65,7 +65,7 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + s"$className.getStartOffset();", isNull = "false") } } @@ -88,7 +88,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + s"$className.getLength();", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index c1e65e34c2ea6..7395609a04ba5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -277,13 +277,9 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { override def eval(input: InternalRow): Any = value override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) if (value == null) { - val defaultValueLiteral = ctx.defaultValue(javaType) match { - case "null" => s"(($javaType)null)" - case lit => lit - } - ExprCode(code = "", isNull = "true", value = defaultValueLiteral) + ExprCode.forNullValue(dataType) } else { dataType match { case BooleanType | IntegerType | DateType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index d8dc0862f1141..2c2cf3d2e6227 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1128,15 +1128,16 @@ abstract class RoundBase(child: Expression, scale: Expression, }""" } + val javaType = CodeGenerator.javaType(dataType) if (scaleV == null) { // if scale is null, no need to eval its child at all ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { ev.copy(code = s""" ${ce.code} boolean ${ev.isNull} = ${ce.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { $evaluationCode }""") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 470d5da041ea5..b35fa72e95d1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -72,7 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -87,14 +87,14 @@ case class Coalesce(children: Seq[Expression]) extends Expression { """.stripMargin } - val resultType = ctx.javaType(dataType) + val resultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "coalesce", returnType = resultType, makeSplitFunction = func => s""" - |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |do { | $func |} while (false); @@ -113,7 +113,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = s""" |${ev.isNull} = true; - |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |do { | $codes |} while (false); @@ -234,7 +234,7 @@ case class IsNaN(child: Expression) extends UnaryExpression case DoubleType | FloatType => ev.copy(code = s""" ${eval.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = "false") } } @@ -281,7 +281,7 @@ case class NaNvl(left: Expression, right: Expression) ev.copy(code = s""" ${leftGen.code} boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${leftGen.isNull}) { ${ev.isNull} = true; } else { @@ -416,8 +416,8 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "atLeastNNonNulls", - extraArguments = (ctx.JAVA_INT, nonnull) :: Nil, - returnType = ctx.JAVA_INT, + extraArguments = (CodeGenerator.JAVA_INT, nonnull) :: Nil, + returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" |do { @@ -436,11 +436,11 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate ev.copy(code = s""" - |${ctx.JAVA_INT} $nonnull = 0; + |${CodeGenerator.JAVA_INT} $nonnull = 0; |do { | $codes |} while (false); - |${ctx.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; + |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; """.stripMargin, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 64da9bb9cdec1..80618af1e859f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -62,13 +62,13 @@ trait InvokeLike extends Expression with NonSQLExpression { def prepareArguments(ctx: CodegenContext): (String, String, String) = { val resultIsNull = if (needNullCheck) { - val resultIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "resultIsNull") + val resultIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "resultIsNull") resultIsNull } else { "false" } val argValues = arguments.map { e => - val argValue = ctx.addMutableState(ctx.javaType(e.dataType), "argValue") + val argValue = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "argValue") argValue } @@ -137,7 +137,7 @@ case class StaticInvoke( throw new UnsupportedOperationException("Only code-generated evaluation is supported.") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val (argCode, argString, resultIsNull) = prepareArguments(ctx) @@ -151,7 +151,7 @@ case class StaticInvoke( } val evaluate = if (returnNullable) { - if (ctx.defaultValue(dataType) == "null") { + if (CodeGenerator.defaultValue(dataType) == "null") { s""" ${ev.value} = $callFunc; ${ev.isNull} = ${ev.value} == null; @@ -159,7 +159,7 @@ case class StaticInvoke( } else { val boxedResult = ctx.freshName("boxedResult") s""" - ${ctx.boxedType(dataType)} $boxedResult = $callFunc; + ${CodeGenerator.boxedType(dataType)} $boxedResult = $callFunc; ${ev.isNull} = $boxedResult == null; if (!${ev.isNull}) { ${ev.value} = $boxedResult; @@ -173,7 +173,7 @@ case class StaticInvoke( val code = s""" $argCode $prepareIsNull - $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!$resultIsNull) { $evaluate } @@ -228,7 +228,7 @@ case class Invoke( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val obj = targetObject.genCode(ctx) val (argCode, argString, resultIsNull) = prepareArguments(ctx) @@ -255,11 +255,11 @@ case class Invoke( // If the function can return null, we do an extra check to make sure our null bit is still // set correctly. val assignResult = if (!returnNullable) { - s"${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;" + s"${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult;" } else { s""" if ($funcResult != null) { - ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; + ${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult; } else { ${ev.isNull} = true; } @@ -275,7 +275,7 @@ case class Invoke( val code = s""" ${obj.code} boolean ${ev.isNull} = true; - $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${obj.isNull}) { $argCode ${ev.isNull} = $resultIsNull; @@ -341,7 +341,7 @@ case class NewInstance( throw new UnsupportedOperationException("Only code-generated evaluation is supported.") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val (argCode, argString, resultIsNull) = prepareArguments(ctx) @@ -358,7 +358,8 @@ case class NewInstance( val code = s""" $argCode ${outer.map(_.code).getOrElse("")} - final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall; + final $javaType ${ev.value} = ${ev.isNull} ? + ${CodeGenerator.defaultValue(dataType)} : $constructorCall; """ ev.copy(code = code) } @@ -385,15 +386,15 @@ case class UnwrapOption( throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val inputObject = child.genCode(ctx) val code = s""" ${inputObject.code} final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty(); - $javaType ${ev.value} = ${ev.isNull} ? - ${ctx.defaultValue(javaType)} : (${ctx.boxedType(javaType)}) ${inputObject.value}.get(); + $javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} : + (${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get(); """ ev.copy(code = code) } @@ -546,7 +547,7 @@ case class MapObjects private( ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val elementJavaType = ctx.javaType(loopVarDataType) + val elementJavaType = CodeGenerator.javaType(loopVarDataType) ctx.addMutableState(elementJavaType, loopValue, forceInline = true, useFreshName = false) val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) @@ -554,7 +555,7 @@ case class MapObjects private( val convertedArray = ctx.freshName("convertedArray") val loopIndex = ctx.freshName("loopIndex") - val convertedType = ctx.boxedType(lambdaFunction.dataType) + val convertedType = CodeGenerator.boxedType(lambdaFunction.dataType) // Because of the way Java defines nested arrays, we have to handle the syntax specially. // Specifically, we have to insert the [$dataLength] in between the type and any extra nested @@ -621,7 +622,7 @@ case class MapObjects private( ( s"${genInputData.value}.numElements()", "", - ctx.getValue(genInputData.value, et, loopIndex) + CodeGenerator.getValue(genInputData.value, et, loopIndex) ) case ObjectType(cls) if cls == classOf[Object] => val it = ctx.freshName("it") @@ -643,7 +644,8 @@ case class MapObjects private( } val loopNullCheck = if (loopIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false) + ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false) inputDataType match { case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" case _ => s"$loopIsNull = $loopValue == null;" @@ -695,7 +697,7 @@ case class MapObjects private( val code = s""" ${genInputData.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { $determineCollectionType @@ -806,10 +808,10 @@ case class CatalystToExternalMap private( } val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType] - val keyElementJavaType = ctx.javaType(mapType.keyType) + val keyElementJavaType = CodeGenerator.javaType(mapType.keyType) ctx.addMutableState(keyElementJavaType, keyLoopValue, forceInline = true, useFreshName = false) val genKeyFunction = keyLambdaFunction.genCode(ctx) - val valueElementJavaType = ctx.javaType(mapType.valueType) + val valueElementJavaType = CodeGenerator.javaType(mapType.valueType) ctx.addMutableState(valueElementJavaType, valueLoopValue, forceInline = true, useFreshName = false) val genValueFunction = valueLambdaFunction.genCode(ctx) @@ -825,10 +827,11 @@ case class CatalystToExternalMap private( val valueArray = ctx.freshName("valueArray") val getKeyArray = s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();" - val getKeyLoopVar = ctx.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) + val getKeyLoopVar = CodeGenerator.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) val getValueArray = s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();" - val getValueLoopVar = ctx.getValue(valueArray, inputDataType(mapType.valueType), loopIndex) + val getValueLoopVar = CodeGenerator.getValue( + valueArray, inputDataType(mapType.valueType), loopIndex) // Make a copy of the data if it's unsafe-backed def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = @@ -844,7 +847,7 @@ case class CatalystToExternalMap private( val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) val valueLoopNullCheck = if (valueLoopIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true, + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true, useFreshName = false) s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" } else { @@ -873,7 +876,7 @@ case class CatalystToExternalMap private( val code = s""" ${genInputData.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { int $dataLength = $getLength; @@ -993,8 +996,8 @@ case class ExternalMapToCatalyst private( val entry = ctx.freshName("entry") val entries = ctx.freshName("entries") - val keyElementJavaType = ctx.javaType(keyType) - val valueElementJavaType = ctx.javaType(valueType) + val keyElementJavaType = CodeGenerator.javaType(keyType) + val valueElementJavaType = CodeGenerator.javaType(valueType) ctx.addMutableState(keyElementJavaType, key, forceInline = true, useFreshName = false) ctx.addMutableState(valueElementJavaType, value, forceInline = true, useFreshName = false) @@ -1009,8 +1012,8 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next(); - $key = (${ctx.boxedType(keyType)}) $entry.getKey(); - $value = (${ctx.boxedType(valueType)}) $entry.getValue(); + $key = (${CodeGenerator.boxedType(keyType)}) $entry.getKey(); + $value = (${CodeGenerator.boxedType(valueType)}) $entry.getValue(); """ defineEntries -> defineKeyValue @@ -1024,22 +1027,24 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next(); - $key = (${ctx.boxedType(keyType)}) $entry._1(); - $value = (${ctx.boxedType(valueType)}) $entry._2(); + $key = (${CodeGenerator.boxedType(keyType)}) $entry._1(); + $value = (${CodeGenerator.boxedType(valueType)}) $entry._2(); """ defineEntries -> defineKeyValue } val keyNullCheck = if (keyIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false) + ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false) s"$keyIsNull = $key == null;" } else { "" } val valueNullCheck = if (valueIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false) + ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false) s"$valueIsNull = $value == null;" } else { "" @@ -1047,12 +1052,12 @@ case class ExternalMapToCatalyst private( val arrayCls = classOf[GenericArrayData].getName val mapCls = classOf[ArrayBasedMapData].getName - val convertedKeyType = ctx.boxedType(keyConverter.dataType) - val convertedValueType = ctx.boxedType(valueConverter.dataType) + val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType) + val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType) val code = s""" ${inputMap.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${inputMap.isNull}) { final int $length = ${inputMap.value}.size(); final Object[] $convertedKeys = new Object[$length]; @@ -1174,12 +1179,13 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) // Code to serialize. val input = child.genCode(ctx) - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val serialize = s"$serializer.serialize(${input.value}, null).array()" val code = s""" ${input.code} - final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $serialize; + final $javaType ${ev.value} = + ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize; """ ev.copy(code = code, isNull = input.isNull) } @@ -1223,13 +1229,14 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B // Code to deserialize. val input = child.genCode(ctx) - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val deserialize = s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" val code = s""" ${input.code} - final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $deserialize; + final $javaType ${ev.value} = + ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize; """ ev.copy(code = code, isNull = input.isNull) } @@ -1254,7 +1261,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val instanceGen = beanInstance.genCode(ctx) val javaBeanInstance = ctx.freshName("javaBean") - val beanInstanceJavaType = ctx.javaType(beanInstance.dataType) + val beanInstanceJavaType = CodeGenerator.javaType(beanInstance.dataType) val initialize = setters.map { case (setterMethod, fieldValue) => @@ -1405,15 +1412,15 @@ case class ValidateExternalType(child: Expression, expected: DataType) case _: ArrayType => s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()" case _ => - s"$obj instanceof ${ctx.boxedType(dataType)}" + s"$obj instanceof ${CodeGenerator.boxedType(dataType)}" } val code = s""" ${input.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${input.isNull}) { if ($typeCheck) { - ${ev.value} = (${ctx.boxedType(dataType)}) $obj; + ${ev.value} = (${CodeGenerator.boxedType(dataType)}) $obj; } else { throw new RuntimeException($obj.getClass().getName() + $errMsgField); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a6d41ea7d00d4..4b85d9adbe311 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -235,7 +235,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaDataType = ctx.javaType(value.dataType) + val javaDataType = CodeGenerator.javaType(value.dataType) val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) // inTmpResult has 3 possible values: @@ -263,8 +263,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = listCode, funcName = "valueIn", - extraArguments = (javaDataType, valueArg) :: (ctx.JAVA_BYTE, tmpResult) :: Nil, - returnType = ctx.JAVA_BYTE, + extraArguments = (javaDataType, valueArg) :: (CodeGenerator.JAVA_BYTE, tmpResult) :: Nil, + returnType = CodeGenerator.JAVA_BYTE, makeSplitFunction = body => s""" |do { @@ -348,8 +348,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with ev.copy(code = s""" |${childGen.code} - |${ctx.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; - |${ctx.JAVA_BOOLEAN} ${ev.value} = false; + |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; + |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false; |if (!${ev.isNull}) { | ${ev.value} = $setTerm.contains(${childGen.value}); | $setIsNull @@ -505,7 +505,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (ctx.isPrimitiveType(left.dataType) + if (CodeGenerator.isPrimitiveType(left.dataType) && left.dataType != BooleanType // java boolean doesn't support > or < operator && left.dataType != FloatType && left.dataType != DoubleType) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 8bc936fcbfc31..6c9937dacc70b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -82,7 +82,8 @@ case class Rand(child: Expression) extends RDG { ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false") + final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", + isNull = "false") } } @@ -116,7 +117,8 @@ case class Randn(child: Expression) extends RDG { ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") + final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", + isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index f3e8f6de58975..ad0c0791d895f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -126,7 +126,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi ev.copy(code = s""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $pattern.matcher(${eval.value}.toString()).matches(); } @@ -134,7 +134,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi } else { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) } } else { @@ -201,7 +201,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress ev.copy(code = s""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $pattern.matcher(${eval.value}.toString()).find(0); } @@ -209,7 +209,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress } else { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) } } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index d7612e30b4c57..22fbb8998ed89 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -102,11 +102,11 @@ case class Concat(children: Seq[Expression]) extends Expression { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = inputs, funcName = "valueConcat", - extraArguments = (s"${ctx.javaType(dataType)}[]", args) :: Nil) + extraArguments = (s"${CodeGenerator.javaType(dataType)}[]", args) :: Nil) ev.copy(s""" $initCode $codes - ${ctx.javaType(dataType)} ${ev.value} = $concatenator.concat($args); + ${CodeGenerator.javaType(dataType)} ${ev.value} = $concatenator.concat($args); boolean ${ev.isNull} = ${ev.value} == null; """) } @@ -196,7 +196,7 @@ case class ConcatWs(children: Seq[Expression]) } else { val array = ctx.freshName("array") val varargNum = ctx.freshName("varargNum") - val idxInVararg = ctx.freshName("idxInVararg") + val idxVararg = ctx.freshName("idxInVararg") val evals = children.map(_.genCode(ctx)) val (varargCount, varargBuild) = children.tail.zip(evals.tail).map { case (child, eval) => @@ -206,7 +206,7 @@ case class ConcatWs(children: Seq[Expression]) if (eval.isNull == "true") { "" } else { - s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};" + s"$array[$idxVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};" }) case _: ArrayType => val size = ctx.freshName("n") @@ -222,7 +222,7 @@ case class ConcatWs(children: Seq[Expression]) if (!${eval.isNull}) { final int $size = ${eval.value}.numElements(); for (int j = 0; j < $size; j ++) { - $array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")}; + $array[$idxVararg ++] = ${CodeGenerator.getValue(eval.value, StringType, "j")}; } } """) @@ -247,20 +247,20 @@ case class ConcatWs(children: Seq[Expression]) val varargBuilds = ctx.splitExpressionsWithCurrentInputs( expressions = varargBuild, funcName = "varargBuildsConcatWs", - extraArguments = ("UTF8String []", array) :: ("int", idxInVararg) :: Nil, + extraArguments = ("UTF8String []", array) :: ("int", idxVararg) :: Nil, returnType = "int", makeSplitFunction = body => s""" |$body - |return $idxInVararg; + |return $idxVararg; """.stripMargin, - foldFunctions = _.map(funcCall => s"$idxInVararg = $funcCall;").mkString("\n")) + foldFunctions = _.map(funcCall => s"$idxVararg = $funcCall;").mkString("\n")) ev.copy( s""" $codes int $varargNum = ${children.count(_.dataType == StringType) - 1}; - int $idxInVararg = 0; + int $idxVararg = 0; $varargCounts UTF8String[] $array = new UTF8String[$varargNum]; $varargBuilds @@ -333,7 +333,7 @@ case class Elt(children: Seq[Expression]) extends Expression { val indexVal = ctx.freshName("index") val indexMatched = ctx.freshName("eltIndexMatched") - val inputVal = ctx.addMutableState(ctx.javaType(dataType), "inputVal") + val inputVal = ctx.addMutableState(CodeGenerator.javaType(dataType), "inputVal") val assignInputValue = inputs.zipWithIndex.map { case (eval, index) => s""" @@ -350,10 +350,10 @@ case class Elt(children: Seq[Expression]) extends Expression { expressions = assignInputValue, funcName = "eltFunc", extraArguments = ("int", indexVal) :: Nil, - returnType = ctx.JAVA_BOOLEAN, + returnType = CodeGenerator.JAVA_BOOLEAN, makeSplitFunction = body => s""" - |${ctx.JAVA_BOOLEAN} $indexMatched = false; + |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; |do { | $body |} while (false); @@ -372,12 +372,12 @@ case class Elt(children: Seq[Expression]) extends Expression { s""" |${index.code} |final int $indexVal = ${index.value}; - |${ctx.JAVA_BOOLEAN} $indexMatched = false; + |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; |$inputVal = null; |do { | $codes |} while (false); - |final ${ctx.javaType(dataType)} ${ev.value} = $inputVal; + |final ${CodeGenerator.javaType(dataType)} ${ev.value} = $inputVal; |final boolean ${ev.isNull} = ${ev.value} == null; """.stripMargin) } @@ -1410,10 +1410,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC val numArgLists = argListGen.length val argListCode = argListGen.zipWithIndex.map { case(v, index) => val value = - if (ctx.boxedType(v._1) != ctx.javaType(v._1)) { + if (CodeGenerator.boxedType(v._1) != CodeGenerator.javaType(v._1)) { // Java primitives get boxed in order to allow null values. - s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " + - s"new ${ctx.boxedType(v._1)}(${v._2.value})" + s"(${v._2.isNull}) ? (${CodeGenerator.boxedType(v._1)}) null : " + + s"new ${CodeGenerator.boxedType(v._1)}(${v._2.value})" } else { s"(${v._2.isNull}) ? null : ${v._2.value}" } @@ -1434,7 +1434,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC ev.copy(code = s""" ${pattern.code} boolean ${ev.isNull} = ${pattern.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { $stringBuffer $sb = new $stringBuffer(); $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US); @@ -2110,7 +2110,8 @@ case class FormatNumber(x: Expression, d: Expression) val usLocale = "US" val i = ctx.freshName("i") val dFormat = ctx.freshName("dFormat") - val lastDValue = ctx.addMutableState(ctx.JAVA_INT, "lastDValue", v => s"$v = -100;") + val lastDValue = + ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;") val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") val numberFormat = ctx.addMutableState(df, "numberFormat", v => s"""$v = new $df("", new $dfs($l.$usLocale));""") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 676ba3956ddc8..1e48c7b8df9da 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -405,12 +405,12 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-18016: define mutable states by using an array") { val ctx1 = new CodegenContext for (i <- 1 to CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) { - ctx1.addMutableState(ctx1.JAVA_INT, "i", v => s"$v = $i;") + ctx1.addMutableState(CodeGenerator.JAVA_INT, "i", v => s"$v = $i;") } assert(ctx1.inlinedMutableStates.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) // When the number of primitive type mutable states is over the threshold, others are // allocated into an array - assert(ctx1.arrayCompactedMutableStates.get(ctx1.JAVA_INT).get.arrayNames.size == 1) + assert(ctx1.arrayCompactedMutableStates.get(CodeGenerator.JAVA_INT).get.arrayNames.size == 1) assert(ctx1.mutableStateInitCode.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) val ctx2 = new CodegenContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 04f2619ed7541..392906a022903 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -49,15 +49,15 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { ordinal: String, dataType: DataType, nullable: Boolean): ExprCode = { - val javaType = ctx.javaType(dataType) - val value = ctx.getValueFromVector(columnVar, dataType, ordinal) + val javaType = CodeGenerator.javaType(dataType) + val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal) val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" val code = s"${ctx.registerComment(str)}\n" + (if (nullable) { s""" boolean $isNullVar = $columnVar.isNullAt($ordinal); - $javaType $valueVar = $isNullVar ? ${ctx.defaultValue(dataType)} : ($value); + $javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value); """ } else { s"$javaType $valueVar = $value;" @@ -85,12 +85,13 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { // metrics val numOutputRows = metricTerm(ctx, "numOutputRows") val scanTimeMetric = metricTerm(ctx, "scanTime") - val scanTimeTotalNs = ctx.addMutableState(ctx.JAVA_LONG, "scanTime") // init as scanTime = 0 + val scanTimeTotalNs = + ctx.addMutableState(CodeGenerator.JAVA_LONG, "scanTime") // init as scanTime = 0 val columnarBatchClz = classOf[ColumnarBatch].getName val batch = ctx.addMutableState(columnarBatchClz, "batch") - val idx = ctx.addMutableState(ctx.JAVA_INT, "batchIdx") // init as batchIdx = 0 + val idx = ctx.addMutableState(CodeGenerator.JAVA_INT, "batchIdx") // init as batchIdx = 0 val columnVectorClzs = vectorTypes.getOrElse( Seq.fill(output.indices.size)(classOf[ColumnVector].getName)) val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index a7bd5ebf93ecd..12ae1ea4a7c13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -154,7 +154,8 @@ case class ExpandExec( val value = ctx.freshName("value") val code = s""" |boolean $isNull = true; - |${ctx.javaType(firstExpr.dataType)} $value = ${ctx.defaultValue(firstExpr.dataType)}; + |${CodeGenerator.javaType(firstExpr.dataType)} $value = + | ${CodeGenerator.defaultValue(firstExpr.dataType)}; """.stripMargin ExprCode(code, isNull, value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 0c2c4a1a9100d..384f0398a1ec0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} @@ -305,15 +305,15 @@ case class GenerateExec( nullable: Boolean, initialChecks: Seq[String]): ExprCode = { val value = ctx.freshName(name) - val javaType = ctx.javaType(dt) - val getter = ctx.getValue(source, dt, index) + val javaType = CodeGenerator.javaType(dt) + val getter = CodeGenerator.getValue(source, dt, index) val checks = initialChecks ++ optionalCode(nullable, s"$source.isNullAt($index)") if (checks.nonEmpty) { val isNull = ctx.freshName("isNull") val code = s""" |boolean $isNull = ${checks.mkString(" || ")}; - |$javaType $value = $isNull ? ${ctx.defaultValue(dt)} : $getter; + |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin ExprCode(code, isNull, value) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index ac1c34d41c4f1..0dc16ba5ce281 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -22,7 +22,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics @@ -133,7 +133,8 @@ case class SortExec( override def needStopCheck: Boolean = false override protected def doProduce(ctx: CodegenContext): String = { - val needToSort = ctx.addMutableState(ctx.JAVA_BOOLEAN, "needToSort", v => s"$v = true;") + val needToSort = + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "needToSort", v => s"$v = true;") // Initialize the class member variables. This includes the instance of the Sorter and // the iterator to return sorted rows. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index deb0a044c2fb2..f89e3fb0e536f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -234,7 +234,7 @@ trait CodegenSupport extends SparkPlan { variables.zipWithIndex.foreach { case (ev, i) => val paramName = ctx.freshName(s"expr_$i") - val paramType = ctx.javaType(attributes(i).dataType) + val paramType = CodeGenerator.javaType(attributes(i).dataType) arguments += ev.value parameters += s"$paramType $paramName" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index ce3c68810f3b6..1926e9373bc55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -178,7 +178,7 @@ case class HashAggregateExec( private var bufVars: Seq[ExprCode] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { - val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg") + val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") // The generated function doesn't have input row in the code context. ctx.INPUT_ROW = null @@ -186,8 +186,8 @@ case class HashAggregateExec( val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val initExpr = functions.flatMap(f => f.initialValues) bufVars = initExpr.map { e => - val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull") - val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue") + val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull") + val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") // The initial expression should not access any column val ev = e.genCode(ctx) val initVars = s""" @@ -532,7 +532,7 @@ case class HashAggregateExec( */ private def checkIfFastHashMapSupported(ctx: CodegenContext): Boolean = { val isSupported = - (groupingKeySchema ++ bufferSchema).forall(f => ctx.isPrimitiveType(f.dataType) || + (groupingKeySchema ++ bufferSchema).forall(f => CodeGenerator.isPrimitiveType(f.dataType) || f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType]) && bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge) @@ -565,7 +565,7 @@ case class HashAggregateExec( } private def doProduceWithKeys(ctx: CodegenContext): String = { - val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg") + val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") if (sqlContext.conf.enableTwoLevelAggMap) { enableTwoLevelHashMap(ctx) } else { @@ -757,7 +757,7 @@ case class HashAggregateExec( val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, incCounter) = if (testFallbackStartsAt.isDefined) { - val countTerm = ctx.addMutableState(ctx.JAVA_INT, "fallbackCounter") + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter") (s"$countTerm < ${testFallbackStartsAt.get._1}", s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;") } else { @@ -832,7 +832,7 @@ case class HashAggregateExec( } val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType - ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) } s""" |// common sub-expressions @@ -855,7 +855,7 @@ case class HashAggregateExec( } val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType - ctx.updateColumn( + CodeGenerator.updateColumn( fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorizedHashMapEnabled) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 1c613b19c4ab1..6b60b414ffe5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types._ /** @@ -41,13 +41,13 @@ abstract class HashMapGenerator( val groupingKeys = groupingKeySchema.map(k => Buffer(k.dataType, ctx.freshName("key"))) val bufferValues = bufferSchema.map(k => Buffer(k.dataType, ctx.freshName("value"))) val groupingKeySignature = - groupingKeys.map(key => s"${ctx.javaType(key.dataType)} ${key.name}").mkString(", ") + groupingKeys.map(key => s"${CodeGenerator.javaType(key.dataType)} ${key.name}").mkString(", ") val buffVars: Seq[ExprCode] = { val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val initExpr = functions.flatMap(f => f.initialValues) initExpr.map { e => - val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull") - val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue") + val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull") + val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") val ev = e.genCode(ctx) val initVars = s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index fd25707dd4ca6..8617be88f3570 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator} import org.apache.spark.sql.types._ /** @@ -114,7 +114,7 @@ class RowBasedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - s"""(${ctx.genEqual(key.dataType, ctx.getValue("row", + s"""(${ctx.genEqual(key.dataType, CodeGenerator.getValue("row", key.dataType, ordinal.toString()), key.name)})""" }.mkString(" && ") } @@ -147,7 +147,7 @@ class RowBasedHashMapGenerator( case t: DecimalType => s"agg_rowWriter.write(${ordinal}, ${key.name}, ${t.precision}, ${t.scale})" case t: DataType => - if (!t.isInstanceOf[StringType] && !ctx.isPrimitiveType(t)) { + if (!t.isInstanceOf[StringType] && !CodeGenerator.isPrimitiveType(t)) { throw new IllegalArgumentException(s"cannot generate code for unsupported type: $t") } s"agg_rowWriter.write(${ordinal}, ${key.name})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 633eeac180974..7b3580cecc60d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator} import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow, OnHeapColumnVector} import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -127,7 +127,8 @@ class VectorizedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - val value = ctx.getValueFromVector(s"vectors[$ordinal]", key.dataType, "buckets[idx]") + val value = CodeGenerator.getValueFromVector(s"vectors[$ordinal]", key.dataType, + "buckets[idx]") s"(${ctx.genEqual(key.dataType, value, key.name)})" }.mkString(" && ") } @@ -182,14 +183,14 @@ class VectorizedHashMapGenerator( def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - ctx.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name) + CodeGenerator.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name) } } def genCodeToSetAggBuffers(bufferValues: Seq[Buffer]): Seq[String] = { bufferValues.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - ctx.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType, - buffVars(ordinal), nullable = true) + CodeGenerator.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", + key.dataType, buffVars(ordinal), nullable = true) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index a15a8d11aa2a0..4707022f74547 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -24,7 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, ExpressionCanonicalizer} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType @@ -364,8 +364,8 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) protected override def doProduce(ctx: CodegenContext): String = { val numOutput = metricTerm(ctx, "numOutputRows") - val initTerm = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initRange") - val number = ctx.addMutableState(ctx.JAVA_LONG, "number") + val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange") + val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number") val value = ctx.freshName("value") val ev = ExprCode("", "false", value) @@ -385,10 +385,10 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // the metrics. // Once number == batchEnd, it's time to progress to the next batch. - val batchEnd = ctx.addMutableState(ctx.JAVA_LONG, "batchEnd") + val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd") // How many values should still be generated by this range operator. - val numElementsTodo = ctx.addMutableState(ctx.JAVA_LONG, "numElementsTodo") + val numElementsTodo = ctx.addMutableState(CodeGenerator.JAVA_LONG, "numElementsTodo") // How many values should be generated in the next batch. val nextBatchTodo = ctx.freshName("nextBatchTodo") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 4f28eeb725cbb..3b5655ba0582e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -91,7 +91,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera val accessorName = ctx.addMutableState(accessorCls, "accessor") val createCode = dt match { - case t if ctx.isPrimitiveType(dt) => + case t if CodeGenerator.isPrimitiveType(dt) => s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" case NullType | StringType | BinaryType => s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 1918fcc5482db..487d6a2383318 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -22,7 +22,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} @@ -182,9 +182,10 @@ case class BroadcastHashJoinExec( // the variables are needed even there is no matched rows val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") + val javaType = CodeGenerator.javaType(a.dataType) val code = s""" |boolean $isNull = true; - |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)}; + |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)}; |if ($matched != null) { | ${ev.code} | $isNull = ${ev.isNull}; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 2de2f30eb05d3..5a511b30e4fd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, @@ -516,9 +516,9 @@ case class SortMergeJoinExec( ctx.INPUT_ROW = leftRow left.output.zipWithIndex.map { case (a, i) => val value = ctx.freshName("value") - val valueCode = ctx.getValue(leftRow, a.dataType, i.toString) - val javaType = ctx.javaType(a.dataType) - val defaultValue = ctx.defaultValue(a.dataType) + val valueCode = CodeGenerator.getValue(leftRow, a.dataType, i.toString) + val javaType = CodeGenerator.javaType(a.dataType) + val defaultValue = CodeGenerator.defaultValue(a.dataType) if (a.nullable) { val isNull = ctx.freshName("isNull") val code = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index cccee63bc0680..66bcda8913738 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, LazilyGeneratedOrdering} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.util.Utils @@ -71,7 +71,8 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val stopEarly = ctx.addMutableState(ctx.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false + val stopEarly = + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false ctx.addNewFunction("stopEarly", s""" @Override @@ -79,7 +80,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { return $stopEarly; } """, inlineToOuterClass = true) - val countTerm = ctx.addMutableState(ctx.JAVA_INT, "count") // init as count = 0 + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "count") // init as count = 0 s""" | if ($countTerm < $limit) { | $countTerm += 1; From 42cf48e20cd5e47e1b7557af9c71c4eea142f10f Mon Sep 17 00:00:00 2001 From: Ala Luszczak Date: Mon, 5 Mar 2018 14:33:12 +0100 Subject: [PATCH 424/774] [SPARK-23496][CORE] Locality of coalesced partitions can be severely skewed by the order of input partitions ## What changes were proposed in this pull request? The algorithm in `DefaultPartitionCoalescer.setupGroups` is responsible for picking preferred locations for coalesced partitions. It analyzes the preferred locations of input partitions. It starts by trying to create one partition for each unique location in the input. However, if the the requested number of coalesced partitions is higher that the number of unique locations, it has to pick duplicate locations. Previously, the duplicate locations would be picked by iterating over the input partitions in order, and copying their preferred locations to coalesced partitions. If the input partitions were clustered by location, this could result in severe skew. With the fix, instead of iterating over the list of input partitions in order, we pick them at random. It's not perfectly balanced, but it's much better. ## How was this patch tested? Unit test reproducing the behavior was added. Author: Ala Luszczak Closes #20664 from ala/SPARK-23496. --- .../org/apache/spark/rdd/CoalescedRDD.scala | 8 ++-- .../scala/org/apache/spark/rdd/RDDSuite.scala | 42 +++++++++++++++++++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 10451a324b0f4..94e7d0b38cba3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -266,17 +266,17 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10) numCreated += 1 } } - tries = 0 // if we don't have enough partition groups, create duplicates while (numCreated < targetLen) { - val (nxt_replica, nxt_part) = partitionLocs.partsWithLocs(tries) - tries += 1 + // Copy the preferred location from a random input partition. + // This helps in avoiding skew when the input partitions are clustered by preferred location. + val (nxt_replica, nxt_part) = partitionLocs.partsWithLocs( + rnd.nextInt(partitionLocs.partsWithLocs.length)) val pgroup = new PartitionGroup(Some(nxt_replica)) groupArr += pgroup groupHash.getOrElseUpdate(nxt_replica, ArrayBuffer()) += pgroup addPartToPGroup(nxt_part, pgroup) numCreated += 1 - if (tries >= partitionLocs.partsWithLocs.length) tries = 0 } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index e994d724c462f..191c61250ce21 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -1129,6 +1129,35 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { }.collect() } + test("SPARK-23496: order of input partitions can result in severe skew in coalesce") { + val numInputPartitions = 100 + val numCoalescedPartitions = 50 + val locations = Array("locA", "locB") + + val inputRDD = sc.makeRDD(Range(0, numInputPartitions).toArray[Int], numInputPartitions) + assert(inputRDD.getNumPartitions == numInputPartitions) + + val locationPrefRDD = new LocationPrefRDD(inputRDD, { (p: Partition) => + if (p.index < numCoalescedPartitions) { + Seq(locations(0)) + } else { + Seq(locations(1)) + } + }) + val coalescedRDD = new CoalescedRDD(locationPrefRDD, numCoalescedPartitions) + + val numPartsPerLocation = coalescedRDD + .getPartitions + .map(coalescedRDD.getPreferredLocations(_).head) + .groupBy(identity) + .mapValues(_.size) + + // Make sure the coalesced partitions are distributed fairly evenly between the two locations. + // This should not become flaky since the DefaultPartitionsCoalescer uses a fixed seed. + assert(numPartsPerLocation(locations(0)) > 0.4 * numCoalescedPartitions) + assert(numPartsPerLocation(locations(1)) > 0.4 * numCoalescedPartitions) + } + // NOTE // Below tests calling sc.stop() have to be the last tests in this suite. If there are tests // running after them and if they access sc those tests will fail as sc is already closed, because @@ -1210,3 +1239,16 @@ class SizeBasedCoalescer(val maxSize: Int) extends PartitionCoalescer with Seria groups.toArray } } + +/** Alters the preferred locations of the parent RDD using provided function. */ +class LocationPrefRDD[T: ClassTag]( + @transient var prev: RDD[T], + val locationPicker: Partition => Seq[String]) extends RDD[T](prev) { + override protected def getPartitions: Array[Partition] = prev.partitions + + override def compute(partition: Partition, context: TaskContext): Iterator[T] = + null.asInstanceOf[Iterator[T]] + + override def getPreferredLocations(partition: Partition): Seq[String] = + locationPicker(partition) +} From 5ff72ffcf495d2823f7f1186078d1cb261667c3d Mon Sep 17 00:00:00 2001 From: Anirudh Date: Mon, 5 Mar 2018 23:17:16 +0900 Subject: [PATCH 425/774] [SPARK-23566][MINOR][DOC] Argument name mismatch fixed Argument name mismatch fixed. ## What changes were proposed in this pull request? `col` changed to `new` in doc string to match the argument list. Patch file added: https://issues.apache.org/jira/browse/SPARK-23566 Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Anirudh Closes #20716 from animenon/master. --- python/pyspark/sql/dataframe.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index f37777e13ee12..9d8e85cde914f 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -588,6 +588,8 @@ def coalesce(self, numPartitions): """ Returns a new :class:`DataFrame` that has exactly `numPartitions` partitions. + :param numPartitions: int, to specify the target number of partitions + Similar to coalesce defined on an :class:`RDD`, this operation results in a narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of the 100 new partitions will @@ -612,9 +614,10 @@ def repartition(self, numPartitions, *cols): Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The resulting DataFrame is hash partitioned. - ``numPartitions`` can be an int to specify the target number of partitions or a Column. - If it is a Column, it will be used as the first partitioning column. If not specified, - the default number of partitions is used. + :param numPartitions: + can be an int to specify the target number of partitions or a Column. + If it is a Column, it will be used as the first partitioning column. If not specified, + the default number of partitions is used. .. versionchanged:: 1.6 Added optional arguments to specify the partitioning columns. Also made numPartitions @@ -673,9 +676,10 @@ def repartitionByRange(self, numPartitions, *cols): Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The resulting DataFrame is range partitioned. - ``numPartitions`` can be an int to specify the target number of partitions or a Column. - If it is a Column, it will be used as the first partitioning column. If not specified, - the default number of partitions is used. + :param numPartitions: + can be an int to specify the target number of partitions or a Column. + If it is a Column, it will be used as the first partitioning column. If not specified, + the default number of partitions is used. At least one partition-by expression must be specified. When no explicit sort order is specified, "ascending nulls first" is assumed. @@ -892,6 +896,8 @@ def colRegex(self, colName): def alias(self, alias): """Returns a new :class:`DataFrame` with an alias set. + :param alias: string, an alias name to be set for the DataFrame. + >>> from pyspark.sql.functions import * >>> df_as1 = df.alias("df_as1") >>> df_as2 = df.alias("df_as2") @@ -1900,7 +1906,7 @@ def withColumnRenamed(self, existing, new): This is a no-op if schema doesn't contain the given column name. :param existing: string, name of the existing column to rename. - :param col: string, new name of the column. + :param new: string, new name of the column. >>> df.withColumnRenamed('age', 'age2').collect() [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] From a366b950b90650693ad0eb1e5b9a988ad028d845 Mon Sep 17 00:00:00 2001 From: Mihaly Toth Date: Mon, 5 Mar 2018 23:46:40 +0900 Subject: [PATCH 426/774] [SPARK-23329][SQL] Fix documentation of trigonometric functions ## What changes were proposed in this pull request? Provide more details in trigonometric function documentations. Referenced `java.lang.Math` for further details in the descriptions. ## How was this patch tested? Ran full build, checked generated documentation manually Author: Mihaly Toth Closes #20618 from misutoth/trigonometric-doc. --- R/pkg/R/functions.R | 34 ++-- python/pyspark/sql/functions.py | 62 ++++--- .../expressions/mathExpressions.scala | 99 ++++++++--- .../org/apache/spark/sql/functions.scala | 160 ++++++++++++------ 4 files changed, 248 insertions(+), 107 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 9f7c6317cd924..29ee146ab14f9 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -278,8 +278,8 @@ setMethod("abs", }) #' @details -#' \code{acos}: Computes the cosine inverse of the given value; the returned angle is in -#' the range 0.0 through pi. +#' \code{acos}: Returns the inverse cosine of the given value, +#' as if computed by \code{java.lang.Math.acos()} #' #' @rdname column_math_functions #' @export @@ -334,8 +334,8 @@ setMethod("ascii", }) #' @details -#' \code{asin}: Computes the sine inverse of the given value; the returned angle is in -#' the range -pi/2 through pi/2. +#' \code{asin}: Returns the inverse sine of the given value, +#' as if computed by \code{java.lang.Math.asin()} #' #' @rdname column_math_functions #' @export @@ -349,8 +349,8 @@ setMethod("asin", }) #' @details -#' \code{atan}: Computes the tangent inverse of the given value; the returned angle is in the range -#' -pi/2 through pi/2. +#' \code{atan}: Returns the inverse tangent of the given value, +#' as if computed by \code{java.lang.Math.atan()} #' #' @rdname column_math_functions #' @export @@ -613,7 +613,8 @@ setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOr }) #' @details -#' \code{cos}: Computes the cosine of the given value. Units in radians. +#' \code{cos}: Returns the cosine of the given value, +#' as if computed by \code{java.lang.Math.cos()}. Units in radians. #' #' @rdname column_math_functions #' @aliases cos cos,Column-method @@ -627,7 +628,8 @@ setMethod("cos", }) #' @details -#' \code{cosh}: Computes the hyperbolic cosine of the given value. +#' \code{cosh}: Returns the hyperbolic cosine of the given value, +#' as if computed by \code{java.lang.Math.cosh()}. #' #' @rdname column_math_functions #' @aliases cosh cosh,Column-method @@ -1463,7 +1465,8 @@ setMethod("sign", signature(x = "Column"), }) #' @details -#' \code{sin}: Computes the sine of the given value. Units in radians. +#' \code{sin}: Returns the sine of the given value, +#' as if computed by \code{java.lang.Math.sin()}. Units in radians. #' #' @rdname column_math_functions #' @aliases sin sin,Column-method @@ -1477,7 +1480,8 @@ setMethod("sin", }) #' @details -#' \code{sinh}: Computes the hyperbolic sine of the given value. +#' \code{sinh}: Returns the hyperbolic sine of the given value, +#' as if computed by \code{java.lang.Math.sinh()}. #' #' @rdname column_math_functions #' @aliases sinh sinh,Column-method @@ -1653,7 +1657,9 @@ setMethod("sumDistinct", }) #' @details -#' \code{tan}: Computes the tangent of the given value. Units in radians. +#' \code{tan}: Returns the tangent of the given value, +#' as if computed by \code{java.lang.Math.tan()}. +#' Units in radians. #' #' @rdname column_math_functions #' @aliases tan tan,Column-method @@ -1667,7 +1673,8 @@ setMethod("tan", }) #' @details -#' \code{tanh}: Computes the hyperbolic tangent of the given value. +#' \code{tanh}: Returns the hyperbolic tangent of the given value, +#' as if computed by \code{java.lang.Math.tanh()}. #' #' @rdname column_math_functions #' @aliases tanh tanh,Column-method @@ -1973,7 +1980,8 @@ setMethod("year", #' @details #' \code{atan2}: Returns the angle theta from the conversion of rectangular coordinates -#' (x, y) to polar coordinates (r, theta). Units in radians. +#' (x, y) to polar coordinates (r, theta), +#' as if computed by \code{java.lang.Math.atan2()}. Units in radians. #' #' @rdname column_math_functions #' @aliases atan2 atan2,Column-method diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9bb9c323a5a60..b9c0c57262c5d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -106,18 +106,15 @@ def _(): _functions_1_4 = { # unary math functions - 'acos': 'Computes the cosine inverse of the given value; the returned angle is in the range' + - '0.0 through pi.', - 'asin': 'Computes the sine inverse of the given value; the returned angle is in the range' + - '-pi/2 through pi/2.', - 'atan': 'Computes the tangent inverse of the given value; the returned angle is in the range' + - '-pi/2 through pi/2', + 'acos': ':return: inverse cosine of `col`, as if computed by `java.lang.Math.acos()`', + 'asin': ':return: inverse sine of `col`, as if computed by `java.lang.Math.asin()`', + 'atan': ':return: inverse tangent of `col`, as if computed by `java.lang.Math.atan()`', 'cbrt': 'Computes the cube-root of the given value.', 'ceil': 'Computes the ceiling of the given value.', - 'cos': """Computes the cosine of the given value. - - :param col: :class:`DoubleType` column, units in radians.""", - 'cosh': 'Computes the hyperbolic cosine of the given value.', + 'cos': """:param col: angle in radians + :return: cosine of the angle, as if computed by `java.lang.Math.cos()`.""", + 'cosh': """:param col: hyperbolic angle + :return: hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh()`""", 'exp': 'Computes the exponential of the given value.', 'expm1': 'Computes the exponential of the given value minus one.', 'floor': 'Computes the floor of the given value.', @@ -127,14 +124,16 @@ def _(): 'rint': 'Returns the double value that is closest in value to the argument and' + ' is equal to a mathematical integer.', 'signum': 'Computes the signum of the given value.', - 'sin': """Computes the sine of the given value. - - :param col: :class:`DoubleType` column, units in radians.""", - 'sinh': 'Computes the hyperbolic sine of the given value.', - 'tan': """Computes the tangent of the given value. - - :param col: :class:`DoubleType` column, units in radians.""", - 'tanh': 'Computes the hyperbolic tangent of the given value.', + 'sin': """:param col: angle in radians + :return: sine of the angle, as if computed by `java.lang.Math.sin()`""", + 'sinh': """:param col: hyperbolic angle + :return: hyperbolic sine of the given value, + as if computed by `java.lang.Math.sinh()`""", + 'tan': """:param col: angle in radians + :return: tangent of the given value, as if computed by `java.lang.Math.tan()`""", + 'tanh': """:param col: hyperbolic angle + :return: hyperbolic tangent of the given value, + as if computed by `java.lang.Math.tanh()`""", 'toDegrees': '.. note:: Deprecated in 2.1, use :func:`degrees` instead.', 'toRadians': '.. note:: Deprecated in 2.1, use :func:`radians` instead.', 'bitwiseNOT': 'Computes bitwise not.', @@ -173,16 +172,31 @@ def _(): _functions_2_1 = { # unary math functions - 'degrees': 'Converts an angle measured in radians to an approximately equivalent angle ' + - 'measured in degrees.', - 'radians': 'Converts an angle measured in degrees to an approximately equivalent angle ' + - 'measured in radians.', + 'degrees': """ + Converts an angle measured in radians to an approximately equivalent angle + measured in degrees. + :param col: angle in radians + :return: angle in degrees, as if computed by `java.lang.Math.toDegrees()` + """, + 'radians': """ + Converts an angle measured in degrees to an approximately equivalent angle + measured in radians. + :param col: angle in degrees + :return: angle in radians, as if computed by `java.lang.Math.toRadians()` + """, } # math functions that take two arguments as input _binary_mathfunctions = { - 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + - 'polar coordinates (r, theta). Units in radians.', + 'atan2': """ + :param col1: coordinate on y-axis + :param col2: coordinate on x-axis + :return: the `theta` component of the point + (`r`, `theta`) + in polar coordinates that corresponds to the point + (`x`, `y`) in Cartesian coordinates, + as if computed by `java.lang.Math.atan2()` + """, 'hypot': 'Computes ``sqrt(a^2 + b^2)`` without intermediate overflow or underflow.', 'pow': 'Returns the value of the first argument raised to the power of the second argument.', } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 2c2cf3d2e6227..bc4cfcec47425 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -168,9 +168,11 @@ case class Pi() extends LeafMathExpression(math.Pi, "PI") //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the inverse cosine (a.k.a. arccosine) of `expr` if -1<=`expr`<=1 or NaN otherwise.", + usage = """ + _FUNC_(expr) - Returns the inverse cosine (a.k.a. arc cosine) of `expr`, as if computed by + `java.lang.Math._FUNC_`. + """, examples = """ Examples: > SELECT _FUNC_(1); @@ -178,12 +180,13 @@ case class Pi() extends LeafMathExpression(math.Pi, "PI") > SELECT _FUNC_(2); NaN """) -// scalastyle:on line.size.limit case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the inverse sine (a.k.a. arcsine) the arc sin of `expr` if -1<=`expr`<=1 or NaN otherwise.", + usage = """ + _FUNC_(expr) - Returns the inverse sine (a.k.a. arc sine) the arc sin of `expr`, + as if computed by `java.lang.Math._FUNC_`. + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -191,18 +194,18 @@ case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS" > SELECT _FUNC_(2); NaN """) -// scalastyle:on line.size.limit case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN") -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the inverse tangent (a.k.a. arctangent).", + usage = """ + _FUNC_(expr) - Returns the inverse tangent (a.k.a. arc tangent) of `expr`, as if computed by + `java.lang.Math._FUNC_` + """, examples = """ Examples: > SELECT _FUNC_(0); 0.0 """) -// scalastyle:on line.size.limit case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN") @ExpressionDescription( @@ -252,7 +255,14 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" } @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the cosine of `expr`.", + usage = """ + _FUNC_(expr) - Returns the cosine of `expr`, as if computed by + `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -261,7 +271,14 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the hyperbolic cosine of `expr`.", + usage = """ + _FUNC_(expr) - Returns the hyperbolic cosine of `expr`, as if computed by + `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - hyperbolic angle + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -512,7 +529,11 @@ case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the sine of `expr`.", + usage = "_FUNC_(expr) - Returns the sine of `expr`, as if computed by `java.lang.Math._FUNC_`.", + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -521,7 +542,13 @@ case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "S case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the hyperbolic sine of `expr`.", + usage = """ + _FUNC_(expr) - Returns hyperbolic sine of `expr`, as if computed by `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - hyperbolic angle + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -539,7 +566,13 @@ case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH" case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the tangent of `expr`.", + usage = """ + _FUNC_(expr) - Returns the tangent of `expr`, as if computed by `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -548,7 +581,13 @@ case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT" case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the cotangent of `expr`.", + usage = """ + _FUNC_(expr) - Returns the cotangent of `expr`, as if computed by `1/java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(1); @@ -562,7 +601,14 @@ case class Cot(child: Expression) } @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the hyperbolic tangent of `expr`.", + usage = """ + _FUNC_(expr) - Returns the hyperbolic tangent of `expr`, as if computed by + `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - hyperbolic angle + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -572,6 +618,10 @@ case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH" @ExpressionDescription( usage = "_FUNC_(expr) - Converts radians to degrees.", + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(3.141592653589793); @@ -583,6 +633,10 @@ case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegre @ExpressionDescription( usage = "_FUNC_(expr) - Converts degrees to radians.", + arguments = """ + Arguments: + * expr - angle in degrees + """, examples = """ Examples: > SELECT _FUNC_(180); @@ -768,15 +822,22 @@ case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInp //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr1, expr2) - Returns the angle in radians between the positive x-axis of a plane and the point given by the coordinates (`expr1`, `expr2`).", + usage = """ + _FUNC_(exprY, exprX) - Returns the angle in radians between the positive x-axis of a plane + and the point given by the coordinates (`exprX`, `exprY`), as if computed by + `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * exprY - coordinate on y-axis + * exprX - coordinate on x-axis + """, examples = """ Examples: > SELECT _FUNC_(0, 0); 0.0 """) -// scalastyle:on line.size.limit case class Atan2(left: Expression, right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0d54c02c3d06f..c9ca9a8996344 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1313,8 +1313,7 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Computes the cosine inverse of the given value; the returned angle is in the range - * 0.0 through pi. + * @return inverse cosine of `e` in radians, as if computed by `java.lang.Math.acos` * * @group math_funcs * @since 1.4.0 @@ -1322,8 +1321,7 @@ object functions { def acos(e: Column): Column = withExpr { Acos(e.expr) } /** - * Computes the cosine inverse of the given column; the returned angle is in the range - * 0.0 through pi. + * @return inverse cosine of `columnName`, as if computed by `java.lang.Math.acos` * * @group math_funcs * @since 1.4.0 @@ -1331,8 +1329,7 @@ object functions { def acos(columnName: String): Column = acos(Column(columnName)) /** - * Computes the sine inverse of the given value; the returned angle is in the range - * -pi/2 through pi/2. + * @return inverse sine of `e` in radians, as if computed by `java.lang.Math.asin` * * @group math_funcs * @since 1.4.0 @@ -1340,8 +1337,7 @@ object functions { def asin(e: Column): Column = withExpr { Asin(e.expr) } /** - * Computes the sine inverse of the given column; the returned angle is in the range - * -pi/2 through pi/2. + * @return inverse sine of `columnName`, as if computed by `java.lang.Math.asin` * * @group math_funcs * @since 1.4.0 @@ -1349,8 +1345,7 @@ object functions { def asin(columnName: String): Column = asin(Column(columnName)) /** - * Computes the tangent inverse of the given column; the returned angle is in the range - * -pi/2 through pi/2 + * @return inverse tangent of `e`, as if computed by `java.lang.Math.atan` * * @group math_funcs * @since 1.4.0 @@ -1358,8 +1353,7 @@ object functions { def atan(e: Column): Column = withExpr { Atan(e.expr) } /** - * Computes the tangent inverse of the given column; the returned angle is in the range - * -pi/2 through pi/2 + * @return inverse tangent of `columnName`, as if computed by `java.lang.Math.atan` * * @group math_funcs * @since 1.4.0 @@ -1367,77 +1361,117 @@ object functions { def atan(columnName: String): Column = atan(Column(columnName)) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). Units in radians. + * @param y coordinate on y-axis + * @param x coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, r: Column): Column = withExpr { Atan2(l.expr, r.expr) } + def atan2(y: Column, x: Column): Column = withExpr { Atan2(y.expr, x.expr) } /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param y coordinate on y-axis + * @param xName coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, rightName: String): Column = atan2(l, Column(rightName)) + def atan2(y: Column, xName: String): Column = atan2(y, Column(xName)) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yName coordinate on y-axis + * @param x coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(leftName: String, r: Column): Column = atan2(Column(leftName), r) + def atan2(yName: String, x: Column): Column = atan2(Column(yName), x) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yName coordinate on y-axis + * @param xName coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(leftName: String, rightName: String): Column = - atan2(Column(leftName), Column(rightName)) + def atan2(yName: String, xName: String): Column = + atan2(Column(yName), Column(xName)) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param y coordinate on y-axis + * @param xValue coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, r: Double): Column = atan2(l, lit(r)) + def atan2(y: Column, xValue: Double): Column = atan2(y, lit(xValue)) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yName coordinate on y-axis + * @param xValue coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(leftName: String, r: Double): Column = atan2(Column(leftName), r) + def atan2(yName: String, xValue: Double): Column = atan2(Column(yName), xValue) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yValue coordinate on y-axis + * @param x coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Double, r: Column): Column = atan2(lit(l), r) + def atan2(yValue: Double, x: Column): Column = atan2(lit(yValue), x) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yValue coordinate on y-axis + * @param xName coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Double, rightName: String): Column = atan2(l, Column(rightName)) + def atan2(yValue: Double, xName: String): Column = atan2(yValue, Column(xName)) /** * An expression that returns the string representation of the binary value of the given long @@ -1500,7 +1534,8 @@ object functions { } /** - * Computes the cosine of the given value. Units in radians. + * @param e angle in radians + * @return cosine of the angle, as if computed by `java.lang.Math.cos` * * @group math_funcs * @since 1.4.0 @@ -1508,7 +1543,8 @@ object functions { def cos(e: Column): Column = withExpr { Cos(e.expr) } /** - * Computes the cosine of the given column. + * @param columnName angle in radians + * @return cosine of the angle, as if computed by `java.lang.Math.cos` * * @group math_funcs * @since 1.4.0 @@ -1516,7 +1552,8 @@ object functions { def cos(columnName: String): Column = cos(Column(columnName)) /** - * Computes the hyperbolic cosine of the given value. + * @param e hyperbolic angle + * @return hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh` * * @group math_funcs * @since 1.4.0 @@ -1524,7 +1561,8 @@ object functions { def cosh(e: Column): Column = withExpr { Cosh(e.expr) } /** - * Computes the hyperbolic cosine of the given column. + * @param columnName hyperbolic angle + * @return hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh` * * @group math_funcs * @since 1.4.0 @@ -1967,7 +2005,8 @@ object functions { def signum(columnName: String): Column = signum(Column(columnName)) /** - * Computes the sine of the given value. Units in radians. + * @param e angle in radians + * @return sine of the angle, as if computed by `java.lang.Math.sin` * * @group math_funcs * @since 1.4.0 @@ -1975,7 +2014,8 @@ object functions { def sin(e: Column): Column = withExpr { Sin(e.expr) } /** - * Computes the sine of the given column. + * @param columnName angle in radians + * @return sine of the angle, as if computed by `java.lang.Math.sin` * * @group math_funcs * @since 1.4.0 @@ -1983,7 +2023,8 @@ object functions { def sin(columnName: String): Column = sin(Column(columnName)) /** - * Computes the hyperbolic sine of the given value. + * @param e hyperbolic angle + * @return hyperbolic sine of the given value, as if computed by `java.lang.Math.sinh` * * @group math_funcs * @since 1.4.0 @@ -1991,7 +2032,8 @@ object functions { def sinh(e: Column): Column = withExpr { Sinh(e.expr) } /** - * Computes the hyperbolic sine of the given column. + * @param columnName hyperbolic angle + * @return hyperbolic sine of the given value, as if computed by `java.lang.Math.sinh` * * @group math_funcs * @since 1.4.0 @@ -1999,7 +2041,8 @@ object functions { def sinh(columnName: String): Column = sinh(Column(columnName)) /** - * Computes the tangent of the given value. Units in radians. + * @param e angle in radians + * @return tangent of the given value, as if computed by `java.lang.Math.tan` * * @group math_funcs * @since 1.4.0 @@ -2007,7 +2050,8 @@ object functions { def tan(e: Column): Column = withExpr { Tan(e.expr) } /** - * Computes the tangent of the given column. + * @param columnName angle in radians + * @return tangent of the given value, as if computed by `java.lang.Math.tan` * * @group math_funcs * @since 1.4.0 @@ -2015,7 +2059,8 @@ object functions { def tan(columnName: String): Column = tan(Column(columnName)) /** - * Computes the hyperbolic tangent of the given value. + * @param e hyperbolic angle + * @return hyperbolic tangent of the given value, as if computed by `java.lang.Math.tanh` * * @group math_funcs * @since 1.4.0 @@ -2023,7 +2068,8 @@ object functions { def tanh(e: Column): Column = withExpr { Tanh(e.expr) } /** - * Computes the hyperbolic tangent of the given column. + * @param columnName hyperbolic angle + * @return hyperbolic tangent of the given value, as if computed by `java.lang.Math.tanh` * * @group math_funcs * @since 1.4.0 @@ -2047,6 +2093,9 @@ object functions { /** * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. * + * @param e angle in radians + * @return angle in degrees, as if computed by `java.lang.Math.toDegrees` + * * @group math_funcs * @since 2.1.0 */ @@ -2055,6 +2104,9 @@ object functions { /** * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. * + * @param columnName angle in radians + * @return angle in degrees, as if computed by `java.lang.Math.toDegrees` + * * @group math_funcs * @since 2.1.0 */ @@ -2077,6 +2129,9 @@ object functions { /** * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. * + * @param e angle in degrees + * @return angle in radians, as if computed by `java.lang.Math.toRadians` + * * @group math_funcs * @since 2.1.0 */ @@ -2085,6 +2140,9 @@ object functions { /** * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. * + * @param columnName angle in degrees + * @return angle in radians, as if computed by `java.lang.Math.toRadians` + * * @group math_funcs * @since 2.1.0 */ @@ -2873,7 +2931,7 @@ object functions { * or equal to the `windowDuration`. Check * `org.apache.spark.unsafe.types.CalendarInterval` for valid duration * identifiers. This duration is likewise absolute, and does not vary - * according to a calendar. + * according to a calendar. * @param startTime The offset with respect to 1970-01-01 00:00:00 UTC with which to start * window intervals. For example, in order to have hourly tumbling windows that * start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide @@ -2929,7 +2987,7 @@ object functions { * or equal to the `windowDuration`. Check * `org.apache.spark.unsafe.types.CalendarInterval` for valid duration * identifiers. This duration is likewise absolute, and does not vary - * according to a calendar. + * according to a calendar. * * @group datetime_funcs * @since 2.0.0 From 947b4e6f09db6aa5d92409344b6e273e9faeb24e Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 5 Mar 2018 16:21:02 +0100 Subject: [PATCH 427/774] [SPARK-23510][DOC][FOLLOW-UP] Update spark.sql.hive.metastore.version ## What changes were proposed in this pull request? Update `spark.sql.hive.metastore.version` to 2.3.2, same as HiveUtils.scala: https://github.com/apache/spark/blob/ff1480189b827af0be38605d566a4ee71b4c36f6/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala#L63-L65 ## How was this patch tested? N/A Author: Yuming Wang Closes #20734 from wangyum/SPARK-23510-FOLLOW-UP. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 4d0f015f401bb..01e2076555ee6 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1214,7 +1214,7 @@ The following options can be used to configure the version of Hive that is used 1.2.1 Version of the Hive metastore. Available - options are 0.12.0 through 1.2.1. + options are 0.12.0 through 2.3.2. From 4586eada42d6a16bb78d1650d145531c51fa747f Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Mon, 5 Mar 2018 09:30:49 -0800 Subject: [PATCH 428/774] [SPARK-22430][R][DOCS] Unknown tag warnings when building R docs with Roxygen 6.0.1 ## What changes were proposed in this pull request? Removed export tag to get rid of unknown tag warnings ## How was this patch tested? Existing tests Author: Rekha Joshi Author: rjoshi2 Closes #20501 from rekhajoshm/SPARK-22430. --- R/pkg/R/DataFrame.R | 92 --------- R/pkg/R/SQLContext.R | 16 -- R/pkg/R/WindowSpec.R | 8 - R/pkg/R/broadcast.R | 3 - R/pkg/R/catalog.R | 18 -- R/pkg/R/column.R | 7 - R/pkg/R/context.R | 6 - R/pkg/R/functions.R | 181 ----------------- R/pkg/R/generics.R | 343 --------------------------------- R/pkg/R/group.R | 7 - R/pkg/R/install.R | 1 - R/pkg/R/jvm.R | 3 - R/pkg/R/mllib_classification.R | 20 -- R/pkg/R/mllib_clustering.R | 23 --- R/pkg/R/mllib_fpm.R | 6 - R/pkg/R/mllib_recommendation.R | 5 - R/pkg/R/mllib_regression.R | 17 -- R/pkg/R/mllib_stat.R | 4 - R/pkg/R/mllib_tree.R | 33 ---- R/pkg/R/mllib_utils.R | 3 - R/pkg/R/schema.R | 7 - R/pkg/R/sparkR.R | 7 - R/pkg/R/stats.R | 6 - R/pkg/R/streaming.R | 9 - R/pkg/R/utils.R | 1 - R/pkg/R/window.R | 4 - 26 files changed, 830 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 41c3c3a89fa72..c4852024c0f49 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -36,7 +36,6 @@ setOldClass("structType") #' @slot sdf A Java object reference to the backing Scala DataFrame #' @seealso \link{createDataFrame}, \link{read.json}, \link{table} #' @seealso \url{https://spark.apache.org/docs/latest/sparkr.html#sparkr-dataframes} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -77,7 +76,6 @@ setWriteMode <- function(write, mode) { write } -#' @export #' @param sdf A Java object reference to the backing Scala DataFrame #' @param isCached TRUE if the SparkDataFrame is cached #' @noRd @@ -97,7 +95,6 @@ dataFrame <- function(sdf, isCached = FALSE) { #' @rdname printSchema #' @name printSchema #' @aliases printSchema,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -123,7 +120,6 @@ setMethod("printSchema", #' @rdname schema #' @name schema #' @aliases schema,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -146,7 +142,6 @@ setMethod("schema", #' @aliases explain,SparkDataFrame-method #' @rdname explain #' @name explain -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -178,7 +173,6 @@ setMethod("explain", #' @rdname isLocal #' @name isLocal #' @aliases isLocal,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -209,7 +203,6 @@ setMethod("isLocal", #' @aliases showDF,SparkDataFrame-method #' @rdname showDF #' @name showDF -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -241,7 +234,6 @@ setMethod("showDF", #' @rdname show #' @aliases show,SparkDataFrame-method #' @name show -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -269,7 +261,6 @@ setMethod("show", "SparkDataFrame", #' @rdname dtypes #' @name dtypes #' @aliases dtypes,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -296,7 +287,6 @@ setMethod("dtypes", #' @rdname columns #' @name columns #' @aliases columns,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -388,7 +378,6 @@ setMethod("colnames<-", #' @aliases coltypes,SparkDataFrame-method #' @name coltypes #' @family SparkDataFrame functions -#' @export #' @examples #'\dontrun{ #' irisDF <- createDataFrame(iris) @@ -445,7 +434,6 @@ setMethod("coltypes", #' @rdname coltypes #' @name coltypes<- #' @aliases coltypes<-,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -494,7 +482,6 @@ setMethod("coltypes<-", #' @rdname createOrReplaceTempView #' @name createOrReplaceTempView #' @aliases createOrReplaceTempView,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -521,7 +508,6 @@ setMethod("createOrReplaceTempView", #' @rdname registerTempTable-deprecated #' @name registerTempTable #' @aliases registerTempTable,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -552,7 +538,6 @@ setMethod("registerTempTable", #' @rdname insertInto #' @name insertInto #' @aliases insertInto,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -580,7 +565,6 @@ setMethod("insertInto", #' @aliases cache,SparkDataFrame-method #' @rdname cache #' @name cache -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -611,7 +595,6 @@ setMethod("cache", #' @rdname persist #' @name persist #' @aliases persist,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -641,7 +624,6 @@ setMethod("persist", #' @rdname unpersist #' @aliases unpersist,SparkDataFrame-method #' @name unpersist -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -669,7 +651,6 @@ setMethod("unpersist", #' @rdname storageLevel #' @aliases storageLevel,SparkDataFrame-method #' @name storageLevel -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -707,7 +688,6 @@ setMethod("storageLevel", #' @name coalesce #' @aliases coalesce,SparkDataFrame-method #' @seealso \link{repartition} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -744,7 +724,6 @@ setMethod("coalesce", #' @name repartition #' @aliases repartition,SparkDataFrame-method #' @seealso \link{coalesce} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -793,7 +772,6 @@ setMethod("repartition", #' @rdname toJSON #' @name toJSON #' @aliases toJSON,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -826,7 +804,6 @@ setMethod("toJSON", #' @rdname write.json #' @name write.json #' @aliases write.json,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -858,7 +835,6 @@ setMethod("write.json", #' @aliases write.orc,SparkDataFrame,character-method #' @rdname write.orc #' @name write.orc -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -890,7 +866,6 @@ setMethod("write.orc", #' @rdname write.parquet #' @name write.parquet #' @aliases write.parquet,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -911,7 +886,6 @@ setMethod("write.parquet", #' @rdname write.parquet #' @name saveAsParquetFile #' @aliases saveAsParquetFile,SparkDataFrame,character-method -#' @export #' @note saveAsParquetFile since 1.4.0 setMethod("saveAsParquetFile", signature(x = "SparkDataFrame", path = "character"), @@ -936,7 +910,6 @@ setMethod("saveAsParquetFile", #' @aliases write.text,SparkDataFrame,character-method #' @rdname write.text #' @name write.text -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -963,7 +936,6 @@ setMethod("write.text", #' @aliases distinct,SparkDataFrame-method #' @rdname distinct #' @name distinct -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1004,7 +976,6 @@ setMethod("unique", #' @aliases sample,SparkDataFrame-method #' @rdname sample #' @name sample -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1061,7 +1032,6 @@ setMethod("sample_frac", #' @rdname nrow #' @name nrow #' @aliases count,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1094,7 +1064,6 @@ setMethod("nrow", #' @rdname ncol #' @name ncol #' @aliases ncol,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1118,7 +1087,6 @@ setMethod("ncol", #' @rdname dim #' @aliases dim,SparkDataFrame-method #' @name dim -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1144,7 +1112,6 @@ setMethod("dim", #' @rdname collect #' @aliases collect,SparkDataFrame-method #' @name collect -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1229,7 +1196,6 @@ setMethod("collect", #' @rdname limit #' @name limit #' @aliases limit,SparkDataFrame,numeric-method -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -1253,7 +1219,6 @@ setMethod("limit", #' @rdname take #' @name take #' @aliases take,SparkDataFrame,numeric-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1282,7 +1247,6 @@ setMethod("take", #' @aliases head,SparkDataFrame-method #' @rdname head #' @name head -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1307,7 +1271,6 @@ setMethod("head", #' @aliases first,SparkDataFrame-method #' @rdname first #' @name first -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1359,7 +1322,6 @@ setMethod("toRDD", #' @aliases groupBy,SparkDataFrame-method #' @rdname groupBy #' @name groupBy -#' @export #' @examples #' \dontrun{ #' # Compute the average for all numeric columns grouped by department. @@ -1401,7 +1363,6 @@ setMethod("group_by", #' @aliases agg,SparkDataFrame-method #' @rdname summarize #' @name agg -#' @export #' @note agg since 1.4.0 setMethod("agg", signature(x = "SparkDataFrame"), @@ -1460,7 +1421,6 @@ setClassUnion("characterOrstructType", c("character", "structType")) #' @aliases dapply,SparkDataFrame,function,characterOrstructType-method #' @name dapply #' @seealso \link{dapplyCollect} -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(iris) @@ -1519,7 +1479,6 @@ setMethod("dapply", #' @aliases dapplyCollect,SparkDataFrame,function-method #' @name dapplyCollect #' @seealso \link{dapply} -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(iris) @@ -1576,7 +1535,6 @@ setMethod("dapplyCollect", #' @rdname gapply #' @name gapply #' @seealso \link{gapplyCollect} -#' @export #' @examples #' #' \dontrun{ @@ -1673,7 +1631,6 @@ setMethod("gapply", #' @rdname gapplyCollect #' @name gapplyCollect #' @seealso \link{gapply} -#' @export #' @examples #' #' \dontrun{ @@ -1947,7 +1904,6 @@ setMethod("[", signature(x = "SparkDataFrame"), #' @param ... currently not used. #' @return A new SparkDataFrame containing only the rows that meet the condition with selected #' columns. -#' @export #' @family SparkDataFrame functions #' @aliases subset,SparkDataFrame-method #' @seealso \link{withColumn} @@ -1992,7 +1948,6 @@ setMethod("subset", signature(x = "SparkDataFrame"), #' If more than one column is assigned in \code{col}, \code{...} #' should be left empty. #' @return A new SparkDataFrame with selected columns. -#' @export #' @family SparkDataFrame functions #' @rdname select #' @aliases select,SparkDataFrame,character-method @@ -2024,7 +1979,6 @@ setMethod("select", signature(x = "SparkDataFrame", col = "character"), }) #' @rdname select -#' @export #' @aliases select,SparkDataFrame,Column-method #' @note select(SparkDataFrame, Column) since 1.4.0 setMethod("select", signature(x = "SparkDataFrame", col = "Column"), @@ -2037,7 +1991,6 @@ setMethod("select", signature(x = "SparkDataFrame", col = "Column"), }) #' @rdname select -#' @export #' @aliases select,SparkDataFrame,list-method #' @note select(SparkDataFrame, list) since 1.4.0 setMethod("select", @@ -2066,7 +2019,6 @@ setMethod("select", #' @aliases selectExpr,SparkDataFrame,character-method #' @rdname selectExpr #' @name selectExpr -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2098,7 +2050,6 @@ setMethod("selectExpr", #' @rdname withColumn #' @name withColumn #' @seealso \link{rename} \link{mutate} \link{subset} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2137,7 +2088,6 @@ setMethod("withColumn", #' @rdname mutate #' @name mutate #' @seealso \link{rename} \link{withColumn} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2208,7 +2158,6 @@ setMethod("mutate", }) #' @param _data a SparkDataFrame. -#' @export #' @rdname mutate #' @aliases transform,SparkDataFrame-method #' @name transform @@ -2232,7 +2181,6 @@ setMethod("transform", #' @name withColumnRenamed #' @aliases withColumnRenamed,SparkDataFrame,character,character-method #' @seealso \link{mutate} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2258,7 +2206,6 @@ setMethod("withColumnRenamed", #' @rdname rename #' @name rename #' @aliases rename,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2304,7 +2251,6 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @aliases arrange,SparkDataFrame,Column-method #' @rdname arrange #' @name arrange -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2335,7 +2281,6 @@ setMethod("arrange", #' @rdname arrange #' @name arrange #' @aliases arrange,SparkDataFrame,character-method -#' @export #' @note arrange(SparkDataFrame, character) since 1.4.0 setMethod("arrange", signature(x = "SparkDataFrame", col = "character"), @@ -2368,7 +2313,6 @@ setMethod("arrange", #' @rdname arrange #' @aliases orderBy,SparkDataFrame,characterOrColumn-method -#' @export #' @note orderBy(SparkDataFrame, characterOrColumn) since 1.4.0 setMethod("orderBy", signature(x = "SparkDataFrame", col = "characterOrColumn"), @@ -2389,7 +2333,6 @@ setMethod("orderBy", #' @rdname filter #' @name filter #' @family subsetting functions -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2432,7 +2375,6 @@ setMethod("where", #' @aliases dropDuplicates,SparkDataFrame-method #' @rdname dropDuplicates #' @name dropDuplicates -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2481,7 +2423,6 @@ setMethod("dropDuplicates", #' @rdname join #' @name join #' @seealso \link{merge} \link{crossJoin} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2533,7 +2474,6 @@ setMethod("join", #' @rdname crossJoin #' @name crossJoin #' @seealso \link{merge} \link{join} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2581,7 +2521,6 @@ setMethod("crossJoin", #' @aliases merge,SparkDataFrame,SparkDataFrame-method #' @rdname merge #' @seealso \link{join} \link{crossJoin} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2721,7 +2660,6 @@ genAliasesForIntersectedCols <- function(x, intersectedColNames, suffix) { #' @name union #' @aliases union,SparkDataFrame,SparkDataFrame-method #' @seealso \link{rbind} \link{unionByName} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2742,7 +2680,6 @@ setMethod("union", #' @rdname union #' @name unionAll #' @aliases unionAll,SparkDataFrame,SparkDataFrame-method -#' @export #' @note unionAll since 1.4.0 setMethod("unionAll", signature(x = "SparkDataFrame", y = "SparkDataFrame"), @@ -2769,7 +2706,6 @@ setMethod("unionAll", #' @name unionByName #' @aliases unionByName,SparkDataFrame,SparkDataFrame-method #' @seealso \link{rbind} \link{union} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2802,7 +2738,6 @@ setMethod("unionByName", #' @rdname rbind #' @name rbind #' @seealso \link{union} \link{unionByName} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2835,7 +2770,6 @@ setMethod("rbind", #' @aliases intersect,SparkDataFrame,SparkDataFrame-method #' @rdname intersect #' @name intersect -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2863,7 +2797,6 @@ setMethod("intersect", #' @aliases except,SparkDataFrame,SparkDataFrame-method #' @rdname except #' @name except -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2872,7 +2805,6 @@ setMethod("intersect", #' exceptDF <- except(df, df2) #' } #' @rdname except -#' @export #' @note except since 1.4.0 setMethod("except", signature(x = "SparkDataFrame", y = "SparkDataFrame"), @@ -2909,7 +2841,6 @@ setMethod("except", #' @aliases write.df,SparkDataFrame-method #' @rdname write.df #' @name write.df -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2944,7 +2875,6 @@ setMethod("write.df", #' @rdname write.df #' @name saveDF #' @aliases saveDF,SparkDataFrame,character-method -#' @export #' @note saveDF since 1.4.0 setMethod("saveDF", signature(df = "SparkDataFrame", path = "character"), @@ -2978,7 +2908,6 @@ setMethod("saveDF", #' @aliases saveAsTable,SparkDataFrame,character-method #' @rdname saveAsTable #' @name saveAsTable -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3015,7 +2944,6 @@ setMethod("saveAsTable", #' @aliases describe,SparkDataFrame,character-method describe,SparkDataFrame,ANY-method #' @rdname describe #' @name describe -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3071,7 +2999,6 @@ setMethod("describe", #' @rdname summary #' @name summary #' @aliases summary,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3117,7 +3044,6 @@ setMethod("summary", #' @rdname nafunctions #' @aliases dropna,SparkDataFrame-method #' @name dropna -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3148,7 +3074,6 @@ setMethod("dropna", #' @rdname nafunctions #' @name na.omit #' @aliases na.omit,SparkDataFrame-method -#' @export #' @note na.omit since 1.5.0 setMethod("na.omit", signature(object = "SparkDataFrame"), @@ -3168,7 +3093,6 @@ setMethod("na.omit", #' @rdname nafunctions #' @name fillna #' @aliases fillna,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3399,7 +3323,6 @@ setMethod("str", #' @rdname drop #' @name drop #' @aliases drop,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3427,7 +3350,6 @@ setMethod("drop", #' @name drop #' @rdname drop #' @aliases drop,ANY-method -#' @export setMethod("drop", signature(x = "ANY"), function(x) { @@ -3446,7 +3368,6 @@ setMethod("drop", #' @rdname histogram #' @aliases histogram,SparkDataFrame,characterOrColumn-method #' @family SparkDataFrame functions -#' @export #' @examples #' \dontrun{ #' @@ -3582,7 +3503,6 @@ setMethod("histogram", #' @rdname write.jdbc #' @name write.jdbc #' @aliases write.jdbc,SparkDataFrame,character,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3611,7 +3531,6 @@ setMethod("write.jdbc", #' @aliases randomSplit,SparkDataFrame,numeric-method #' @rdname randomSplit #' @name randomSplit -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3645,7 +3564,6 @@ setMethod("randomSplit", #' @aliases getNumPartitions,SparkDataFrame-method #' @rdname getNumPartitions #' @name getNumPartitions -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3672,7 +3590,6 @@ setMethod("getNumPartitions", #' @rdname isStreaming #' @name isStreaming #' @seealso \link{read.stream} \link{write.stream} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3726,7 +3643,6 @@ setMethod("isStreaming", #' @aliases write.stream,SparkDataFrame-method #' @rdname write.stream #' @name write.stream -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3819,7 +3735,6 @@ setMethod("write.stream", #' @rdname checkpoint #' @name checkpoint #' @seealso \link{setCheckpointDir} -#' @export #' @examples #'\dontrun{ #' setCheckpointDir("/checkpoint") @@ -3847,7 +3762,6 @@ setMethod("checkpoint", #' @aliases localCheckpoint,SparkDataFrame-method #' @rdname localCheckpoint #' @name localCheckpoint -#' @export #' @examples #'\dontrun{ #' df <- localCheckpoint(df) @@ -3874,7 +3788,6 @@ setMethod("localCheckpoint", #' @aliases cube,SparkDataFrame-method #' @rdname cube #' @name cube -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(mtcars) @@ -3909,7 +3822,6 @@ setMethod("cube", #' @aliases rollup,SparkDataFrame-method #' @rdname rollup #' @name rollup -#' @export #' @examples #'\dontrun{ #' df <- createDataFrame(mtcars) @@ -3942,7 +3854,6 @@ setMethod("rollup", #' @aliases hint,SparkDataFrame,character-method #' @rdname hint #' @name hint -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(mtcars) @@ -3966,7 +3877,6 @@ setMethod("hint", #' @family SparkDataFrame functions #' @rdname alias #' @name alias -#' @export #' @examples #' \dontrun{ #' df <- alias(createDataFrame(mtcars), "mtcars") @@ -3997,7 +3907,6 @@ setMethod("alias", #' @family SparkDataFrame functions #' @rdname broadcast #' @name broadcast -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(mtcars) @@ -4041,7 +3950,6 @@ setMethod("broadcast", #' @family SparkDataFrame functions #' @rdname withWatermark #' @name withWatermark -#' @export #' @examples #' \dontrun{ #' sparkR.session() diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 9d0a2d5e074e4..ebec0ce3d1920 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -123,7 +123,6 @@ infer_type <- function(x) { #' @return a list of config values with keys as their names #' @rdname sparkR.conf #' @name sparkR.conf -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -163,7 +162,6 @@ sparkR.conf <- function(key, defaultValue) { #' @return a character string of the Spark version #' @rdname sparkR.version #' @name sparkR.version -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -191,7 +189,6 @@ getDefaultSqlSource <- function() { #' limited by length of the list or number of rows of the data.frame #' @return A SparkDataFrame. #' @rdname createDataFrame -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -294,7 +291,6 @@ createDataFrame <- function(x, ...) { #' @rdname createDataFrame #' @aliases createDataFrame -#' @export #' @method as.DataFrame default #' @note as.DataFrame since 1.6.0 as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0, numPartitions = NULL) { @@ -304,7 +300,6 @@ as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0, numPa #' @param ... additional argument(s). #' @rdname createDataFrame #' @aliases as.DataFrame -#' @export as.DataFrame <- function(data, ...) { dispatchFunc("as.DataFrame(data, schema = NULL)", data, ...) } @@ -342,7 +337,6 @@ setMethod("toDF", signature(x = "RDD"), #' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.json -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -371,7 +365,6 @@ read.json <- function(x, ...) { #' @rdname read.json #' @name jsonFile -#' @export #' @method jsonFile default #' @note jsonFile since 1.4.0 jsonFile.default <- function(path) { @@ -423,7 +416,6 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { #' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.orc -#' @export #' @name read.orc #' @note read.orc since 2.0.0 read.orc <- function(path, ...) { @@ -444,7 +436,6 @@ read.orc <- function(path, ...) { #' @param path path of file to read. A vector of multiple paths is allowed. #' @return SparkDataFrame #' @rdname read.parquet -#' @export #' @name read.parquet #' @method read.parquet default #' @note read.parquet since 1.6.0 @@ -466,7 +457,6 @@ read.parquet <- function(x, ...) { #' @param ... argument(s) passed to the method. #' @rdname read.parquet #' @name parquetFile -#' @export #' @method parquetFile default #' @note parquetFile since 1.4.0 parquetFile.default <- function(...) { @@ -490,7 +480,6 @@ parquetFile <- function(x, ...) { #' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.text -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -522,7 +511,6 @@ read.text <- function(x, ...) { #' @param sqlQuery A character vector containing the SQL query #' @return SparkDataFrame #' @rdname sql -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -556,7 +544,6 @@ sql <- function(x, ...) { #' @return SparkDataFrame #' @rdname tableToDF #' @name tableToDF -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -591,7 +578,6 @@ tableToDF <- function(tableName) { #' @rdname read.df #' @name read.df #' @seealso \link{read.json} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -681,7 +667,6 @@ loadDF <- function(x = NULL, ...) { #' @return SparkDataFrame #' @rdname read.jdbc #' @name read.jdbc -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -734,7 +719,6 @@ read.jdbc <- function(url, tableName, #' @rdname read.stream #' @name read.stream #' @seealso \link{write.stream} -#' @export #' @examples #'\dontrun{ #' sparkR.session() diff --git a/R/pkg/R/WindowSpec.R b/R/pkg/R/WindowSpec.R index debc7cbde55e7..ee7f4adf726e6 100644 --- a/R/pkg/R/WindowSpec.R +++ b/R/pkg/R/WindowSpec.R @@ -28,7 +28,6 @@ NULL #' @seealso \link{windowPartitionBy}, \link{windowOrderBy} #' #' @param sws A Java object reference to the backing Scala WindowSpec -#' @export #' @note WindowSpec since 2.0.0 setClass("WindowSpec", slots = list(sws = "jobj")) @@ -44,7 +43,6 @@ windowSpec <- function(sws) { } #' @rdname show -#' @export #' @note show(WindowSpec) since 2.0.0 setMethod("show", "WindowSpec", function(object) { @@ -63,7 +61,6 @@ setMethod("show", "WindowSpec", #' @name partitionBy #' @aliases partitionBy,WindowSpec-method #' @family windowspec_method -#' @export #' @examples #' \dontrun{ #' partitionBy(ws, "col1", "col2") @@ -97,7 +94,6 @@ setMethod("partitionBy", #' @aliases orderBy,WindowSpec,character-method #' @family windowspec_method #' @seealso See \link{arrange} for use in sorting a SparkDataFrame -#' @export #' @examples #' \dontrun{ #' orderBy(ws, "col1", "col2") @@ -113,7 +109,6 @@ setMethod("orderBy", #' @rdname orderBy #' @name orderBy #' @aliases orderBy,WindowSpec,Column-method -#' @export #' @note orderBy(WindowSpec, Column) since 2.0.0 setMethod("orderBy", signature(x = "WindowSpec", col = "Column"), @@ -142,7 +137,6 @@ setMethod("orderBy", #' @aliases rowsBetween,WindowSpec,numeric,numeric-method #' @name rowsBetween #' @family windowspec_method -#' @export #' @examples #' \dontrun{ #' rowsBetween(ws, 0, 3) @@ -174,7 +168,6 @@ setMethod("rowsBetween", #' @aliases rangeBetween,WindowSpec,numeric,numeric-method #' @name rangeBetween #' @family windowspec_method -#' @export #' @examples #' \dontrun{ #' rangeBetween(ws, 0, 3) @@ -202,7 +195,6 @@ setMethod("rangeBetween", #' @name over #' @aliases over,Column,WindowSpec-method #' @family colum_func -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(mtcars) diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R index 398dffc4ab1b4..282f8a6857738 100644 --- a/R/pkg/R/broadcast.R +++ b/R/pkg/R/broadcast.R @@ -32,14 +32,12 @@ # @seealso broadcast # # @param id Id of the backing Spark broadcast variable -# @export setClass("Broadcast", slots = list(id = "character")) # @rdname broadcast-class # @param value Value of the broadcast variable # @param jBroadcastRef reference to the backing Java broadcast object # @param objName name of broadcasted object -# @export Broadcast <- function(id, value, jBroadcastRef, objName) { .broadcastValues[[id]] <- value .broadcastNames[[as.character(objName)]] <- jBroadcastRef @@ -73,7 +71,6 @@ setMethod("value", # @param bcastId The id of broadcast variable to set # @param value The value to be set -# @export setBroadcastValue <- function(bcastId, value) { bcastIdStr <- as.character(bcastId) .broadcastValues[[bcastIdStr]] <- value diff --git a/R/pkg/R/catalog.R b/R/pkg/R/catalog.R index e59a7024333ac..baf4d861fcf86 100644 --- a/R/pkg/R/catalog.R +++ b/R/pkg/R/catalog.R @@ -34,7 +34,6 @@ #' @return A SparkDataFrame. #' @rdname createExternalTable-deprecated #' @seealso \link{createTable} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -71,7 +70,6 @@ createExternalTable <- function(x, ...) { #' @return A SparkDataFrame. #' @rdname createTable #' @seealso \link{createExternalTable} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -110,7 +108,6 @@ createTable <- function(tableName, path = NULL, source = NULL, schema = NULL, .. #' identifier is provided, it refers to a table in the current database. #' @return SparkDataFrame #' @rdname cacheTable -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -140,7 +137,6 @@ cacheTable <- function(x, ...) { #' identifier is provided, it refers to a table in the current database. #' @return SparkDataFrame #' @rdname uncacheTable -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -167,7 +163,6 @@ uncacheTable <- function(x, ...) { #' Removes all cached tables from the in-memory cache. #' #' @rdname clearCache -#' @export #' @examples #' \dontrun{ #' clearCache() @@ -193,7 +188,6 @@ clearCache <- function() { #' @param tableName The name of the SparkSQL table to be dropped. #' @seealso \link{dropTempView} #' @rdname dropTempTable-deprecated -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -225,7 +219,6 @@ dropTempTable <- function(x, ...) { #' @return TRUE if the view is dropped successfully, FALSE otherwise. #' @rdname dropTempView #' @name dropTempView -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -251,7 +244,6 @@ dropTempView <- function(viewName) { #' @return a SparkDataFrame #' @rdname tables #' @seealso \link{listTables} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -276,7 +268,6 @@ tables <- function(x, ...) { #' @param databaseName (optional) name of the database #' @return a list of table names #' @rdname tableNames -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -304,7 +295,6 @@ tableNames <- function(x, ...) { #' @return name of the current default database. #' @rdname currentDatabase #' @name currentDatabase -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -324,7 +314,6 @@ currentDatabase <- function() { #' @param databaseName name of the database #' @rdname setCurrentDatabase #' @name setCurrentDatabase -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -347,7 +336,6 @@ setCurrentDatabase <- function(databaseName) { #' @return a SparkDataFrame of the list of databases. #' @rdname listDatabases #' @name listDatabases -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -370,7 +358,6 @@ listDatabases <- function() { #' @rdname listTables #' @name listTables #' @seealso \link{tables} -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -403,7 +390,6 @@ listTables <- function(databaseName = NULL) { #' @return a SparkDataFrame of the list of column descriptions. #' @rdname listColumns #' @name listColumns -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -433,7 +419,6 @@ listColumns <- function(tableName, databaseName = NULL) { #' @return a SparkDataFrame of the list of function descriptions. #' @rdname listFunctions #' @name listFunctions -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -463,7 +448,6 @@ listFunctions <- function(databaseName = NULL) { #' identifier is provided, it refers to a table in the current database. #' @rdname recoverPartitions #' @name recoverPartitions -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -490,7 +474,6 @@ recoverPartitions <- function(tableName) { #' identifier is provided, it refers to a table in the current database. #' @rdname refreshTable #' @name refreshTable -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -512,7 +495,6 @@ refreshTable <- function(tableName) { #' @param path the path of the data source. #' @rdname refreshByPath #' @name refreshByPath -#' @export #' @examples #' \dontrun{ #' sparkR.session() diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 3095adb918b67..9727efc354f10 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -29,7 +29,6 @@ setOldClass("jobj") #' @rdname column #' #' @slot jc reference to JVM SparkDataFrame column -#' @export #' @note Column since 1.4.0 setClass("Column", slots = list(jc = "jobj")) @@ -56,7 +55,6 @@ setMethod("column", #' @rdname show #' @name show #' @aliases show,Column-method -#' @export #' @note show(Column) since 1.4.0 setMethod("show", "Column", function(object) { @@ -134,7 +132,6 @@ createMethods() #' @name alias #' @aliases alias,Column-method #' @family colum_func -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(iris) @@ -270,7 +267,6 @@ setMethod("cast", #' @name %in% #' @aliases %in%,Column-method #' @return A matched values as a result of comparing with given values. -#' @export #' @examples #' \dontrun{ #' filter(df, "age in (10, 30)") @@ -296,7 +292,6 @@ setMethod("%in%", #' @name otherwise #' @family colum_func #' @aliases otherwise,Column-method -#' @export #' @note otherwise since 1.5.0 setMethod("otherwise", signature(x = "Column", value = "ANY"), @@ -318,7 +313,6 @@ setMethod("otherwise", #' @rdname eq_null_safe #' @name %<=>% #' @aliases %<=>%,Column-method -#' @export #' @examples #' \dontrun{ #' df1 <- createDataFrame(data.frame( @@ -348,7 +342,6 @@ setMethod("%<=>%", #' @rdname not #' @name not #' @aliases !,Column-method -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(data.frame(x = c(-1, 0, 1))) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 443c2ff8f9ace..8ec727dd042bc 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -308,7 +308,6 @@ setCheckpointDirSC <- function(sc, dirName) { #' @rdname spark.addFile #' @param path The path of the file to be added #' @param recursive Whether to add files recursively from the path. Default is FALSE. -#' @export #' @examples #'\dontrun{ #' spark.addFile("~/myfile") @@ -323,7 +322,6 @@ spark.addFile <- function(path, recursive = FALSE) { #' #' @rdname spark.getSparkFilesRootDirectory #' @return the root directory that contains files added through spark.addFile -#' @export #' @examples #'\dontrun{ #' spark.getSparkFilesRootDirectory() @@ -344,7 +342,6 @@ spark.getSparkFilesRootDirectory <- function() { # nolint #' @rdname spark.getSparkFiles #' @param fileName The name of the file added through spark.addFile #' @return the absolute path of a file added through spark.addFile. -#' @export #' @examples #'\dontrun{ #' spark.getSparkFiles("myfile") @@ -391,7 +388,6 @@ spark.getSparkFiles <- function(fileName) { #' @param list the list of elements #' @param func a function that takes one argument. #' @return a list of results (the exact type being determined by the function) -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -412,7 +408,6 @@ spark.lapply <- function(list, func) { #' #' @rdname setLogLevel #' @param level New log level -#' @export #' @examples #'\dontrun{ #' setLogLevel("ERROR") @@ -431,7 +426,6 @@ setLogLevel <- function(level) { #' @rdname setCheckpointDir #' @param directory Directory path to checkpoint to #' @seealso \link{checkpoint} -#' @export #' @examples #'\dontrun{ #' setCheckpointDir("/checkpoint") diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 29ee146ab14f9..a527426b19674 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -244,7 +244,6 @@ NULL #' If the parameter is a Column, it is returned unchanged. #' #' @rdname column_nonaggregate_functions -#' @export #' @aliases lit lit,ANY-method #' @examples #' @@ -267,7 +266,6 @@ setMethod("lit", signature("ANY"), #' \code{abs}: Computes the absolute value. #' #' @rdname column_math_functions -#' @export #' @aliases abs abs,Column-method #' @note abs since 1.5.0 setMethod("abs", @@ -282,7 +280,6 @@ setMethod("abs", #' as if computed by \code{java.lang.Math.acos()} #' #' @rdname column_math_functions -#' @export #' @aliases acos acos,Column-method #' @note acos since 1.5.0 setMethod("acos", @@ -296,7 +293,6 @@ setMethod("acos", #' \code{approxCountDistinct}: Returns the approximate number of distinct items in a group. #' #' @rdname column_aggregate_functions -#' @export #' @aliases approxCountDistinct approxCountDistinct,Column-method #' @examples #' @@ -319,7 +315,6 @@ setMethod("approxCountDistinct", #' and returns the result as an int column. #' #' @rdname column_string_functions -#' @export #' @aliases ascii ascii,Column-method #' @examples #' @@ -338,7 +333,6 @@ setMethod("ascii", #' as if computed by \code{java.lang.Math.asin()} #' #' @rdname column_math_functions -#' @export #' @aliases asin asin,Column-method #' @note asin since 1.5.0 setMethod("asin", @@ -353,7 +347,6 @@ setMethod("asin", #' as if computed by \code{java.lang.Math.atan()} #' #' @rdname column_math_functions -#' @export #' @aliases atan atan,Column-method #' @note atan since 1.5.0 setMethod("atan", @@ -370,7 +363,6 @@ setMethod("atan", #' @rdname avg #' @name avg #' @family aggregate functions -#' @export #' @aliases avg,Column-method #' @examples \dontrun{avg(df$c)} #' @note avg since 1.4.0 @@ -386,7 +378,6 @@ setMethod("avg", #' a string column. This is the reverse of unbase64. #' #' @rdname column_string_functions -#' @export #' @aliases base64 base64,Column-method #' @examples #' @@ -410,7 +401,6 @@ setMethod("base64", #' of the given long column. For example, bin("12") returns "1100". #' #' @rdname column_math_functions -#' @export #' @aliases bin bin,Column-method #' @note bin since 1.5.0 setMethod("bin", @@ -424,7 +414,6 @@ setMethod("bin", #' \code{bitwiseNOT}: Computes bitwise NOT. #' #' @rdname column_nonaggregate_functions -#' @export #' @aliases bitwiseNOT bitwiseNOT,Column-method #' @examples #' @@ -442,7 +431,6 @@ setMethod("bitwiseNOT", #' \code{cbrt}: Computes the cube-root of the given value. #' #' @rdname column_math_functions -#' @export #' @aliases cbrt cbrt,Column-method #' @note cbrt since 1.4.0 setMethod("cbrt", @@ -456,7 +444,6 @@ setMethod("cbrt", #' \code{ceil}: Computes the ceiling of the given value. #' #' @rdname column_math_functions -#' @export #' @aliases ceil ceil,Column-method #' @note ceil since 1.5.0 setMethod("ceil", @@ -471,7 +458,6 @@ setMethod("ceil", #' #' @rdname column_math_functions #' @aliases ceiling ceiling,Column-method -#' @export #' @note ceiling since 1.5.0 setMethod("ceiling", signature(x = "Column"), @@ -483,7 +469,6 @@ setMethod("ceiling", #' \code{coalesce}: Returns the first column that is not NA, or NA if all inputs are. #' #' @rdname column_nonaggregate_functions -#' @export #' @aliases coalesce,Column-method #' @note coalesce(Column) since 2.1.1 setMethod("coalesce", @@ -514,7 +499,6 @@ col <- function(x) { #' @rdname column #' @name column #' @family non-aggregate functions -#' @export #' @aliases column,character-method #' @examples \dontrun{column("name")} #' @note column since 1.6.0 @@ -533,7 +517,6 @@ setMethod("column", #' @rdname corr #' @name corr #' @family aggregate functions -#' @export #' @aliases corr,Column-method #' @examples #' \dontrun{ @@ -557,7 +540,6 @@ setMethod("corr", signature(x = "Column"), #' @rdname cov #' @name cov #' @family aggregate functions -#' @export #' @aliases cov,characterOrColumn-method #' @examples #' \dontrun{ @@ -598,7 +580,6 @@ setMethod("covar_samp", signature(col1 = "characterOrColumn", col2 = "characterO #' #' @rdname cov #' @name covar_pop -#' @export #' @aliases covar_pop,characterOrColumn,characterOrColumn-method #' @note covar_pop since 2.0.0 setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOrColumn"), @@ -618,7 +599,6 @@ setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOr #' #' @rdname column_math_functions #' @aliases cos cos,Column-method -#' @export #' @note cos since 1.5.0 setMethod("cos", signature(x = "Column"), @@ -633,7 +613,6 @@ setMethod("cos", #' #' @rdname column_math_functions #' @aliases cosh cosh,Column-method -#' @export #' @note cosh since 1.5.0 setMethod("cosh", signature(x = "Column"), @@ -651,7 +630,6 @@ setMethod("cosh", #' @name count #' @family aggregate functions #' @aliases count,Column-method -#' @export #' @examples \dontrun{count(df$c)} #' @note count since 1.4.0 setMethod("count", @@ -667,7 +645,6 @@ setMethod("count", #' #' @rdname column_misc_functions #' @aliases crc32 crc32,Column-method -#' @export #' @note crc32 since 1.5.0 setMethod("crc32", signature(x = "Column"), @@ -682,7 +659,6 @@ setMethod("crc32", #' #' @rdname column_misc_functions #' @aliases hash hash,Column-method -#' @export #' @note hash since 2.0.0 setMethod("hash", signature(x = "Column"), @@ -701,7 +677,6 @@ setMethod("hash", #' #' @rdname column_datetime_functions #' @aliases dayofmonth dayofmonth,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -723,7 +698,6 @@ setMethod("dayofmonth", #' #' @rdname column_datetime_functions #' @aliases dayofweek dayofweek,Column-method -#' @export #' @note dayofweek since 2.3.0 setMethod("dayofweek", signature(x = "Column"), @@ -738,7 +712,6 @@ setMethod("dayofweek", #' #' @rdname column_datetime_functions #' @aliases dayofyear dayofyear,Column-method -#' @export #' @note dayofyear since 1.5.0 setMethod("dayofyear", signature(x = "Column"), @@ -756,7 +729,6 @@ setMethod("dayofyear", #' #' @rdname column_string_functions #' @aliases decode decode,Column,character-method -#' @export #' @note decode since 1.6.0 setMethod("decode", signature(x = "Column", charset = "character"), @@ -771,7 +743,6 @@ setMethod("decode", #' #' @rdname column_string_functions #' @aliases encode encode,Column,character-method -#' @export #' @note encode since 1.6.0 setMethod("encode", signature(x = "Column", charset = "character"), @@ -785,7 +756,6 @@ setMethod("encode", #' #' @rdname column_math_functions #' @aliases exp exp,Column-method -#' @export #' @note exp since 1.5.0 setMethod("exp", signature(x = "Column"), @@ -799,7 +769,6 @@ setMethod("exp", #' #' @rdname column_math_functions #' @aliases expm1 expm1,Column-method -#' @export #' @note expm1 since 1.5.0 setMethod("expm1", signature(x = "Column"), @@ -813,7 +782,6 @@ setMethod("expm1", #' #' @rdname column_math_functions #' @aliases factorial factorial,Column-method -#' @export #' @note factorial since 1.5.0 setMethod("factorial", signature(x = "Column"), @@ -836,7 +804,6 @@ setMethod("factorial", #' @name first #' @aliases first,characterOrColumn-method #' @family aggregate functions -#' @export #' @examples #' \dontrun{ #' first(df$c) @@ -860,7 +827,6 @@ setMethod("first", #' #' @rdname column_math_functions #' @aliases floor floor,Column-method -#' @export #' @note floor since 1.5.0 setMethod("floor", signature(x = "Column"), @@ -874,7 +840,6 @@ setMethod("floor", #' #' @rdname column_math_functions #' @aliases hex hex,Column-method -#' @export #' @note hex since 1.5.0 setMethod("hex", signature(x = "Column"), @@ -888,7 +853,6 @@ setMethod("hex", #' #' @rdname column_datetime_functions #' @aliases hour hour,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -911,7 +875,6 @@ setMethod("hour", #' #' @rdname column_string_functions #' @aliases initcap initcap,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -946,7 +909,6 @@ setMethod("isnan", #' #' @rdname column_nonaggregate_functions #' @aliases is.nan is.nan,Column-method -#' @export #' @note is.nan since 2.0.0 setMethod("is.nan", signature(x = "Column"), @@ -959,7 +921,6 @@ setMethod("is.nan", #' #' @rdname column_aggregate_functions #' @aliases kurtosis kurtosis,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -988,7 +949,6 @@ setMethod("kurtosis", #' @name last #' @aliases last,characterOrColumn-method #' @family aggregate functions -#' @export #' @examples #' \dontrun{ #' last(df$c) @@ -1014,7 +974,6 @@ setMethod("last", #' #' @rdname column_datetime_functions #' @aliases last_day last_day,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1034,7 +993,6 @@ setMethod("last_day", #' #' @rdname column_string_functions #' @aliases length length,Column-method -#' @export #' @note length since 1.5.0 setMethod("length", signature(x = "Column"), @@ -1048,7 +1006,6 @@ setMethod("length", #' #' @rdname column_math_functions #' @aliases log log,Column-method -#' @export #' @note log since 1.5.0 setMethod("log", signature(x = "Column"), @@ -1062,7 +1019,6 @@ setMethod("log", #' #' @rdname column_math_functions #' @aliases log10 log10,Column-method -#' @export #' @note log10 since 1.5.0 setMethod("log10", signature(x = "Column"), @@ -1076,7 +1032,6 @@ setMethod("log10", #' #' @rdname column_math_functions #' @aliases log1p log1p,Column-method -#' @export #' @note log1p since 1.5.0 setMethod("log1p", signature(x = "Column"), @@ -1090,7 +1045,6 @@ setMethod("log1p", #' #' @rdname column_math_functions #' @aliases log2 log2,Column-method -#' @export #' @note log2 since 1.5.0 setMethod("log2", signature(x = "Column"), @@ -1104,7 +1058,6 @@ setMethod("log2", #' #' @rdname column_string_functions #' @aliases lower lower,Column-method -#' @export #' @note lower since 1.4.0 setMethod("lower", signature(x = "Column"), @@ -1119,7 +1072,6 @@ setMethod("lower", #' #' @rdname column_string_functions #' @aliases ltrim ltrim,Column,missing-method -#' @export #' @examples #' #' \dontrun{ @@ -1143,7 +1095,6 @@ setMethod("ltrim", #' @param trimString a character string to trim with #' @rdname column_string_functions #' @aliases ltrim,Column,character-method -#' @export #' @note ltrim(Column, character) since 2.3.0 setMethod("ltrim", signature(x = "Column", trimString = "character"), @@ -1171,7 +1122,6 @@ setMethod("max", #' #' @rdname column_misc_functions #' @aliases md5 md5,Column-method -#' @export #' @note md5 since 1.5.0 setMethod("md5", signature(x = "Column"), @@ -1185,7 +1135,6 @@ setMethod("md5", #' #' @rdname column_aggregate_functions #' @aliases mean mean,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1211,7 +1160,6 @@ setMethod("mean", #' #' @rdname column_aggregate_functions #' @aliases min min,Column-method -#' @export #' @note min since 1.5.0 setMethod("min", signature(x = "Column"), @@ -1225,7 +1173,6 @@ setMethod("min", #' #' @rdname column_datetime_functions #' @aliases minute minute,Column-method -#' @export #' @note minute since 1.5.0 setMethod("minute", signature(x = "Column"), @@ -1248,7 +1195,6 @@ setMethod("minute", #' #' @rdname column_nonaggregate_functions #' @aliases monotonically_increasing_id monotonically_increasing_id,missing-method -#' @export #' @examples #' #' \dontrun{head(select(df, monotonically_increasing_id()))} @@ -1264,7 +1210,6 @@ setMethod("monotonically_increasing_id", #' #' @rdname column_datetime_functions #' @aliases month month,Column-method -#' @export #' @note month since 1.5.0 setMethod("month", signature(x = "Column"), @@ -1278,7 +1223,6 @@ setMethod("month", #' #' @rdname column_nonaggregate_functions #' @aliases negate negate,Column-method -#' @export #' @note negate since 1.5.0 setMethod("negate", signature(x = "Column"), @@ -1292,7 +1236,6 @@ setMethod("negate", #' #' @rdname column_datetime_functions #' @aliases quarter quarter,Column-method -#' @export #' @note quarter since 1.5.0 setMethod("quarter", signature(x = "Column"), @@ -1306,7 +1249,6 @@ setMethod("quarter", #' #' @rdname column_string_functions #' @aliases reverse reverse,Column-method -#' @export #' @note reverse since 1.5.0 setMethod("reverse", signature(x = "Column"), @@ -1321,7 +1263,6 @@ setMethod("reverse", #' #' @rdname column_math_functions #' @aliases rint rint,Column-method -#' @export #' @note rint since 1.5.0 setMethod("rint", signature(x = "Column"), @@ -1336,7 +1277,6 @@ setMethod("rint", #' #' @rdname column_math_functions #' @aliases round round,Column-method -#' @export #' @note round since 1.5.0 setMethod("round", signature(x = "Column"), @@ -1356,7 +1296,6 @@ setMethod("round", #' to the left of the decimal point when \code{scale} < 0. #' @rdname column_math_functions #' @aliases bround bround,Column-method -#' @export #' @note bround since 2.0.0 setMethod("bround", signature(x = "Column"), @@ -1371,7 +1310,6 @@ setMethod("bround", #' #' @rdname column_string_functions #' @aliases rtrim rtrim,Column,missing-method -#' @export #' @note rtrim since 1.5.0 setMethod("rtrim", signature(x = "Column", trimString = "missing"), @@ -1382,7 +1320,6 @@ setMethod("rtrim", #' @rdname column_string_functions #' @aliases rtrim,Column,character-method -#' @export #' @note rtrim(Column, character) since 2.3.0 setMethod("rtrim", signature(x = "Column", trimString = "character"), @@ -1396,7 +1333,6 @@ setMethod("rtrim", #' #' @rdname column_aggregate_functions #' @aliases sd sd,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1414,7 +1350,6 @@ setMethod("sd", #' #' @rdname column_datetime_functions #' @aliases second second,Column-method -#' @export #' @note second since 1.5.0 setMethod("second", signature(x = "Column"), @@ -1429,7 +1364,6 @@ setMethod("second", #' #' @rdname column_misc_functions #' @aliases sha1 sha1,Column-method -#' @export #' @note sha1 since 1.5.0 setMethod("sha1", signature(x = "Column"), @@ -1443,7 +1377,6 @@ setMethod("sha1", #' #' @rdname column_math_functions #' @aliases signum signum,Column-method -#' @export #' @note signum since 1.5.0 setMethod("signum", signature(x = "Column"), @@ -1457,7 +1390,6 @@ setMethod("signum", #' #' @rdname column_math_functions #' @aliases sign sign,Column-method -#' @export #' @note sign since 1.5.0 setMethod("sign", signature(x = "Column"), function(x) { @@ -1470,7 +1402,6 @@ setMethod("sign", signature(x = "Column"), #' #' @rdname column_math_functions #' @aliases sin sin,Column-method -#' @export #' @note sin since 1.5.0 setMethod("sin", signature(x = "Column"), @@ -1485,7 +1416,6 @@ setMethod("sin", #' #' @rdname column_math_functions #' @aliases sinh sinh,Column-method -#' @export #' @note sinh since 1.5.0 setMethod("sinh", signature(x = "Column"), @@ -1499,7 +1429,6 @@ setMethod("sinh", #' #' @rdname column_aggregate_functions #' @aliases skewness skewness,Column-method -#' @export #' @note skewness since 1.6.0 setMethod("skewness", signature(x = "Column"), @@ -1513,7 +1442,6 @@ setMethod("skewness", #' #' @rdname column_string_functions #' @aliases soundex soundex,Column-method -#' @export #' @note soundex since 1.5.0 setMethod("soundex", signature(x = "Column"), @@ -1530,7 +1458,6 @@ setMethod("soundex", #' #' @rdname column_nonaggregate_functions #' @aliases spark_partition_id spark_partition_id,missing-method -#' @export #' @examples #' #' \dontrun{head(select(df, spark_partition_id()))} @@ -1560,7 +1487,6 @@ setMethod("stddev", #' #' @rdname column_aggregate_functions #' @aliases stddev_pop stddev_pop,Column-method -#' @export #' @note stddev_pop since 1.6.0 setMethod("stddev_pop", signature(x = "Column"), @@ -1574,7 +1500,6 @@ setMethod("stddev_pop", #' #' @rdname column_aggregate_functions #' @aliases stddev_samp stddev_samp,Column-method -#' @export #' @note stddev_samp since 1.6.0 setMethod("stddev_samp", signature(x = "Column"), @@ -1588,7 +1513,6 @@ setMethod("stddev_samp", #' #' @rdname column_nonaggregate_functions #' @aliases struct struct,characterOrColumn-method -#' @export #' @examples #' #' \dontrun{ @@ -1614,7 +1538,6 @@ setMethod("struct", #' #' @rdname column_math_functions #' @aliases sqrt sqrt,Column-method -#' @export #' @note sqrt since 1.5.0 setMethod("sqrt", signature(x = "Column"), @@ -1628,7 +1551,6 @@ setMethod("sqrt", #' #' @rdname column_aggregate_functions #' @aliases sum sum,Column-method -#' @export #' @note sum since 1.5.0 setMethod("sum", signature(x = "Column"), @@ -1642,7 +1564,6 @@ setMethod("sum", #' #' @rdname column_aggregate_functions #' @aliases sumDistinct sumDistinct,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1663,7 +1584,6 @@ setMethod("sumDistinct", #' #' @rdname column_math_functions #' @aliases tan tan,Column-method -#' @export #' @note tan since 1.5.0 setMethod("tan", signature(x = "Column"), @@ -1678,7 +1598,6 @@ setMethod("tan", #' #' @rdname column_math_functions #' @aliases tanh tanh,Column-method -#' @export #' @note tanh since 1.5.0 setMethod("tanh", signature(x = "Column"), @@ -1693,7 +1612,6 @@ setMethod("tanh", #' #' @rdname column_math_functions #' @aliases toDegrees toDegrees,Column-method -#' @export #' @note toDegrees since 1.4.0 setMethod("toDegrees", signature(x = "Column"), @@ -1708,7 +1626,6 @@ setMethod("toDegrees", #' #' @rdname column_math_functions #' @aliases toRadians toRadians,Column-method -#' @export #' @note toRadians since 1.4.0 setMethod("toRadians", signature(x = "Column"), @@ -1728,7 +1645,6 @@ setMethod("toRadians", #' #' @rdname column_datetime_functions #' @aliases to_date to_date,Column,missing-method -#' @export #' @examples #' #' \dontrun{ @@ -1749,7 +1665,6 @@ setMethod("to_date", #' @rdname column_datetime_functions #' @aliases to_date,Column,character-method -#' @export #' @note to_date(Column, character) since 2.2.0 setMethod("to_date", signature(x = "Column", format = "character"), @@ -1765,7 +1680,6 @@ setMethod("to_date", #' #' @rdname column_collection_functions #' @aliases to_json to_json,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1803,7 +1717,6 @@ setMethod("to_json", signature(x = "Column"), #' #' @rdname column_datetime_functions #' @aliases to_timestamp to_timestamp,Column,missing-method -#' @export #' @note to_timestamp(Column) since 2.2.0 setMethod("to_timestamp", signature(x = "Column", format = "missing"), @@ -1814,7 +1727,6 @@ setMethod("to_timestamp", #' @rdname column_datetime_functions #' @aliases to_timestamp,Column,character-method -#' @export #' @note to_timestamp(Column, character) since 2.2.0 setMethod("to_timestamp", signature(x = "Column", format = "character"), @@ -1829,7 +1741,6 @@ setMethod("to_timestamp", #' #' @rdname column_string_functions #' @aliases trim trim,Column,missing-method -#' @export #' @note trim since 1.5.0 setMethod("trim", signature(x = "Column", trimString = "missing"), @@ -1840,7 +1751,6 @@ setMethod("trim", #' @rdname column_string_functions #' @aliases trim,Column,character-method -#' @export #' @note trim(Column, character) since 2.3.0 setMethod("trim", signature(x = "Column", trimString = "character"), @@ -1855,7 +1765,6 @@ setMethod("trim", #' #' @rdname column_string_functions #' @aliases unbase64 unbase64,Column-method -#' @export #' @note unbase64 since 1.5.0 setMethod("unbase64", signature(x = "Column"), @@ -1870,7 +1779,6 @@ setMethod("unbase64", #' #' @rdname column_math_functions #' @aliases unhex unhex,Column-method -#' @export #' @note unhex since 1.5.0 setMethod("unhex", signature(x = "Column"), @@ -1884,7 +1792,6 @@ setMethod("unhex", #' #' @rdname column_string_functions #' @aliases upper upper,Column-method -#' @export #' @note upper since 1.4.0 setMethod("upper", signature(x = "Column"), @@ -1898,7 +1805,6 @@ setMethod("upper", #' #' @rdname column_aggregate_functions #' @aliases var var,Column-method -#' @export #' @examples #' #'\dontrun{ @@ -1913,7 +1819,6 @@ setMethod("var", #' @rdname column_aggregate_functions #' @aliases variance variance,Column-method -#' @export #' @note variance since 1.6.0 setMethod("variance", signature(x = "Column"), @@ -1927,7 +1832,6 @@ setMethod("variance", #' #' @rdname column_aggregate_functions #' @aliases var_pop var_pop,Column-method -#' @export #' @note var_pop since 1.5.0 setMethod("var_pop", signature(x = "Column"), @@ -1941,7 +1845,6 @@ setMethod("var_pop", #' #' @rdname column_aggregate_functions #' @aliases var_samp var_samp,Column-method -#' @export #' @note var_samp since 1.6.0 setMethod("var_samp", signature(x = "Column"), @@ -1955,7 +1858,6 @@ setMethod("var_samp", #' #' @rdname column_datetime_functions #' @aliases weekofyear weekofyear,Column-method -#' @export #' @note weekofyear since 1.5.0 setMethod("weekofyear", signature(x = "Column"), @@ -1969,7 +1871,6 @@ setMethod("weekofyear", #' #' @rdname column_datetime_functions #' @aliases year year,Column-method -#' @export #' @note year since 1.5.0 setMethod("year", signature(x = "Column"), @@ -1985,7 +1886,6 @@ setMethod("year", #' #' @rdname column_math_functions #' @aliases atan2 atan2,Column-method -#' @export #' @note atan2 since 1.5.0 setMethod("atan2", signature(y = "Column"), function(y, x) { @@ -2001,7 +1901,6 @@ setMethod("atan2", signature(y = "Column"), #' #' @rdname column_datetime_diff_functions #' @aliases datediff datediff,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2025,7 +1924,6 @@ setMethod("datediff", signature(y = "Column"), #' #' @rdname column_math_functions #' @aliases hypot hypot,Column-method -#' @export #' @note hypot since 1.4.0 setMethod("hypot", signature(y = "Column"), function(y, x) { @@ -2041,7 +1939,6 @@ setMethod("hypot", signature(y = "Column"), #' #' @rdname column_string_functions #' @aliases levenshtein levenshtein,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2064,7 +1961,6 @@ setMethod("levenshtein", signature(y = "Column"), #' #' @rdname column_datetime_diff_functions #' @aliases months_between months_between,Column-method -#' @export #' @note months_between since 1.5.0 setMethod("months_between", signature(y = "Column"), function(y, x) { @@ -2082,7 +1978,6 @@ setMethod("months_between", signature(y = "Column"), #' #' @rdname column_nonaggregate_functions #' @aliases nanvl nanvl,Column-method -#' @export #' @note nanvl since 1.5.0 setMethod("nanvl", signature(y = "Column"), function(y, x) { @@ -2099,7 +1994,6 @@ setMethod("nanvl", signature(y = "Column"), #' #' @rdname column_math_functions #' @aliases pmod pmod,Column-method -#' @export #' @note pmod since 1.5.0 setMethod("pmod", signature(y = "Column"), function(y, x) { @@ -2114,7 +2008,6 @@ setMethod("pmod", signature(y = "Column"), #' #' @rdname column_aggregate_functions #' @aliases approxCountDistinct,Column-method -#' @export #' @note approxCountDistinct(Column, numeric) since 1.4.0 setMethod("approxCountDistinct", signature(x = "Column"), @@ -2128,7 +2021,6 @@ setMethod("approxCountDistinct", #' #' @rdname column_aggregate_functions #' @aliases countDistinct countDistinct,Column-method -#' @export #' @note countDistinct since 1.4.0 setMethod("countDistinct", signature(x = "Column"), @@ -2148,7 +2040,6 @@ setMethod("countDistinct", #' #' @rdname column_string_functions #' @aliases concat concat,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2177,7 +2068,6 @@ setMethod("concat", #' #' @rdname column_nonaggregate_functions #' @aliases greatest greatest,Column-method -#' @export #' @note greatest since 1.5.0 setMethod("greatest", signature(x = "Column"), @@ -2197,7 +2087,6 @@ setMethod("greatest", #' #' @rdname column_nonaggregate_functions #' @aliases least least,Column-method -#' @export #' @note least since 1.5.0 setMethod("least", signature(x = "Column"), @@ -2216,7 +2105,6 @@ setMethod("least", #' #' @rdname column_aggregate_functions #' @aliases n_distinct n_distinct,Column-method -#' @export #' @note n_distinct since 1.4.0 setMethod("n_distinct", signature(x = "Column"), function(x, ...) { @@ -2226,7 +2114,6 @@ setMethod("n_distinct", signature(x = "Column"), #' @rdname count #' @name n #' @aliases n,Column-method -#' @export #' @examples \dontrun{n(df$c)} #' @note n since 1.4.0 setMethod("n", signature(x = "Column"), @@ -2245,7 +2132,6 @@ setMethod("n", signature(x = "Column"), #' @rdname column_datetime_diff_functions #' #' @aliases date_format date_format,Column,character-method -#' @export #' @note date_format since 1.5.0 setMethod("date_format", signature(y = "Column", x = "character"), function(y, x) { @@ -2263,7 +2149,6 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' Since Spark 2.3, the DDL-formatted string is also supported for the schema. #' @param as.json.array indicating if input string is JSON array of objects or a single object. #' @aliases from_json from_json,Column,characterOrstructType-method -#' @export #' @examples #' #' \dontrun{ @@ -2306,7 +2191,6 @@ setMethod("from_json", signature(x = "Column", schema = "characterOrstructType") #' @rdname column_datetime_diff_functions #' #' @aliases from_utc_timestamp from_utc_timestamp,Column,character-method -#' @export #' @examples #' #' \dontrun{ @@ -2328,7 +2212,6 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), #' #' @rdname column_string_functions #' @aliases instr instr,Column,character-method -#' @export #' @examples #' #' \dontrun{ @@ -2351,7 +2234,6 @@ setMethod("instr", signature(y = "Column", x = "character"), #' #' @rdname column_datetime_diff_functions #' @aliases next_day next_day,Column,character-method -#' @export #' @note next_day since 1.5.0 setMethod("next_day", signature(y = "Column", x = "character"), function(y, x) { @@ -2366,7 +2248,6 @@ setMethod("next_day", signature(y = "Column", x = "character"), #' #' @rdname column_datetime_diff_functions #' @aliases to_utc_timestamp to_utc_timestamp,Column,character-method -#' @export #' @note to_utc_timestamp since 1.5.0 setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { @@ -2379,7 +2260,6 @@ setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), #' #' @rdname column_datetime_diff_functions #' @aliases add_months add_months,Column,numeric-method -#' @export #' @examples #' #' \dontrun{ @@ -2400,7 +2280,6 @@ setMethod("add_months", signature(y = "Column", x = "numeric"), #' #' @rdname column_datetime_diff_functions #' @aliases date_add date_add,Column,numeric-method -#' @export #' @note date_add since 1.5.0 setMethod("date_add", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2414,7 +2293,6 @@ setMethod("date_add", signature(y = "Column", x = "numeric"), #' @rdname column_datetime_diff_functions #' #' @aliases date_sub date_sub,Column,numeric-method -#' @export #' @note date_sub since 1.5.0 setMethod("date_sub", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2431,7 +2309,6 @@ setMethod("date_sub", signature(y = "Column", x = "numeric"), #' #' @rdname column_string_functions #' @aliases format_number format_number,Column,numeric-method -#' @export #' @examples #' #' \dontrun{ @@ -2454,7 +2331,6 @@ setMethod("format_number", signature(y = "Column", x = "numeric"), #' #' @rdname column_misc_functions #' @aliases sha2 sha2,Column,numeric-method -#' @export #' @note sha2 since 1.5.0 setMethod("sha2", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2468,7 +2344,6 @@ setMethod("sha2", signature(y = "Column", x = "numeric"), #' #' @rdname column_math_functions #' @aliases shiftLeft shiftLeft,Column,numeric-method -#' @export #' @note shiftLeft since 1.5.0 setMethod("shiftLeft", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2484,7 +2359,6 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), #' #' @rdname column_math_functions #' @aliases shiftRight shiftRight,Column,numeric-method -#' @export #' @note shiftRight since 1.5.0 setMethod("shiftRight", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2500,7 +2374,6 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"), #' #' @rdname column_math_functions #' @aliases shiftRightUnsigned shiftRightUnsigned,Column,numeric-method -#' @export #' @note shiftRightUnsigned since 1.5.0 setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2517,7 +2390,6 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' @param sep separator to use. #' @rdname column_string_functions #' @aliases concat_ws concat_ws,character,Column-method -#' @export #' @note concat_ws since 1.5.0 setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { @@ -2533,7 +2405,6 @@ setMethod("concat_ws", signature(sep = "character", x = "Column"), #' @param toBase base to convert to. #' @rdname column_math_functions #' @aliases conv conv,Column,numeric,numeric-method -#' @export #' @note conv since 1.5.0 setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeric"), function(x, fromBase, toBase) { @@ -2551,7 +2422,6 @@ setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeri #' #' @rdname column_nonaggregate_functions #' @aliases expr expr,character-method -#' @export #' @note expr since 1.5.0 setMethod("expr", signature(x = "character"), function(x) { @@ -2566,7 +2436,6 @@ setMethod("expr", signature(x = "character"), #' @param format a character object of format strings. #' @rdname column_string_functions #' @aliases format_string format_string,character,Column-method -#' @export #' @note format_string since 1.5.0 setMethod("format_string", signature(format = "character", x = "Column"), function(format, x, ...) { @@ -2587,7 +2456,6 @@ setMethod("format_string", signature(format = "character", x = "Column"), #' @rdname column_datetime_functions #' #' @aliases from_unixtime from_unixtime,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2629,7 +2497,6 @@ setMethod("from_unixtime", signature(x = "Column"), #' \code{startTime} as \code{"15 minutes"}. #' @rdname column_datetime_functions #' @aliases window window,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2680,7 +2547,6 @@ setMethod("window", signature(x = "Column"), #' @param pos start position of search. #' @rdname column_string_functions #' @aliases locate locate,character,Column-method -#' @export #' @note locate since 1.5.0 setMethod("locate", signature(substr = "character", str = "Column"), function(substr, str, pos = 1) { @@ -2697,7 +2563,6 @@ setMethod("locate", signature(substr = "character", str = "Column"), #' @param pad a character string to be padded with. #' @rdname column_string_functions #' @aliases lpad lpad,Column,numeric,character-method -#' @export #' @note lpad since 1.5.0 setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { @@ -2714,7 +2579,6 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' @rdname column_nonaggregate_functions #' @param seed a random seed. Can be missing. #' @aliases rand rand,missing-method -#' @export #' @examples #' #' \dontrun{ @@ -2729,7 +2593,6 @@ setMethod("rand", signature(seed = "missing"), #' @rdname column_nonaggregate_functions #' @aliases rand,numeric-method -#' @export #' @note rand(numeric) since 1.5.0 setMethod("rand", signature(seed = "numeric"), function(seed) { @@ -2743,7 +2606,6 @@ setMethod("rand", signature(seed = "numeric"), #' #' @rdname column_nonaggregate_functions #' @aliases randn randn,missing-method -#' @export #' @note randn since 1.5.0 setMethod("randn", signature(seed = "missing"), function(seed) { @@ -2753,7 +2615,6 @@ setMethod("randn", signature(seed = "missing"), #' @rdname column_nonaggregate_functions #' @aliases randn,numeric-method -#' @export #' @note randn(numeric) since 1.5.0 setMethod("randn", signature(seed = "numeric"), function(seed) { @@ -2770,7 +2631,6 @@ setMethod("randn", signature(seed = "numeric"), #' @param idx a group index. #' @rdname column_string_functions #' @aliases regexp_extract regexp_extract,Column,character,numeric-method -#' @export #' @examples #' #' \dontrun{ @@ -2799,7 +2659,6 @@ setMethod("regexp_extract", #' @param replacement a character string that a matched \code{pattern} is replaced with. #' @rdname column_string_functions #' @aliases regexp_replace regexp_replace,Column,character,character-method -#' @export #' @note regexp_replace since 1.5.0 setMethod("regexp_replace", signature(x = "Column", pattern = "character", replacement = "character"), @@ -2815,7 +2674,6 @@ setMethod("regexp_replace", #' #' @rdname column_string_functions #' @aliases rpad rpad,Column,numeric,character-method -#' @export #' @note rpad since 1.5.0 setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { @@ -2838,7 +2696,6 @@ setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), #' counting from the right. #' @rdname column_string_functions #' @aliases substring_index substring_index,Column,character,numeric-method -#' @export #' @note substring_index since 1.5.0 setMethod("substring_index", signature(x = "Column", delim = "character", count = "numeric"), @@ -2861,7 +2718,6 @@ setMethod("substring_index", #' at the same location, if any. #' @rdname column_string_functions #' @aliases translate translate,Column,character,character-method -#' @export #' @note translate since 1.5.0 setMethod("translate", signature(x = "Column", matchingString = "character", replaceString = "character"), @@ -2876,7 +2732,6 @@ setMethod("translate", #' #' @rdname column_datetime_functions #' @aliases unix_timestamp unix_timestamp,missing,missing-method -#' @export #' @note unix_timestamp since 1.5.0 setMethod("unix_timestamp", signature(x = "missing", format = "missing"), function(x, format) { @@ -2886,7 +2741,6 @@ setMethod("unix_timestamp", signature(x = "missing", format = "missing"), #' @rdname column_datetime_functions #' @aliases unix_timestamp,Column,missing-method -#' @export #' @note unix_timestamp(Column) since 1.5.0 setMethod("unix_timestamp", signature(x = "Column", format = "missing"), function(x, format) { @@ -2896,7 +2750,6 @@ setMethod("unix_timestamp", signature(x = "Column", format = "missing"), #' @rdname column_datetime_functions #' @aliases unix_timestamp,Column,character-method -#' @export #' @note unix_timestamp(Column, character) since 1.5.0 setMethod("unix_timestamp", signature(x = "Column", format = "character"), function(x, format = "yyyy-MM-dd HH:mm:ss") { @@ -2912,7 +2765,6 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), #' @param condition the condition to test on. Must be a Column expression. #' @param value result expression. #' @aliases when when,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2941,7 +2793,6 @@ setMethod("when", signature(condition = "Column", value = "ANY"), #' @param yes return values for \code{TRUE} elements of test. #' @param no return values for \code{FALSE} elements of test. #' @aliases ifelse ifelse,Column-method -#' @export #' @note ifelse since 1.5.0 setMethod("ifelse", signature(test = "Column", yes = "ANY", no = "ANY"), @@ -2967,7 +2818,6 @@ setMethod("ifelse", #' #' @rdname column_window_functions #' @aliases cume_dist cume_dist,missing-method -#' @export #' @note cume_dist since 1.6.0 setMethod("cume_dist", signature("missing"), @@ -2988,7 +2838,6 @@ setMethod("cume_dist", #' #' @rdname column_window_functions #' @aliases dense_rank dense_rank,missing-method -#' @export #' @note dense_rank since 1.6.0 setMethod("dense_rank", signature("missing"), @@ -3005,7 +2854,6 @@ setMethod("dense_rank", #' #' @rdname column_window_functions #' @aliases lag lag,characterOrColumn-method -#' @export #' @note lag since 1.6.0 setMethod("lag", signature(x = "characterOrColumn"), @@ -3030,7 +2878,6 @@ setMethod("lag", #' #' @rdname column_window_functions #' @aliases lead lead,characterOrColumn,numeric-method -#' @export #' @note lead since 1.6.0 setMethod("lead", signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"), @@ -3054,7 +2901,6 @@ setMethod("lead", #' #' @rdname column_window_functions #' @aliases ntile ntile,numeric-method -#' @export #' @note ntile since 1.6.0 setMethod("ntile", signature(x = "numeric"), @@ -3072,7 +2918,6 @@ setMethod("ntile", #' #' @rdname column_window_functions #' @aliases percent_rank percent_rank,missing-method -#' @export #' @note percent_rank since 1.6.0 setMethod("percent_rank", signature("missing"), @@ -3093,7 +2938,6 @@ setMethod("percent_rank", #' #' @rdname column_window_functions #' @aliases rank rank,missing-method -#' @export #' @note rank since 1.6.0 setMethod("rank", signature(x = "missing"), @@ -3104,7 +2948,6 @@ setMethod("rank", #' @rdname column_window_functions #' @aliases rank,ANY-method -#' @export setMethod("rank", signature(x = "ANY"), function(x, ...) { @@ -3118,7 +2961,6 @@ setMethod("rank", #' #' @rdname column_window_functions #' @aliases row_number row_number,missing-method -#' @export #' @note row_number since 1.6.0 setMethod("row_number", signature("missing"), @@ -3136,7 +2978,6 @@ setMethod("row_number", #' @param value a value to be checked if contained in the column #' @rdname column_collection_functions #' @aliases array_contains array_contains,Column-method -#' @export #' @note array_contains since 1.6.0 setMethod("array_contains", signature(x = "Column", value = "ANY"), @@ -3150,7 +2991,6 @@ setMethod("array_contains", #' #' @rdname column_collection_functions #' @aliases map_keys map_keys,Column-method -#' @export #' @note map_keys since 2.3.0 setMethod("map_keys", signature(x = "Column"), @@ -3164,7 +3004,6 @@ setMethod("map_keys", #' #' @rdname column_collection_functions #' @aliases map_values map_values,Column-method -#' @export #' @note map_values since 2.3.0 setMethod("map_values", signature(x = "Column"), @@ -3178,7 +3017,6 @@ setMethod("map_values", #' #' @rdname column_collection_functions #' @aliases explode explode,Column-method -#' @export #' @note explode since 1.5.0 setMethod("explode", signature(x = "Column"), @@ -3192,7 +3030,6 @@ setMethod("explode", #' #' @rdname column_collection_functions #' @aliases size size,Column-method -#' @export #' @note size since 1.5.0 setMethod("size", signature(x = "Column"), @@ -3210,7 +3047,6 @@ setMethod("size", #' TRUE, sorting is in ascending order. #' FALSE, sorting is in descending order. #' @aliases sort_array sort_array,Column-method -#' @export #' @note sort_array since 1.6.0 setMethod("sort_array", signature(x = "Column"), @@ -3225,7 +3061,6 @@ setMethod("sort_array", #' #' @rdname column_collection_functions #' @aliases posexplode posexplode,Column-method -#' @export #' @note posexplode since 2.1.0 setMethod("posexplode", signature(x = "Column"), @@ -3240,7 +3075,6 @@ setMethod("posexplode", #' #' @rdname column_nonaggregate_functions #' @aliases create_array create_array,Column-method -#' @export #' @note create_array since 2.3.0 setMethod("create_array", signature(x = "Column"), @@ -3261,7 +3095,6 @@ setMethod("create_array", #' #' @rdname column_nonaggregate_functions #' @aliases create_map create_map,Column-method -#' @export #' @note create_map since 2.3.0 setMethod("create_map", signature(x = "Column"), @@ -3279,7 +3112,6 @@ setMethod("create_map", #' #' @rdname column_aggregate_functions #' @aliases collect_list collect_list,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3299,7 +3131,6 @@ setMethod("collect_list", #' #' @rdname column_aggregate_functions #' @aliases collect_set collect_set,Column-method -#' @export #' @note collect_set since 2.3.0 setMethod("collect_set", signature(x = "Column"), @@ -3314,7 +3145,6 @@ setMethod("collect_set", #' #' @rdname column_string_functions #' @aliases split_string split_string,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3337,7 +3167,6 @@ setMethod("split_string", #' @param n number of repetitions. #' @rdname column_string_functions #' @aliases repeat_string repeat_string,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3360,7 +3189,6 @@ setMethod("repeat_string", #' #' @rdname column_collection_functions #' @aliases explode_outer explode_outer,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3385,7 +3213,6 @@ setMethod("explode_outer", #' #' @rdname column_collection_functions #' @aliases posexplode_outer posexplode_outer,Column-method -#' @export #' @note posexplode_outer since 2.3.0 setMethod("posexplode_outer", signature(x = "Column"), @@ -3406,7 +3233,6 @@ setMethod("posexplode_outer", #' @name not #' @aliases not,Column-method #' @family non-aggregate functions -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(data.frame( @@ -3434,7 +3260,6 @@ setMethod("not", #' #' @rdname column_aggregate_functions #' @aliases grouping_bit grouping_bit,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3467,7 +3292,6 @@ setMethod("grouping_bit", #' #' @rdname column_aggregate_functions #' @aliases grouping_id grouping_id,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3502,7 +3326,6 @@ setMethod("grouping_id", #' #' @rdname column_nonaggregate_functions #' @aliases input_file_name input_file_name,missing-method -#' @export #' @examples #' #' \dontrun{ @@ -3520,7 +3343,6 @@ setMethod("input_file_name", signature("missing"), #' #' @rdname column_datetime_functions #' @aliases trunc trunc,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3540,7 +3362,6 @@ setMethod("trunc", #' #' @rdname column_datetime_functions #' @aliases date_trunc date_trunc,character,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3559,7 +3380,6 @@ setMethod("date_trunc", #' #' @rdname column_datetime_functions #' @aliases current_date current_date,missing-method -#' @export #' @examples #' \dontrun{ #' head(select(df, current_date(), current_timestamp()))} @@ -3576,7 +3396,6 @@ setMethod("current_date", #' #' @rdname column_datetime_functions #' @aliases current_timestamp current_timestamp,missing-method -#' @export #' @note current_timestamp since 2.3.0 setMethod("current_timestamp", signature("missing"), diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index e0dde3339fabc..6fba4b6c761dd 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -19,7 +19,6 @@ # @rdname aggregateRDD # @seealso reduce -# @export setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) @@ -27,21 +26,17 @@ setGeneric("cacheRDD", function(x) { standardGeneric("cacheRDD") }) # @rdname coalesce # @seealso repartition -# @export setGeneric("coalesceRDD", function(x, numPartitions, ...) { standardGeneric("coalesceRDD") }) # @rdname checkpoint-methods -# @export setGeneric("checkpointRDD", function(x) { standardGeneric("checkpointRDD") }) setGeneric("collectRDD", function(x, ...) { standardGeneric("collectRDD") }) # @rdname collect-methods -# @export setGeneric("collectAsMap", function(x) { standardGeneric("collectAsMap") }) # @rdname collect-methods -# @export setGeneric("collectPartition", function(x, partitionId) { standardGeneric("collectPartition") @@ -52,19 +47,15 @@ setGeneric("countRDD", function(x) { standardGeneric("countRDD") }) setGeneric("lengthRDD", function(x) { standardGeneric("lengthRDD") }) # @rdname countByValue -# @export setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) # @rdname crosstab -# @export setGeneric("crosstab", function(x, col1, col2) { standardGeneric("crosstab") }) # @rdname freqItems -# @export setGeneric("freqItems", function(x, cols, support = 0.01) { standardGeneric("freqItems") }) # @rdname approxQuantile -# @export setGeneric("approxQuantile", function(x, cols, probabilities, relativeError) { standardGeneric("approxQuantile") @@ -73,18 +64,15 @@ setGeneric("approxQuantile", setGeneric("distinctRDD", function(x, numPartitions = 1) { standardGeneric("distinctRDD") }) # @rdname filterRDD -# @export setGeneric("filterRDD", function(x, f) { standardGeneric("filterRDD") }) setGeneric("firstRDD", function(x, ...) { standardGeneric("firstRDD") }) # @rdname flatMap -# @export setGeneric("flatMap", function(X, FUN) { standardGeneric("flatMap") }) # @rdname fold # @seealso reduce -# @export setGeneric("fold", function(x, zeroValue, op) { standardGeneric("fold") }) setGeneric("foreach", function(x, func) { standardGeneric("foreach") }) @@ -95,17 +83,14 @@ setGeneric("foreachPartition", function(x, func) { standardGeneric("foreachParti setGeneric("getJRDD", function(rdd, ...) { standardGeneric("getJRDD") }) # @rdname glom -# @export setGeneric("glom", function(x) { standardGeneric("glom") }) # @rdname histogram -# @export setGeneric("histogram", function(df, col, nbins=10) { standardGeneric("histogram") }) setGeneric("joinRDD", function(x, y, ...) { standardGeneric("joinRDD") }) # @rdname keyBy -# @export setGeneric("keyBy", function(x, func) { standardGeneric("keyBy") }) setGeneric("lapplyPartition", function(X, FUN) { standardGeneric("lapplyPartition") }) @@ -123,47 +108,37 @@ setGeneric("mapPartitionsWithIndex", function(X, FUN) { standardGeneric("mapPartitionsWithIndex") }) # @rdname maximum -# @export setGeneric("maximum", function(x) { standardGeneric("maximum") }) # @rdname minimum -# @export setGeneric("minimum", function(x) { standardGeneric("minimum") }) # @rdname sumRDD -# @export setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) # @rdname name -# @export setGeneric("name", function(x) { standardGeneric("name") }) # @rdname getNumPartitionsRDD -# @export setGeneric("getNumPartitionsRDD", function(x) { standardGeneric("getNumPartitionsRDD") }) # @rdname getNumPartitions -# @export setGeneric("numPartitions", function(x) { standardGeneric("numPartitions") }) setGeneric("persistRDD", function(x, newLevel) { standardGeneric("persistRDD") }) # @rdname pipeRDD -# @export setGeneric("pipeRDD", function(x, command, env = list()) { standardGeneric("pipeRDD")}) # @rdname pivot -# @export setGeneric("pivot", function(x, colname, values = list()) { standardGeneric("pivot") }) # @rdname reduce -# @export setGeneric("reduce", function(x, func) { standardGeneric("reduce") }) setGeneric("repartitionRDD", function(x, ...) { standardGeneric("repartitionRDD") }) # @rdname sampleRDD -# @export setGeneric("sampleRDD", function(x, withReplacement, fraction, seed) { standardGeneric("sampleRDD") @@ -171,21 +146,17 @@ setGeneric("sampleRDD", # @rdname saveAsObjectFile # @seealso objectFile -# @export setGeneric("saveAsObjectFile", function(x, path) { standardGeneric("saveAsObjectFile") }) # @rdname saveAsTextFile -# @export setGeneric("saveAsTextFile", function(x, path) { standardGeneric("saveAsTextFile") }) # @rdname setName -# @export setGeneric("setName", function(x, name) { standardGeneric("setName") }) setGeneric("showRDD", function(object, ...) { standardGeneric("showRDD") }) # @rdname sortBy -# @export setGeneric("sortBy", function(x, func, ascending = TRUE, numPartitions = 1) { standardGeneric("sortBy") @@ -194,88 +165,71 @@ setGeneric("sortBy", setGeneric("takeRDD", function(x, num) { standardGeneric("takeRDD") }) # @rdname takeOrdered -# @export setGeneric("takeOrdered", function(x, num) { standardGeneric("takeOrdered") }) # @rdname takeSample -# @export setGeneric("takeSample", function(x, withReplacement, num, seed) { standardGeneric("takeSample") }) # @rdname top -# @export setGeneric("top", function(x, num) { standardGeneric("top") }) # @rdname unionRDD -# @export setGeneric("unionRDD", function(x, y) { standardGeneric("unionRDD") }) setGeneric("unpersistRDD", function(x, ...) { standardGeneric("unpersistRDD") }) # @rdname zipRDD -# @export setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") }) # @rdname zipRDD -# @export setGeneric("zipPartitions", function(..., func) { standardGeneric("zipPartitions") }, signature = "...") # @rdname zipWithIndex # @seealso zipWithUniqueId -# @export setGeneric("zipWithIndex", function(x) { standardGeneric("zipWithIndex") }) # @rdname zipWithUniqueId # @seealso zipWithIndex -# @export setGeneric("zipWithUniqueId", function(x) { standardGeneric("zipWithUniqueId") }) ############ Binary Functions ############# # @rdname cartesian -# @export setGeneric("cartesian", function(x, other) { standardGeneric("cartesian") }) # @rdname countByKey -# @export setGeneric("countByKey", function(x) { standardGeneric("countByKey") }) # @rdname flatMapValues -# @export setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") }) # @rdname intersection -# @export setGeneric("intersection", function(x, other, numPartitions = 1) { standardGeneric("intersection") }) # @rdname keys -# @export setGeneric("keys", function(x) { standardGeneric("keys") }) # @rdname lookup -# @export setGeneric("lookup", function(x, key) { standardGeneric("lookup") }) # @rdname mapValues -# @export setGeneric("mapValues", function(X, FUN) { standardGeneric("mapValues") }) # @rdname sampleByKey -# @export setGeneric("sampleByKey", function(x, withReplacement, fractions, seed) { standardGeneric("sampleByKey") }) # @rdname values -# @export setGeneric("values", function(x) { standardGeneric("values") }) @@ -283,14 +237,12 @@ setGeneric("values", function(x) { standardGeneric("values") }) # @rdname aggregateByKey # @seealso foldByKey, combineByKey -# @export setGeneric("aggregateByKey", function(x, zeroValue, seqOp, combOp, numPartitions) { standardGeneric("aggregateByKey") }) # @rdname cogroup -# @export setGeneric("cogroup", function(..., numPartitions) { standardGeneric("cogroup") @@ -299,7 +251,6 @@ setGeneric("cogroup", # @rdname combineByKey # @seealso groupByKey, reduceByKey -# @export setGeneric("combineByKey", function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) { standardGeneric("combineByKey") @@ -307,64 +258,53 @@ setGeneric("combineByKey", # @rdname foldByKey # @seealso aggregateByKey, combineByKey -# @export setGeneric("foldByKey", function(x, zeroValue, func, numPartitions) { standardGeneric("foldByKey") }) # @rdname join-methods -# @export setGeneric("fullOuterJoin", function(x, y, numPartitions) { standardGeneric("fullOuterJoin") }) # @rdname groupByKey # @seealso reduceByKey -# @export setGeneric("groupByKey", function(x, numPartitions) { standardGeneric("groupByKey") }) # @rdname join-methods -# @export setGeneric("join", function(x, y, ...) { standardGeneric("join") }) # @rdname join-methods -# @export setGeneric("leftOuterJoin", function(x, y, numPartitions) { standardGeneric("leftOuterJoin") }) setGeneric("partitionByRDD", function(x, ...) { standardGeneric("partitionByRDD") }) # @rdname reduceByKey # @seealso groupByKey -# @export setGeneric("reduceByKey", function(x, combineFunc, numPartitions) { standardGeneric("reduceByKey")}) # @rdname reduceByKeyLocally # @seealso reduceByKey -# @export setGeneric("reduceByKeyLocally", function(x, combineFunc) { standardGeneric("reduceByKeyLocally") }) # @rdname join-methods -# @export setGeneric("rightOuterJoin", function(x, y, numPartitions) { standardGeneric("rightOuterJoin") }) # @rdname sortByKey -# @export setGeneric("sortByKey", function(x, ascending = TRUE, numPartitions = 1) { standardGeneric("sortByKey") }) # @rdname subtract -# @export setGeneric("subtract", function(x, other, numPartitions = 1) { standardGeneric("subtract") }) # @rdname subtractByKey -# @export setGeneric("subtractByKey", function(x, other, numPartitions = 1) { standardGeneric("subtractByKey") @@ -374,7 +314,6 @@ setGeneric("subtractByKey", ################### Broadcast Variable Methods ################# # @rdname broadcast -# @export setGeneric("value", function(bcast) { standardGeneric("value") }) @@ -384,7 +323,6 @@ setGeneric("value", function(bcast) { standardGeneric("value") }) #' @param ... further arguments to be passed to or from other methods. #' @return A SparkDataFrame. #' @rdname summarize -#' @export setGeneric("agg", function(x, ...) { standardGeneric("agg") }) #' alias @@ -399,11 +337,9 @@ setGeneric("agg", function(x, ...) { standardGeneric("agg") }) NULL #' @rdname arrange -#' @export setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) #' @rdname as.data.frame -#' @export setGeneric("as.data.frame", function(x, row.names = NULL, optional = FALSE, ...) { standardGeneric("as.data.frame") @@ -411,52 +347,41 @@ setGeneric("as.data.frame", # Do not document the generic because of signature changes across R versions #' @noRd -#' @export setGeneric("attach") #' @rdname cache -#' @export setGeneric("cache", function(x) { standardGeneric("cache") }) #' @rdname checkpoint -#' @export setGeneric("checkpoint", function(x, eager = TRUE) { standardGeneric("checkpoint") }) #' @rdname coalesce #' @param x a SparkDataFrame. #' @param ... additional argument(s). -#' @export setGeneric("coalesce", function(x, ...) { standardGeneric("coalesce") }) #' @rdname collect -#' @export setGeneric("collect", function(x, ...) { standardGeneric("collect") }) #' @param do.NULL currently not used. #' @param prefix currently not used. #' @rdname columns -#' @export setGeneric("colnames", function(x, do.NULL = TRUE, prefix = "col") { standardGeneric("colnames") }) #' @rdname columns -#' @export setGeneric("colnames<-", function(x, value) { standardGeneric("colnames<-") }) #' @rdname coltypes -#' @export setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) #' @rdname coltypes -#' @export setGeneric("coltypes<-", function(x, value) { standardGeneric("coltypes<-") }) #' @rdname columns -#' @export setGeneric("columns", function(x) {standardGeneric("columns") }) #' @param x a GroupedData or Column. #' @rdname count -#' @export setGeneric("count", function(x) { standardGeneric("count") }) #' @rdname cov @@ -464,7 +389,6 @@ setGeneric("count", function(x) { standardGeneric("count") }) #' @param ... additional argument(s). If \code{x} is a Column, a Column #' should be provided. If \code{x} is a SparkDataFrame, two column names should #' be provided. -#' @export setGeneric("cov", function(x, ...) {standardGeneric("cov") }) #' @rdname corr @@ -472,294 +396,229 @@ setGeneric("cov", function(x, ...) {standardGeneric("cov") }) #' @param ... additional argument(s). If \code{x} is a Column, a Column #' should be provided. If \code{x} is a SparkDataFrame, two column names should #' be provided. -#' @export setGeneric("corr", function(x, ...) {standardGeneric("corr") }) #' @rdname cov -#' @export setGeneric("covar_samp", function(col1, col2) {standardGeneric("covar_samp") }) #' @rdname cov -#' @export setGeneric("covar_pop", function(col1, col2) {standardGeneric("covar_pop") }) #' @rdname createOrReplaceTempView -#' @export setGeneric("createOrReplaceTempView", function(x, viewName) { standardGeneric("createOrReplaceTempView") }) # @rdname crossJoin -# @export setGeneric("crossJoin", function(x, y) { standardGeneric("crossJoin") }) #' @rdname cube -#' @export setGeneric("cube", function(x, ...) { standardGeneric("cube") }) #' @rdname dapply -#' @export setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") }) #' @rdname dapplyCollect -#' @export setGeneric("dapplyCollect", function(x, func) { standardGeneric("dapplyCollect") }) #' @param x a SparkDataFrame or GroupedData. #' @param ... additional argument(s) passed to the method. #' @rdname gapply -#' @export setGeneric("gapply", function(x, ...) { standardGeneric("gapply") }) #' @param x a SparkDataFrame or GroupedData. #' @param ... additional argument(s) passed to the method. #' @rdname gapplyCollect -#' @export setGeneric("gapplyCollect", function(x, ...) { standardGeneric("gapplyCollect") }) # @rdname getNumPartitions -# @export setGeneric("getNumPartitions", function(x) { standardGeneric("getNumPartitions") }) #' @rdname describe -#' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) #' @rdname distinct -#' @export setGeneric("distinct", function(x) { standardGeneric("distinct") }) #' @rdname drop -#' @export setGeneric("drop", function(x, ...) { standardGeneric("drop") }) #' @rdname dropDuplicates -#' @export setGeneric("dropDuplicates", function(x, ...) { standardGeneric("dropDuplicates") }) #' @rdname nafunctions -#' @export setGeneric("dropna", function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { standardGeneric("dropna") }) #' @rdname nafunctions -#' @export setGeneric("na.omit", function(object, ...) { standardGeneric("na.omit") }) #' @rdname dtypes -#' @export setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) #' @rdname explain -#' @export #' @param x a SparkDataFrame or a StreamingQuery. #' @param extended Logical. If extended is FALSE, prints only the physical plan. #' @param ... further arguments to be passed to or from other methods. setGeneric("explain", function(x, ...) { standardGeneric("explain") }) #' @rdname except -#' @export setGeneric("except", function(x, y) { standardGeneric("except") }) #' @rdname nafunctions -#' @export setGeneric("fillna", function(x, value, cols = NULL) { standardGeneric("fillna") }) #' @rdname filter -#' @export setGeneric("filter", function(x, condition) { standardGeneric("filter") }) #' @rdname first -#' @export setGeneric("first", function(x, ...) { standardGeneric("first") }) #' @rdname groupBy -#' @export setGeneric("group_by", function(x, ...) { standardGeneric("group_by") }) #' @rdname groupBy -#' @export setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) #' @rdname hint -#' @export setGeneric("hint", function(x, name, ...) { standardGeneric("hint") }) #' @rdname insertInto -#' @export setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertInto") }) #' @rdname intersect -#' @export setGeneric("intersect", function(x, y) { standardGeneric("intersect") }) #' @rdname isLocal -#' @export setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) #' @rdname isStreaming -#' @export setGeneric("isStreaming", function(x) { standardGeneric("isStreaming") }) #' @rdname limit -#' @export setGeneric("limit", function(x, num) {standardGeneric("limit") }) #' @rdname localCheckpoint -#' @export setGeneric("localCheckpoint", function(x, eager = TRUE) { standardGeneric("localCheckpoint") }) #' @rdname merge -#' @export setGeneric("merge") #' @rdname mutate -#' @export setGeneric("mutate", function(.data, ...) {standardGeneric("mutate") }) #' @rdname orderBy -#' @export setGeneric("orderBy", function(x, col, ...) { standardGeneric("orderBy") }) #' @rdname persist -#' @export setGeneric("persist", function(x, newLevel) { standardGeneric("persist") }) #' @rdname printSchema -#' @export setGeneric("printSchema", function(x) { standardGeneric("printSchema") }) #' @rdname registerTempTable-deprecated -#' @export setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") }) #' @rdname rename -#' @export setGeneric("rename", function(x, ...) { standardGeneric("rename") }) #' @rdname repartition -#' @export setGeneric("repartition", function(x, ...) { standardGeneric("repartition") }) #' @rdname sample -#' @export setGeneric("sample", function(x, withReplacement = FALSE, fraction, seed) { standardGeneric("sample") }) #' @rdname rollup -#' @export setGeneric("rollup", function(x, ...) { standardGeneric("rollup") }) #' @rdname sample -#' @export setGeneric("sample_frac", function(x, withReplacement = FALSE, fraction, seed) { standardGeneric("sample_frac") }) #' @rdname sampleBy -#' @export setGeneric("sampleBy", function(x, col, fractions, seed) { standardGeneric("sampleBy") }) #' @rdname saveAsTable -#' @export setGeneric("saveAsTable", function(df, tableName, source = NULL, mode = "error", ...) { standardGeneric("saveAsTable") }) -#' @export setGeneric("str") #' @rdname take -#' @export setGeneric("take", function(x, num) { standardGeneric("take") }) #' @rdname mutate -#' @export setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") }) #' @rdname write.df -#' @export setGeneric("write.df", function(df, path = NULL, source = NULL, mode = "error", ...) { standardGeneric("write.df") }) #' @rdname write.df -#' @export setGeneric("saveDF", function(df, path, source = NULL, mode = "error", ...) { standardGeneric("saveDF") }) #' @rdname write.jdbc -#' @export setGeneric("write.jdbc", function(x, url, tableName, mode = "error", ...) { standardGeneric("write.jdbc") }) #' @rdname write.json -#' @export setGeneric("write.json", function(x, path, ...) { standardGeneric("write.json") }) #' @rdname write.orc -#' @export setGeneric("write.orc", function(x, path, ...) { standardGeneric("write.orc") }) #' @rdname write.parquet -#' @export setGeneric("write.parquet", function(x, path, ...) { standardGeneric("write.parquet") }) #' @rdname write.parquet -#' @export setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) #' @rdname write.stream -#' @export setGeneric("write.stream", function(df, source = NULL, outputMode = NULL, ...) { standardGeneric("write.stream") }) #' @rdname write.text -#' @export setGeneric("write.text", function(x, path, ...) { standardGeneric("write.text") }) #' @rdname schema -#' @export setGeneric("schema", function(x) { standardGeneric("schema") }) #' @rdname select -#' @export setGeneric("select", function(x, col, ...) { standardGeneric("select") }) #' @rdname selectExpr -#' @export setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") }) #' @rdname showDF -#' @export setGeneric("showDF", function(x, ...) { standardGeneric("showDF") }) # @rdname storageLevel -# @export setGeneric("storageLevel", function(x) { standardGeneric("storageLevel") }) #' @rdname subset -#' @export setGeneric("subset", function(x, ...) { standardGeneric("subset") }) #' @rdname summarize -#' @export setGeneric("summarize", function(x, ...) { standardGeneric("summarize") }) #' @rdname summary -#' @export setGeneric("summary", function(object, ...) { standardGeneric("summary") }) setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) @@ -767,830 +626,660 @@ setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) #' @rdname union -#' @export setGeneric("union", function(x, y) { standardGeneric("union") }) #' @rdname union -#' @export setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) #' @rdname unionByName -#' @export setGeneric("unionByName", function(x, y) { standardGeneric("unionByName") }) #' @rdname unpersist -#' @export setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) #' @rdname filter -#' @export setGeneric("where", function(x, condition) { standardGeneric("where") }) #' @rdname with -#' @export setGeneric("with") #' @rdname withColumn -#' @export setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn") }) #' @rdname rename -#' @export setGeneric("withColumnRenamed", function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) #' @rdname withWatermark -#' @export setGeneric("withWatermark", function(x, eventTime, delayThreshold) { standardGeneric("withWatermark") }) #' @rdname write.df -#' @export setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.df") }) #' @rdname randomSplit -#' @export setGeneric("randomSplit", function(x, weights, seed) { standardGeneric("randomSplit") }) #' @rdname broadcast -#' @export setGeneric("broadcast", function(x) { standardGeneric("broadcast") }) ###################### Column Methods ########################## #' @rdname columnfunctions -#' @export setGeneric("asc", function(x) { standardGeneric("asc") }) #' @rdname between -#' @export setGeneric("between", function(x, bounds) { standardGeneric("between") }) #' @rdname cast -#' @export setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) #' @rdname columnfunctions #' @param x a Column object. #' @param ... additional argument(s). -#' @export setGeneric("contains", function(x, ...) { standardGeneric("contains") }) #' @rdname columnfunctions -#' @export setGeneric("desc", function(x) { standardGeneric("desc") }) #' @rdname endsWith -#' @export setGeneric("endsWith", function(x, suffix) { standardGeneric("endsWith") }) #' @rdname columnfunctions -#' @export setGeneric("getField", function(x, ...) { standardGeneric("getField") }) #' @rdname columnfunctions -#' @export setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) #' @rdname columnfunctions -#' @export setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) #' @rdname columnfunctions -#' @export setGeneric("isNull", function(x) { standardGeneric("isNull") }) #' @rdname columnfunctions -#' @export setGeneric("isNotNull", function(x) { standardGeneric("isNotNull") }) #' @rdname columnfunctions -#' @export setGeneric("like", function(x, ...) { standardGeneric("like") }) #' @rdname columnfunctions -#' @export setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) #' @rdname startsWith -#' @export setGeneric("startsWith", function(x, prefix) { standardGeneric("startsWith") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("when", function(condition, value) { standardGeneric("when") }) #' @rdname otherwise -#' @export setGeneric("otherwise", function(x, value) { standardGeneric("otherwise") }) #' @rdname over -#' @export setGeneric("over", function(x, window) { standardGeneric("over") }) #' @rdname eq_null_safe -#' @export setGeneric("%<=>%", function(x, value) { standardGeneric("%<=>%") }) ###################### WindowSpec Methods ########################## #' @rdname partitionBy -#' @export setGeneric("partitionBy", function(x, ...) { standardGeneric("partitionBy") }) #' @rdname rowsBetween -#' @export setGeneric("rowsBetween", function(x, start, end) { standardGeneric("rowsBetween") }) #' @rdname rangeBetween -#' @export setGeneric("rangeBetween", function(x, start, end) { standardGeneric("rangeBetween") }) #' @rdname windowPartitionBy -#' @export setGeneric("windowPartitionBy", function(col, ...) { standardGeneric("windowPartitionBy") }) #' @rdname windowOrderBy -#' @export setGeneric("windowOrderBy", function(col, ...) { standardGeneric("windowOrderBy") }) ###################### Expression Function Methods ########################## #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) #' @param x Column to compute on or a GroupedData object. #' @param ... additional argument(s) when \code{x} is a GroupedData object. #' @rdname avg -#' @export setGeneric("avg", function(x, ...) { standardGeneric("avg") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("base64", function(x) { standardGeneric("base64") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("bin", function(x) { standardGeneric("bin") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("bround", function(x, ...) { standardGeneric("bround") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("ceil", function(x) { standardGeneric("ceil") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("collect_list", function(x) { standardGeneric("collect_list") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("collect_set", function(x) { standardGeneric("collect_set") }) #' @rdname column -#' @export setGeneric("column", function(x) { standardGeneric("column") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("concat", function(x, ...) { standardGeneric("concat") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("concat_ws", function(sep, x, ...) { standardGeneric("concat_ws") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("conv", function(x, fromBase, toBase) { standardGeneric("conv") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("crc32", function(x) { standardGeneric("crc32") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("create_array", function(x, ...) { standardGeneric("create_array") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("create_map", function(x, ...) { standardGeneric("create_map") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("hash", function(x, ...) { standardGeneric("hash") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("cume_dist", function(x = "missing") { standardGeneric("cume_dist") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("current_date", function(x = "missing") { standardGeneric("current_date") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("current_timestamp", function(x = "missing") { standardGeneric("current_timestamp") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("datediff", function(y, x) { standardGeneric("datediff") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("date_add", function(y, x) { standardGeneric("date_add") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("date_format", function(y, x) { standardGeneric("date_format") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("date_sub", function(y, x) { standardGeneric("date_sub") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("date_trunc", function(format, x) { standardGeneric("date_trunc") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("dayofweek", function(x) { standardGeneric("dayofweek") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("decode", function(x, charset) { standardGeneric("decode") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("encode", function(x, charset) { standardGeneric("encode") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("explode", function(x) { standardGeneric("explode") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("expr", function(x) { standardGeneric("expr") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("format_number", function(y, x) { standardGeneric("format_number") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("format_string", function(format, x, ...) { standardGeneric("format_string") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("from_json", function(x, schema, ...) { standardGeneric("from_json") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("greatest", function(x, ...) { standardGeneric("greatest") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("grouping_bit", function(x) { standardGeneric("grouping_bit") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("grouping_id", function(x, ...) { standardGeneric("grouping_id") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("hex", function(x) { standardGeneric("hex") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("hour", function(x) { standardGeneric("hour") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("initcap", function(x) { standardGeneric("initcap") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("input_file_name", function(x = "missing") { standardGeneric("input_file_name") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("instr", function(y, x) { standardGeneric("instr") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("isnan", function(x) { standardGeneric("isnan") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("lag", function(x, ...) { standardGeneric("lag") }) #' @rdname last -#' @export setGeneric("last", function(x, ...) { standardGeneric("last") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("last_day", function(x) { standardGeneric("last_day") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("lead", function(x, offset, defaultValue = NULL) { standardGeneric("lead") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("least", function(x, ...) { standardGeneric("least") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("levenshtein", function(y, x) { standardGeneric("levenshtein") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("lit", function(x) { standardGeneric("lit") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("locate", function(substr, str, ...) { standardGeneric("locate") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("lower", function(x) { standardGeneric("lower") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("map_keys", function(x) { standardGeneric("map_keys") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("map_values", function(x) { standardGeneric("map_values") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("md5", function(x) { standardGeneric("md5") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("minute", function(x) { standardGeneric("minute") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("monotonically_increasing_id", function(x = "missing") { standardGeneric("monotonically_increasing_id") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("month", function(x) { standardGeneric("month") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("months_between", function(y, x) { standardGeneric("months_between") }) #' @rdname count -#' @export setGeneric("n", function(x) { standardGeneric("n") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("negate", function(x) { standardGeneric("negate") }) #' @rdname not -#' @export setGeneric("not", function(x) { standardGeneric("not") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("ntile", function(x) { standardGeneric("ntile") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("percent_rank", function(x = "missing") { standardGeneric("percent_rank") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("posexplode", function(x) { standardGeneric("posexplode") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("posexplode_outer", function(x) { standardGeneric("posexplode_outer") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("quarter", function(x) { standardGeneric("quarter") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("rand", function(seed) { standardGeneric("rand") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("randn", function(seed) { standardGeneric("randn") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("rank", function(x, ...) { standardGeneric("rank") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("regexp_extract", function(x, pattern, idx) { standardGeneric("regexp_extract") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("regexp_replace", function(x, pattern, replacement) { standardGeneric("regexp_replace") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("reverse", function(x) { standardGeneric("reverse") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("rint", function(x) { standardGeneric("rint") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("row_number", function(x = "missing") { standardGeneric("row_number") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("rtrim", function(x, trimString) { standardGeneric("rtrim") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("sd", function(x, na.rm = FALSE) { standardGeneric("sd") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("second", function(x) { standardGeneric("second") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("sha1", function(x) { standardGeneric("sha1") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("sha2", function(y, x) { standardGeneric("sha2") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("shiftLeft", function(y, x) { standardGeneric("shiftLeft") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("shiftRight", function(y, x) { standardGeneric("shiftRight") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUnsigned") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("signum", function(x) { standardGeneric("signum") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("size", function(x) { standardGeneric("size") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("skewness", function(x) { standardGeneric("skewness") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("split_string", function(x, pattern) { standardGeneric("split_string") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("soundex", function(x) { standardGeneric("soundex") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("spark_partition_id", function(x = "missing") { standardGeneric("spark_partition_id") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("stddev", function(x) { standardGeneric("stddev") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("struct", function(x, ...) { standardGeneric("struct") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("to_date", function(x, format) { standardGeneric("to_date") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("to_json", function(x, ...) { standardGeneric("to_json") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("to_timestamp", function(x, format) { standardGeneric("to_timestamp") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("to_utc_timestamp", function(y, x) { standardGeneric("to_utc_timestamp") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("translate", function(x, matchingString, replaceString) { standardGeneric("translate") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("trim", function(x, trimString) { standardGeneric("trim") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("unbase64", function(x) { standardGeneric("unbase64") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("unhex", function(x) { standardGeneric("unhex") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timestamp") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("upper", function(x) { standardGeneric("upper") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("var", function(x, y = NULL, na.rm = FALSE, use) { standardGeneric("var") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("variance", function(x) { standardGeneric("variance") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("var_pop", function(x) { standardGeneric("var_pop") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("var_samp", function(x) { standardGeneric("var_samp") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("window", function(x, ...) { standardGeneric("window") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("year", function(x) { standardGeneric("year") }) @@ -1598,142 +1287,110 @@ setGeneric("year", function(x) { standardGeneric("year") }) ###################### Spark.ML Methods ########################## #' @rdname fitted -#' @export setGeneric("fitted") # Do not carry stats::glm usage and param here, and do not document the generic -#' @export #' @noRd setGeneric("glm") #' @param object a fitted ML model object. #' @param ... additional argument(s) passed to the method. #' @rdname predict -#' @export setGeneric("predict", function(object, ...) { standardGeneric("predict") }) #' @rdname rbind -#' @export setGeneric("rbind", signature = "...") #' @rdname spark.als -#' @export setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") }) #' @rdname spark.bisectingKmeans -#' @export setGeneric("spark.bisectingKmeans", function(data, formula, ...) { standardGeneric("spark.bisectingKmeans") }) #' @rdname spark.gaussianMixture -#' @export setGeneric("spark.gaussianMixture", function(data, formula, ...) { standardGeneric("spark.gaussianMixture") }) #' @rdname spark.gbt -#' @export setGeneric("spark.gbt", function(data, formula, ...) { standardGeneric("spark.gbt") }) #' @rdname spark.glm -#' @export setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") }) #' @rdname spark.isoreg -#' @export setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") }) #' @rdname spark.kmeans -#' @export setGeneric("spark.kmeans", function(data, formula, ...) { standardGeneric("spark.kmeans") }) #' @rdname spark.kstest -#' @export setGeneric("spark.kstest", function(data, ...) { standardGeneric("spark.kstest") }) #' @rdname spark.lda -#' @export setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") }) #' @rdname spark.logit -#' @export setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.logit") }) #' @rdname spark.mlp -#' @export setGeneric("spark.mlp", function(data, formula, ...) { standardGeneric("spark.mlp") }) #' @rdname spark.naiveBayes -#' @export setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") }) #' @rdname spark.decisionTree -#' @export setGeneric("spark.decisionTree", function(data, formula, ...) { standardGeneric("spark.decisionTree") }) #' @rdname spark.randomForest -#' @export setGeneric("spark.randomForest", function(data, formula, ...) { standardGeneric("spark.randomForest") }) #' @rdname spark.survreg -#' @export setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") }) #' @rdname spark.svmLinear -#' @export setGeneric("spark.svmLinear", function(data, formula, ...) { standardGeneric("spark.svmLinear") }) #' @rdname spark.lda -#' @export setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark.posterior") }) #' @rdname spark.lda -#' @export setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") }) #' @rdname spark.fpGrowth -#' @export setGeneric("spark.fpGrowth", function(data, ...) { standardGeneric("spark.fpGrowth") }) #' @rdname spark.fpGrowth -#' @export setGeneric("spark.freqItemsets", function(object) { standardGeneric("spark.freqItemsets") }) #' @rdname spark.fpGrowth -#' @export setGeneric("spark.associationRules", function(object) { standardGeneric("spark.associationRules") }) #' @param object a fitted ML model object. #' @param path the directory where the model is saved. #' @param ... additional argument(s) passed to the method. #' @rdname write.ml -#' @export setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") }) ###################### Streaming Methods ########################## #' @rdname awaitTermination -#' @export setGeneric("awaitTermination", function(x, timeout = NULL) { standardGeneric("awaitTermination") }) #' @rdname isActive -#' @export setGeneric("isActive", function(x) { standardGeneric("isActive") }) #' @rdname lastProgress -#' @export setGeneric("lastProgress", function(x) { standardGeneric("lastProgress") }) #' @rdname queryName -#' @export setGeneric("queryName", function(x) { standardGeneric("queryName") }) #' @rdname status -#' @export setGeneric("status", function(x) { standardGeneric("status") }) #' @rdname stopQuery -#' @export setGeneric("stopQuery", function(x) { standardGeneric("stopQuery") }) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 54ef9f07d6fae..f751b952f3915 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -30,7 +30,6 @@ setOldClass("jobj") #' @seealso groupBy #' #' @param sgd A Java object reference to the backing Scala GroupedData -#' @export #' @note GroupedData since 1.4.0 setClass("GroupedData", slots = list(sgd = "jobj")) @@ -48,7 +47,6 @@ groupedData <- function(sgd) { #' @rdname show #' @aliases show,GroupedData-method -#' @export #' @note show(GroupedData) since 1.4.0 setMethod("show", "GroupedData", function(object) { @@ -63,7 +61,6 @@ setMethod("show", "GroupedData", #' @return A SparkDataFrame. #' @rdname count #' @aliases count,GroupedData-method -#' @export #' @examples #' \dontrun{ #' count(groupBy(df, "name")) @@ -87,7 +84,6 @@ setMethod("count", #' @aliases agg,GroupedData-method #' @name agg #' @family agg_funcs -#' @export #' @examples #' \dontrun{ #' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' @@ -150,7 +146,6 @@ methods <- c("avg", "max", "mean", "min", "sum") #' @rdname pivot #' @aliases pivot,GroupedData,character-method #' @name pivot -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(data.frame( @@ -202,7 +197,6 @@ createMethods() #' @rdname gapply #' @aliases gapply,GroupedData-method #' @name gapply -#' @export #' @note gapply(GroupedData) since 2.0.0 setMethod("gapply", signature(x = "GroupedData"), @@ -216,7 +210,6 @@ setMethod("gapply", #' @rdname gapplyCollect #' @aliases gapplyCollect,GroupedData-method #' @name gapplyCollect -#' @export #' @note gapplyCollect(GroupedData) since 2.0.0 setMethod("gapplyCollect", signature(x = "GroupedData"), diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R index 04dc7562e5346..6d1edf6b6f3cf 100644 --- a/R/pkg/R/install.R +++ b/R/pkg/R/install.R @@ -58,7 +58,6 @@ #' @rdname install.spark #' @name install.spark #' @aliases install.spark -#' @export #' @examples #'\dontrun{ #' install.spark() diff --git a/R/pkg/R/jvm.R b/R/pkg/R/jvm.R index bb5c77544a3da..9a1b26b0fa3c5 100644 --- a/R/pkg/R/jvm.R +++ b/R/pkg/R/jvm.R @@ -35,7 +35,6 @@ #' @param ... parameters to pass to the Java method. #' @return the return value of the Java method. Either returned as a R object #' if it can be deserialized or returned as a "jobj". See details section for more. -#' @export #' @seealso \link{sparkR.callJStatic}, \link{sparkR.newJObject} #' @rdname sparkR.callJMethod #' @examples @@ -69,7 +68,6 @@ sparkR.callJMethod <- function(x, methodName, ...) { #' @param ... parameters to pass to the Java method. #' @return the return value of the Java method. Either returned as a R object #' if it can be deserialized or returned as a "jobj". See details section for more. -#' @export #' @seealso \link{sparkR.callJMethod}, \link{sparkR.newJObject} #' @rdname sparkR.callJStatic #' @examples @@ -100,7 +98,6 @@ sparkR.callJStatic <- function(x, methodName, ...) { #' @param ... arguments to be passed to the constructor. #' @return the object created. Either returned as a R object #' if it can be deserialized or returned as a "jobj". See details section for more. -#' @export #' @seealso \link{sparkR.callJMethod}, \link{sparkR.callJStatic} #' @rdname sparkR.newJObject #' @examples diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index f6e9b1357561b..2964fdeff0957 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -21,28 +21,24 @@ #' S4 class that represents an LinearSVCModel #' #' @param jobj a Java object reference to the backing Scala LinearSVCModel -#' @export #' @note LinearSVCModel since 2.2.0 setClass("LinearSVCModel", representation(jobj = "jobj")) #' S4 class that represents an LogisticRegressionModel #' #' @param jobj a Java object reference to the backing Scala LogisticRegressionModel -#' @export #' @note LogisticRegressionModel since 2.1.0 setClass("LogisticRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a MultilayerPerceptronClassificationModel #' #' @param jobj a Java object reference to the backing Scala MultilayerPerceptronClassifierWrapper -#' @export #' @note MultilayerPerceptronClassificationModel since 2.1.0 setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj")) #' S4 class that represents a NaiveBayesModel #' #' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper -#' @export #' @note NaiveBayesModel since 2.0.0 setClass("NaiveBayesModel", representation(jobj = "jobj")) @@ -82,7 +78,6 @@ setClass("NaiveBayesModel", representation(jobj = "jobj")) #' @rdname spark.svmLinear #' @aliases spark.svmLinear,SparkDataFrame,formula-method #' @name spark.svmLinear -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -131,7 +126,6 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu #' @return \code{predict} returns the predicted values based on a LinearSVCModel. #' @rdname spark.svmLinear #' @aliases predict,LinearSVCModel,SparkDataFrame-method -#' @export #' @note predict(LinearSVCModel) since 2.2.0 setMethod("predict", signature(object = "LinearSVCModel"), function(object, newData) { @@ -146,7 +140,6 @@ setMethod("predict", signature(object = "LinearSVCModel"), #' \code{numClasses} (number of classes), \code{numFeatures} (number of features). #' @rdname spark.svmLinear #' @aliases summary,LinearSVCModel-method -#' @export #' @note summary(LinearSVCModel) since 2.2.0 setMethod("summary", signature(object = "LinearSVCModel"), function(object) { @@ -169,7 +162,6 @@ setMethod("summary", signature(object = "LinearSVCModel"), #' #' @rdname spark.svmLinear #' @aliases write.ml,LinearSVCModel,character-method -#' @export #' @note write.ml(LogisticRegression, character) since 2.2.0 setMethod("write.ml", signature(object = "LinearSVCModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -257,7 +249,6 @@ function(object, path, overwrite = FALSE) { #' @rdname spark.logit #' @aliases spark.logit,SparkDataFrame,formula-method #' @name spark.logit -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -374,7 +365,6 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") #' The list includes \code{coefficients} (coefficients matrix of the fitted model). #' @rdname spark.logit #' @aliases summary,LogisticRegressionModel-method -#' @export #' @note summary(LogisticRegressionModel) since 2.1.0 setMethod("summary", signature(object = "LogisticRegressionModel"), function(object) { @@ -402,7 +392,6 @@ setMethod("summary", signature(object = "LogisticRegressionModel"), #' @return \code{predict} returns the predicted values based on an LogisticRegressionModel. #' @rdname spark.logit #' @aliases predict,LogisticRegressionModel,SparkDataFrame-method -#' @export #' @note predict(LogisticRegressionModel) since 2.1.0 setMethod("predict", signature(object = "LogisticRegressionModel"), function(object, newData) { @@ -417,7 +406,6 @@ setMethod("predict", signature(object = "LogisticRegressionModel"), #' #' @rdname spark.logit #' @aliases write.ml,LogisticRegressionModel,character-method -#' @export #' @note write.ml(LogisticRegression, character) since 2.1.0 setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -458,7 +446,6 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char #' @aliases spark.mlp,SparkDataFrame,formula-method #' @name spark.mlp #' @seealso \link{read.ml} -#' @export #' @examples #' \dontrun{ #' df <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") @@ -517,7 +504,6 @@ setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"), #' For \code{weights}, it is a numeric vector with length equal to the expected #' given the architecture (i.e., for 8-10-2 network, 112 connection weights). #' @rdname spark.mlp -#' @export #' @aliases summary,MultilayerPerceptronClassificationModel-method #' @note summary(MultilayerPerceptronClassificationModel) since 2.1.0 setMethod("summary", signature(object = "MultilayerPerceptronClassificationModel"), @@ -538,7 +524,6 @@ setMethod("summary", signature(object = "MultilayerPerceptronClassificationModel #' "prediction". #' @rdname spark.mlp #' @aliases predict,MultilayerPerceptronClassificationModel-method -#' @export #' @note predict(MultilayerPerceptronClassificationModel) since 2.1.0 setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel"), function(object, newData) { @@ -553,7 +538,6 @@ setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel #' #' @rdname spark.mlp #' @aliases write.ml,MultilayerPerceptronClassificationModel,character-method -#' @export #' @seealso \link{write.ml} #' @note write.ml(MultilayerPerceptronClassificationModel, character) since 2.1.0 setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationModel", @@ -585,7 +569,6 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode #' @aliases spark.naiveBayes,SparkDataFrame,formula-method #' @name spark.naiveBayes #' @seealso e1071: \url{https://cran.r-project.org/package=e1071} -#' @export #' @examples #' \dontrun{ #' data <- as.data.frame(UCBAdmissions) @@ -624,7 +607,6 @@ setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "form #' The list includes \code{apriori} (the label distribution) and #' \code{tables} (conditional probabilities given the target label). #' @rdname spark.naiveBayes -#' @export #' @note summary(NaiveBayesModel) since 2.0.0 setMethod("summary", signature(object = "NaiveBayesModel"), function(object) { @@ -648,7 +630,6 @@ setMethod("summary", signature(object = "NaiveBayesModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named #' "prediction". #' @rdname spark.naiveBayes -#' @export #' @note predict(NaiveBayesModel) since 2.0.0 setMethod("predict", signature(object = "NaiveBayesModel"), function(object, newData) { @@ -662,7 +643,6 @@ setMethod("predict", signature(object = "NaiveBayesModel"), #' which means throw exception if the output path exists. #' #' @rdname spark.naiveBayes -#' @export #' @seealso \link{write.ml} #' @note write.ml(NaiveBayesModel, character) since 2.0.0 setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"), diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R index a25bf81c6d977..900be685824da 100644 --- a/R/pkg/R/mllib_clustering.R +++ b/R/pkg/R/mllib_clustering.R @@ -20,28 +20,24 @@ #' S4 class that represents a BisectingKMeansModel #' #' @param jobj a Java object reference to the backing Scala BisectingKMeansModel -#' @export #' @note BisectingKMeansModel since 2.2.0 setClass("BisectingKMeansModel", representation(jobj = "jobj")) #' S4 class that represents a GaussianMixtureModel #' #' @param jobj a Java object reference to the backing Scala GaussianMixtureModel -#' @export #' @note GaussianMixtureModel since 2.1.0 setClass("GaussianMixtureModel", representation(jobj = "jobj")) #' S4 class that represents a KMeansModel #' #' @param jobj a Java object reference to the backing Scala KMeansModel -#' @export #' @note KMeansModel since 2.0.0 setClass("KMeansModel", representation(jobj = "jobj")) #' S4 class that represents an LDAModel #' #' @param jobj a Java object reference to the backing Scala LDAWrapper -#' @export #' @note LDAModel since 2.1.0 setClass("LDAModel", representation(jobj = "jobj")) @@ -68,7 +64,6 @@ setClass("LDAModel", representation(jobj = "jobj")) #' @rdname spark.bisectingKmeans #' @aliases spark.bisectingKmeans,SparkDataFrame,formula-method #' @name spark.bisectingKmeans -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -117,7 +112,6 @@ setMethod("spark.bisectingKmeans", signature(data = "SparkDataFrame", formula = #' (cluster centers of the transformed data; cluster is NULL if is.loaded is TRUE), #' and \code{is.loaded} (whether the model is loaded from a saved file). #' @rdname spark.bisectingKmeans -#' @export #' @note summary(BisectingKMeansModel) since 2.2.0 setMethod("summary", signature(object = "BisectingKMeansModel"), function(object) { @@ -144,7 +138,6 @@ setMethod("summary", signature(object = "BisectingKMeansModel"), #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns the predicted values based on a bisecting k-means model. #' @rdname spark.bisectingKmeans -#' @export #' @note predict(BisectingKMeansModel) since 2.2.0 setMethod("predict", signature(object = "BisectingKMeansModel"), function(object, newData) { @@ -160,7 +153,6 @@ setMethod("predict", signature(object = "BisectingKMeansModel"), #' or \code{"classes"} for assigned classes. #' @return \code{fitted} returns a SparkDataFrame containing fitted values. #' @rdname spark.bisectingKmeans -#' @export #' @note fitted since 2.2.0 setMethod("fitted", signature(object = "BisectingKMeansModel"), function(object, method = c("centers", "classes")) { @@ -181,7 +173,6 @@ setMethod("fitted", signature(object = "BisectingKMeansModel"), #' which means throw exception if the output path exists. #' #' @rdname spark.bisectingKmeans -#' @export #' @note write.ml(BisectingKMeansModel, character) since 2.2.0 setMethod("write.ml", signature(object = "BisectingKMeansModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -208,7 +199,6 @@ setMethod("write.ml", signature(object = "BisectingKMeansModel", path = "charact #' @rdname spark.gaussianMixture #' @name spark.gaussianMixture #' @seealso mixtools: \url{https://cran.r-project.org/package=mixtools} -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -251,7 +241,6 @@ setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula = #' \code{sigma} (sigma), \code{loglik} (loglik), and \code{posterior} (posterior). #' @aliases spark.gaussianMixture,SparkDataFrame,formula-method #' @rdname spark.gaussianMixture -#' @export #' @note summary(GaussianMixtureModel) since 2.1.0 setMethod("summary", signature(object = "GaussianMixtureModel"), function(object) { @@ -291,7 +280,6 @@ setMethod("summary", signature(object = "GaussianMixtureModel"), #' "prediction". #' @aliases predict,GaussianMixtureModel,SparkDataFrame-method #' @rdname spark.gaussianMixture -#' @export #' @note predict(GaussianMixtureModel) since 2.1.0 setMethod("predict", signature(object = "GaussianMixtureModel"), function(object, newData) { @@ -306,7 +294,6 @@ setMethod("predict", signature(object = "GaussianMixtureModel"), #' #' @aliases write.ml,GaussianMixtureModel,character-method #' @rdname spark.gaussianMixture -#' @export #' @note write.ml(GaussianMixtureModel, character) since 2.1.0 setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -336,7 +323,6 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact #' @rdname spark.kmeans #' @aliases spark.kmeans,SparkDataFrame,formula-method #' @name spark.kmeans -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -385,7 +371,6 @@ setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula" #' (the actual number of cluster centers. When using initMode = "random", #' \code{clusterSize} may not equal to \code{k}). #' @rdname spark.kmeans -#' @export #' @note summary(KMeansModel) since 2.0.0 setMethod("summary", signature(object = "KMeansModel"), function(object) { @@ -413,7 +398,6 @@ setMethod("summary", signature(object = "KMeansModel"), #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns the predicted values based on a k-means model. #' @rdname spark.kmeans -#' @export #' @note predict(KMeansModel) since 2.0.0 setMethod("predict", signature(object = "KMeansModel"), function(object, newData) { @@ -431,7 +415,6 @@ setMethod("predict", signature(object = "KMeansModel"), #' @param ... additional argument(s) passed to the method. #' @return \code{fitted} returns a SparkDataFrame containing fitted values. #' @rdname fitted -#' @export #' @examples #' \dontrun{ #' model <- spark.kmeans(trainingData, ~ ., 2) @@ -458,7 +441,6 @@ setMethod("fitted", signature(object = "KMeansModel"), #' which means throw exception if the output path exists. #' #' @rdname spark.kmeans -#' @export #' @note write.ml(KMeansModel, character) since 2.0.0 setMethod("write.ml", signature(object = "KMeansModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -496,7 +478,6 @@ setMethod("write.ml", signature(object = "KMeansModel", path = "character"), #' @rdname spark.lda #' @aliases spark.lda,SparkDataFrame-method #' @seealso topicmodels: \url{https://cran.r-project.org/package=topicmodels} -#' @export #' @examples #' \dontrun{ #' text <- read.df("data/mllib/sample_lda_libsvm_data.txt", source = "libsvm") @@ -558,7 +539,6 @@ setMethod("spark.lda", signature(data = "SparkDataFrame"), #' It is only for distributed LDA model (i.e., optimizer = "em")} #' @rdname spark.lda #' @aliases summary,LDAModel-method -#' @export #' @note summary(LDAModel) since 2.1.0 setMethod("summary", signature(object = "LDAModel"), function(object, maxTermsPerTopic) { @@ -596,7 +576,6 @@ setMethod("summary", signature(object = "LDAModel"), #' perplexity of the training data if missing argument "data". #' @rdname spark.lda #' @aliases spark.perplexity,LDAModel-method -#' @export #' @note spark.perplexity(LDAModel) since 2.1.0 setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFrame"), function(object, data) { @@ -611,7 +590,6 @@ setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFr #' vectors named "topicDistribution". #' @rdname spark.lda #' @aliases spark.posterior,LDAModel,SparkDataFrame-method -#' @export #' @note spark.posterior(LDAModel) since 2.1.0 setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkDataFrame"), function(object, newData) { @@ -626,7 +604,6 @@ setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkData #' #' @rdname spark.lda #' @aliases write.ml,LDAModel,character-method -#' @export #' @seealso \link{read.ml} #' @note write.ml(LDAModel, character) since 2.1.0 setMethod("write.ml", signature(object = "LDAModel", path = "character"), diff --git a/R/pkg/R/mllib_fpm.R b/R/pkg/R/mllib_fpm.R index dfcb45a1b66c9..e2394906d8012 100644 --- a/R/pkg/R/mllib_fpm.R +++ b/R/pkg/R/mllib_fpm.R @@ -20,7 +20,6 @@ #' S4 class that represents a FPGrowthModel #' #' @param jobj a Java object reference to the backing Scala FPGrowthModel -#' @export #' @note FPGrowthModel since 2.2.0 setClass("FPGrowthModel", slots = list(jobj = "jobj")) @@ -45,7 +44,6 @@ setClass("FPGrowthModel", slots = list(jobj = "jobj")) #' @rdname spark.fpGrowth #' @name spark.fpGrowth #' @aliases spark.fpGrowth,SparkDataFrame-method -#' @export #' @examples #' \dontrun{ #' raw_data <- read.df( @@ -109,7 +107,6 @@ setMethod("spark.fpGrowth", signature(data = "SparkDataFrame"), #' and \code{freq} (frequency of the itemset). #' @rdname spark.fpGrowth #' @aliases freqItemsets,FPGrowthModel-method -#' @export #' @note spark.freqItemsets(FPGrowthModel) since 2.2.0 setMethod("spark.freqItemsets", signature(object = "FPGrowthModel"), function(object) { @@ -125,7 +122,6 @@ setMethod("spark.freqItemsets", signature(object = "FPGrowthModel"), #' and \code{condfidence} (confidence). #' @rdname spark.fpGrowth #' @aliases associationRules,FPGrowthModel-method -#' @export #' @note spark.associationRules(FPGrowthModel) since 2.2.0 setMethod("spark.associationRules", signature(object = "FPGrowthModel"), function(object) { @@ -138,7 +134,6 @@ setMethod("spark.associationRules", signature(object = "FPGrowthModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted values. #' @rdname spark.fpGrowth #' @aliases predict,FPGrowthModel-method -#' @export #' @note predict(FPGrowthModel) since 2.2.0 setMethod("predict", signature(object = "FPGrowthModel"), function(object, newData) { @@ -153,7 +148,6 @@ setMethod("predict", signature(object = "FPGrowthModel"), #' if the output path exists. #' @rdname spark.fpGrowth #' @aliases write.ml,FPGrowthModel,character-method -#' @export #' @seealso \link{read.ml} #' @note write.ml(FPGrowthModel, character) since 2.2.0 setMethod("write.ml", signature(object = "FPGrowthModel", path = "character"), diff --git a/R/pkg/R/mllib_recommendation.R b/R/pkg/R/mllib_recommendation.R index 5441c4a4022a9..9a77b07462585 100644 --- a/R/pkg/R/mllib_recommendation.R +++ b/R/pkg/R/mllib_recommendation.R @@ -20,7 +20,6 @@ #' S4 class that represents an ALSModel #' #' @param jobj a Java object reference to the backing Scala ALSWrapper -#' @export #' @note ALSModel since 2.1.0 setClass("ALSModel", representation(jobj = "jobj")) @@ -55,7 +54,6 @@ setClass("ALSModel", representation(jobj = "jobj")) #' @rdname spark.als #' @aliases spark.als,SparkDataFrame-method #' @name spark.als -#' @export #' @examples #' \dontrun{ #' ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), @@ -118,7 +116,6 @@ setMethod("spark.als", signature(data = "SparkDataFrame"), #' and \code{rank} (rank of the matrix factorization model). #' @rdname spark.als #' @aliases summary,ALSModel-method -#' @export #' @note summary(ALSModel) since 2.1.0 setMethod("summary", signature(object = "ALSModel"), function(object) { @@ -139,7 +136,6 @@ setMethod("summary", signature(object = "ALSModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted values. #' @rdname spark.als #' @aliases predict,ALSModel-method -#' @export #' @note predict(ALSModel) since 2.1.0 setMethod("predict", signature(object = "ALSModel"), function(object, newData) { @@ -155,7 +151,6 @@ setMethod("predict", signature(object = "ALSModel"), #' #' @rdname spark.als #' @aliases write.ml,ALSModel,character-method -#' @export #' @seealso \link{read.ml} #' @note write.ml(ALSModel, character) since 2.1.0 setMethod("write.ml", signature(object = "ALSModel", path = "character"), diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index 545be5e1d89f0..95c1a29905197 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -21,21 +21,18 @@ #' S4 class that represents a AFTSurvivalRegressionModel #' #' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper -#' @export #' @note AFTSurvivalRegressionModel since 2.0.0 setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a generalized linear model #' #' @param jobj a Java object reference to the backing Scala GeneralizedLinearRegressionWrapper -#' @export #' @note GeneralizedLinearRegressionModel since 2.0.0 setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj")) #' S4 class that represents an IsotonicRegressionModel #' #' @param jobj a Java object reference to the backing Scala IsotonicRegressionModel -#' @export #' @note IsotonicRegressionModel since 2.1.0 setClass("IsotonicRegressionModel", representation(jobj = "jobj")) @@ -85,7 +82,6 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' @return \code{spark.glm} returns a fitted generalized linear model. #' @rdname spark.glm #' @name spark.glm -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -211,7 +207,6 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @return \code{glm} returns a fitted generalized linear model. #' @rdname glm #' @aliases glm -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -244,7 +239,6 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat #' and \code{iter} (number of iterations IRLS takes). If there are collinear columns in #' the data, the coefficients matrix only provides coefficients. #' @rdname spark.glm -#' @export #' @note summary(GeneralizedLinearRegressionModel) since 2.0.0 setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), function(object) { @@ -290,7 +284,6 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), #' @rdname spark.glm #' @param x summary object of fitted generalized linear model returned by \code{summary} function. -#' @export #' @note print.summary.GeneralizedLinearRegressionModel since 2.0.0 print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { if (x$is.loaded) { @@ -324,7 +317,6 @@ print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { #' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named #' "prediction". #' @rdname spark.glm -#' @export #' @note predict(GeneralizedLinearRegressionModel) since 1.5.0 setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"), function(object, newData) { @@ -338,7 +330,6 @@ setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"), #' which means throw exception if the output path exists. #' #' @rdname spark.glm -#' @export #' @note write.ml(GeneralizedLinearRegressionModel, character) since 2.0.0 setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -363,7 +354,6 @@ setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", pat #' @rdname spark.isoreg #' @aliases spark.isoreg,SparkDataFrame,formula-method #' @name spark.isoreg -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -412,7 +402,6 @@ setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula" #' and \code{predictions} (predictions associated with the boundaries at the same index). #' @rdname spark.isoreg #' @aliases summary,IsotonicRegressionModel-method -#' @export #' @note summary(IsotonicRegressionModel) since 2.1.0 setMethod("summary", signature(object = "IsotonicRegressionModel"), function(object) { @@ -429,7 +418,6 @@ setMethod("summary", signature(object = "IsotonicRegressionModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted values. #' @rdname spark.isoreg #' @aliases predict,IsotonicRegressionModel,SparkDataFrame-method -#' @export #' @note predict(IsotonicRegressionModel) since 2.1.0 setMethod("predict", signature(object = "IsotonicRegressionModel"), function(object, newData) { @@ -444,7 +432,6 @@ setMethod("predict", signature(object = "IsotonicRegressionModel"), #' #' @rdname spark.isoreg #' @aliases write.ml,IsotonicRegressionModel,character-method -#' @export #' @note write.ml(IsotonicRegression, character) since 2.1.0 setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -477,7 +464,6 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' @return \code{spark.survreg} returns a fitted AFT survival regression model. #' @rdname spark.survreg #' @seealso survival: \url{https://cran.r-project.org/package=survival} -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(ovarian) @@ -517,7 +503,6 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula #' The list includes the model's \code{coefficients} (features, coefficients, #' intercept and log(scale)). #' @rdname spark.survreg -#' @export #' @note summary(AFTSurvivalRegressionModel) since 2.0.0 setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), function(object) { @@ -537,7 +522,6 @@ setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted values #' on the original scale of the data (mean predicted value at scale = 1.0). #' @rdname spark.survreg -#' @export #' @note predict(AFTSurvivalRegressionModel) since 2.0.0 setMethod("predict", signature(object = "AFTSurvivalRegressionModel"), function(object, newData) { @@ -550,7 +534,6 @@ setMethod("predict", signature(object = "AFTSurvivalRegressionModel"), #' @param overwrite overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' @rdname spark.survreg -#' @export #' @note write.ml(AFTSurvivalRegressionModel, character) since 2.0.0 #' @seealso \link{write.ml} setMethod("write.ml", signature(object = "AFTSurvivalRegressionModel", path = "character"), diff --git a/R/pkg/R/mllib_stat.R b/R/pkg/R/mllib_stat.R index 3e013f1d45e38..f8c3329359961 100644 --- a/R/pkg/R/mllib_stat.R +++ b/R/pkg/R/mllib_stat.R @@ -20,7 +20,6 @@ #' S4 class that represents an KSTest #' #' @param jobj a Java object reference to the backing Scala KSTestWrapper -#' @export #' @note KSTest since 2.1.0 setClass("KSTest", representation(jobj = "jobj")) @@ -52,7 +51,6 @@ setClass("KSTest", representation(jobj = "jobj")) #' @name spark.kstest #' @seealso \href{http://spark.apache.org/docs/latest/mllib-statistics.html#hypothesis-testing}{ #' MLlib: Hypothesis Testing} -#' @export #' @examples #' \dontrun{ #' data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25)) @@ -94,7 +92,6 @@ setMethod("spark.kstest", signature(data = "SparkDataFrame"), #' parameters tested against) and \code{degreesOfFreedom} (degrees of freedom of the test). #' @rdname spark.kstest #' @aliases summary,KSTest-method -#' @export #' @note summary(KSTest) since 2.1.0 setMethod("summary", signature(object = "KSTest"), function(object) { @@ -117,7 +114,6 @@ setMethod("summary", signature(object = "KSTest"), #' @rdname spark.kstest #' @param x summary object of KSTest returned by \code{summary}. -#' @export #' @note print.summary.KSTest since 2.1.0 print.summary.KSTest <- function(x, ...) { jobj <- x$jobj diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 4e5ddf22ee16d..6769be038efa9 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -20,42 +20,36 @@ #' S4 class that represents a GBTRegressionModel #' #' @param jobj a Java object reference to the backing Scala GBTRegressionModel -#' @export #' @note GBTRegressionModel since 2.1.0 setClass("GBTRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a GBTClassificationModel #' #' @param jobj a Java object reference to the backing Scala GBTClassificationModel -#' @export #' @note GBTClassificationModel since 2.1.0 setClass("GBTClassificationModel", representation(jobj = "jobj")) #' S4 class that represents a RandomForestRegressionModel #' #' @param jobj a Java object reference to the backing Scala RandomForestRegressionModel -#' @export #' @note RandomForestRegressionModel since 2.1.0 setClass("RandomForestRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a RandomForestClassificationModel #' #' @param jobj a Java object reference to the backing Scala RandomForestClassificationModel -#' @export #' @note RandomForestClassificationModel since 2.1.0 setClass("RandomForestClassificationModel", representation(jobj = "jobj")) #' S4 class that represents a DecisionTreeRegressionModel #' #' @param jobj a Java object reference to the backing Scala DecisionTreeRegressionModel -#' @export #' @note DecisionTreeRegressionModel since 2.3.0 setClass("DecisionTreeRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a DecisionTreeClassificationModel #' #' @param jobj a Java object reference to the backing Scala DecisionTreeClassificationModel -#' @export #' @note DecisionTreeClassificationModel since 2.3.0 setClass("DecisionTreeClassificationModel", representation(jobj = "jobj")) @@ -179,7 +173,6 @@ print.summary.decisionTree <- function(x) { #' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model. #' @rdname spark.gbt #' @name spark.gbt -#' @export #' @examples #' \dontrun{ #' # fit a Gradient Boosted Tree Regression Model @@ -261,7 +254,6 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), #' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights). #' @rdname spark.gbt #' @aliases summary,GBTRegressionModel-method -#' @export #' @note summary(GBTRegressionModel) since 2.1.0 setMethod("summary", signature(object = "GBTRegressionModel"), function(object) { @@ -275,7 +267,6 @@ setMethod("summary", signature(object = "GBTRegressionModel"), #' @param x summary object of Gradient Boosted Tree regression model or classification model #' returned by \code{summary}. #' @rdname spark.gbt -#' @export #' @note print.summary.GBTRegressionModel since 2.1.0 print.summary.GBTRegressionModel <- function(x, ...) { print.summary.treeEnsemble(x) @@ -285,7 +276,6 @@ print.summary.GBTRegressionModel <- function(x, ...) { #' @rdname spark.gbt #' @aliases summary,GBTClassificationModel-method -#' @export #' @note summary(GBTClassificationModel) since 2.1.0 setMethod("summary", signature(object = "GBTClassificationModel"), function(object) { @@ -297,7 +287,6 @@ setMethod("summary", signature(object = "GBTClassificationModel"), # Prints the summary of Gradient Boosted Tree Classification Model #' @rdname spark.gbt -#' @export #' @note print.summary.GBTClassificationModel since 2.1.0 print.summary.GBTClassificationModel <- function(x, ...) { print.summary.treeEnsemble(x) @@ -310,7 +299,6 @@ print.summary.GBTClassificationModel <- function(x, ...) { #' "prediction". #' @rdname spark.gbt #' @aliases predict,GBTRegressionModel-method -#' @export #' @note predict(GBTRegressionModel) since 2.1.0 setMethod("predict", signature(object = "GBTRegressionModel"), function(object, newData) { @@ -319,7 +307,6 @@ setMethod("predict", signature(object = "GBTRegressionModel"), #' @rdname spark.gbt #' @aliases predict,GBTClassificationModel-method -#' @export #' @note predict(GBTClassificationModel) since 2.1.0 setMethod("predict", signature(object = "GBTClassificationModel"), function(object, newData) { @@ -334,7 +321,6 @@ setMethod("predict", signature(object = "GBTClassificationModel"), #' which means throw exception if the output path exists. #' @aliases write.ml,GBTRegressionModel,character-method #' @rdname spark.gbt -#' @export #' @note write.ml(GBTRegressionModel, character) since 2.1.0 setMethod("write.ml", signature(object = "GBTRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -343,7 +329,6 @@ setMethod("write.ml", signature(object = "GBTRegressionModel", path = "character #' @aliases write.ml,GBTClassificationModel,character-method #' @rdname spark.gbt -#' @export #' @note write.ml(GBTClassificationModel, character) since 2.1.0 setMethod("write.ml", signature(object = "GBTClassificationModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -402,7 +387,6 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' @return \code{spark.randomForest} returns a fitted Random Forest model. #' @rdname spark.randomForest #' @name spark.randomForest -#' @export #' @examples #' \dontrun{ #' # fit a Random Forest Regression Model @@ -480,7 +464,6 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo #' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights). #' @rdname spark.randomForest #' @aliases summary,RandomForestRegressionModel-method -#' @export #' @note summary(RandomForestRegressionModel) since 2.1.0 setMethod("summary", signature(object = "RandomForestRegressionModel"), function(object) { @@ -494,7 +477,6 @@ setMethod("summary", signature(object = "RandomForestRegressionModel"), #' @param x summary object of Random Forest regression model or classification model #' returned by \code{summary}. #' @rdname spark.randomForest -#' @export #' @note print.summary.RandomForestRegressionModel since 2.1.0 print.summary.RandomForestRegressionModel <- function(x, ...) { print.summary.treeEnsemble(x) @@ -504,7 +486,6 @@ print.summary.RandomForestRegressionModel <- function(x, ...) { #' @rdname spark.randomForest #' @aliases summary,RandomForestClassificationModel-method -#' @export #' @note summary(RandomForestClassificationModel) since 2.1.0 setMethod("summary", signature(object = "RandomForestClassificationModel"), function(object) { @@ -516,7 +497,6 @@ setMethod("summary", signature(object = "RandomForestClassificationModel"), # Prints the summary of Random Forest Classification Model #' @rdname spark.randomForest -#' @export #' @note print.summary.RandomForestClassificationModel since 2.1.0 print.summary.RandomForestClassificationModel <- function(x, ...) { print.summary.treeEnsemble(x) @@ -529,7 +509,6 @@ print.summary.RandomForestClassificationModel <- function(x, ...) { #' "prediction". #' @rdname spark.randomForest #' @aliases predict,RandomForestRegressionModel-method -#' @export #' @note predict(RandomForestRegressionModel) since 2.1.0 setMethod("predict", signature(object = "RandomForestRegressionModel"), function(object, newData) { @@ -538,7 +517,6 @@ setMethod("predict", signature(object = "RandomForestRegressionModel"), #' @rdname spark.randomForest #' @aliases predict,RandomForestClassificationModel-method -#' @export #' @note predict(RandomForestClassificationModel) since 2.1.0 setMethod("predict", signature(object = "RandomForestClassificationModel"), function(object, newData) { @@ -554,7 +532,6 @@ setMethod("predict", signature(object = "RandomForestClassificationModel"), #' #' @aliases write.ml,RandomForestRegressionModel,character-method #' @rdname spark.randomForest -#' @export #' @note write.ml(RandomForestRegressionModel, character) since 2.1.0 setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -563,7 +540,6 @@ setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = " #' @aliases write.ml,RandomForestClassificationModel,character-method #' @rdname spark.randomForest -#' @export #' @note write.ml(RandomForestClassificationModel, character) since 2.1.0 setMethod("write.ml", signature(object = "RandomForestClassificationModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -617,7 +593,6 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path #' @return \code{spark.decisionTree} returns a fitted Decision Tree model. #' @rdname spark.decisionTree #' @name spark.decisionTree -#' @export #' @examples #' \dontrun{ #' # fit a Decision Tree Regression Model @@ -690,7 +665,6 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo #' trees). #' @rdname spark.decisionTree #' @aliases summary,DecisionTreeRegressionModel-method -#' @export #' @note summary(DecisionTreeRegressionModel) since 2.3.0 setMethod("summary", signature(object = "DecisionTreeRegressionModel"), function(object) { @@ -704,7 +678,6 @@ setMethod("summary", signature(object = "DecisionTreeRegressionModel"), #' @param x summary object of Decision Tree regression model or classification model #' returned by \code{summary}. #' @rdname spark.decisionTree -#' @export #' @note print.summary.DecisionTreeRegressionModel since 2.3.0 print.summary.DecisionTreeRegressionModel <- function(x, ...) { print.summary.decisionTree(x) @@ -714,7 +687,6 @@ print.summary.DecisionTreeRegressionModel <- function(x, ...) { #' @rdname spark.decisionTree #' @aliases summary,DecisionTreeClassificationModel-method -#' @export #' @note summary(DecisionTreeClassificationModel) since 2.3.0 setMethod("summary", signature(object = "DecisionTreeClassificationModel"), function(object) { @@ -726,7 +698,6 @@ setMethod("summary", signature(object = "DecisionTreeClassificationModel"), # Prints the summary of Decision Tree Classification Model #' @rdname spark.decisionTree -#' @export #' @note print.summary.DecisionTreeClassificationModel since 2.3.0 print.summary.DecisionTreeClassificationModel <- function(x, ...) { print.summary.decisionTree(x) @@ -739,7 +710,6 @@ print.summary.DecisionTreeClassificationModel <- function(x, ...) { #' "prediction". #' @rdname spark.decisionTree #' @aliases predict,DecisionTreeRegressionModel-method -#' @export #' @note predict(DecisionTreeRegressionModel) since 2.3.0 setMethod("predict", signature(object = "DecisionTreeRegressionModel"), function(object, newData) { @@ -748,7 +718,6 @@ setMethod("predict", signature(object = "DecisionTreeRegressionModel"), #' @rdname spark.decisionTree #' @aliases predict,DecisionTreeClassificationModel-method -#' @export #' @note predict(DecisionTreeClassificationModel) since 2.3.0 setMethod("predict", signature(object = "DecisionTreeClassificationModel"), function(object, newData) { @@ -764,7 +733,6 @@ setMethod("predict", signature(object = "DecisionTreeClassificationModel"), #' #' @aliases write.ml,DecisionTreeRegressionModel,character-method #' @rdname spark.decisionTree -#' @export #' @note write.ml(DecisionTreeRegressionModel, character) since 2.3.0 setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -773,7 +741,6 @@ setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = " #' @aliases write.ml,DecisionTreeClassificationModel,character-method #' @rdname spark.decisionTree -#' @export #' @note write.ml(DecisionTreeClassificationModel, character) since 2.3.0 setMethod("write.ml", signature(object = "DecisionTreeClassificationModel", path = "character"), function(object, path, overwrite = FALSE) { diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R index a53c92c2c4815..7d04bffcba3a4 100644 --- a/R/pkg/R/mllib_utils.R +++ b/R/pkg/R/mllib_utils.R @@ -31,7 +31,6 @@ #' MLlib model below. #' @rdname write.ml #' @name write.ml -#' @export #' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree}, #' @seealso \link{spark.gaussianMixture}, \link{spark.gbt}, #' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg}, @@ -48,7 +47,6 @@ NULL #' MLlib model below. #' @rdname predict #' @name predict -#' @export #' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree}, #' @seealso \link{spark.gaussianMixture}, \link{spark.gbt}, #' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg}, @@ -75,7 +73,6 @@ predict_internal <- function(object, newData) { #' @return A fitted MLlib model. #' @rdname read.ml #' @name read.ml -#' @export #' @seealso \link{write.ml} #' @examples #' \dontrun{ diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 65f418740c643..9831fc3cc6d01 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -29,7 +29,6 @@ #' @param ... additional structField objects #' @return a structType object #' @rdname structType -#' @export #' @examples #'\dontrun{ #' schema <- structType(structField("a", "integer"), structField("c", "string"), @@ -49,7 +48,6 @@ structType <- function(x, ...) { #' @rdname structType #' @method structType jobj -#' @export structType.jobj <- function(x, ...) { obj <- structure(list(), class = "structType") obj$jobj <- x @@ -59,7 +57,6 @@ structType.jobj <- function(x, ...) { #' @rdname structType #' @method structType structField -#' @export structType.structField <- function(x, ...) { fields <- list(x, ...) if (!all(sapply(fields, inherits, "structField"))) { @@ -76,7 +73,6 @@ structType.structField <- function(x, ...) { #' @rdname structType #' @method structType character -#' @export structType.character <- function(x, ...) { if (!is.character(x)) { stop("schema must be a DDL-formatted string.") @@ -119,7 +115,6 @@ print.structType <- function(x, ...) { #' @param ... additional argument(s) passed to the method. #' @return A structField object. #' @rdname structField -#' @export #' @examples #'\dontrun{ #' field1 <- structField("a", "integer") @@ -137,7 +132,6 @@ structField <- function(x, ...) { #' @rdname structField #' @method structField jobj -#' @export structField.jobj <- function(x, ...) { obj <- structure(list(), class = "structField") obj$jobj <- x @@ -212,7 +206,6 @@ checkType <- function(type) { #' @param type The data type of the field #' @param nullable A logical vector indicating whether or not the field is nullable #' @rdname structField -#' @export structField.character <- function(x, type, nullable = TRUE, ...) { if (class(x) != "character") { stop("Field name must be a string.") diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 965471f3b07a0..a480ac606f10d 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -35,7 +35,6 @@ connExists <- function(env) { #' Also terminates the backend this R session is connected to. #' @rdname sparkR.session.stop #' @name sparkR.session.stop -#' @export #' @note sparkR.session.stop since 2.0.0 sparkR.session.stop <- function() { env <- .sparkREnv @@ -84,7 +83,6 @@ sparkR.session.stop <- function() { #' @rdname sparkR.session.stop #' @name sparkR.stop -#' @export #' @note sparkR.stop since 1.4.0 sparkR.stop <- function() { sparkR.session.stop() @@ -103,7 +101,6 @@ sparkR.stop <- function() { #' @param sparkPackages Character vector of package coordinates #' @seealso \link{sparkR.session} #' @rdname sparkR.init-deprecated -#' @export #' @examples #'\dontrun{ #' sc <- sparkR.init("local[2]", "SparkR", "/home/spark") @@ -270,7 +267,6 @@ sparkR.sparkContext <- function( #' @param jsc The existing JavaSparkContext created with SparkR.init() #' @seealso \link{sparkR.session} #' @rdname sparkRSQL.init-deprecated -#' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() @@ -298,7 +294,6 @@ sparkRSQL.init <- function(jsc = NULL) { #' @param jsc The existing JavaSparkContext created with SparkR.init() #' @seealso \link{sparkR.session} #' @rdname sparkRHive.init-deprecated -#' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() @@ -347,7 +342,6 @@ sparkRHive.init <- function(jsc = NULL) { #' @param enableHiveSupport enable support for Hive, fallback if not built with Hive support; once #' set, this cannot be turned off on an existing session #' @param ... named Spark properties passed to the method. -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -442,7 +436,6 @@ sparkR.session <- function( #' @return the SparkUI URL, or NA if it is disabled, or not started. #' @rdname sparkR.uiWebUrl #' @name sparkR.uiWebUrl -#' @export #' @examples #'\dontrun{ #' sparkR.session() diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index c8af798830b30..497f18c763048 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -37,7 +37,6 @@ setOldClass("jobj") #' @name crosstab #' @aliases crosstab,SparkDataFrame,character,character-method #' @family stat functions -#' @export #' @examples #' \dontrun{ #' df <- read.json("/path/to/file.json") @@ -63,7 +62,6 @@ setMethod("crosstab", #' @rdname cov #' @aliases cov,SparkDataFrame-method #' @family stat functions -#' @export #' @examples #' #' \dontrun{ @@ -92,7 +90,6 @@ setMethod("cov", #' @name corr #' @aliases corr,SparkDataFrame-method #' @family stat functions -#' @export #' @examples #' #' \dontrun{ @@ -124,7 +121,6 @@ setMethod("corr", #' @name freqItems #' @aliases freqItems,SparkDataFrame,character-method #' @family stat functions -#' @export #' @examples #' \dontrun{ #' df <- read.json("/path/to/file.json") @@ -168,7 +164,6 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"), #' @name approxQuantile #' @aliases approxQuantile,SparkDataFrame,character,numeric,numeric-method #' @family stat functions -#' @export #' @examples #' \dontrun{ #' df <- read.json("/path/to/file.json") @@ -205,7 +200,6 @@ setMethod("approxQuantile", #' @aliases sampleBy,SparkDataFrame,character,list,numeric-method #' @name sampleBy #' @family stat functions -#' @export #' @examples #'\dontrun{ #' df <- read.json("/path/to/file.json") diff --git a/R/pkg/R/streaming.R b/R/pkg/R/streaming.R index 8390bd5e6de72..fc83463f72cd4 100644 --- a/R/pkg/R/streaming.R +++ b/R/pkg/R/streaming.R @@ -28,7 +28,6 @@ NULL #' @seealso \link{read.stream} #' #' @param ssq A Java object reference to the backing Scala StreamingQuery -#' @export #' @note StreamingQuery since 2.2.0 #' @note experimental setClass("StreamingQuery", @@ -45,7 +44,6 @@ streamingQuery <- function(ssq) { } #' @rdname show -#' @export #' @note show(StreamingQuery) since 2.2.0 setMethod("show", "StreamingQuery", function(object) { @@ -70,7 +68,6 @@ setMethod("show", "StreamingQuery", #' @aliases queryName,StreamingQuery-method #' @family StreamingQuery methods #' @seealso \link{write.stream} -#' @export #' @examples #' \dontrun{ queryName(sq) } #' @note queryName(StreamingQuery) since 2.2.0 @@ -85,7 +82,6 @@ setMethod("queryName", #' @name explain #' @aliases explain,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ explain(sq) } #' @note explain(StreamingQuery) since 2.2.0 @@ -104,7 +100,6 @@ setMethod("explain", #' @name lastProgress #' @aliases lastProgress,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ lastProgress(sq) } #' @note lastProgress(StreamingQuery) since 2.2.0 @@ -129,7 +124,6 @@ setMethod("lastProgress", #' @name status #' @aliases status,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ status(sq) } #' @note status(StreamingQuery) since 2.2.0 @@ -150,7 +144,6 @@ setMethod("status", #' @name isActive #' @aliases isActive,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ isActive(sq) } #' @note isActive(StreamingQuery) since 2.2.0 @@ -177,7 +170,6 @@ setMethod("isActive", #' @name awaitTermination #' @aliases awaitTermination,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ awaitTermination(sq, 10000) } #' @note awaitTermination(StreamingQuery) since 2.2.0 @@ -202,7 +194,6 @@ setMethod("awaitTermination", #' @name stopQuery #' @aliases stopQuery,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ stopQuery(sq) } #' @note stopQuery(StreamingQuery) since 2.2.0 diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 164cd6d01a347..f1b5ecaa017df 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -108,7 +108,6 @@ isRDD <- function(name, env) { #' #' @param key the object to be hashed #' @return the hash code as an integer -#' @export #' @examples #'\dontrun{ #' hashCode(1L) # 1 diff --git a/R/pkg/R/window.R b/R/pkg/R/window.R index 0799d841e5dc9..396b27bee80c6 100644 --- a/R/pkg/R/window.R +++ b/R/pkg/R/window.R @@ -29,7 +29,6 @@ #' @rdname windowPartitionBy #' @name windowPartitionBy #' @aliases windowPartitionBy,character-method -#' @export #' @examples #' \dontrun{ #' ws <- orderBy(windowPartitionBy("key1", "key2"), "key3") @@ -52,7 +51,6 @@ setMethod("windowPartitionBy", #' @rdname windowPartitionBy #' @name windowPartitionBy #' @aliases windowPartitionBy,Column-method -#' @export #' @note windowPartitionBy(Column) since 2.0.0 setMethod("windowPartitionBy", signature(col = "Column"), @@ -78,7 +76,6 @@ setMethod("windowPartitionBy", #' @rdname windowOrderBy #' @name windowOrderBy #' @aliases windowOrderBy,character-method -#' @export #' @examples #' \dontrun{ #' ws <- windowOrderBy("key1", "key2") @@ -101,7 +98,6 @@ setMethod("windowOrderBy", #' @rdname windowOrderBy #' @name windowOrderBy #' @aliases windowOrderBy,Column-method -#' @export #' @note windowOrderBy(Column) since 2.0.0 setMethod("windowOrderBy", signature(col = "Column"), From 98a5c0a35f0a24730f5074522939acf57ef95422 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 5 Mar 2018 10:50:00 -0800 Subject: [PATCH 429/774] [SPARK-22882][ML][TESTS] ML test for structured streaming: ml.classification ## What changes were proposed in this pull request? adding Structured Streaming tests for all Models/Transformers in spark.ml.classification ## How was this patch tested? N/A Author: WeichenXu Closes #20121 from WeichenXu123/ml_stream_test_classification. --- .../DecisionTreeClassifierSuite.scala | 29 +-- .../classification/GBTClassifierSuite.scala | 77 ++---- .../ml/classification/LinearSVCSuite.scala | 15 +- .../LogisticRegressionSuite.scala | 229 +++++++----------- .../MultilayerPerceptronClassifierSuite.scala | 44 ++-- .../ml/classification/NaiveBayesSuite.scala | 47 ++-- .../ml/classification/OneVsRestSuite.scala | 21 +- .../ProbabilisticClassifierSuite.scala | 29 +-- .../RandomForestClassifierSuite.scala | 16 +- 9 files changed, 202 insertions(+), 305 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 38b265d62611b..eeb0324187c5b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -23,15 +23,14 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode} import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} -import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, + DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} -class DecisionTreeClassifierSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { import DecisionTreeClassifierSuite.compareAPIs import testImplicits._ @@ -251,20 +250,18 @@ class DecisionTreeClassifierSuite MLTestingUtils.checkCopyAndUids(dt, newTree) - val predictions = newTree.transform(newData) - .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol) - .collect() - - predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) => - assert(pred === rawPred.argmax, - s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") - val sum = rawPred.toArray.sum - assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, - "probability prediction mismatch") + testTransformer[(Vector, Double)](newData, newTree, + "prediction", "rawPrediction", "probability") { + case Row(pred: Double, rawPred: Vector, probPred: Vector) => + assert(pred === rawPred.argmax, + s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") + val sum = rawPred.toArray.sum + assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, + "probability prediction mismatch") } ProbabilisticClassifierSuite.testPredictMethods[ - Vector, DecisionTreeClassificationModel](newTree, newData) + Vector, DecisionTreeClassificationModel](this, newTree, newData) } test("training with 1-category categorical feature") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 978f89c459f0a..092b4a01d5b0d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -26,13 +26,12 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.LogLoss -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.util.Utils @@ -40,8 +39,7 @@ import org.apache.spark.util.Utils /** * Test suite for [[GBTClassifier]]. */ -class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ import GBTClassifierSuite.compareAPIs @@ -126,14 +124,15 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext // should predict all zeros binaryModel.setThresholds(Array(0.0, 1.0)) - val binaryZeroPredictions = binaryModel.transform(df).select("prediction").collect() - assert(binaryZeroPredictions.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](df, binaryModel, "prediction") { + case Row(prediction: Double) => prediction === 0.0 + } // should predict all ones binaryModel.setThresholds(Array(1.0, 0.0)) - val binaryOnePredictions = binaryModel.transform(df).select("prediction").collect() - assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0)) - + testTransformer[(Double, Vector)](df, binaryModel, "prediction") { + case Row(prediction: Double) => prediction === 1.0 + } val gbtBase = new GBTClassifier val model = gbtBase.fit(df) @@ -141,15 +140,18 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext // constant threshold scaling is the same as no thresholds binaryModel.setThresholds(Array(1.0, 1.0)) - val scaledPredictions = binaryModel.transform(df).select("prediction").collect() - assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => - scaled.getDouble(0) === base.getDouble(0) - }) + testTransformerByGlobalCheckFunc[(Double, Vector)](df, binaryModel, "prediction") { + scaledPredictions: Seq[Row] => + assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => + scaled.getDouble(0) === base.getDouble(0) + }) + } // force it to use the predict method model.setRawPredictionCol("").setProbabilityCol("").setThresholds(Array(0, 1)) - val predictionsWithPredict = model.transform(df).select("prediction").collect() - assert(predictionsWithPredict.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](df, model, "prediction") { + case Row(prediction: Double) => prediction === 0.0 + } } test("GBTClassifier: Predictor, Classifier methods") { @@ -169,61 +171,30 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext val blas = BLAS.getInstance() val validationDataset = validationData.toDF(labelCol, featuresCol) - val results = gbtModel.transform(validationDataset) - // check that raw prediction is tree predictions dot tree weights - results.select(rawPredictionCol, featuresCol).collect().foreach { - case Row(raw: Vector, features: Vector) => + testTransformer[(Double, Vector)](validationDataset, gbtModel, + "rawPrediction", "features", "probability", "prediction") { + case Row(raw: Vector, features: Vector, prob: Vector, pred: Double) => assert(raw.size === 2) + // check that raw prediction is tree predictions dot tree weights val treePredictions = gbtModel.trees.map(_.rootNode.predictImpl(features).prediction) val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1, gbtModel.treeWeights, 1) assert(raw ~== Vectors.dense(-prediction, prediction) relTol eps) - } - // Compare rawPrediction with probability - results.select(rawPredictionCol, probabilityCol).collect().foreach { - case Row(raw: Vector, prob: Vector) => - assert(raw.size === 2) + // Compare rawPrediction with probability assert(prob.size === 2) // Note: we should check other loss types for classification if they are added val predFromRaw = raw.toDense.values.map(value => LogLoss.computeProbability(value)) assert(prob(0) ~== predFromRaw(0) relTol eps) assert(prob(1) ~== predFromRaw(1) relTol eps) assert(prob(0) + prob(1) ~== 1.0 absTol absEps) - } - // Compare prediction with probability - results.select(predictionCol, probabilityCol).collect().foreach { - case Row(pred: Double, prob: Vector) => + // Compare prediction with probability val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 assert(pred == predFromProb) } - // force it to use raw2prediction - gbtModel.setRawPredictionCol(rawPredictionCol).setProbabilityCol("") - val resultsUsingRaw2Predict = - gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() - resultsUsingRaw2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use probability2prediction - gbtModel.setRawPredictionCol("").setProbabilityCol(probabilityCol) - val resultsUsingProb2Predict = - gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() - resultsUsingProb2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use predict - gbtModel.setRawPredictionCol("").setProbabilityCol("") - val resultsUsingPredict = - gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() - resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - ProbabilisticClassifierSuite.testPredictMethods[ - Vector, GBTClassificationModel](gbtModel, validationDataset) + Vector, GBTClassificationModel](this, gbtModel, validationDataset) } test("GBT parameter stepSize should be in interval (0, 1]") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index 41a5d22dd6283..a93825b8a812d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -21,20 +21,18 @@ import scala.util.Random import breeze.linalg.{DenseVector => BDV} -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LinearSVCSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.optim.aggregator.HingeAggregator import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.udf -class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class LinearSVCSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -141,10 +139,11 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau threshold: Double, expected: Set[(Int, Double)]): Unit = { model.setThreshold(threshold) - val results = model.transform(df).select("id", "prediction").collect() - .map(r => (r.getInt(0), r.getDouble(1))) - .toSet - assert(results === expected, s"Failed for threshold = $threshold") + testTransformerByGlobalCheckFunc[(Int, Vector)](df, model, "id", "prediction") { + rows: Seq[Row] => + val results = rows.map(r => (r.getInt(0), r.getDouble(1))).toSet + assert(results === expected, s"Failed for threshold = $threshold") + } } def checkResults(threshold: Double, expected: Set[(Int, Double)]): Unit = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index a5f81a38face9..9987cbf6ba116 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -22,22 +22,20 @@ import scala.language.existentials import scala.util.Random import scala.util.control.Breaks._ -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix, Vector, Vectors} import org.apache.spark.ml.optim.aggregator.LogisticAggregator import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.{col, lit, rand} import org.apache.spark.sql.types.LongType -class LogisticRegressionSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -332,15 +330,14 @@ class LogisticRegressionSuite val binaryModel = blr.fit(smallBinaryDataset) binaryModel.setThreshold(1.0) - val binaryZeroPredictions = - binaryModel.transform(smallBinaryDataset).select("prediction").collect() - assert(binaryZeroPredictions.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](smallBinaryDataset.toDF(), binaryModel, "prediction") { + row => assert(row.getDouble(0) === 0.0) + } binaryModel.setThreshold(0.0) - val binaryOnePredictions = - binaryModel.transform(smallBinaryDataset).select("prediction").collect() - assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0)) - + testTransformer[(Double, Vector)](smallBinaryDataset.toDF(), binaryModel, "prediction") { + row => assert(row.getDouble(0) === 1.0) + } val mlr = new LogisticRegression().setFamily("multinomial") val model = mlr.fit(smallMultinomialDataset) @@ -348,31 +345,36 @@ class LogisticRegressionSuite // should predict all zeros model.setThresholds(Array(1, 1000, 1000)) - val zeroPredictions = model.transform(smallMultinomialDataset).select("prediction").collect() - assert(zeroPredictions.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") { + row => assert(row.getDouble(0) === 0.0) + } // should predict all ones model.setThresholds(Array(1000, 1, 1000)) - val onePredictions = model.transform(smallMultinomialDataset).select("prediction").collect() - assert(onePredictions.forall(_.getDouble(0) === 1.0)) + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") { + row => assert(row.getDouble(0) === 1.0) + } // should predict all twos model.setThresholds(Array(1000, 1000, 1)) - val twoPredictions = model.transform(smallMultinomialDataset).select("prediction").collect() - assert(twoPredictions.forall(_.getDouble(0) === 2.0)) + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") { + row => assert(row.getDouble(0) === 2.0) + } // constant threshold scaling is the same as no thresholds model.setThresholds(Array(1000, 1000, 1000)) - val scaledPredictions = model.transform(smallMultinomialDataset).select("prediction").collect() - assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => - scaled.getDouble(0) === base.getDouble(0) - }) + testTransformerByGlobalCheckFunc[(Double, Vector)](smallMultinomialDataset.toDF(), model, + "prediction") { scaledPredictions: Seq[Row] => + assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => + scaled.getDouble(0) === base.getDouble(0) + }) + } // force it to use the predict method model.setRawPredictionCol("").setProbabilityCol("").setThresholds(Array(0, 1, 1)) - val predictionsWithPredict = - model.transform(smallMultinomialDataset).select("prediction").collect() - assert(predictionsWithPredict.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") { + row => assert(row.getDouble(0) === 0.0) + } } test("logistic regression doesn't fit intercept when fitIntercept is off") { @@ -403,21 +405,19 @@ class LogisticRegressionSuite // Modify model params, and check that the params worked. model.setThreshold(1.0) - val predAllZero = model.transform(smallBinaryDataset) - .select("prediction", "myProbability") - .collect() - .map { case Row(pred: Double, prob: Vector) => pred } - assert(predAllZero.forall(_ === 0), - s"With threshold=1.0, expected predictions to be all 0, but only" + - s" ${predAllZero.count(_ === 0)} of ${smallBinaryDataset.count()} were 0.") + testTransformerByGlobalCheckFunc[(Double, Vector)](smallBinaryDataset.toDF(), + model, "prediction", "myProbability") { rows => + val predAllZero = rows.map(_.getDouble(0)) + assert(predAllZero.forall(_ === 0), + s"With threshold=1.0, expected predictions to be all 0, but only" + + s" ${predAllZero.count(_ === 0)} of ${smallBinaryDataset.count()} were 0.") + } // Call transform with params, and check that the params worked. - val predNotAllZero = - model.transform(smallBinaryDataset, model.threshold -> 0.0, - model.probabilityCol -> "myProb") - .select("prediction", "myProb") - .collect() - .map { case Row(pred: Double, prob: Vector) => pred } - assert(predNotAllZero.exists(_ !== 0.0)) + testTransformerByGlobalCheckFunc[(Double, Vector)](smallBinaryDataset.toDF(), + model.copy(ParamMap(model.threshold -> 0.0, + model.probabilityCol -> "myProb")), "prediction", "myProb") { + rows => assert(rows.map(_.getDouble(0)).exists(_ !== 0.0)) + } // Call fit() with new params, and check as many params as we can. lr.setThresholds(Array(0.6, 0.4)) @@ -441,10 +441,10 @@ class LogisticRegressionSuite val numFeatures = smallMultinomialDataset.select("features").first().getAs[Vector](0).size assert(model.numFeatures === numFeatures) - val results = model.transform(smallMultinomialDataset) - // check that raw prediction is coefficients dot features + intercept - results.select("rawPrediction", "features").collect().foreach { - case Row(raw: Vector, features: Vector) => + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), + model, "rawPrediction", "features", "probability") { + case Row(raw: Vector, features: Vector, prob: Vector) => + // check that raw prediction is coefficients dot features + intercept assert(raw.size === 3) val margins = Array.tabulate(3) { k => var margin = 0.0 @@ -455,12 +455,7 @@ class LogisticRegressionSuite margin } assert(raw ~== Vectors.dense(margins) relTol eps) - } - - // Compare rawPrediction with probability - results.select("rawPrediction", "probability").collect().foreach { - case Row(raw: Vector, prob: Vector) => - assert(raw.size === 3) + // Compare rawPrediction with probability assert(prob.size === 3) val max = raw.toArray.max val subtract = if (max > 0) max else 0.0 @@ -472,39 +467,8 @@ class LogisticRegressionSuite assert(prob(2) ~== 1.0 - probFromRaw1 - probFromRaw0 relTol eps) } - // Compare prediction with probability - results.select("prediction", "probability").collect().foreach { - case Row(pred: Double, prob: Vector) => - val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 - assert(pred == predFromProb) - } - - // force it to use raw2prediction - model.setRawPredictionCol("rawPrediction").setProbabilityCol("") - val resultsUsingRaw2Predict = - model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() - resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use probability2prediction - model.setRawPredictionCol("").setProbabilityCol("probability") - val resultsUsingProb2Predict = - model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() - resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use predict - model.setRawPredictionCol("").setProbabilityCol("") - val resultsUsingPredict = - model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() - resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - ProbabilisticClassifierSuite.testPredictMethods[ - Vector, LogisticRegressionModel](model, smallMultinomialDataset) + Vector, LogisticRegressionModel](this, model, smallMultinomialDataset) } test("binary logistic regression: Predictor, Classifier methods") { @@ -517,51 +481,22 @@ class LogisticRegressionSuite val numFeatures = smallBinaryDataset.select("features").first().getAs[Vector](0).size assert(model.numFeatures === numFeatures) - val results = model.transform(smallBinaryDataset) - - // Compare rawPrediction with probability - results.select("rawPrediction", "probability").collect().foreach { - case Row(raw: Vector, prob: Vector) => + testTransformer[(Double, Vector)](smallBinaryDataset.toDF(), + model, "rawPrediction", "probability", "prediction") { + case Row(raw: Vector, prob: Vector, pred: Double) => + // Compare rawPrediction with probability assert(raw.size === 2) assert(prob.size === 2) val probFromRaw1 = 1.0 / (1.0 + math.exp(-raw(1))) assert(prob(1) ~== probFromRaw1 relTol eps) assert(prob(0) ~== 1.0 - probFromRaw1 relTol eps) - } - - // Compare prediction with probability - results.select("prediction", "probability").collect().foreach { - case Row(pred: Double, prob: Vector) => + // Compare prediction with probability val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 assert(pred == predFromProb) } - // force it to use raw2prediction - model.setRawPredictionCol("rawPrediction").setProbabilityCol("") - val resultsUsingRaw2Predict = - model.transform(smallBinaryDataset).select("prediction").as[Double].collect() - resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use probability2prediction - model.setRawPredictionCol("").setProbabilityCol("probability") - val resultsUsingProb2Predict = - model.transform(smallBinaryDataset).select("prediction").as[Double].collect() - resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use predict - model.setRawPredictionCol("").setProbabilityCol("") - val resultsUsingPredict = - model.transform(smallBinaryDataset).select("prediction").as[Double].collect() - resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - ProbabilisticClassifierSuite.testPredictMethods[ - Vector, LogisticRegressionModel](model, smallBinaryDataset) + Vector, LogisticRegressionModel](this, model, smallBinaryDataset) } test("coefficients and intercept methods") { @@ -616,19 +551,21 @@ class LogisticRegressionSuite LabeledPoint(1.0, Vectors.dense(0.0, 1000.0)), LabeledPoint(1.0, Vectors.dense(0.0, -1.0)) ).toDF() - val results = model.transform(overFlowData).select("rawPrediction", "probability").collect() - - // probabilities are correct when margins have to be adjusted - val raw1 = results(0).getAs[Vector](0) - val prob1 = results(0).getAs[Vector](1) - assert(raw1 === Vectors.dense(1000.0, 2000.0, 3000.0)) - assert(prob1 ~== Vectors.dense(0.0, 0.0, 1.0) absTol eps) - - // probabilities are correct when margins don't have to be adjusted - val raw2 = results(1).getAs[Vector](0) - val prob2 = results(1).getAs[Vector](1) - assert(raw2 === Vectors.dense(-1.0, -2.0, -3.0)) - assert(prob2 ~== Vectors.dense(0.66524096, 0.24472847, 0.09003057) relTol eps) + + testTransformerByGlobalCheckFunc[(Double, Vector)](overFlowData.toDF(), + model, "rawPrediction", "probability") { results: Seq[Row] => + // probabilities are correct when margins have to be adjusted + val raw1 = results(0).getAs[Vector](0) + val prob1 = results(0).getAs[Vector](1) + assert(raw1 === Vectors.dense(1000.0, 2000.0, 3000.0)) + assert(prob1 ~== Vectors.dense(0.0, 0.0, 1.0) absTol eps) + + // probabilities are correct when margins don't have to be adjusted + val raw2 = results(1).getAs[Vector](0) + val prob2 = results(1).getAs[Vector](1) + assert(raw2 === Vectors.dense(-1.0, -2.0, -3.0)) + assert(prob2 ~== Vectors.dense(0.66524096, 0.24472847, 0.09003057) relTol eps) + } } test("MultiClassSummarizer") { @@ -2567,10 +2504,13 @@ class LogisticRegressionSuite val model1 = lr.fit(smallBinaryDataset) val lr2 = new LogisticRegression().setInitialModel(model1).setMaxIter(5).setFamily("binomial") val model2 = lr2.fit(smallBinaryDataset) - val predictions1 = model1.transform(smallBinaryDataset).select("prediction").collect() - val predictions2 = model2.transform(smallBinaryDataset).select("prediction").collect() - predictions1.zip(predictions2).foreach { case (Row(p1: Double), Row(p2: Double)) => - assert(p1 === p2) + val binaryExpected = model1.transform(smallBinaryDataset).select("prediction").collect() + .map(_.getDouble(0)) + for (model <- Seq(model1, model2)) { + testTransformerByGlobalCheckFunc[(Double, Vector)](smallBinaryDataset.toDF(), model, + "prediction") { rows: Seq[Row] => + rows.map(_.getDouble(0)).toArray === binaryExpected + } } assert(model2.summary.totalIterations === 1) @@ -2579,10 +2519,13 @@ class LogisticRegressionSuite val lr4 = new LogisticRegression() .setInitialModel(model3).setMaxIter(5).setFamily("multinomial") val model4 = lr4.fit(smallMultinomialDataset) - val predictions3 = model3.transform(smallMultinomialDataset).select("prediction").collect() - val predictions4 = model4.transform(smallMultinomialDataset).select("prediction").collect() - predictions3.zip(predictions4).foreach { case (Row(p1: Double), Row(p2: Double)) => - assert(p1 === p2) + val multinomialExpected = model3.transform(smallMultinomialDataset).select("prediction") + .collect().map(_.getDouble(0)) + for (model <- Seq(model3, model4)) { + testTransformerByGlobalCheckFunc[(Double, Vector)](smallMultinomialDataset.toDF(), model, + "prediction") { rows: Seq[Row] => + rows.map(_.getDouble(0)).toArray === multinomialExpected + } } assert(model4.summary.totalIterations === 1) } @@ -2638,8 +2581,8 @@ class LogisticRegressionSuite LabeledPoint(4.0, Vectors.dense(2.0))).toDF() val mlr = new LogisticRegression().setFamily("multinomial") val model = mlr.fit(constantData) - val results = model.transform(constantData) - results.select("rawPrediction", "probability", "prediction").collect().foreach { + testTransformer[(Double, Vector)](constantData, model, + "rawPrediction", "probability", "prediction") { case Row(raw: Vector, prob: Vector, pred: Double) => assert(raw === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, Double.PositiveInfinity))) assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0))) @@ -2653,8 +2596,8 @@ class LogisticRegressionSuite LabeledPoint(0.0, Vectors.dense(1.0)), LabeledPoint(0.0, Vectors.dense(2.0))).toDF() val modelZeroLabel = mlr.setFitIntercept(false).fit(constantZeroData) - val resultsZero = modelZeroLabel.transform(constantZeroData) - resultsZero.select("rawPrediction", "probability", "prediction").collect().foreach { + testTransformer[(Double, Vector)](constantZeroData, modelZeroLabel, + "rawPrediction", "probability", "prediction") { case Row(raw: Vector, prob: Vector, pred: Double) => assert(prob === Vectors.dense(Array(1.0))) assert(pred === 0.0) @@ -2666,8 +2609,8 @@ class LogisticRegressionSuite val constantDataWithMetadata = constantData .select(constantData("label").as("label", labelMeta), constantData("features")) val modelWithMetadata = mlr.setFitIntercept(true).fit(constantDataWithMetadata) - val resultsWithMetadata = modelWithMetadata.transform(constantDataWithMetadata) - resultsWithMetadata.select("rawPrediction", "probability", "prediction").collect().foreach { + testTransformer[(Double, Vector)](constantDataWithMetadata, modelWithMetadata, + "rawPrediction", "probability", "prediction") { case Row(raw: Vector, prob: Vector, pred: Double) => assert(raw === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, Double.PositiveInfinity, 0.0))) assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0, 0.0))) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index d3141ec708560..daa58a56896d7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -17,22 +17,17 @@ package org.apache.spark.ml.classification -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} -import org.apache.spark.sql.functions._ -class MultilayerPerceptronClassifierSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class MultilayerPerceptronClassifierSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -75,11 +70,9 @@ class MultilayerPerceptronClassifierSuite .setMaxIter(100) .setSolver("l-bfgs") val model = trainer.fit(dataset) - val result = model.transform(dataset) MLTestingUtils.checkCopyAndUids(trainer, model) - val predictionAndLabels = result.select("prediction", "label").collect() - predictionAndLabels.foreach { case Row(p: Double, l: Double) => - assert(p == l) + testTransformer[(Vector, Double)](dataset.toDF(), model, "prediction", "label") { + case Row(p: Double, l: Double) => assert(p == l) } } @@ -99,13 +92,12 @@ class MultilayerPerceptronClassifierSuite .setMaxIter(100) .setSolver("l-bfgs") val model = trainer.fit(strongDataset) - val result = model.transform(strongDataset) - result.select("probability", "expectedProbability").collect().foreach { - case Row(p: Vector, e: Vector) => - assert(p ~== e absTol 1e-3) + testTransformer[(Vector, Double, Vector)](strongDataset.toDF(), model, + "probability", "expectedProbability") { + case Row(p: Vector, e: Vector) => assert(p ~== e absTol 1e-3) } ProbabilisticClassifierSuite.testPredictMethods[ - Vector, MultilayerPerceptronClassificationModel](model, strongDataset) + Vector, MultilayerPerceptronClassificationModel](this, model, strongDataset) } test("test model probability") { @@ -118,11 +110,10 @@ class MultilayerPerceptronClassifierSuite .setSolver("l-bfgs") val model = trainer.fit(dataset) model.setProbabilityCol("probability") - val result = model.transform(dataset) - val features2prob = udf { features: Vector => model.mlpModel.predict(features) } - result.select(features2prob(col("features")), col("probability")).collect().foreach { - case Row(p1: Vector, p2: Vector) => - assert(p1 ~== p2 absTol 1e-3) + testTransformer[(Vector, Double)](dataset.toDF(), model, "features", "probability") { + case Row(features: Vector, prob: Vector) => + val prob2 = model.mlpModel.predict(features) + assert(prob ~== prob2 absTol 1e-3) } } @@ -175,9 +166,6 @@ class MultilayerPerceptronClassifierSuite val model = trainer.fit(dataFrame) val numFeatures = dataFrame.select("features").first().getAs[Vector](0).size assert(model.numFeatures === numFeatures) - val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label").rdd.map { - case Row(p: Double, l: Double) => (p, l) - } // train multinomial logistic regression val lr = new LogisticRegressionWithLBFGS() .setIntercept(true) @@ -189,8 +177,12 @@ class MultilayerPerceptronClassifierSuite lrModel.predict(data.rdd.map(p => OldVectors.fromML(p.features))).zip(data.rdd.map(_.label)) // MLP's predictions should not differ a lot from LR's. val lrMetrics = new MulticlassMetrics(lrPredictionAndLabels) - val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels) - assert(mlpMetrics.confusionMatrix.asML ~== lrMetrics.confusionMatrix.asML absTol 100) + testTransformerByGlobalCheckFunc[(Double, Vector)](dataFrame, model, "prediction", "label") { + rows: Seq[Row] => + val mlpPredictionAndLabels = rows.map(x => (x.getDouble(0), x.getDouble(1))) + val mlpMetrics = new MulticlassMetrics(sc.makeRDD(mlpPredictionAndLabels)) + assert(mlpMetrics.confusionMatrix.asML ~== lrMetrics.confusionMatrix.asML absTol 100) + } } test("read/write: MultilayerPerceptronClassifier") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 0d3adf993383f..49115c8a4db30 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -28,12 +28,11 @@ import org.apache.spark.ml.classification.NaiveBayesSuite._ import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, Row} -class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class NaiveBayesSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -56,13 +55,13 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa bernoulliDataset = generateNaiveBayesInput(pi, theta, 100, seed, "bernoulli").toDF() } - def validatePrediction(predictionAndLabels: DataFrame): Unit = { - val numOfErrorPredictions = predictionAndLabels.collect().count { + def validatePrediction(predictionAndLabels: Seq[Row]): Unit = { + val numOfErrorPredictions = predictionAndLabels.filter { case Row(prediction: Double, label: Double) => prediction != label - } + }.length // At least 80% of the predictions should be on. - assert(numOfErrorPredictions < predictionAndLabels.count() / 5) + assert(numOfErrorPredictions < predictionAndLabels.length / 5) } def validateModelFit( @@ -92,10 +91,10 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } def validateProbabilities( - featureAndProbabilities: DataFrame, + featureAndProbabilities: Seq[Row], model: NaiveBayesModel, modelType: String): Unit = { - featureAndProbabilities.collect().foreach { + featureAndProbabilities.foreach { case Row(features: Vector, probability: Vector) => assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10) val expected = modelType match { @@ -154,15 +153,18 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val validationDataset = generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF() - val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") - validatePrediction(predictionAndLabels) + testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, + "prediction", "label") { predictionAndLabels: Seq[Row] => + validatePrediction(predictionAndLabels) + } - val featureAndProbabilities = model.transform(validationDataset) - .select("features", "probability") - validateProbabilities(featureAndProbabilities, model, "multinomial") + testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, + "features", "probability") { featureAndProbabilities: Seq[Row] => + validateProbabilities(featureAndProbabilities, model, "multinomial") + } ProbabilisticClassifierSuite.testPredictMethods[ - Vector, NaiveBayesModel](model, testDataset) + Vector, NaiveBayesModel](this, model, testDataset) } test("Naive Bayes with weighted samples") { @@ -210,15 +212,18 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val validationDataset = generateNaiveBayesInput(piArray, thetaArray, nPoints, 20, "bernoulli").toDF() - val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") - validatePrediction(predictionAndLabels) + testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, + "prediction", "label") { predictionAndLabels: Seq[Row] => + validatePrediction(predictionAndLabels) + } - val featureAndProbabilities = model.transform(validationDataset) - .select("features", "probability") - validateProbabilities(featureAndProbabilities, model, "bernoulli") + testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, + "features", "probability") { featureAndProbabilities: Seq[Row] => + validateProbabilities(featureAndProbabilities, model, "bernoulli") + } ProbabilisticClassifierSuite.testPredictMethods[ - Vector, NaiveBayesModel](model, testDataset) + Vector, NaiveBayesModel](this, model, testDataset) } test("detect negative values") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 25bad59b9c9cf..11e88367108b4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -17,26 +17,24 @@ package org.apache.spark.ml.classification -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.StringIndexer -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.Metadata -class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class OneVsRestSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -85,10 +83,6 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol) assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3)) - val ovaResults = transformedDataset.select("prediction", "label").rdd.map { - row => (row.getDouble(0), row.getDouble(1)) - } - val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses) lr.optimizer.setRegParam(0.1).setNumIterations(100) @@ -97,8 +91,13 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau // determine the #confusion matrix in each class. // bound how much error we allow compared to multinomial logistic regression. val expectedMetrics = new MulticlassMetrics(results) - val ovaMetrics = new MulticlassMetrics(ovaResults) - assert(expectedMetrics.confusionMatrix.asML ~== ovaMetrics.confusionMatrix.asML absTol 400) + + testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), ovaModel, + "prediction", "label") { rows => + val ovaResults = rows.map { row => (row.getDouble(0), row.getDouble(1)) } + val ovaMetrics = new MulticlassMetrics(sc.makeRDD(ovaResults)) + assert(expectedMetrics.confusionMatrix.asML ~== ovaMetrics.confusionMatrix.asML absTol 400) + } } test("one-vs-rest: tuning parallelism does not change output") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index d649ceac949c4..1c8c9829f18d1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.MLTest import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.sql.{Dataset, Row} @@ -122,13 +123,15 @@ object ProbabilisticClassifierSuite { def testPredictMethods[ FeaturesType, M <: ProbabilisticClassificationModel[FeaturesType, M]]( - model: M, testData: Dataset[_]): Unit = { + mlTest: MLTest, model: M, testData: Dataset[_]): Unit = { val allColModel = model.copy(ParamMap.empty) .setRawPredictionCol("rawPredictionAll") .setProbabilityCol("probabilityAll") .setPredictionCol("predictionAll") - val allColResult = allColModel.transform(testData) + + val allColResult = allColModel.transform(testData.select(allColModel.getFeaturesCol)) + .select(allColModel.getFeaturesCol, "rawPredictionAll", "probabilityAll", "predictionAll") for (rawPredictionCol <- Seq("", "rawPredictionSingle")) { for (probabilityCol <- Seq("", "probabilitySingle")) { @@ -138,22 +141,14 @@ object ProbabilisticClassifierSuite { .setProbabilityCol(probabilityCol) .setPredictionCol(predictionCol) - val result = newModel.transform(allColResult) - - import org.apache.spark.sql.functions._ - - val resultRawPredictionCol = - if (rawPredictionCol.isEmpty) col("rawPredictionAll") else col(rawPredictionCol) - val resultProbabilityCol = - if (probabilityCol.isEmpty) col("probabilityAll") else col(probabilityCol) - val resultPredictionCol = - if (predictionCol.isEmpty) col("predictionAll") else col(predictionCol) + import allColResult.sparkSession.implicits._ - result.select( - resultRawPredictionCol, col("rawPredictionAll"), - resultProbabilityCol, col("probabilityAll"), - resultPredictionCol, col("predictionAll") - ).collect().foreach { + mlTest.testTransformer[(Vector, Vector, Vector, Double)](allColResult, newModel, + if (rawPredictionCol.isEmpty) "rawPredictionAll" else rawPredictionCol, + "rawPredictionAll", + if (probabilityCol.isEmpty) "probabilityAll" else probabilityCol, "probabilityAll", + if (predictionCol.isEmpty) "predictionAll" else predictionCol, "predictionAll" + ) { case Row( rawPredictionSingle: Vector, rawPredictionAll: Vector, probabilitySingle: Vector, probabilityAll: Vector, diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 2cca2e6c04698..02a9d5c2a18c0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -23,11 +23,10 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} @@ -35,8 +34,7 @@ import org.apache.spark.sql.{DataFrame, Row} /** * Test suite for [[RandomForestClassifier]]. */ -class RandomForestClassifierSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { import RandomForestClassifierSuite.compareAPIs import testImplicits._ @@ -143,11 +141,8 @@ class RandomForestClassifierSuite MLTestingUtils.checkCopyAndUids(rf, model) - val predictions = model.transform(df) - .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol) - .collect() - - predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) => + testTransformer[(Vector, Double)](df, model, "prediction", "rawPrediction", + "probability") { case Row(pred: Double, rawPred: Vector, probPred: Vector) => assert(pred === rawPred.argmax, s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") val sum = rawPred.toArray.sum @@ -155,8 +150,9 @@ class RandomForestClassifierSuite "probability prediction mismatch") assert(probPred.toArray.sum ~== 1.0 relTol 1E-5) } + ProbabilisticClassifierSuite.testPredictMethods[ - Vector, RandomForestClassificationModel](model, df) + Vector, RandomForestClassificationModel](this, model, df) } test("Fitting without numClasses in metadata") { From ba622f45caa808a9320c1f7ba4a4f344365dcf90 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 5 Mar 2018 20:43:03 +0100 Subject: [PATCH 430/774] [SPARK-23585][SQL] Add interpreted execution to UnwrapOption ## What changes were proposed in this pull request? The PR adds interpreted execution to UnwrapOption. ## How was this patch tested? added UT Author: Marco Gaido Closes #20736 from mgaido91/SPARK-23586. --- .../sql/catalyst/expressions/objects/objects.scala | 10 ++++++++-- .../catalyst/expressions/ObjectExpressionsSuite.scala | 11 ++++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 80618af1e859f..03cc8eaceb4e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -382,8 +382,14 @@ case class UnwrapOption( override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override def eval(input: InternalRow): Any = { + val inputObject = child.eval(input) + if (inputObject == null) { + null + } else { + inputObject.asInstanceOf[Option[_]].orNull + } + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = CodeGenerator.javaType(dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 3edcc02f15264..d95db5867b19c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, UnwrapOption} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types.{IntegerType, ObjectType} @@ -66,4 +66,13 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvalutionWithUnsafeProjection( mapEncoder.serializer.head, mapExpected, mapInputRow) } + + test("SPARK-23585: UnwrapOption should support interpreted execution") { + val cls = classOf[Option[Int]] + val inputObject = BoundReference(0, ObjectType(cls), nullable = true) + val unwrapObject = UnwrapOption(IntegerType, inputObject) + Seq((Some(1), 1), (None, null), (null, null)).foreach { case (input, expected) => + checkEvaluation(unwrapObject, expected, InternalRow.fromSeq(Seq(input))) + } + } } From b0f422c3861a5a3831e481b8ffac08f6fa085d00 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 5 Mar 2018 13:23:01 -0800 Subject: [PATCH 431/774] [SPARK-23559][SS] Add epoch ID to DataWriterFactory. ## What changes were proposed in this pull request? Add an epoch ID argument to DataWriterFactory for use in streaming. As a side effect of passing in this value, DataWriter will now have a consistent lifecycle; commit() or abort() ends the lifecycle of a DataWriter instance in any execution mode. I considered making a separate streaming interface and adding the epoch ID only to that one, but I think it requires a lot of extra work for no real gain. I think it makes sense to define epoch 0 as the one and only epoch of a non-streaming query. ## How was this patch tested? existing unit tests Author: Jose Torres Closes #20710 from jose-torres/api2. --- .../sql/kafka010/KafkaStreamWriter.scala | 5 +++- .../sql/sources/v2/writer/DataWriter.java | 12 ++++++--- .../sources/v2/writer/DataWriterFactory.java | 5 +++- .../v2/writer/streaming/StreamWriter.java | 19 +++++++------- .../datasources/v2/WriteToDataSourceV2.scala | 25 +++++++++++++------ .../streaming/MicroBatchExecution.scala | 7 ++++++ .../sources/PackedRowWriterFactory.scala | 5 +++- .../streaming/sources/memoryV2.scala | 5 +++- .../sources/v2/SimpleWritableDataSource.scala | 10 ++++++-- 9 files changed, 65 insertions(+), 28 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index 9307bfc001c03..ae5b5c52d514e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -65,7 +65,10 @@ case class KafkaStreamWriterFactory( topic: Option[String], producerParams: Map[String, String], schema: StructType) extends DataWriterFactory[InternalRow] { - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[InternalRow] = { new KafkaStreamDataWriter(topic, producerParams, schema.toAttributes) } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 53941a89ba94e..39bf458298862 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data writer returned by {@link DataWriterFactory#createDataWriter(int, int)} and is + * A data writer returned by {@link DataWriterFactory#createDataWriter(int, int, long)} and is * responsible for writing data for an input RDD partition. * * One Spark task has one exclusive data writer, so there is no thread-safe concern. @@ -31,13 +31,17 @@ * the {@link #write(Object)}, {@link #abort()} is called afterwards and the remaining records will * not be processed. If all records are successfully written, {@link #commit()} is called. * + * Once a data writer returns successfully from {@link #commit()} or {@link #abort()}, its lifecycle + * is over and Spark will not use it again. + * * If this data writer succeeds(all records are successfully written and {@link #commit()} * succeeds), a {@link WriterCommitMessage} will be sent to the driver side and pass to * {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an - * exception will be sent to the driver side, and Spark will retry this writing task for some times, - * each time {@link DataWriterFactory#createDataWriter(int, int)} gets a different `attemptNumber`, - * and finally call {@link DataSourceWriter#abort(WriterCommitMessage[])} if all retry fail. + * exception will be sent to the driver side, and Spark may retry this writing task a few times. + * In each retry, {@link DataWriterFactory#createDataWriter(int, int, long)} will receive a + * different `attemptNumber`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])} + * when the configured number of retries is exhausted. * * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task * takes too long to finish. Different from retried tasks, which are launched one by one after the diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index ea95442511ce5..c2c2ab73257e8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -48,6 +48,9 @@ public interface DataWriterFactory extends Serializable { * same task id but different attempt number, which means there are multiple * tasks with the same task id running at the same time. Implementations can * use this attempt number to distinguish writers of different task attempts. + * @param epochId A monotonically increasing id for streaming queries that are split in to + * discrete periods of execution. For non-streaming queries, + * this ID will always be 0. */ - DataWriter createDataWriter(int partitionId, int attemptNumber); + DataWriter createDataWriter(int partitionId, int attemptNumber, long epochId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java index 4913341bd505d..a316b2a4c1d82 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java @@ -23,11 +23,10 @@ import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; /** - * A {@link DataSourceWriter} for use with structured streaming. This writer handles commits and - * aborts relative to an epoch ID determined by the execution engine. + * A {@link DataSourceWriter} for use with structured streaming. * - * {@link DataWriter} implementations generated by a StreamWriter may be reused for multiple epochs, - * and so must reset any internal state after a successful commit. + * Streaming queries are divided into intervals of data called epochs, with a monotonically + * increasing numeric ID. This writer handles commits and aborts for each successive epoch. */ @InterfaceStability.Evolving public interface StreamWriter extends DataSourceWriter { @@ -39,21 +38,21 @@ public interface StreamWriter extends DataSourceWriter { * If this method fails (by throwing an exception), this writing job is considered to have been * failed, and the execution engine will attempt to call {@link #abort(WriterCommitMessage[])}. * - * To support exactly-once processing, writer implementations should ensure that this method is - * idempotent. The execution engine may call commit() multiple times for the same epoch - * in some circumstances. + * The execution engine may call commit() multiple times for the same epoch in some circumstances. + * To support exactly-once data semantics, implementations must ensure that multiple commits for + * the same epoch are idempotent. */ void commit(long epochId, WriterCommitMessage[] messages); /** - * Aborts this writing job because some data writers are failed and keep failing when retry, or + * Aborts this writing job because some data writers are failed and keep failing when retried, or * the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} fails. * * If this method fails (by throwing an exception), the underlying data source may require manual * cleanup. * - * Unless the abort is triggered by the failure of commit, the given messages should have some - * null slots as there maybe only a few data writers that are committed before the abort + * Unless the abort is triggered by the failure of commit, the given messages will have some + * null slots, as there may be only a few data writers that were committed before the abort * happens, or some data writers were committed but their commit messages haven't reached the * driver when the abort is triggered. So this is just a "best effort" for data sources to * clean up the data left by data writers. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 41cdfc80d8a19..e80b44c1cdc66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.{MicroBatchExecution, StreamExecution} import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -132,7 +132,8 @@ object DataWritingSparkTask extends Logging { val stageId = context.stageId() val partId = context.partitionId() val attemptId = context.attemptNumber() - val dataWriter = writeTask.createDataWriter(partId, attemptId) + val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0") + val dataWriter = writeTask.createDataWriter(partId, attemptId, epochId.toLong) // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { @@ -172,7 +173,6 @@ object DataWritingSparkTask extends Logging { writeTask: DataWriterFactory[InternalRow], context: TaskContext, iter: Iterator[InternalRow]): WriterCommitMessage = { - val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber()) val epochCoordinator = EpochCoordinatorRef.get( context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) @@ -180,10 +180,15 @@ object DataWritingSparkTask extends Logging { var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong do { + var dataWriter: DataWriter[InternalRow] = null // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { try { - iter.foreach(dataWriter.write) + dataWriter = writeTask.createDataWriter( + context.partitionId(), context.attemptNumber(), currentEpoch) + while (iter.hasNext) { + dataWriter.write(iter.next()) + } logInfo(s"Writer for partition ${context.partitionId()} is committing.") val msg = dataWriter.commit() logInfo(s"Writer for partition ${context.partitionId()} committed.") @@ -196,9 +201,10 @@ object DataWritingSparkTask extends Logging { // Continuous shutdown always involves an interrupt. Just finish the task. } })(catchBlock = { - // If there is an error, abort this writer + // If there is an error, abort this writer. We enter this callback in the middle of + // rethrowing an exception, so runContinuous will stop executing at this point. logError(s"Writer for partition ${context.partitionId()} is aborting.") - dataWriter.abort() + if (dataWriter != null) dataWriter.abort() logError(s"Writer for partition ${context.partitionId()} aborted.") }) } while (!context.isInterrupted()) @@ -211,9 +217,12 @@ class InternalRowDataWriterFactory( rowWriterFactory: DataWriterFactory[Row], schema: StructType) extends DataWriterFactory[InternalRow] { - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[InternalRow] = { new InternalRowDataWriter( - rowWriterFactory.createDataWriter(partitionId, attemptNumber), + rowWriterFactory.createDataWriter(partitionId, attemptNumber, epochId), RowEncoder.apply(schema).resolveAndBind()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 6bd03972c301d..ff4be9c7ab874 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -469,6 +469,9 @@ class MicroBatchExecution( case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } + sparkSessionToRunBatch.sparkContext.setLocalProperty( + MicroBatchExecution.BATCH_ID_KEY, currentBatchId.toString) + reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( sparkSessionToRunBatch, @@ -518,3 +521,7 @@ class MicroBatchExecution( Optional.ofNullable(scalaOption.orNull) } } + +object MicroBatchExecution { + val BATCH_ID_KEY = "streaming.sql.batchId" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index 248295e401a0d..e07355aa37dba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -31,7 +31,10 @@ import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, Dat * for production-quality sinks. It's intended for use in tests. */ case object PackedRowWriterFactory extends DataWriterFactory[Row] { - def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[Row] = { new PackedRowDataWriter() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index f960208155e3b..5f58246083bb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -147,7 +147,10 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) } case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] { - def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[Row] = { new MemoryDataWriter(partitionId, outputMode) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 36dd2a350a055..a5007fa321359 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -207,7 +207,10 @@ private[v2] object SimpleCounter { class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[Row] { - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[Row] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") val fs = filePath.getFileSystem(conf.value) @@ -240,7 +243,10 @@ class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[Row] { class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[InternalRow] { - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[InternalRow] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") val fs = filePath.getFileSystem(conf.value) From f2cab56ca22ed5db5ff604cd78cdb55aaa58f651 Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Mon, 5 Mar 2018 14:57:32 -0800 Subject: [PATCH 432/774] [SPARK-23040][CORE] Returns interruptible iterator for shuffle reader ## What changes were proposed in this pull request? Before this commit, a non-interruptible iterator is returned if aggregator or ordering is specified. This commit also ensures that sorter is closed even when task is cancelled(killed) in the middle of sorting. ## How was this patch tested? Add a unit test in JobCancellationSuite Author: Xianjin YE Closes #20449 from advancedxy/SPARK-23040. --- .../shuffle/BlockStoreShuffleReader.scala | 9 ++- .../apache/spark/JobCancellationSuite.scala | 65 ++++++++++++++++++- 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index edd69715c9602..85e7e56a04a7d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -94,7 +94,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( } // Sort the output if there is a sort ordering defined. - dep.keyOrdering match { + val resultIter = dep.keyOrdering match { case Some(keyOrd: Ordering[K]) => // Create an ExternalSorter to sort the data. val sorter = @@ -103,9 +103,16 @@ private[spark] class BlockStoreShuffleReader[K, C]( context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) + // Use completion callback to stop sorter if task was finished/cancelled. + context.addTaskCompletionListener(_ => { + sorter.stop() + }) CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) case None => aggregatedIter } + // Use another interruptible iterator here to support task cancellation as aggregator or(and) + // sorter may have consumed previous interruptible iterator. + new InterruptibleIterator[Product2[K, C]](context, resultIter) } } diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 8a77aea75a992..3b793bb231cf3 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.util.concurrent.Semaphore +import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future @@ -26,7 +27,7 @@ import scala.concurrent.duration._ import org.scalatest.BeforeAndAfter import org.scalatest.Matchers -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} +import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart} import org.apache.spark.util.ThreadUtils /** @@ -40,6 +41,10 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft override def afterEach() { try { resetSparkContext() + JobCancellationSuite.taskStartedSemaphore.drainPermits() + JobCancellationSuite.taskCancelledSemaphore.drainPermits() + JobCancellationSuite.twoJobsSharingStageSemaphore.drainPermits() + JobCancellationSuite.executionOfInterruptibleCounter.set(0) } finally { super.afterEach() } @@ -320,6 +325,62 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft f2.get() } + test("interruptible iterator of shuffle reader") { + // In this test case, we create a Spark job of two stages. The second stage is cancelled during + // execution and a counter is used to make sure that the corresponding tasks are indeed + // cancelled. + import JobCancellationSuite._ + sc = new SparkContext("local[2]", "test interruptible iterator") + + val taskCompletedSem = new Semaphore(0) + + sc.addSparkListener(new SparkListener { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + // release taskCancelledSemaphore when cancelTasks event has been posted + if (stageCompleted.stageInfo.stageId == 1) { + taskCancelledSemaphore.release(1000) + } + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + if (taskEnd.stageId == 1) { // make sure tasks are completed + taskCompletedSem.release() + } + } + }) + + val f = sc.parallelize(1 to 1000).map { i => (i, i) } + .repartitionAndSortWithinPartitions(new HashPartitioner(1)) + .mapPartitions { iter => + taskStartedSemaphore.release() + iter + }.foreachAsync { x => + if (x._1 >= 10) { + // This block of code is partially executed. It will be blocked when x._1 >= 10 and the + // next iteration will be cancelled if the source iterator is interruptible. Then in this + // case, the maximum num of increment would be 10(|1...10|) + taskCancelledSemaphore.acquire() + } + executionOfInterruptibleCounter.getAndIncrement() + } + + taskStartedSemaphore.acquire() + // Job is cancelled when: + // 1. task in reduce stage has been started, guaranteed by previous line. + // 2. task in reduce stage is blocked after processing at most 10 records as + // taskCancelledSemaphore is not released until cancelTasks event is posted + // After job being cancelled, task in reduce stage will be cancelled and no more iteration are + // executed. + f.cancel() + + val e = intercept[SparkException](f.get()).getCause + assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) + + // Make sure tasks are indeed completed. + taskCompletedSem.acquire() + assert(executionOfInterruptibleCounter.get() <= 10) + } + def testCount() { // Cancel before launching any tasks { @@ -381,7 +442,9 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft object JobCancellationSuite { + // To avoid any headaches, reset these global variables in the companion class's afterEach block val taskStartedSemaphore = new Semaphore(0) val taskCancelledSemaphore = new Semaphore(0) val twoJobsSharingStageSemaphore = new Semaphore(0) + val executionOfInterruptibleCounter = new AtomicInteger(0) } From 508573958dc9b6402e684cd6dd37202deaaa97f6 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 5 Mar 2018 15:03:27 -0800 Subject: [PATCH 433/774] [SPARK-23538][CORE] Remove custom configuration for SSL client. These options were used to configure the built-in JRE SSL libraries when downloading files from HTTPS servers. But because they were also used to set up the now (long) removed internal HTTPS file server, their default configuration chose convenience over security by having overly lenient settings. This change removes the configuration options that affect the JRE SSL libraries. The JRE trust store can still be configured via system properties (or globally in the JRE security config). The only lost functionality is not being able to disable the default hostname verifier when using spark-submit, which should be fine since Spark itself is not using https for any internal functionality anymore. I also removed the HTTP-related code from the REPL class loader, since we haven't had a HTTP server for REPL-generated classes for a while. Author: Marcelo Vanzin Closes #20723 from vanzin/SPARK-23538. --- .../org/apache/spark/SecurityManager.scala | 45 ------------ .../scala/org/apache/spark/util/Utils.scala | 15 ---- .../org/apache/spark/SSLSampleConfigs.scala | 68 ------------------- .../apache/spark/SecurityManagerSuite.scala | 45 ------------ docs/security.md | 4 -- .../spark/repl/ExecutorClassLoader.scala | 53 ++------------- 6 files changed, 7 insertions(+), 223 deletions(-) delete mode 100644 core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 2519d266879aa..da1c89cd78901 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -256,51 +256,6 @@ private[spark] class SecurityManager( // the default SSL configuration - it will be used by all communication layers unless overwritten private val defaultSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl", defaults = None) - // SSL configuration for the file server. This is used by Utils.setupSecureURLConnection(). - val fileServerSSLOptions = getSSLOptions("fs") - val (sslSocketFactory, hostnameVerifier) = if (fileServerSSLOptions.enabled) { - val trustStoreManagers = - for (trustStore <- fileServerSSLOptions.trustStore) yield { - val input = Files.asByteSource(fileServerSSLOptions.trustStore.get).openStream() - - try { - val ks = KeyStore.getInstance(KeyStore.getDefaultType) - ks.load(input, fileServerSSLOptions.trustStorePassword.get.toCharArray) - - val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm) - tmf.init(ks) - tmf.getTrustManagers - } finally { - input.close() - } - } - - lazy val credulousTrustStoreManagers = Array({ - logWarning("Using 'accept-all' trust manager for SSL connections.") - new X509TrustManager { - override def getAcceptedIssuers: Array[X509Certificate] = null - - override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String) {} - - override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String) {} - }: TrustManager - }) - - require(fileServerSSLOptions.protocol.isDefined, - "spark.ssl.protocol is required when enabling SSL connections.") - - val sslContext = SSLContext.getInstance(fileServerSSLOptions.protocol.get) - sslContext.init(null, trustStoreManagers.getOrElse(credulousTrustStoreManagers), null) - - val hostVerifier = new HostnameVerifier { - override def verify(s: String, sslSession: SSLSession): Boolean = true - } - - (Some(sslContext.getSocketFactory), Some(hostVerifier)) - } else { - (None, None) - } - def getSSLOptions(module: String): SSLOptions = { val opts = SSLOptions.parse(sparkConf, s"spark.ssl.$module", Some(defaultSSLOptions)) logDebug(s"Created SSL options for $module: $opts") diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index d493663f0b168..2e2a4a259e9af 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -673,7 +673,6 @@ private[spark] object Utils extends Logging { logDebug("fetchFile not using security") uc = new URL(url).openConnection() } - Utils.setupSecureURLConnection(uc, securityMgr) val timeoutMs = conf.getTimeAsSeconds("spark.files.fetchTimeout", "60s").toInt * 1000 @@ -2363,20 +2362,6 @@ private[spark] object Utils extends Logging { PropertyConfigurator.configure(pro) } - /** - * If the given URL connection is HttpsURLConnection, it sets the SSL socket factory and - * the host verifier from the given security manager. - */ - def setupSecureURLConnection(urlConnection: URLConnection, sm: SecurityManager): URLConnection = { - urlConnection match { - case https: HttpsURLConnection => - sm.sslSocketFactory.foreach(https.setSSLSocketFactory) - sm.hostnameVerifier.foreach(https.setHostnameVerifier) - https - case connection => connection - } - } - def invoke( clazz: Class[_], obj: AnyRef, diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala deleted file mode 100644 index 33270bec6247c..0000000000000 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.io.File - -object SSLSampleConfigs { - val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath - val untrustedKeyStorePath = new File( - this.getClass.getResource("/untrusted-keystore").toURI).getAbsolutePath - val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath - - val enabledAlgorithms = - // A reasonable set of TLSv1.2 Oracle security provider suites - "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " + - "TLS_RSA_WITH_AES_256_CBC_SHA256, " + - "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, " + - "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " + - "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, " + - // and their equivalent names in the IBM Security provider - "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " + - "SSL_RSA_WITH_AES_256_CBC_SHA256, " + - "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256, " + - "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " + - "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256" - - def sparkSSLConfig(): SparkConf = { - val conf = new SparkConf(loadDefaults = false) - conf.set("spark.ssl.enabled", "true") - conf.set("spark.ssl.keyStore", keyStorePath) - conf.set("spark.ssl.keyStorePassword", "password") - conf.set("spark.ssl.keyPassword", "password") - conf.set("spark.ssl.trustStore", trustStorePath) - conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms) - conf.set("spark.ssl.protocol", "TLSv1.2") - conf - } - - def sparkSSLConfigUntrusted(): SparkConf = { - val conf = new SparkConf(loadDefaults = false) - conf.set("spark.ssl.enabled", "true") - conf.set("spark.ssl.keyStore", untrustedKeyStorePath) - conf.set("spark.ssl.keyStorePassword", "password") - conf.set("spark.ssl.keyPassword", "password") - conf.set("spark.ssl.trustStore", trustStorePath) - conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms) - conf.set("spark.ssl.protocol", "TLSv1.2") - conf - } - -} diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index 106ece7aed0a4..e357299770a2e 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -370,51 +370,6 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { assert(securityManager.checkModifyPermissions("user1") === false) } - test("ssl on setup") { - val conf = SSLSampleConfigs.sparkSSLConfig() - val expectedAlgorithms = Set( - "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384", - "TLS_RSA_WITH_AES_256_CBC_SHA256", - "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256", - "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", - "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256", - "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384", - "SSL_RSA_WITH_AES_256_CBC_SHA256", - "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256", - "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256", - "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256") - - val securityManager = new SecurityManager(conf) - - assert(securityManager.fileServerSSLOptions.enabled === true) - - assert(securityManager.sslSocketFactory.isDefined === true) - assert(securityManager.hostnameVerifier.isDefined === true) - - assert(securityManager.fileServerSSLOptions.trustStore.isDefined === true) - assert(securityManager.fileServerSSLOptions.trustStore.get.getName === "truststore") - assert(securityManager.fileServerSSLOptions.keyStore.isDefined === true) - assert(securityManager.fileServerSSLOptions.keyStore.get.getName === "keystore") - assert(securityManager.fileServerSSLOptions.trustStorePassword === Some("password")) - assert(securityManager.fileServerSSLOptions.keyStorePassword === Some("password")) - assert(securityManager.fileServerSSLOptions.keyPassword === Some("password")) - assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1.2")) - assert(securityManager.fileServerSSLOptions.enabledAlgorithms === expectedAlgorithms) - } - - test("ssl off setup") { - val file = File.createTempFile("SSLOptionsSuite", "conf", Utils.createTempDir()) - - System.setProperty("spark.ssl.configFile", file.getAbsolutePath) - val conf = new SparkConf() - - val securityManager = new SecurityManager(conf) - - assert(securityManager.fileServerSSLOptions.enabled === false) - assert(securityManager.sslSocketFactory.isDefined === false) - assert(securityManager.hostnameVerifier.isDefined === false) - } - test("missing secret authentication key") { val conf = new SparkConf().set("spark.authenticate", "true") val mgr = new SecurityManager(conf) diff --git a/docs/security.md b/docs/security.md index 0f384b411812a..913d9df50eb1c 100644 --- a/docs/security.md +++ b/docs/security.md @@ -44,10 +44,6 @@ component-specific configuration namespaces used to override the default setting Config Namespace Component - - spark.ssl.fs - File download client (used to download jars and files from HTTPS-enabled servers). - spark.ssl.ui Spark application Web UI diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 127f67329f266..4dc399827ffed 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -17,12 +17,10 @@ package org.apache.spark.repl -import java.io.{ByteArrayOutputStream, FileNotFoundException, FilterInputStream, InputStream, IOException} -import java.net.{HttpURLConnection, URI, URL, URLEncoder} +import java.io.{ByteArrayOutputStream, FileNotFoundException, FilterInputStream, InputStream} +import java.net.{URI, URL, URLEncoder} import java.nio.channels.Channels -import scala.util.control.NonFatal - import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.xbean.asm5._ import org.apache.xbean.asm5.Opcodes._ @@ -30,13 +28,13 @@ import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.util.{ParentClassLoader, Utils} +import org.apache.spark.util.ParentClassLoader /** - * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI, used to load classes - * defined by the interpreter when the REPL is used. Allows the user to specify if user class path - * should be first. This class loader delegates getting/finding resources to parent loader, which - * makes sense until REPL never provide resource dynamically. + * A ClassLoader that reads classes from a Hadoop FileSystem or Spark RPC endpoint, used to load + * classes defined by the interpreter when the REPL is used. Allows the user to specify if user + * class path should be first. This class loader delegates getting/finding resources to parent + * loader, which makes sense until REPL never provide resource dynamically. * * Note: [[ClassLoader]] will preferentially load class from parent. Only when parent is null or * the load failed, that it will call the overridden `findClass` function. To avoid the potential @@ -60,7 +58,6 @@ class ExecutorClassLoader( private val fetchFn: (String) => InputStream = uri.getScheme() match { case "spark" => getClassFileInputStreamFromSparkRPC - case "http" | "https" | "ftp" => getClassFileInputStreamFromHttpServer case _ => val fileSystem = FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf)) getClassFileInputStreamFromFileSystem(fileSystem) @@ -113,42 +110,6 @@ class ExecutorClassLoader( } } - private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = { - val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { - val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) - val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) - newuri.toURL - } else { - new URL(classUri + "/" + urlEncode(pathInDirectory)) - } - val connection: HttpURLConnection = Utils.setupSecureURLConnection(url.openConnection(), - SparkEnv.get.securityManager).asInstanceOf[HttpURLConnection] - // Set the connection timeouts (for testing purposes) - if (httpUrlConnectionTimeoutMillis != -1) { - connection.setConnectTimeout(httpUrlConnectionTimeoutMillis) - connection.setReadTimeout(httpUrlConnectionTimeoutMillis) - } - connection.connect() - try { - if (connection.getResponseCode != 200) { - // Close the error stream so that the connection is eligible for re-use - try { - connection.getErrorStream.close() - } catch { - case ioe: IOException => - logError("Exception while closing error stream", ioe) - } - throw new ClassNotFoundException(s"Class file not found at URL $url") - } else { - connection.getInputStream - } - } catch { - case NonFatal(e) if !e.isInstanceOf[ClassNotFoundException] => - connection.disconnect() - throw e - } - } - private def getClassFileInputStreamFromFileSystem(fileSystem: FileSystem)( pathInDirectory: String): InputStream = { val path = new Path(directory, pathInDirectory) From 7706eea6a8bdcd73e9dde5212368f8825e2f1801 Mon Sep 17 00:00:00 2001 From: Yogesh Garg Date: Mon, 5 Mar 2018 15:53:10 -0800 Subject: [PATCH 434/774] [SPARK-18630][PYTHON][ML] Move del method from JavaParams to JavaWrapper; add tests The `__del__` method that explicitly detaches the object was moved from `JavaParams` to `JavaWrapper` class, this way model summaries could also be garbage collected in Java. A test case was added to make sure that relevant error messages are thrown after the objects are deleted. I ran pyspark tests agains `pyspark-ml` module `./python/run-tests --python-executables=$(which python) --modules=pyspark-ml` Author: Yogesh Garg Closes #20724 from yogeshg/java_wrapper_memory. --- python/pyspark/ml/tests.py | 39 ++++++++++++++++++++++++++++++++++++ python/pyspark/ml/wrapper.py | 8 ++++---- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 116885969345c..6dee6938d8916 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -173,6 +173,45 @@ class MockModel(MockTransformer, Model, HasFake): pass +class JavaWrapperMemoryTests(SparkSessionTestCase): + + def test_java_object_gets_detached(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", weightCol="weight", + fitIntercept=False) + + model = lr.fit(df) + summary = model.summary + + self.assertIsInstance(model, JavaWrapper) + self.assertIsInstance(summary, JavaWrapper) + self.assertIsInstance(model, JavaParams) + self.assertNotIsInstance(summary, JavaParams) + + error_no_object = 'Target Object ID does not exist for this gateway' + + self.assertIn("LinearRegression_", model._java_obj.toString()) + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + + model.__del__() + + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + + try: + summary.__del__() + except: + pass + + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + summary._java_obj.toString() + + class ParamTypeConversionTests(PySparkTestCase): """ Test that param type conversion happens. diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 0f846fbc5b5ef..5061f6434794a 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -36,6 +36,10 @@ def __init__(self, java_obj=None): super(JavaWrapper, self).__init__() self._java_obj = java_obj + def __del__(self): + if SparkContext._active_spark_context and self._java_obj is not None: + SparkContext._active_spark_context._gateway.detach(self._java_obj) + @classmethod def _create_from_java_class(cls, java_class, *args): """ @@ -100,10 +104,6 @@ class JavaParams(JavaWrapper, Params): __metaclass__ = ABCMeta - def __del__(self): - if SparkContext._active_spark_context: - SparkContext._active_spark_context._gateway.detach(self._java_obj) - def _make_java_param_pair(self, param, value): """ Makes a Java param pair. From f6b49f9d1b6f218408197f7272c1999fe3d94328 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 6 Mar 2018 01:37:51 +0100 Subject: [PATCH 435/774] [SPARK-23586][SQL] Add interpreted execution to WrapOption ## What changes were proposed in this pull request? The PR adds interpreted execution to WrapOption. ## How was this patch tested? added UT Author: Marco Gaido Closes #20741 from mgaido91/SPARK-23586_2. --- .../sql/catalyst/expressions/objects/objects.scala | 3 +-- .../catalyst/expressions/ObjectExpressionsSuite.scala | 11 ++++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 03cc8eaceb4e6..d832fe0a6857c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -422,8 +422,7 @@ case class WrapOption(child: Expression, optType: DataType) override def inputTypes: Seq[AbstractDataType] = optType :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override def eval(input: InternalRow): Any = Option(child.eval(input)) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputObject = child.genCode(ctx) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index d95db5867b19c..d535578a7eb06 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, UnwrapOption} +import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types.{IntegerType, ObjectType} @@ -75,4 +75,13 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(unwrapObject, expected, InternalRow.fromSeq(Seq(input))) } } + + test("SPARK-23586: WrapOption should support interpreted execution") { + val cls = ObjectType(classOf[java.lang.Integer]) + val inputObject = BoundReference(0, cls, nullable = true) + val wrapObject = WrapOption(inputObject, cls) + Seq((1, Some(1)), (null, None)).foreach { case (input, expected) => + checkEvaluation(wrapObject, expected, InternalRow.fromSeq(Seq(input))) + } + } } From 8c5b34c425bda2079a1ff969b12c067f2bb3f18f Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Mon, 5 Mar 2018 16:49:24 -0800 Subject: [PATCH 436/774] =?UTF-8?q?[SPARK-23604][SQL]=20Change=20Statistic?= =?UTF-8?q?s.isEmpty=20to=20!Statistics.hasNonNul=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …lValue ## What changes were proposed in this pull request? Parquet 1.9 will change the semantics of Statistics.isEmpty slightly to reflect if the null value count has been set. That breaks a timestamp interoperability test that cares only about whether there are column values present in the statistics of a written file for an INT96 column. Fix by using Statistics.hasNonNullValue instead. ## How was this patch tested? Unit tests continue to pass against Parquet 1.8, and also pass against a Parquet build including PARQUET-1217. Author: Henry Robinson Closes #20740 from henryr/spark-23604. --- .../datasources/parquet/ParquetInteroperabilitySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala index fbd83a0fa425a..9c75965639d8a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -184,7 +184,7 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS // when the data is read back as mentioned above, b/c int96 is unsigned. This // assert makes sure this holds even if we change parquet versions (if eg. there // were ever statistics even on unsigned columns). - assert(columnStats.isEmpty) + assert(!columnStats.hasNonNullValue) } // These queries should return the entire dataset with the conversion applied, From ad640a5affceaaf3979e25848628fb1dfcdf932a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 5 Mar 2018 20:35:14 -0800 Subject: [PATCH 437/774] [SPARK-23303][SQL] improve the explain result for data source v2 relations ## What changes were proposed in this pull request? The proposed explain format: **[streaming header] [RelationV2/ScanV2] [data source name] [output] [pushed filters] [options]** **streaming header**: if it's a streaming relation, put a "Streaming" at the beginning. **RelationV2/ScanV2**: if it's a logical plan, put a "RelationV2", else, put a "ScanV2" **data source name**: the simple class name of the data source implementation **output**: a string of the plan output attributes **pushed filters**: a string of all the filters that have been pushed to this data source **options**: all the options to create the data source reader. The current explain result for data source v2 relation is unreadable: ``` == Parsed Logical Plan == 'Filter ('i > 6) +- AnalysisBarrier +- Project [j#1] +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Analyzed Logical Plan == j: int Project [j#1] +- Filter (i#0 > 6) +- Project [j#1, i#0] +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Optimized Logical Plan == Project [j#1] +- Filter isnotnull(i#0) +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Physical Plan == *(1) Project [j#1] +- *(1) Filter isnotnull(i#0) +- *(1) DataSourceV2Scan [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 ``` after this PR ``` == Parsed Logical Plan == 'Project [unresolvedalias('j, None)] +- AnalysisBarrier +- RelationV2 AdvancedDataSourceV2[i#0, j#1] == Analyzed Logical Plan == j: int Project [j#1] +- RelationV2 AdvancedDataSourceV2[i#0, j#1] == Optimized Logical Plan == RelationV2 AdvancedDataSourceV2[j#1] == Physical Plan == *(1) ScanV2 AdvancedDataSourceV2[j#1] ``` ------- ``` == Analyzed Logical Plan == i: int, j: int Filter (i#88 > 3) +- RelationV2 JavaAdvancedDataSourceV2[i#88, j#89] == Optimized Logical Plan == Filter isnotnull(i#88) +- RelationV2 JavaAdvancedDataSourceV2[i#88, j#89] (Pushed Filters: [GreaterThan(i,3)]) == Physical Plan == *(1) Filter isnotnull(i#88) +- *(1) ScanV2 JavaAdvancedDataSourceV2[i#88, j#89] (Pushed Filters: [GreaterThan(i,3)]) ``` an example for streaming query ``` == Parsed Logical Plan == Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject cast(value#25 as string).toString, obj#4: java.lang.String +- Streaming RelationV2 MemoryStreamDataSource[value#25] == Analyzed Logical Plan == value: string, count(1): bigint Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject cast(value#25 as string).toString, obj#4: java.lang.String +- Streaming RelationV2 MemoryStreamDataSource[value#25] == Optimized Logical Plan == Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject value#25.toString, obj#4: java.lang.String +- Streaming RelationV2 MemoryStreamDataSource[value#25] == Physical Plan == *(4) HashAggregate(keys=[value#6], functions=[count(1)], output=[value#6, count(1)#11L]) +- StateStoreSave [value#6], state info [ checkpoint = *********(redacted)/cloud/dev/spark/target/tmp/temporary-549f264b-2531-4fcb-a52f-433c77347c12/state, runId = f84d9da9-2f8c-45c1-9ea1-70791be684de, opId = 0, ver = 0, numPartitions = 5], Complete, 0 +- *(3) HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#16L]) +- StateStoreRestore [value#6], state info [ checkpoint = *********(redacted)/cloud/dev/spark/target/tmp/temporary-549f264b-2531-4fcb-a52f-433c77347c12/state, runId = f84d9da9-2f8c-45c1-9ea1-70791be684de, opId = 0, ver = 0, numPartitions = 5] +- *(2) HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#16L]) +- Exchange hashpartitioning(value#6, 5) +- *(1) HashAggregate(keys=[value#6], functions=[partial_count(1)], output=[value#6, count#16L]) +- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- *(1) MapElements , obj#5: java.lang.String +- *(1) DeserializeToObject value#25.toString, obj#4: java.lang.String +- *(1) ScanV2 MemoryStreamDataSource[value#25] ``` ## How was this patch tested? N/A Author: Wenchen Fan Closes #20647 from cloud-fan/explain. --- .../kafka010/KafkaContinuousSourceSuite.scala | 2 +- .../sql/kafka010/KafkaContinuousTest.scala | 2 +- .../kafka010/KafkaMicroBatchSourceSuite.scala | 2 +- .../v2/DataSourceReaderHolder.scala | 64 ------------- .../datasources/v2/DataSourceV2Relation.scala | 34 +++++-- .../datasources/v2/DataSourceV2ScanExec.scala | 18 +++- .../datasources/v2/DataSourceV2Strategy.scala | 8 +- .../v2/DataSourceV2StringFormat.scala | 94 +++++++++++++++++++ .../streaming/MicroBatchExecution.scala | 29 +++++- .../continuous/ContinuousExecution.scala | 8 +- .../spark/sql/streaming/StreamSuite.scala | 12 ++- .../spark/sql/streaming/StreamTest.scala | 4 +- .../continuous/ContinuousSuite.scala | 11 +-- 13 files changed, 183 insertions(+), 105 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index f679e9bfc0450..aab8ec42189fb 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -60,7 +60,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r + case StreamingDataSourceV2Relation(_, _, _, r: KafkaContinuousReader) => r }.exists { r => // Ensure the new topic is present and the old topic is gone. r.knownPartitions.exists(_.topic == topic2) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index 48ac3fc1e8f9d..fa1468a3943c8 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -47,7 +47,7 @@ trait KafkaContinuousTest extends KafkaSourceTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r + case StreamingDataSourceV2Relation(_, _, _, r: KafkaContinuousReader) => r }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index f2b3ff7615e74..e017fd9b84d21 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -124,7 +124,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case StreamingDataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + case StreamingDataSourceV2Relation(_, _, _, reader: KafkaContinuousReader) => reader } }) }.distinct diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala deleted file mode 100644 index 81219e9771bd8..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import java.util.Objects - -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.sources.v2.reader._ - -/** - * A base class for data source reader holder with customized equals/hashCode methods. - */ -trait DataSourceReaderHolder { - - /** - * The output of the data source reader, w.r.t. column pruning. - */ - def output: Seq[Attribute] - - /** - * The held data source reader. - */ - def reader: DataSourceReader - - /** - * The metadata of this data source reader that can be used for equality test. - */ - private def metadata: Seq[Any] = { - val filters: Any = reader match { - case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet - case s: SupportsPushDownFilters => s.pushedFilters().toSet - case _ => Nil - } - Seq(output, reader.getClass, filters) - } - - def canEqual(other: Any): Boolean - - override def equals(other: Any): Boolean = other match { - case other: DataSourceReaderHolder => - canEqual(other) && metadata.length == other.metadata.length && - metadata.zip(other.metadata).forall { case (l, r) => l == r } - case _ => false - } - - override def hashCode(): Int = { - metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index cc6cb631e3f06..2b282ffae2390 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -35,15 +35,12 @@ case class DataSourceV2Relation( options: Map[String, String], projection: Seq[AttributeReference], filters: Option[Seq[Expression]] = None, - userSpecifiedSchema: Option[StructType] = None) extends LeafNode with MultiInstanceRelation { + userSpecifiedSchema: Option[StructType] = None) + extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { import DataSourceV2Relation._ - override def simpleString: String = { - s"DataSourceV2Relation(source=${source.name}, " + - s"schema=[${output.map(a => s"$a ${a.dataType.simpleString}").mkString(", ")}], " + - s"filters=[${pushedFilters.mkString(", ")}], options=$options)" - } + override def simpleString: String = "RelationV2 " + metadataString override lazy val schema: StructType = reader.readSchema() @@ -107,19 +104,36 @@ case class DataSourceV2Relation( } /** - * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical - * to the non-streaming relation. + * A specialization of [[DataSourceV2Relation]] with the streaming bit set to true. + * + * Note that, this plan has a mutable reader, so Spark won't apply operator push-down for this plan, + * to avoid making the plan mutable. We should consolidate this plan and [[DataSourceV2Relation]] + * after we figure out how to apply operator push-down for streaming data sources. */ case class StreamingDataSourceV2Relation( output: Seq[AttributeReference], + source: DataSourceV2, + options: Map[String, String], reader: DataSourceReader) - extends LeafNode with DataSourceReaderHolder with MultiInstanceRelation { + extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { + override def isStreaming: Boolean = true - override def canEqual(other: Any): Boolean = other.isInstanceOf[StreamingDataSourceV2Relation] + override def simpleString: String = "Streaming RelationV2 " + metadataString override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance())) + // TODO: unify the equal/hashCode implementation for all data source v2 query plans. + override def equals(other: Any): Boolean = other match { + case other: StreamingDataSourceV2Relation => + output == other.output && reader.getClass == other.reader.getClass && options == other.options + case _ => false + } + + override def hashCode(): Int = { + Seq(output, source, options).hashCode() + } + override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 7d9581be4db89..cb691ba297076 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types.StructType @@ -36,10 +37,23 @@ import org.apache.spark.sql.types.StructType */ case class DataSourceV2ScanExec( output: Seq[AttributeReference], + @transient source: DataSourceV2, + @transient options: Map[String, String], @transient reader: DataSourceReader) - extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { + extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan { - override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] + override def simpleString: String = "ScanV2 " + metadataString + + // TODO: unify the equal/hashCode implementation for all data source v2 query plans. + override def equals(other: Any): Boolean = other match { + case other: DataSourceV2ScanExec => + output == other.output && reader.getClass == other.reader.getClass && options == other.options + case _ => false + } + + override def hashCode(): Int = { + Seq(output, source, options).hashCode() + } override def outputPartitioning: physical.Partitioning = reader match { case s: SupportsReportPartitioning => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index c4e7644683c36..1ac9572de6412 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,11 +23,11 @@ import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case relation: DataSourceV2Relation => - DataSourceV2ScanExec(relation.output, relation.reader) :: Nil + case r: DataSourceV2Relation => + DataSourceV2ScanExec(r.output, r.source, r.options, r.reader) :: Nil - case relation: StreamingDataSourceV2Relation => - DataSourceV2ScanExec(relation.output, relation.reader) :: Nil + case r: StreamingDataSourceV2Relation => + DataSourceV2ScanExec(r.output, r.source, r.options, r.reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala new file mode 100644 index 0000000000000..aed55a429bfd7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.commons.lang3.StringUtils + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.DataSourceV2 +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.util.Utils + +/** + * A trait that can be used by data source v2 related query plans(both logical and physical), to + * provide a string format of the data source information for explain. + */ +trait DataSourceV2StringFormat { + + /** + * The instance of this data source implementation. Note that we only consider its class in + * equals/hashCode, not the instance itself. + */ + def source: DataSourceV2 + + /** + * The output of the data source reader, w.r.t. column pruning. + */ + def output: Seq[Attribute] + + /** + * The options for this data source reader. + */ + def options: Map[String, String] + + /** + * The created data source reader. Here we use it to get the filters that has been pushed down + * so far, itself doesn't take part in the equals/hashCode. + */ + def reader: DataSourceReader + + private lazy val filters = reader match { + case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet + case s: SupportsPushDownFilters => s.pushedFilters().toSet + case _ => Set.empty + } + + private def sourceName: String = source match { + case registered: DataSourceRegister => registered.shortName() + case _ => source.getClass.getSimpleName.stripSuffix("$") + } + + def metadataString: String = { + val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] + + if (filters.nonEmpty) { + entries += "Filters" -> filters.mkString("[", ", ", "]") + } + + // TODO: we should only display some standard options like path, table, etc. + if (options.nonEmpty) { + entries += "Options" -> Utils.redact(options).map { + case (k, v) => s"$k=$v" + }.mkString("[", ",", "]") + } + + val outputStr = Utils.truncatedString(output, "[", ", ", "]") + + val entriesStr = if (entries.nonEmpty) { + Utils.truncatedString(entries.map { + case (key, value) => key + ": " + StringUtils.abbreviate(value, 100) + }, " (", ", ", ")") + } else { + "" + } + + s"$sourceName$outputStr$entriesStr" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index ff4be9c7ab874..6e231970f4a22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -20,16 +20,16 @@ package org.apache.spark.sql.execution.streaming import java.util.Optional import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} +import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} @@ -52,6 +52,9 @@ class MicroBatchExecution( @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty + private val readerToDataSourceMap = + MutableMap.empty[MicroBatchReader, (DataSourceV2, Map[String, String])] + private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) case OneTimeTrigger => OneTimeExecutor() @@ -97,6 +100,7 @@ class MicroBatchExecution( metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 + readerToDataSourceMap(reader) = dataSourceV2 -> options logInfo(s"Using MicroBatchReader [$reader] from " + s"DataSourceV2 named '$sourceName' [$dataSourceV2]") StreamingExecutionRelation(reader, output)(sparkSession) @@ -419,8 +423,19 @@ class MicroBatchExecution( toJava(current), Optional.of(availableV2)) logDebug(s"Retrieving data from $reader: $current -> $availableV2") - Some(reader -> - new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader)) + + val (source, options) = reader match { + // `MemoryStream` is special. It's for test only and doesn't have a `DataSourceV2` + // implementation. We provide a fake one here for explain. + case _: MemoryStream[_] => MemoryStreamDataSource -> Map.empty[String, String] + // Provide a fake value here just in case something went wrong, e.g. the reader gives + // a wrong `equals` implementation. + case _ => readerToDataSourceMap.getOrElse(reader, { + FakeDataSourceV2 -> Map.empty[String, String] + }) + } + Some(reader -> StreamingDataSourceV2Relation( + reader.readSchema().toAttributes, source, options, reader)) case _ => None } } @@ -525,3 +540,7 @@ class MicroBatchExecution( object MicroBatchExecution { val BATCH_ID_KEY = "streaming.sql.batchId" } + +object MemoryStreamDataSource extends DataSourceV2 + +object FakeDataSourceV2 extends DataSourceV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index daebd1dd010ac..1758b3844bd62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} @@ -167,7 +167,7 @@ class ContinuousExecution( var insertedSourceId = 0 val withNewSources = logicalPlan transform { - case ContinuousExecutionRelation(_, _, output) => + case ContinuousExecutionRelation(source, options, output) => val reader = continuousSources(insertedSourceId) insertedSourceId += 1 val newOutput = reader.readSchema().toAttributes @@ -180,7 +180,7 @@ class ContinuousExecution( val loggedOffset = offsets.offsets(0) val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) - new StreamingDataSourceV2Relation(newOutput, reader) + StreamingDataSourceV2Relation(newOutput, source, options, reader) } // Rewire the plan to use the new attributes that were returned by the source. @@ -201,7 +201,7 @@ class ContinuousExecution( val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { - case StreamingDataSourceV2Relation(_, r: ContinuousReader) => r + case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r }.head reportTimeTaken("queryPlanning") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index d1a04833390f5..c1ec1eba69fb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -492,16 +492,20 @@ class StreamSuite extends StreamTest { val explainWithoutExtended = q.explainInternal(false) // `extended = false` only displays the physical plan. - assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) - assert("DataSourceV2Scan".r.findAllMatchIn(explainWithoutExtended).size === 1) + assert("Streaming RelationV2 MemoryStreamDataSource".r + .findAllMatchIn(explainWithoutExtended).size === 0) + assert("ScanV2 MemoryStreamDataSource".r + .findAllMatchIn(explainWithoutExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithoutExtended.contains("StateStoreRestore")) val explainWithExtended = q.explainInternal(true) // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical // plan. - assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithExtended).size === 3) - assert("DataSourceV2Scan".r.findAllMatchIn(explainWithExtended).size === 1) + assert("Streaming RelationV2 MemoryStreamDataSource".r + .findAllMatchIn(explainWithExtended).size === 3) + assert("ScanV2 MemoryStreamDataSource".r + .findAllMatchIn(explainWithExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithExtended.contains("StateStoreRestore")) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 08f722ecb10e5..e44aef09f1f3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -629,8 +629,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def findSourceIndex(plan: LogicalPlan): Option[Int] = { plan .collect { - case StreamingExecutionRelation(s, _) => s - case StreamingDataSourceV2Relation(_, r) => r + case r: StreamingExecutionRelation => r.source + case r: StreamingDataSourceV2Relation => r.reader } .zipWithIndex .find(_._1 == source) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 4b4ed82dc6520..f5884b9c8de12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.streaming.continuous -import java.util.UUID - -import org.apache.spark.{SparkContext, SparkEnv, SparkException} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.test.TestSparkSession @@ -43,7 +40,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, r: RateStreamContinuousReader) => r + case DataSourceV2ScanExec(_, _, _, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 From e8a259d66dda0d4c76f3af8933676bade8a7451d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 6 Mar 2018 13:55:13 +0100 Subject: [PATCH 438/774] [SPARK-23594][SQL] GetExternalRowField should support interpreted execution ## What changes were proposed in this pull request? This pr added interpreted execution for `GetExternalRowField`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Takeshi Yamamuro Closes #20746 from maropu/SPARK-23594. --- .../expressions/objects/objects.scala | 14 ++++++++++--- .../expressions/ObjectExpressionsSuite.scala | 20 +++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index d832fe0a6857c..97e3ff88858d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1358,11 +1358,19 @@ case class GetExternalRowField( override def dataType: DataType = ObjectType(classOf[Object]) - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") - private val errMsg = s"The ${index}th field '$fieldName' of input row cannot be null." + override def eval(input: InternalRow): Any = { + val inputRow = child.eval(input).asInstanceOf[Row] + if (inputRow == null) { + throw new RuntimeException("The input external row cannot be null.") + } + if (inputRow.isNullAt(index)) { + throw new RuntimeException(errMsg) + } + inputRow.get(index) + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the field is null. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index d535578a7eb06..0f376c4b63c15 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.objects._ @@ -84,4 +85,23 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(wrapObject, expected, InternalRow.fromSeq(Seq(input))) } } + + test("SPARK-23594 GetExternalRowField should support interpreted execution") { + val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) + val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0") + Seq((Row(1), 1), (Row(3), 3)).foreach { case (input, expected) => + checkEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input))) + } + + // If an input row or a field are null, a runtime exception will be thrown + val errMsg1 = intercept[RuntimeException] { + evaluate(getRowField, InternalRow.fromSeq(Seq(null))) + }.getMessage + assert(errMsg1 === "The input external row cannot be null.") + + val errMsg2 = intercept[RuntimeException] { + evaluate(getRowField, InternalRow.fromSeq(Seq(Row(null)))) + }.getMessage + assert(errMsg2 === "The 0th field 'c0' of input row cannot be null.") + } } From 8bceb899dc3220998a4ea4021f3b477f78faaca8 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 6 Mar 2018 08:52:28 -0600 Subject: [PATCH 439/774] [SPARK-23601][BUILD] Remove .md5 files from release ## What changes were proposed in this pull request? Remove .md5 files from release artifacts ## How was this patch tested? N/A Author: Sean Owen Closes #20737 from srowen/SPARK-23601. --- dev/create-release/release-build.sh | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index a3579f21fc539..c00b00b845401 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -164,8 +164,6 @@ if [[ "$1" == "package" ]]; then tar cvzf spark-$SPARK_VERSION.tgz spark-$SPARK_VERSION echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour --output spark-$SPARK_VERSION.tgz.asc \ --detach-sig spark-$SPARK_VERSION.tgz - echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md MD5 spark-$SPARK_VERSION.tgz > \ - spark-$SPARK_VERSION.tgz.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 spark-$SPARK_VERSION.tgz > spark-$SPARK_VERSION.tgz.sha512 rm -rf spark-$SPARK_VERSION @@ -215,9 +213,6 @@ if [[ "$1" == "package" ]]; then echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ --output $R_DIST_NAME.asc \ --detach-sig $R_DIST_NAME - echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ - MD5 $R_DIST_NAME > \ - $R_DIST_NAME.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 $R_DIST_NAME > \ $R_DIST_NAME.sha512 @@ -234,9 +229,6 @@ if [[ "$1" == "package" ]]; then echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ --output $PYTHON_DIST_NAME.asc \ --detach-sig $PYTHON_DIST_NAME - echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ - MD5 $PYTHON_DIST_NAME > \ - $PYTHON_DIST_NAME.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 $PYTHON_DIST_NAME > \ $PYTHON_DIST_NAME.sha512 @@ -247,9 +239,6 @@ if [[ "$1" == "package" ]]; then echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ --output spark-$SPARK_VERSION-bin-$NAME.tgz.asc \ --detach-sig spark-$SPARK_VERSION-bin-$NAME.tgz - echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ - MD5 spark-$SPARK_VERSION-bin-$NAME.tgz > \ - spark-$SPARK_VERSION-bin-$NAME.tgz.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 spark-$SPARK_VERSION-bin-$NAME.tgz > \ spark-$SPARK_VERSION-bin-$NAME.tgz.sha512 @@ -382,18 +371,11 @@ if [[ "$1" == "publish-release" ]]; then find . -type f |grep -v \.jar |grep -v \.pom | xargs rm echo "Creating hash and signature files" - # this must have .asc, .md5 and .sha1 - it really doesn't like anything else there + # this must have .asc and .sha1 - it really doesn't like anything else there for file in $(find . -type f) do echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --output $file.asc \ --detach-sig --armour $file; - if [ $(command -v md5) ]; then - # Available on OS X; -q to keep only hash - md5 -q $file > $file.md5 - else - # Available on Linux; cut to keep only hash - md5sum $file | cut -f1 -d' ' > $file.md5 - fi sha1sum $file | cut -f1 -d' ' > $file.sha1 done From 4c587eb4887623c839854c1505f495de42898229 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 6 Mar 2018 17:42:17 +0100 Subject: [PATCH 440/774] [SPARK-23590][SQL] Add interpreted execution to CreateExternalRow ## What changes were proposed in this pull request? The PR adds interpreted execution to CreateExternalRow ## How was this patch tested? added UT Author: Marco Gaido Closes #20749 from mgaido91/SPARK-23590. --- .../spark/sql/catalyst/expressions/objects/objects.scala | 6 ++++-- .../sql/catalyst/expressions/ExpressionEvalHelper.scala | 4 +++- .../sql/catalyst/expressions/ObjectExpressionsSuite.scala | 8 +++++++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 97e3ff88858d0..721d589709131 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1111,8 +1111,10 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) override def nullable: Boolean = false - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override def eval(input: InternalRow): Any = { + val values = children.map(_.eval(input)).toArray + new GenericRowWithSchema(values, schema) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericRowWithSchema].getName diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index b4c8eab19c5cc..29f0cc0d991aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -24,6 +24,7 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer} import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -60,7 +61,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { /** * Check the equality between result of expression and expected value, it will handle - * Array[Byte], Spread[Double], and MapData. + * Array[Byte], Spread[Double], MapData and Row. */ protected def checkResult(result: Any, expected: Any, dataType: DataType): Boolean = { (result, expected) match { @@ -88,6 +89,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => if (expected.isNaN) result.isNaN else expected == result + case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) case _ => result == expected } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 0f376c4b63c15..50e57737a4612 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} -import org.apache.spark.sql.types.{IntegerType, ObjectType} +import org.apache.spark.sql.types._ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -86,6 +86,12 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-23590: CreateExternalRow should support interpreted execution") { + val schema = new StructType().add("a", IntegerType).add("b", StringType) + val createExternalRow = CreateExternalRow(Seq(Literal(1), Literal("x")), schema) + checkEvaluation(createExternalRow, Row.fromSeq(Seq(1, "x")), InternalRow.fromSeq(Seq())) + } + test("SPARK-23594 GetExternalRowField should support interpreted execution") { val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0") From 04e71c31603af3a13bc13300df799f003fe185f7 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 7 Mar 2018 17:01:29 +0800 Subject: [PATCH 441/774] [MINOR][YARN] Add disable yarn.nodemanager.vmem-check-enabled option to memLimitExceededLogMessage My spark application sometimes will throw `Container killed by YARN for exceeding memory limits`. Even I increased `spark.yarn.executor.memoryOverhead` to 10G, this error still happen. The latest config: memory-config And error message: ``` ExecutorLostFailure (executor 121 exited caused by one of the running tasks) Reason: Container killed by YARN for exceeding memory limits. 30.7 GB of 30 GB physical memory used. Consider boosting spark.yarn.executor.memoryOverhead. ``` This is because of [Linux glibc >= 2.10 (RHEL 6) malloc may show excessive virtual memory usage](https://www.ibm.com/developerworks/community/blogs/kevgrig/entry/linux_glibc_2_10_rhel_6_malloc_may_show_excessive_virtual_memory_usage?lang=en). So disable `yarn.nodemanager.vmem-check-enabled` looks like a good option as [MapR mentioned ](https://mapr.com/blog/best-practices-yarn-resource-management). This PR add disable `yarn.nodemanager.vmem-check-enabled` option to memLimitExceededLogMessage. More details: https://issues.apache.org/jira/browse/YARN-4714 https://stackoverflow.com/a/31450291 https://stackoverflow.com/a/42091255 After this PR: yarn N/A Author: Yuming Wang Author: Yuming Wang Closes #20735 from wangyum/YARN-4714. Change-Id: Ie10836e2c07b6384d228c3f9e89f802823bd9f16 --- .../scala/org/apache/spark/deploy/yarn/YarnAllocator.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 506adb363aa90..a537243d641cb 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -736,7 +736,8 @@ private object YarnAllocator { def memLimitExceededLogMessage(diagnostics: String, pattern: Pattern): String = { val matcher = pattern.matcher(diagnostics) val diag = if (matcher.find()) " " + matcher.group() + "." else "" - ("Container killed by YARN for exceeding memory limits." + diag - + " Consider boosting spark.yarn.executor.memoryOverhead.") + s"Container killed by YARN for exceeding memory limits. $diag " + + "Consider boosting spark.yarn.executor.memoryOverhead or " + + "disabling yarn.nodemanager.vmem-check-enabled because of YARN-4714." } } From 33c2cb22b3b246a413717042a5f741da04ded69d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 7 Mar 2018 13:10:51 +0100 Subject: [PATCH 442/774] [SPARK-23611][SQL] Add a helper function to check exception for expr evaluation ## What changes were proposed in this pull request? This pr added a helper function in `ExpressionEvalHelper` to check exceptions in all the path of expression evaluation. ## How was this patch tested? Modified the existing tests. Author: Takeshi Yamamuro Closes #20748 from maropu/SPARK-23611. --- .../expressions/ExpressionEvalHelper.scala | 83 ++++++++++++++----- .../expressions/MathExpressionsSuite.scala | 2 +- .../expressions/MiscExpressionsSuite.scala | 2 +- .../expressions/NullExpressionsSuite.scala | 2 +- .../expressions/ObjectExpressionsSuite.scala | 17 ++-- .../expressions/RegexpExpressionsSuite.scala | 8 +- .../expressions/StringExpressionsSuite.scala | 2 +- .../expressions/TimeWindowSuite.scala | 2 +- 8 files changed, 79 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 29f0cc0d991aa..58d0c07622eb9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.reflect.ClassTag + import org.scalacheck.Gen import org.scalactic.TripleEqualsSupport.Spread import org.scalatest.exceptions.TestFailedException @@ -45,11 +47,15 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) } - protected def checkEvaluation( - expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + private def prepareEvaluation(expression: Expression): Expression = { val serializer = new JavaSerializer(new SparkConf()).newInstance val resolver = ResolveTimeZone(new SQLConf) - val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) + resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) + } + + protected def checkEvaluation( + expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + val expr = prepareEvaluation(expression) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) @@ -95,7 +101,31 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } - protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { + protected def checkExceptionInExpression[T <: Throwable : ClassTag]( + expression: => Expression, + inputRow: InternalRow, + expectedErrMsg: String): Unit = { + + def checkException(eval: => Unit, testMode: String): Unit = { + withClue(s"($testMode)") { + val errMsg = intercept[T] { + eval + }.getMessage + if (errMsg != expectedErrMsg) { + fail(s"Expected error message is `$expectedErrMsg`, but `$errMsg` found") + } + } + } + val expr = prepareEvaluation(expression) + checkException(evaluateWithoutCodegen(expr, inputRow), "non-codegen mode") + checkException(evaluateWithGeneratedMutableProjection(expr, inputRow), "codegen mode") + if (GenerateUnsafeProjection.canSupport(expr.dataType)) { + checkException(evaluateWithUnsafeProjection(expr, inputRow), "unsafe mode") + } + } + + protected def evaluateWithoutCodegen( + expression: Expression, inputRow: InternalRow = EmptyRow): Any = { expression.foreach { case n: Nondeterministic => n.initialize(0) case _ => @@ -124,7 +154,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { + val actual = try evaluateWithoutCodegen(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } if (!checkResult(actual, expected, expression.dataType)) { @@ -139,33 +169,29 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + val actual = evaluateWithGeneratedMutableProjection(expression, inputRow) + if (!checkResult(actual, expected, expression.dataType)) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + private def evaluateWithGeneratedMutableProjection( + expression: Expression, + inputRow: InternalRow = EmptyRow): Any = { val plan = generateProject( GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) plan.initialize(0) - val actual = plan(inputRow).get(0, expression.dataType) - if (!checkResult(actual, expected, expression.dataType)) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input") - } + plan(inputRow).get(0, expression.dataType) } protected def checkEvalutionWithUnsafeProjection( expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - // SPARK-16489 Explicitly doing code generation twice so code gen will fail if - // some expression is reusing variable names across different instances. - // This behavior is tested in ExpressionEvalHelperSuite. - val plan = generateProject( - UnsafeProjection.create( - Alias(expression, s"Optimized($expression)1")() :: - Alias(expression, s"Optimized($expression)2")() :: Nil), - expression) - - val unsafeRow = plan(inputRow) + val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow) val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" if (expected == null) { @@ -185,6 +211,21 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } + private def evaluateWithUnsafeProjection( + expression: Expression, + inputRow: InternalRow = EmptyRow): InternalRow = { + // SPARK-16489 Explicitly doing code generation twice so code gen will fail if + // some expression is reusing variable names across different instances. + // This behavior is tested in ExpressionEvalHelperSuite. + val plan = generateProject( + UnsafeProjection.create( + Alias(expression, s"Optimized($expression)1")() :: + Alias(expression, s"Optimized($expression)2")() :: Nil), + expression) + + plan(inputRow) + } + protected def checkEvaluationWithOptimization( expression: Expression, expected: Any, @@ -294,7 +335,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { private def cmpInterpretWithCodegen(inputRow: InternalRow, expr: Expression): Unit = { val interpret = try { - evaluate(expr, inputRow) + evaluateWithoutCodegen(expr, inputRow) } catch { case e: Exception => fail(s"Exception evaluating $expr", e) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 39e0060d41dd4..3a094079380fd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -124,7 +124,7 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { private def checkNaNWithoutCodegen( expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { + val actual = try evaluateWithoutCodegen(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } if (!actual.asInstanceOf[Double].isNaN) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index facc863081303..a21c139fe71d0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -41,6 +41,6 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("uuid") { checkEvaluation(Length(Uuid()), 36) - assert(evaluate(Uuid()) !== evaluate(Uuid())) + assert(evaluateWithoutCodegen(Uuid()) !== evaluateWithoutCodegen(Uuid())) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index cc6c15cb2c909..424c3a4696077 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -51,7 +51,7 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("AssertNotNUll") { val ex = intercept[RuntimeException] { - evaluate(AssertNotNull(Literal(null), Seq.empty[String])) + evaluateWithoutCodegen(AssertNotNull(Literal(null), Seq.empty[String])) }.getMessage assert(ex.contains("Null value appeared in non-nullable field")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 50e57737a4612..cbfbb6573ae8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -100,14 +100,13 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } // If an input row or a field are null, a runtime exception will be thrown - val errMsg1 = intercept[RuntimeException] { - evaluate(getRowField, InternalRow.fromSeq(Seq(null))) - }.getMessage - assert(errMsg1 === "The input external row cannot be null.") - - val errMsg2 = intercept[RuntimeException] { - evaluate(getRowField, InternalRow.fromSeq(Seq(Row(null)))) - }.getMessage - assert(errMsg2 === "The 0th field 'c0' of input row cannot be null.") + checkExceptionInExpression[RuntimeException]( + getRowField, + InternalRow.fromSeq(Seq(null)), + "The input external row cannot be null.") + checkExceptionInExpression[RuntimeException]( + getRowField, + InternalRow.fromSeq(Seq(Row(null))), + "The 0th field 'c0' of input row cannot be null.") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 2a0a42c65b086..d532dc4f77198 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -100,12 +100,12 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // invalid escaping val invalidEscape = intercept[AnalysisException] { - evaluate("""a""" like """\a""") + evaluateWithoutCodegen("""a""" like """\a""") } assert(invalidEscape.getMessage.contains("pattern")) val endEscape = intercept[AnalysisException] { - evaluate("""a""" like """a\""") + evaluateWithoutCodegen("""a""" like """a\""") } assert(endEscape.getMessage.contains("pattern")) @@ -147,11 +147,11 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkLiteralRow("abc" rlike _, "^bc", false) intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike "**") + evaluateWithoutCodegen("abbbbc" rlike "**") } intercept[java.util.regex.PatternSyntaxException] { val regex = 'a.string.at(0) - evaluate("abbbbc" rlike regex, create_row("**")) + evaluateWithoutCodegen("abbbbc" rlike regex, create_row("**")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 97ddbeba2c5ca..9a1a4da074ce3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -756,7 +756,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // exceptional cases intercept[java.util.regex.PatternSyntaxException] { - evaluate(ParseUrl(Seq(Literal("http://spark.apache.org/path?"), + evaluateWithoutCodegen(ParseUrl(Seq(Literal("http://spark.apache.org/path?"), Literal("QUERY"), Literal("???")))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala index d6c8fcf291842..351d4d0c2eac9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala @@ -27,7 +27,7 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva test("time window is unevaluable") { intercept[UnsupportedOperationException] { - evaluate(TimeWindow(Literal(10L), "1 second", "1 second", "0 second")) + evaluateWithoutCodegen(TimeWindow(Literal(10L), "1 second", "1 second", "0 second")) } } From aff7d81cb73133483fc2256ca10e21b4b8101647 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 7 Mar 2018 18:31:59 +0100 Subject: [PATCH 443/774] [SPARK-23591][SQL] Add interpreted execution to EncodeUsingSerializer ## What changes were proposed in this pull request? The PR adds interpreted execution to EncodeUsingSerializer. ## How was this patch tested? added UT Author: Marco Gaido Closes #20751 from mgaido91/SPARK-23591. --- .../expressions/objects/objects.scala | 114 ++++++++++-------- .../expressions/ObjectExpressionsSuite.scala | 16 ++- 2 files changed, 77 insertions(+), 53 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 721d589709131..7bbc3c732e782 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -105,6 +105,61 @@ trait InvokeLike extends Expression with NonSQLExpression { } } +/** + * Common trait for [[DecodeUsingSerializer]] and [[EncodeUsingSerializer]] + */ +trait SerializerSupport { + /** + * If true, Kryo serialization is used, otherwise the Java one is used + */ + val kryo: Boolean + + /** + * The serializer instance to be used for serialization/deserialization in interpreted execution + */ + lazy val serializerInstance: SerializerInstance = SerializerSupport.newSerializer(kryo) + + /** + * Adds a immutable state to the generated class containing a reference to the serializer. + * @return a string containing the name of the variable referencing the serializer + */ + def addImmutableSerializerIfNeeded(ctx: CodegenContext): String = { + val (serializerInstance, serializerInstanceClass) = { + if (kryo) { + ("kryoSerializer", + classOf[KryoSerializerInstance].getName) + } else { + ("javaSerializer", + classOf[JavaSerializerInstance].getName) + } + } + val newSerializerMethod = s"${classOf[SerializerSupport].getName}$$.MODULE$$.newSerializer" + // Code to initialize the serializer + ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializerInstance, v => + s""" + |$v = ($serializerInstanceClass) $newSerializerMethod($kryo); + """.stripMargin) + serializerInstance + } +} + +object SerializerSupport { + /** + * It creates a new `SerializerInstance` which is either a `KryoSerializerInstance` (is + * `useKryo` is set to `true`) or a `JavaSerializerInstance`. + */ + def newSerializer(useKryo: Boolean): SerializerInstance = { + // try conf from env, otherwise create a new one + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) + val s = if (useKryo) { + new KryoSerializer(conf) + } else { + new JavaSerializer(conf) + } + s.newInstance() + } +} + /** * Invokes a static function, returning the result. By default, any of the arguments being null * will result in returning null instead of calling the function. @@ -1154,36 +1209,14 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) * @param kryo if true, use Kryo. Otherwise, use Java. */ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) - extends UnaryExpression with NonSQLExpression { + extends UnaryExpression with NonSQLExpression with SerializerSupport { - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override def nullSafeEval(input: Any): Any = { + serializerInstance.serialize(input).array() + } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - // Code to initialize the serializer. - val (serializer, serializerClass, serializerInstanceClass) = { - if (kryo) { - ("kryoSerializer", - classOf[KryoSerializer].getName, - classOf[KryoSerializerInstance].getName) - } else { - ("javaSerializer", - classOf[JavaSerializer].getName, - classOf[JavaSerializerInstance].getName) - } - } - // try conf from env, otherwise create a new one - val env = s"${classOf[SparkEnv].getName}.get()" - val sparkConf = s"new ${classOf[SparkConf].getName}()" - ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v => - s""" - |if ($env == null) { - | $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); - |} else { - | $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); - |} - """.stripMargin) - + val serializer = addImmutableSerializerIfNeeded(ctx) // Code to serialize. val input = child.genCode(ctx) val javaType = CodeGenerator.javaType(dataType) @@ -1207,33 +1240,10 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) * @param kryo if true, use Kryo. Otherwise, use Java. */ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) - extends UnaryExpression with NonSQLExpression { + extends UnaryExpression with NonSQLExpression with SerializerSupport { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - // Code to initialize the serializer. - val (serializer, serializerClass, serializerInstanceClass) = { - if (kryo) { - ("kryoSerializer", - classOf[KryoSerializer].getName, - classOf[KryoSerializerInstance].getName) - } else { - ("javaSerializer", - classOf[JavaSerializer].getName, - classOf[JavaSerializerInstance].getName) - } - } - // try conf from env, otherwise create a new one - val env = s"${classOf[SparkEnv].getName}.get()" - val sparkConf = s"new ${classOf[SparkConf].getName}()" - ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v => - s""" - |if ($env == null) { - | $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); - |} else { - | $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); - |} - """.stripMargin) - + val serializer = addImmutableSerializerIfNeeded(ctx) // Code to deserialize. val input = child.genCode(ctx) val javaType = CodeGenerator.javaType(dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index cbfbb6573ae8e..346b13277c709 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -109,4 +110,17 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { InternalRow.fromSeq(Seq(Row(null))), "The 0th field 'c0' of input row cannot be null.") } + + test("SPARK-23591: EncodeUsingSerializer should support interpreted execution") { + val cls = ObjectType(classOf[java.lang.Integer]) + val inputObject = BoundReference(0, cls, nullable = true) + val conf = new SparkConf() + Seq(true, false).foreach { useKryo => + val serializer = if (useKryo) new KryoSerializer(conf) else new JavaSerializer(conf) + val expected = serializer.newInstance().serialize(new Integer(1)).array() + val encodeUsingSerializer = EncodeUsingSerializer(inputObject, useKryo) + checkEvaluation(encodeUsingSerializer, expected, InternalRow.fromSeq(Seq(1))) + checkEvaluation(encodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) + } + } } From 53561d27c45db31893bcabd4aca2387fde869b72 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 7 Mar 2018 09:37:42 -0800 Subject: [PATCH 444/774] [SPARK-23291][SQL][R] R's substr should not reduce starting position by 1 when calling Scala API ## What changes were proposed in this pull request? Seems R's substr API treats Scala substr API as zero based and so subtracts the given starting position by 1. Because Scala's substr API also accepts zero-based starting position (treated as the first element), so the current R's substr test results are correct as they all use 1 as starting positions. ## How was this patch tested? Modified tests. Author: Liang-Chi Hsieh Closes #20464 from viirya/SPARK-23291. --- R/pkg/R/column.R | 10 ++++++++-- R/pkg/tests/fulltests/test_sparkSQL.R | 1 + docs/sparkr.md | 4 ++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 9727efc354f10..7926a9a2467ee 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -161,12 +161,18 @@ setMethod("alias", #' @aliases substr,Column-method #' #' @param x a Column. -#' @param start starting position. +#' @param start starting position. It should be 1-base. #' @param stop ending position. +#' @examples +#' \dontrun{ +#' df <- createDataFrame(list(list(a="abcdef"))) +#' collect(select(df, substr(df$a, 1, 4))) # the result is `abcd`. +#' collect(select(df, substr(df$a, 2, 4))) # the result is `bcd`. +#' } #' @note substr since 1.4.0 setMethod("substr", signature(x = "Column"), function(x, start, stop) { - jc <- callJMethod(x@jc, "substr", as.integer(start - 1), as.integer(stop - start + 1)) + jc <- callJMethod(x@jc, "substr", as.integer(start), as.integer(stop - start + 1)) column(jc) }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index bd0a0dcd0674c..439191adb23ea 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1651,6 +1651,7 @@ test_that("string operators", { expect_false(first(select(df, startsWith(df$name, "m")))[[1]]) expect_true(first(select(df, endsWith(df$name, "el")))[[1]]) expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") + expect_equal(first(select(df, substr(df$name, 4, 6)))[[1]], "hae") if (as.numeric(R.version$major) >= 3 && as.numeric(R.version$minor) >= 3) { expect_true(startsWith("Hello World", "Hello")) expect_false(endsWith("Hello World", "a")) diff --git a/docs/sparkr.md b/docs/sparkr.md index 6685b585a393a..2909247e79e95 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -663,3 +663,7 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma - The `stringsAsFactors` parameter was previously ignored with `collect`, for example, in `collect(createDataFrame(iris), stringsAsFactors = TRUE))`. It has been corrected. - For `summary`, option for statistics to compute has been added. Its output is changed from that from `describe`. - A warning can be raised if versions of SparkR package and the Spark JVM do not match. + +## Upgrading to Spark 2.4.0 + + - The `start` parameter of `substr` method was wrongly subtracted by one, previously. In other words, the index specified by `start` parameter was considered as 0-base. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. It has been fixed so the `start` parameter of `substr` method is now 1-base, e.g., therefore to get the same result as `substr(df$a, 2, 5)`, it should be changed to `substr(df$a, 1, 4)`. From c99fc9ad9b600095baba003053dbf84304ca392b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 7 Mar 2018 13:42:06 -0800 Subject: [PATCH 445/774] [SPARK-23550][CORE] Cleanup `Utils`. A few different things going on: - Remove unused methods. - Move JSON methods to the only class that uses them. - Move test-only methods to TestUtils. - Make getMaxResultSize() a config constant. - Reuse functionality from existing libraries (JRE or JavaUtils) where possible. The change also includes changes to a few tests to call `Utils.createTempFile` correctly, so that temp dirs are created under the designated top-level temp dir instead of potentially polluting git index. Author: Marcelo Vanzin Closes #20706 from vanzin/SPARK-23550. --- .../scala/org/apache/spark/TestUtils.scala | 26 +++- .../spark/deploy/SparkSubmitArguments.scala | 4 +- .../org/apache/spark/executor/Executor.scala | 4 +- .../spark/internal/config/ConfigBuilder.scala | 3 +- .../spark/internal/config/package.scala | 5 + .../spark/scheduler/TaskSetManager.scala | 3 +- .../org/apache/spark/util/JsonProtocol.scala | 124 +++++++++-------- .../scala/org/apache/spark/util/Utils.scala | 131 +----------------- .../sort/UnsafeShuffleWriterSuite.java | 2 +- .../scala/org/apache/spark/DriverSuite.scala | 2 +- .../spark/deploy/SparkSubmitSuite.scala | 18 +-- .../spark/scheduler/ReplayListenerSuite.scala | 12 +- .../org/apache/spark/util/UtilsSuite.scala | 1 + scalastyle-config.xml | 2 +- .../datasources/orc/OrcSourceSuite.scala | 4 +- .../metric/SQLMetricsTestUtils.scala | 6 +- .../sql/sources/PartitionedWriteSuite.scala | 15 +- .../HiveThriftServer2Suites.scala | 2 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 20 +-- .../spark/streaming/CheckpointSuite.scala | 2 +- .../spark/streaming/MapWithStateSuite.scala | 2 +- 21 files changed, 152 insertions(+), 236 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 93e7ee3d2a404..b5c4c705dcbc7 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -22,7 +22,7 @@ import java.net.{HttpURLConnection, URI, URL} import java.nio.charset.StandardCharsets import java.security.SecureRandom import java.security.cert.X509Certificate -import java.util.Arrays +import java.util.{Arrays, Properties} import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} import java.util.jar.{JarEntry, JarOutputStream} import javax.net.ssl._ @@ -35,6 +35,7 @@ import scala.sys.process.{Process, ProcessLogger} import scala.util.Try import com.google.common.io.{ByteStreams, Files} +import org.apache.log4j.PropertyConfigurator import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ @@ -256,6 +257,29 @@ private[spark] object TestUtils { s"Can't find $numExecutors executors before $timeout milliseconds elapsed") } + /** + * config a log4j properties used for testsuite + */ + def configTestLog4j(level: String): Unit = { + val pro = new Properties() + pro.put("log4j.rootLogger", s"$level, console") + pro.put("log4j.appender.console", "org.apache.log4j.ConsoleAppender") + pro.put("log4j.appender.console.target", "System.err") + pro.put("log4j.appender.console.layout", "org.apache.log4j.PatternLayout") + pro.put("log4j.appender.console.layout.ConversionPattern", + "%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n") + PropertyConfigurator.configure(pro) + } + + /** + * Lists files recursively. + */ + def recursiveList(f: File): Array[File] = { + require(f.isDirectory) + val current = f.listFiles + current ++ current.filter(_.isDirectory).flatMap(recursiveList) + } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 9db7a1fe3106d..e7796d4ddbe34 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy -import java.io.{ByteArrayOutputStream, PrintStream} +import java.io.{ByteArrayOutputStream, File, PrintStream} import java.lang.reflect.InvocationTargetException import java.net.URI import java.nio.charset.StandardCharsets @@ -233,7 +233,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S // Set name from main class if not given name = Option(name).orElse(Option(mainClass)).orNull if (name == null && primaryResource != null) { - name = Utils.stripDirectory(primaryResource) + name = new File(primaryResource).getName() } // Action should be SUBMIT unless otherwise specified diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 2c3a8ef74800b..dcec3ec21b546 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -35,6 +35,7 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager} import org.apache.spark.rpc.RpcTimeout import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task, TaskDescription} @@ -141,8 +142,7 @@ private[spark] class Executor( conf.getSizeAsBytes("spark.task.maxDirectResultSize", 1L << 20), RpcUtils.maxMessageSizeBytes(conf)) - // Limit of bytes for total size of results (default is 1GB) - private val maxResultSize = Utils.getMaxResultSize(conf) + private val maxResultSize = conf.get(MAX_RESULT_SIZE) // Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index b0cd7110a3b47..f27aca03773a9 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -23,6 +23,7 @@ import java.util.regex.PatternSyntaxException import scala.util.matching.Regex import org.apache.spark.network.util.{ByteUnit, JavaUtils} +import org.apache.spark.util.Utils private object ConfigHelpers { @@ -45,7 +46,7 @@ private object ConfigHelpers { } def stringToSeq[T](str: String, converter: String => T): Seq[T] = { - str.split(",").map(_.trim()).filter(_.nonEmpty).map(converter) + Utils.stringToSeq(str).map(converter) } def seqToString[T](v: Seq[T], stringConverter: T => String): String = { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index bbfcfbaa7363c..a313ad0554a3a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -520,4 +520,9 @@ package object config { .checkValue(v => v > 0, "The threshold should be positive.") .createWithDefault(10000000) + private[spark] val MAX_RESULT_SIZE = ConfigBuilder("spark.driver.maxResultSize") + .doc("Size limit for results.") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("1g") + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 886c2c99f1ff3..d958658527f6d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -64,8 +64,7 @@ private[spark] class TaskSetManager( val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75) val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 1.5) - // Limit of bytes for total size of results (default is 1GB) - val maxResultSize = Utils.getMaxResultSize(conf) + val maxResultSize = conf.get(config.MAX_RESULT_SIZE) val speculationEnabled = conf.getBoolean("spark.speculation", false) diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index ff83301d631c4..40383fe05026b 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -48,7 +48,7 @@ import org.apache.spark.storage._ * To ensure that we provide these guarantees, follow these rules when modifying these methods: * * - Never delete any JSON fields. - * - Any new JSON fields should be optional; use `Utils.jsonOption` when reading these fields + * - Any new JSON fields should be optional; use `jsonOption` when reading these fields * in `*FromJson` methods. */ private[spark] object JsonProtocol { @@ -408,7 +408,7 @@ private[spark] object JsonProtocol { ("Loss Reason" -> reason.map(_.toString)) case taskKilled: TaskKilled => ("Kill Reason" -> taskKilled.reason) - case _ => Utils.emptyJson + case _ => emptyJson } ("Reason" -> reason) ~ json } @@ -422,7 +422,7 @@ private[spark] object JsonProtocol { def jobResultToJson(jobResult: JobResult): JValue = { val result = Utils.getFormattedClassName(jobResult) val json = jobResult match { - case JobSucceeded => Utils.emptyJson + case JobSucceeded => emptyJson case jobFailed: JobFailed => JObject("Exception" -> exceptionToJson(jobFailed.exception)) } @@ -573,7 +573,7 @@ private[spark] object JsonProtocol { def taskStartFromJson(json: JValue): SparkListenerTaskStart = { val stageId = (json \ "Stage ID").extract[Int] val stageAttemptId = - Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) + jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val taskInfo = taskInfoFromJson(json \ "Task Info") SparkListenerTaskStart(stageId, stageAttemptId, taskInfo) } @@ -586,7 +586,7 @@ private[spark] object JsonProtocol { def taskEndFromJson(json: JValue): SparkListenerTaskEnd = { val stageId = (json \ "Stage ID").extract[Int] val stageAttemptId = - Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) + jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val taskType = (json \ "Task Type").extract[String] val taskEndReason = taskEndReasonFromJson(json \ "Task End Reason") val taskInfo = taskInfoFromJson(json \ "Task Info") @@ -597,11 +597,11 @@ private[spark] object JsonProtocol { def jobStartFromJson(json: JValue): SparkListenerJobStart = { val jobId = (json \ "Job ID").extract[Int] val submissionTime = - Utils.jsonOption(json \ "Submission Time").map(_.extract[Long]).getOrElse(-1L) + jsonOption(json \ "Submission Time").map(_.extract[Long]).getOrElse(-1L) val stageIds = (json \ "Stage IDs").extract[List[JValue]].map(_.extract[Int]) val properties = propertiesFromJson(json \ "Properties") // The "Stage Infos" field was added in Spark 1.2.0 - val stageInfos = Utils.jsonOption(json \ "Stage Infos") + val stageInfos = jsonOption(json \ "Stage Infos") .map(_.extract[Seq[JValue]].map(stageInfoFromJson)).getOrElse { stageIds.map { id => new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown") @@ -613,7 +613,7 @@ private[spark] object JsonProtocol { def jobEndFromJson(json: JValue): SparkListenerJobEnd = { val jobId = (json \ "Job ID").extract[Int] val completionTime = - Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]).getOrElse(-1L) + jsonOption(json \ "Completion Time").map(_.extract[Long]).getOrElse(-1L) val jobResult = jobResultFromJson(json \ "Job Result") SparkListenerJobEnd(jobId, completionTime, jobResult) } @@ -630,15 +630,15 @@ private[spark] object JsonProtocol { def blockManagerAddedFromJson(json: JValue): SparkListenerBlockManagerAdded = { val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID") val maxMem = (json \ "Maximum Memory").extract[Long] - val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) - val maxOnHeapMem = Utils.jsonOption(json \ "Maximum Onheap Memory").map(_.extract[Long]) - val maxOffHeapMem = Utils.jsonOption(json \ "Maximum Offheap Memory").map(_.extract[Long]) + val time = jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) + val maxOnHeapMem = jsonOption(json \ "Maximum Onheap Memory").map(_.extract[Long]) + val maxOffHeapMem = jsonOption(json \ "Maximum Offheap Memory").map(_.extract[Long]) SparkListenerBlockManagerAdded(time, blockManagerId, maxMem, maxOnHeapMem, maxOffHeapMem) } def blockManagerRemovedFromJson(json: JValue): SparkListenerBlockManagerRemoved = { val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID") - val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) + val time = jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) SparkListenerBlockManagerRemoved(time, blockManagerId) } @@ -648,11 +648,11 @@ private[spark] object JsonProtocol { def applicationStartFromJson(json: JValue): SparkListenerApplicationStart = { val appName = (json \ "App Name").extract[String] - val appId = Utils.jsonOption(json \ "App ID").map(_.extract[String]) + val appId = jsonOption(json \ "App ID").map(_.extract[String]) val time = (json \ "Timestamp").extract[Long] val sparkUser = (json \ "User").extract[String] - val appAttemptId = Utils.jsonOption(json \ "App Attempt ID").map(_.extract[String]) - val driverLogs = Utils.jsonOption(json \ "Driver Logs").map(mapFromJson) + val appAttemptId = jsonOption(json \ "App Attempt ID").map(_.extract[String]) + val driverLogs = jsonOption(json \ "Driver Logs").map(mapFromJson) SparkListenerApplicationStart(appName, appId, time, sparkUser, appAttemptId, driverLogs) } @@ -703,19 +703,19 @@ private[spark] object JsonProtocol { def stageInfoFromJson(json: JValue): StageInfo = { val stageId = (json \ "Stage ID").extract[Int] - val attemptId = Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) + val attemptId = jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val stageName = (json \ "Stage Name").extract[String] val numTasks = (json \ "Number of Tasks").extract[Int] val rddInfos = (json \ "RDD Info").extract[List[JValue]].map(rddInfoFromJson) - val parentIds = Utils.jsonOption(json \ "Parent IDs") + val parentIds = jsonOption(json \ "Parent IDs") .map { l => l.extract[List[JValue]].map(_.extract[Int]) } .getOrElse(Seq.empty) - val details = Utils.jsonOption(json \ "Details").map(_.extract[String]).getOrElse("") - val submissionTime = Utils.jsonOption(json \ "Submission Time").map(_.extract[Long]) - val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]) - val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String]) + val details = jsonOption(json \ "Details").map(_.extract[String]).getOrElse("") + val submissionTime = jsonOption(json \ "Submission Time").map(_.extract[Long]) + val completionTime = jsonOption(json \ "Completion Time").map(_.extract[Long]) + val failureReason = jsonOption(json \ "Failure Reason").map(_.extract[String]) val accumulatedValues = { - Utils.jsonOption(json \ "Accumulables").map(_.extract[List[JValue]]) match { + jsonOption(json \ "Accumulables").map(_.extract[List[JValue]]) match { case Some(values) => values.map(accumulableInfoFromJson) case None => Seq.empty[AccumulableInfo] } @@ -735,17 +735,17 @@ private[spark] object JsonProtocol { def taskInfoFromJson(json: JValue): TaskInfo = { val taskId = (json \ "Task ID").extract[Long] val index = (json \ "Index").extract[Int] - val attempt = Utils.jsonOption(json \ "Attempt").map(_.extract[Int]).getOrElse(1) + val attempt = jsonOption(json \ "Attempt").map(_.extract[Int]).getOrElse(1) val launchTime = (json \ "Launch Time").extract[Long] val executorId = (json \ "Executor ID").extract[String].intern() val host = (json \ "Host").extract[String].intern() val taskLocality = TaskLocality.withName((json \ "Locality").extract[String]) - val speculative = Utils.jsonOption(json \ "Speculative").exists(_.extract[Boolean]) + val speculative = jsonOption(json \ "Speculative").exists(_.extract[Boolean]) val gettingResultTime = (json \ "Getting Result Time").extract[Long] val finishTime = (json \ "Finish Time").extract[Long] val failed = (json \ "Failed").extract[Boolean] - val killed = Utils.jsonOption(json \ "Killed").exists(_.extract[Boolean]) - val accumulables = Utils.jsonOption(json \ "Accumulables").map(_.extract[Seq[JValue]]) match { + val killed = jsonOption(json \ "Killed").exists(_.extract[Boolean]) + val accumulables = jsonOption(json \ "Accumulables").map(_.extract[Seq[JValue]]) match { case Some(values) => values.map(accumulableInfoFromJson) case None => Seq.empty[AccumulableInfo] } @@ -762,13 +762,13 @@ private[spark] object JsonProtocol { def accumulableInfoFromJson(json: JValue): AccumulableInfo = { val id = (json \ "ID").extract[Long] - val name = Utils.jsonOption(json \ "Name").map(_.extract[String]) - val update = Utils.jsonOption(json \ "Update").map { v => accumValueFromJson(name, v) } - val value = Utils.jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) } - val internal = Utils.jsonOption(json \ "Internal").exists(_.extract[Boolean]) + val name = jsonOption(json \ "Name").map(_.extract[String]) + val update = jsonOption(json \ "Update").map { v => accumValueFromJson(name, v) } + val value = jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) } + val internal = jsonOption(json \ "Internal").exists(_.extract[Boolean]) val countFailedValues = - Utils.jsonOption(json \ "Count Failed Values").exists(_.extract[Boolean]) - val metadata = Utils.jsonOption(json \ "Metadata").map(_.extract[String]) + jsonOption(json \ "Count Failed Values").exists(_.extract[Boolean]) + val metadata = jsonOption(json \ "Metadata").map(_.extract[String]) new AccumulableInfo(id, name, update, value, internal, countFailedValues, metadata) } @@ -821,49 +821,49 @@ private[spark] object JsonProtocol { metrics.incDiskBytesSpilled((json \ "Disk Bytes Spilled").extract[Long]) // Shuffle read metrics - Utils.jsonOption(json \ "Shuffle Read Metrics").foreach { readJson => + jsonOption(json \ "Shuffle Read Metrics").foreach { readJson => val readMetrics = metrics.createTempShuffleReadMetrics() readMetrics.incRemoteBlocksFetched((readJson \ "Remote Blocks Fetched").extract[Int]) readMetrics.incLocalBlocksFetched((readJson \ "Local Blocks Fetched").extract[Int]) readMetrics.incRemoteBytesRead((readJson \ "Remote Bytes Read").extract[Long]) - Utils.jsonOption(readJson \ "Remote Bytes Read To Disk") + jsonOption(readJson \ "Remote Bytes Read To Disk") .foreach { v => readMetrics.incRemoteBytesReadToDisk(v.extract[Long])} readMetrics.incLocalBytesRead( - Utils.jsonOption(readJson \ "Local Bytes Read").map(_.extract[Long]).getOrElse(0L)) + jsonOption(readJson \ "Local Bytes Read").map(_.extract[Long]).getOrElse(0L)) readMetrics.incFetchWaitTime((readJson \ "Fetch Wait Time").extract[Long]) readMetrics.incRecordsRead( - Utils.jsonOption(readJson \ "Total Records Read").map(_.extract[Long]).getOrElse(0L)) + jsonOption(readJson \ "Total Records Read").map(_.extract[Long]).getOrElse(0L)) metrics.mergeShuffleReadMetrics() } // Shuffle write metrics // TODO: Drop the redundant "Shuffle" since it's inconsistent with related classes. - Utils.jsonOption(json \ "Shuffle Write Metrics").foreach { writeJson => + jsonOption(json \ "Shuffle Write Metrics").foreach { writeJson => val writeMetrics = metrics.shuffleWriteMetrics writeMetrics.incBytesWritten((writeJson \ "Shuffle Bytes Written").extract[Long]) writeMetrics.incRecordsWritten( - Utils.jsonOption(writeJson \ "Shuffle Records Written").map(_.extract[Long]).getOrElse(0L)) + jsonOption(writeJson \ "Shuffle Records Written").map(_.extract[Long]).getOrElse(0L)) writeMetrics.incWriteTime((writeJson \ "Shuffle Write Time").extract[Long]) } // Output metrics - Utils.jsonOption(json \ "Output Metrics").foreach { outJson => + jsonOption(json \ "Output Metrics").foreach { outJson => val outputMetrics = metrics.outputMetrics outputMetrics.setBytesWritten((outJson \ "Bytes Written").extract[Long]) outputMetrics.setRecordsWritten( - Utils.jsonOption(outJson \ "Records Written").map(_.extract[Long]).getOrElse(0L)) + jsonOption(outJson \ "Records Written").map(_.extract[Long]).getOrElse(0L)) } // Input metrics - Utils.jsonOption(json \ "Input Metrics").foreach { inJson => + jsonOption(json \ "Input Metrics").foreach { inJson => val inputMetrics = metrics.inputMetrics inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long]) inputMetrics.incRecordsRead( - Utils.jsonOption(inJson \ "Records Read").map(_.extract[Long]).getOrElse(0L)) + jsonOption(inJson \ "Records Read").map(_.extract[Long]).getOrElse(0L)) } // Updated blocks - Utils.jsonOption(json \ "Updated Blocks").foreach { blocksJson => + jsonOption(json \ "Updated Blocks").foreach { blocksJson => metrics.setUpdatedBlockStatuses(blocksJson.extract[List[JValue]].map { blockJson => val id = BlockId((blockJson \ "Block ID").extract[String]) val status = blockStatusFromJson(blockJson \ "Status") @@ -897,7 +897,7 @@ private[spark] object JsonProtocol { val shuffleId = (json \ "Shuffle ID").extract[Int] val mapId = (json \ "Map ID").extract[Int] val reduceId = (json \ "Reduce ID").extract[Int] - val message = Utils.jsonOption(json \ "Message").map(_.extract[String]) + val message = jsonOption(json \ "Message").map(_.extract[String]) new FetchFailed(blockManagerAddress, shuffleId, mapId, reduceId, message.getOrElse("Unknown reason")) case `exceptionFailure` => @@ -905,9 +905,9 @@ private[spark] object JsonProtocol { val description = (json \ "Description").extract[String] val stackTrace = stackTraceFromJson(json \ "Stack Trace") val fullStackTrace = - Utils.jsonOption(json \ "Full Stack Trace").map(_.extract[String]).orNull + jsonOption(json \ "Full Stack Trace").map(_.extract[String]).orNull // Fallback on getting accumulator updates from TaskMetrics, which was logged in Spark 1.x - val accumUpdates = Utils.jsonOption(json \ "Accumulator Updates") + val accumUpdates = jsonOption(json \ "Accumulator Updates") .map(_.extract[List[JValue]].map(accumulableInfoFromJson)) .getOrElse(taskMetricsFromJson(json \ "Metrics").accumulators().map(acc => { acc.toInfo(Some(acc.value), None) @@ -915,21 +915,21 @@ private[spark] object JsonProtocol { ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates) case `taskResultLost` => TaskResultLost case `taskKilled` => - val killReason = Utils.jsonOption(json \ "Kill Reason") + val killReason = jsonOption(json \ "Kill Reason") .map(_.extract[String]).getOrElse("unknown reason") TaskKilled(killReason) case `taskCommitDenied` => // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON // de/serialization logic was not added until 1.5.1. To provide backward compatibility // for reading those logs, we need to provide default values for all the fields. - val jobId = Utils.jsonOption(json \ "Job ID").map(_.extract[Int]).getOrElse(-1) - val partitionId = Utils.jsonOption(json \ "Partition ID").map(_.extract[Int]).getOrElse(-1) - val attemptNo = Utils.jsonOption(json \ "Attempt Number").map(_.extract[Int]).getOrElse(-1) + val jobId = jsonOption(json \ "Job ID").map(_.extract[Int]).getOrElse(-1) + val partitionId = jsonOption(json \ "Partition ID").map(_.extract[Int]).getOrElse(-1) + val attemptNo = jsonOption(json \ "Attempt Number").map(_.extract[Int]).getOrElse(-1) TaskCommitDenied(jobId, partitionId, attemptNo) case `executorLostFailure` => - val exitCausedByApp = Utils.jsonOption(json \ "Exit Caused By App").map(_.extract[Boolean]) - val executorId = Utils.jsonOption(json \ "Executor ID").map(_.extract[String]) - val reason = Utils.jsonOption(json \ "Loss Reason").map(_.extract[String]) + val exitCausedByApp = jsonOption(json \ "Exit Caused By App").map(_.extract[Boolean]) + val executorId = jsonOption(json \ "Executor ID").map(_.extract[String]) + val reason = jsonOption(json \ "Loss Reason").map(_.extract[String]) ExecutorLostFailure( executorId.getOrElse("Unknown"), exitCausedByApp.getOrElse(true), @@ -968,11 +968,11 @@ private[spark] object JsonProtocol { def rddInfoFromJson(json: JValue): RDDInfo = { val rddId = (json \ "RDD ID").extract[Int] val name = (json \ "Name").extract[String] - val scope = Utils.jsonOption(json \ "Scope") + val scope = jsonOption(json \ "Scope") .map(_.extract[String]) .map(RDDOperationScope.fromJson) - val callsite = Utils.jsonOption(json \ "Callsite").map(_.extract[String]).getOrElse("") - val parentIds = Utils.jsonOption(json \ "Parent IDs") + val callsite = jsonOption(json \ "Callsite").map(_.extract[String]).getOrElse("") + val parentIds = jsonOption(json \ "Parent IDs") .map { l => l.extract[List[JValue]].map(_.extract[Int]) } .getOrElse(Seq.empty) val storageLevel = storageLevelFromJson(json \ "Storage Level") @@ -1029,7 +1029,7 @@ private[spark] object JsonProtocol { } def propertiesFromJson(json: JValue): Properties = { - Utils.jsonOption(json).map { value => + jsonOption(json).map { value => val properties = new Properties mapFromJson(json).foreach { case (k, v) => properties.setProperty(k, v) } properties @@ -1058,4 +1058,14 @@ private[spark] object JsonProtocol { e } + /** Return an option that translates JNothing to None */ + private def jsonOption(json: JValue): Option[JValue] = { + json match { + case JNothing => None + case value: JValue => Some(value) + } + } + + private def emptyJson: JObject = JObject(List[JField]()) + } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2e2a4a259e9af..29d26ea2c85df 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -25,7 +25,7 @@ import java.net._ import java.nio.ByteBuffer import java.nio.channels.{Channels, FileChannel} import java.nio.charset.StandardCharsets -import java.nio.file.{Files, Paths} +import java.nio.file.Files import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean @@ -51,9 +51,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException -import org.json4s._ import org.slf4j.Logger import org.apache.spark._ @@ -1017,70 +1015,18 @@ private[spark] object Utils extends Logging { " " + (System.currentTimeMillis - startTimeMs) + " ms" } - private def listFilesSafely(file: File): Seq[File] = { - if (file.exists()) { - val files = file.listFiles() - if (files == null) { - throw new IOException("Failed to list files for dir: " + file) - } - files - } else { - List() - } - } - - /** - * Lists files recursively. - */ - def recursiveList(f: File): Array[File] = { - require(f.isDirectory) - val current = f.listFiles - current ++ current.filter(_.isDirectory).flatMap(recursiveList) - } - /** * Delete a file or directory and its contents recursively. * Don't follow directories if they are symlinks. * Throws an exception if deletion is unsuccessful. */ - def deleteRecursively(file: File) { + def deleteRecursively(file: File): Unit = { if (file != null) { - try { - if (file.isDirectory && !isSymlink(file)) { - var savedIOException: IOException = null - for (child <- listFilesSafely(file)) { - try { - deleteRecursively(child) - } catch { - // In case of multiple exceptions, only last one will be thrown - case ioe: IOException => savedIOException = ioe - } - } - if (savedIOException != null) { - throw savedIOException - } - ShutdownHookManager.removeShutdownDeleteDir(file) - } - } finally { - if (file.delete()) { - logTrace(s"${file.getAbsolutePath} has been deleted") - } else { - // Delete can also fail if the file simply did not exist - if (file.exists()) { - throw new IOException("Failed to delete: " + file.getAbsolutePath) - } - } - } + JavaUtils.deleteRecursively(file) + ShutdownHookManager.removeShutdownDeleteDir(file) } } - /** - * Check to see if file is a symbolic link. - */ - def isSymlink(file: File): Boolean = { - return Files.isSymbolicLink(Paths.get(file.toURI)) - } - /** * Determines if a directory contains any files newer than cutoff seconds. * @@ -1828,7 +1774,7 @@ private[spark] object Utils extends Logging { * [[scala.collection.Iterator#size]] because it uses a for loop, which is slightly slower * in the current version of Scala. */ - def getIteratorSize[T](iterator: Iterator[T]): Long = { + def getIteratorSize(iterator: Iterator[_]): Long = { var count = 0L while (iterator.hasNext) { count += 1L @@ -1875,17 +1821,6 @@ private[spark] object Utils extends Logging { obj.getClass.getSimpleName.replace("$", "") } - /** Return an option that translates JNothing to None */ - def jsonOption(json: JValue): Option[JValue] = { - json match { - case JNothing => None - case value: JValue => Some(value) - } - } - - /** Return an empty JSON object */ - def emptyJson: JsonAST.JObject = JObject(List[JField]()) - /** * Return a Hadoop FileSystem with the scheme encoded in the given path. */ @@ -1900,15 +1835,6 @@ private[spark] object Utils extends Logging { getHadoopFileSystem(new URI(path), conf) } - /** - * Return the absolute path of a file in the given directory. - */ - def getFilePath(dir: File, fileName: String): Path = { - assert(dir.isDirectory) - val path = new File(dir, fileName).getAbsolutePath - new Path(path) - } - /** * Whether the underlying operating system is Windows. */ @@ -1931,13 +1857,6 @@ private[spark] object Utils extends Logging { sys.env.contains("SPARK_TESTING") || sys.props.contains("spark.testing") } - /** - * Strip the directory from a path name - */ - def stripDirectory(path: String): String = { - new File(path).getName - } - /** * Terminates a process waiting for at most the specified duration. * @@ -2348,36 +2267,6 @@ private[spark] object Utils extends Logging { org.apache.log4j.Logger.getRootLogger().setLevel(l) } - /** - * config a log4j properties used for testsuite - */ - def configTestLog4j(level: String): Unit = { - val pro = new Properties() - pro.put("log4j.rootLogger", s"$level, console") - pro.put("log4j.appender.console", "org.apache.log4j.ConsoleAppender") - pro.put("log4j.appender.console.target", "System.err") - pro.put("log4j.appender.console.layout", "org.apache.log4j.PatternLayout") - pro.put("log4j.appender.console.layout.ConversionPattern", - "%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n") - PropertyConfigurator.configure(pro) - } - - def invoke( - clazz: Class[_], - obj: AnyRef, - methodName: String, - args: (Class[_], AnyRef)*): AnyRef = { - val (types, values) = args.unzip - val method = clazz.getDeclaredMethod(methodName, types: _*) - method.setAccessible(true) - method.invoke(obj, values.toSeq: _*) - } - - // Limit of bytes for total size of results (default is 1GB) - def getMaxResultSize(conf: SparkConf): Long = { - memoryStringToMb(conf.get("spark.driver.maxResultSize", "1g")).toLong << 20 - } - /** * Return the current system LD_LIBRARY_PATH name */ @@ -2610,16 +2499,6 @@ private[spark] object Utils extends Logging { SignalUtils.registerLogger(log) } - /** - * Unions two comma-separated lists of files and filters out empty strings. - */ - def unionFileLists(leftList: Option[String], rightList: Option[String]): Set[String] = { - var allFiles = Set.empty[String] - leftList.foreach { value => allFiles ++= value.split(",") } - rightList.foreach { value => allFiles ++= value.split(",") } - allFiles.filter { _.nonEmpty } - } - /** * Return the jar files pointed by the "spark.jars" property. Spark internally will distribute * these jars through file server. In the YARN mode, it will return an empty list, since YARN diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 24a55df84a240..0d5c5ea7903e9 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -95,7 +95,7 @@ public void tearDown() { @SuppressWarnings("unchecked") public void setUp() throws IOException { MockitoAnnotations.initMocks(this); - tempDir = Utils.createTempDir("test", "test"); + tempDir = Utils.createTempDir(null, "test"); mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); partitionSizesInMergedFile = null; spillFilesCreated.clear(); diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index 962945e5b6bb1..896cd2e80aaef 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -51,7 +51,7 @@ class DriverSuite extends SparkFunSuite with TimeLimits { */ object DriverWithoutCleanup { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf val sc = new SparkContext(args(0), "DriverWithoutCleanup", conf) sc.parallelize(1 to 100, 4).count() diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 803a38d77fb82..d265643a80b4e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -35,6 +35,7 @@ import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.TestUtils import org.apache.spark.TestUtils.JavaSourceFromString import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.SparkSubmit._ @@ -761,18 +762,6 @@ class SparkSubmitSuite } } - test("comma separated list of files are unioned correctly") { - val left = Option("/tmp/a.jar,/tmp/b.jar") - val right = Option("/tmp/c.jar,/tmp/a.jar") - val emptyString = Option("") - Utils.unionFileLists(left, right) should be (Set("/tmp/a.jar", "/tmp/b.jar", "/tmp/c.jar")) - Utils.unionFileLists(emptyString, emptyString) should be (Set.empty) - Utils.unionFileLists(Option("/tmp/a.jar"), emptyString) should be (Set("/tmp/a.jar")) - Utils.unionFileLists(emptyString, Option("/tmp/a.jar")) should be (Set("/tmp/a.jar")) - Utils.unionFileLists(None, Option("/tmp/a.jar")) should be (Set("/tmp/a.jar")) - Utils.unionFileLists(Option("/tmp/a.jar"), None) should be (Set("/tmp/a.jar")) - } - test("support glob path") { val tmpJarDir = Utils.createTempDir() val jar1 = TestUtils.createJarWithFiles(Map("test.resource" -> "1"), tmpJarDir) @@ -1042,6 +1031,7 @@ class SparkSubmitSuite assert(exception.getMessage() === "hello") } + } object SparkSubmitSuite extends SparkFunSuite with TimeLimits { @@ -1076,7 +1066,7 @@ object SparkSubmitSuite extends SparkFunSuite with TimeLimits { object JarCreationTest extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() val sc = new SparkContext(conf) val result = sc.makeRDD(1 to 100, 10).mapPartitions { x => @@ -1100,7 +1090,7 @@ object JarCreationTest extends Logging { object SimpleApplicationTest { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() val sc = new SparkContext(conf) val configs = Seq("spark.master", "spark.app.name") diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 73e7b3fe8c1de..e24d550a62665 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -47,7 +47,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp } test("Simple replay") { - val logFilePath = Utils.getFilePath(testDir, "events.txt") + val logFilePath = getFilePath(testDir, "events.txt") val fstream = fileSystem.create(logFilePath) val writer = new PrintWriter(fstream) val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, @@ -97,7 +97,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp // scalastyle:on println } - val logFilePath = Utils.getFilePath(testDir, "events.lz4.inprogress") + val logFilePath = getFilePath(testDir, "events.lz4.inprogress") val bytes = buffered.toByteArray Utils.tryWithResource(fileSystem.create(logFilePath)) { fstream => fstream.write(bytes, 0, buffered.size) @@ -129,7 +129,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp } test("Replay incompatible event log") { - val logFilePath = Utils.getFilePath(testDir, "incompatible.txt") + val logFilePath = getFilePath(testDir, "incompatible.txt") val fstream = fileSystem.create(logFilePath) val writer = new PrintWriter(fstream) val applicationStart = SparkListenerApplicationStart("Incompatible App", None, @@ -226,6 +226,12 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp } } + private def getFilePath(dir: File, fileName: String): Path = { + assert(dir.isDirectory) + val path = new File(dir, fileName).getAbsolutePath + new Path(path) + } + /** * A simple listener that buffers all the events it receives. * diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index eaea6b030c154..3b4273184f1e9 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -648,6 +648,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("fetch hcfs dir") { val tempDir = Utils.createTempDir() val sourceDir = new File(tempDir, "source-dir") + sourceDir.mkdir() val innerSourceDir = Utils.createTempDir(root = sourceDir.getPath) val sourceFile = File.createTempFile("someprefix", "somesuffix", innerSourceDir) val targetDir = new File(tempDir, "target-dir") diff --git a/scalastyle-config.xml b/scalastyle-config.xml index e2fa5754afaee..e65e3aafe5b5b 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -229,7 +229,7 @@ This file is divided into 3 sections: extractOpt - Use Utils.jsonOption(x).map(.extract[T]) instead of .extractOpt[T], as the latter + Use jsonOption(x).map(.extract[T]) instead of .extractOpt[T], as the latter is slower. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 523f7cf77e103..8a3bbd03a26dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -39,8 +39,8 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { protected override def beforeAll(): Unit = { super.beforeAll() - orcTableAsDir = Utils.createTempDir("orctests", "sparksql") - orcTableDir = Utils.createTempDir("orctests", "sparksql") + orcTableAsDir = Utils.createTempDir(namePrefix = "orctests") + orcTableDir = Utils.createTempDir(namePrefix = "orctests") sparkContext .makeRDD(1 to 10) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 122d28798136f..534d8bb629b8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -21,13 +21,13 @@ import java.io.File import scala.collection.mutable.HashMap +import org.apache.spark.TestUtils import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SQLAppStatusStore} import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.util.Utils trait SQLMetricsTestUtils extends SQLTestUtils { @@ -91,7 +91,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils { (0 until 100).map(i => (i, i + 1)).toDF("i", "j").repartition(2) .write.format(dataFormat).mode("overwrite").insertInto(tableName) } - assert(Utils.recursiveList(tableLocation).count(_.getName.startsWith("part-")) == 2) + assert(TestUtils.recursiveList(tableLocation).count(_.getName.startsWith("part-")) == 2) } } @@ -121,7 +121,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils { .mode("overwrite") .insertInto(tableName) } - assert(Utils.recursiveList(dir).count(_.getName.startsWith("part-")) == 40) + assert(TestUtils.recursiveList(dir).count(_.getName.startsWith("part-")) == 40) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index 0fe33e87318a5..27c983f270bf6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -22,6 +22,7 @@ import java.sql.Timestamp import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.spark.TestUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils @@ -86,15 +87,15 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { withTempDir { f => spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", 1).mode("overwrite").parquet(f.getAbsolutePath) - assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) + assert(TestUtils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", 2).mode("overwrite").parquet(f.getAbsolutePath) - assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 2) + assert(TestUtils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 2) spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", -1).mode("overwrite").parquet(f.getAbsolutePath) - assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 1) + assert(TestUtils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 1) } } @@ -106,7 +107,7 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { .option("maxRecordsPerFile", 1) .mode("overwrite") .parquet(f.getAbsolutePath) - assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) + assert(TestUtils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) } } @@ -133,14 +134,14 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { val df = Seq((1, ts)).toDF("i", "ts") withTempPath { f => df.write.partitionBy("ts").parquet(f.getAbsolutePath) - val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = TestUtils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) checkPartitionValues(files.head, "2016-12-01 00:00:00") } withTempPath { f => df.write.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .partitionBy("ts").parquet(f.getAbsolutePath) - val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = TestUtils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) // use timeZone option "GMT" to format partition value. checkPartitionValues(files.head, "2016-12-01 08:00:00") @@ -148,7 +149,7 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { withTempPath { f => withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { df.write.partitionBy("ts").parquet(f.getAbsolutePath) - val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = TestUtils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) // if there isn't timeZone option, then use session local timezone. checkPartitionValues(files.head, "2016-12-01 08:00:00") diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 496f8c82a6c61..b32c547cefefe 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -804,7 +804,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl protected var metastorePath: File = _ protected def metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" - private val pidDir: File = Utils.createTempDir("thriftserver-pid") + private val pidDir: File = Utils.createTempDir(namePrefix = "thriftserver-pid") protected var logPath: File = _ protected var operationLogPath: File = _ private var logTailingProcess: Process = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 2d31781132edc..079fe45860544 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -330,7 +330,7 @@ class HiveSparkSubmitSuite object SetMetastoreURLTest extends Logging { def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val sparkConf = new SparkConf(loadDefaults = true) val builder = SparkSession.builder() @@ -368,7 +368,7 @@ object SetMetastoreURLTest extends Logging { object SetWarehouseLocationTest extends Logging { def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val sparkConf = new SparkConf(loadDefaults = true).set("spark.ui.enabled", "false") val providedExpectedWarehouseLocation = @@ -447,7 +447,7 @@ object SetWarehouseLocationTest extends Logging { // can load the jar defined with the function. object TemporaryHiveUDFTest extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() conf.set("spark.ui.enabled", "false") val sc = new SparkContext(conf) @@ -485,7 +485,7 @@ object TemporaryHiveUDFTest extends Logging { // can load the jar defined with the function. object PermanentHiveUDFTest1 extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() conf.set("spark.ui.enabled", "false") val sc = new SparkContext(conf) @@ -523,7 +523,7 @@ object PermanentHiveUDFTest1 extends Logging { // can load the jar defined with the function. object PermanentHiveUDFTest2 extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() conf.set("spark.ui.enabled", "false") val sc = new SparkContext(conf) @@ -558,7 +558,7 @@ object PermanentHiveUDFTest2 extends Logging { // We test if we can load user jars in both driver and executors when HiveContext is used. object SparkSubmitClassLoaderTest extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() val hiveWarehouseLocation = Utils.createTempDir() conf.set("spark.ui.enabled", "false") @@ -628,7 +628,7 @@ object SparkSubmitClassLoaderTest extends Logging { // We test if we can correctly set spark sql configurations when HiveContext is used. object SparkSQLConfTest extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") // We override the SparkConf to add spark.sql.hive.metastore.version and // spark.sql.hive.metastore.jars to the beginning of the conf entry array. // So, if metadataHive get initialized after we set spark.sql.hive.metastore.version but @@ -669,7 +669,7 @@ object SPARK_9757 extends QueryTest { protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val hiveWarehouseLocation = Utils.createTempDir() val sparkContext = new SparkContext( @@ -718,7 +718,7 @@ object SPARK_11009 extends QueryTest { protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val sparkContext = new SparkContext( new SparkConf() @@ -749,7 +749,7 @@ object SPARK_14244 extends QueryTest { protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val sparkContext = new SparkContext( new SparkConf() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index ee2fd45a7e851..19b621f11759d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -97,7 +97,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite => val batchDurationMillis = batchDuration.milliseconds // Setup the stream computation - val checkpointDir = Utils.createTempDir(this.getClass.getSimpleName()).toString + val checkpointDir = Utils.createTempDir(namePrefix = this.getClass.getSimpleName()).toString logDebug(s"Using checkpoint directory $checkpointDir") val ssc = createContextForCheckpointOperation(batchDuration) require(ssc.conf.get("spark.streaming.clock") === classOf[ManualClock].getName, diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala index 3b662ec1833aa..06c0c2aa97ee1 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala @@ -39,7 +39,7 @@ class MapWithStateSuite extends SparkFunSuite before { StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } - checkpointDir = Utils.createTempDir("checkpoint") + checkpointDir = Utils.createTempDir(namePrefix = "checkpoint") } after { From ac76eff6a88f6358a321b84cb5e60fb9d6403419 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Wed, 7 Mar 2018 13:51:44 -0800 Subject: [PATCH 446/774] [SPARK-23525][SQL] Support ALTER TABLE CHANGE COLUMN COMMENT for external hive table ## What changes were proposed in this pull request? The following query doesn't work as expected: ``` CREATE EXTERNAL TABLE ext_table(a STRING, b INT, c STRING) PARTITIONED BY (d STRING) LOCATION 'sql/core/spark-warehouse/ext_table'; ALTER TABLE ext_table CHANGE a a STRING COMMENT "new comment"; DESC ext_table; ``` The comment of column `a` is not updated, that's because `HiveExternalCatalog.doAlterTable` ignores table schema changes. To fix the issue, we should call `doAlterTableDataSchema` instead of `doAlterTable`. ## How was this patch tested? Updated `DDLSuite.testChangeColumn`. Author: Xingbo Jiang Closes #20696 from jiangxb1987/alterColumnComment. --- .../spark/sql/execution/command/ddl.scala | 12 ++++++------ .../sql-tests/inputs/change-column.sql | 1 + .../sql-tests/results/change-column.sql.out | 19 ++++++++++++++----- .../sql/execution/command/DDLSuite.scala | 1 + 4 files changed, 22 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 964cbca049b27..bf4d96fa18d0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -314,8 +314,8 @@ case class AlterTableChangeColumnCommand( val resolver = sparkSession.sessionState.conf.resolver DDLUtils.verifyAlterTableType(catalog, table, isView = false) - // Find the origin column from schema by column name. - val originColumn = findColumnByName(table.schema, columnName, resolver) + // Find the origin column from dataSchema by column name. + val originColumn = findColumnByName(table.dataSchema, columnName, resolver) // Throw an AnalysisException if the column name/dataType is changed. if (!columnEqual(originColumn, newColumn, resolver)) { throw new AnalysisException( @@ -324,7 +324,7 @@ case class AlterTableChangeColumnCommand( s"'${newColumn.name}' with type '${newColumn.dataType}'") } - val newSchema = table.schema.fields.map { field => + val newDataSchema = table.dataSchema.fields.map { field => if (field.name == originColumn.name) { // Create a new column from the origin column with the new comment. addComment(field, newColumn.getComment) @@ -332,8 +332,7 @@ case class AlterTableChangeColumnCommand( field } } - val newTable = table.copy(schema = StructType(newSchema)) - catalog.alterTable(newTable) + catalog.alterTableDataSchema(tableName, StructType(newDataSchema)) Seq.empty[Row] } @@ -345,7 +344,8 @@ case class AlterTableChangeColumnCommand( schema.fields.collectFirst { case field if resolver(field.name, name) => field }.getOrElse(throw new AnalysisException( - s"Invalid column reference '$name', table schema is '${schema}'")) + s"Can't find column `$name` given table data columns " + + s"${schema.fieldNames.mkString("[`", "`, `", "`]")}")) } // Add the comment to a column, if comment is empty, return the original column. diff --git a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql index ad0f885f63d3d..2909024e4c9f7 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql @@ -49,6 +49,7 @@ ALTER TABLE global_temp.global_temp_view CHANGE a a INT COMMENT 'this is column -- Change column in partition spec (not supported yet) CREATE TABLE partition_table(a INT, b STRING, c INT, d STRING) USING parquet PARTITIONED BY (c, d); ALTER TABLE partition_table PARTITION (c = 1) CHANGE COLUMN a new_a INT; +ALTER TABLE partition_table CHANGE COLUMN c c INT COMMENT 'this is column C'; -- DROP TEST TABLE DROP TABLE test_change; diff --git a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out index ba8bc936f0c79..ff1ecbcc44c23 100644 --- a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 32 +-- Number of queries: 33 -- !query 0 @@ -154,7 +154,7 @@ ALTER TABLE test_change CHANGE invalid_col invalid_col INT struct<> -- !query 15 output org.apache.spark.sql.AnalysisException -Invalid column reference 'invalid_col', table schema is 'StructType(StructField(a,IntegerType,true), StructField(b,StringType,true), StructField(c,IntegerType,true))'; +Can't find column `invalid_col` given table data columns [`a`, `b`, `c`]; -- !query 16 @@ -291,16 +291,25 @@ ALTER TABLE partition_table PARTITION (c = 1) CHANGE COLUMN a new_a INT -- !query 30 -DROP TABLE test_change +ALTER TABLE partition_table CHANGE COLUMN c c INT COMMENT 'this is column C' -- !query 30 schema struct<> -- !query 30 output - +org.apache.spark.sql.AnalysisException +Can't find column `c` given table data columns [`a`, `b`]; -- !query 31 -DROP TABLE partition_table +DROP TABLE test_change -- !query 31 schema struct<> -- !query 31 output + + +-- !query 32 +DROP TABLE partition_table +-- !query 32 schema +struct<> +-- !query 32 output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index db9023b7ec8b6..4041176262426 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1597,6 +1597,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { // Ensure that change column will preserve other metadata fields. sql("ALTER TABLE dbx.tab1 CHANGE COLUMN col1 col1 INT COMMENT 'this is col1'") assert(getMetadata("col1").getString("key") == "value") + assert(getMetadata("col1").getString("comment") == "this is col1") } test("drop build-in function") { From 77c91cc746f93e609c412f3a220495d9e931f696 Mon Sep 17 00:00:00 2001 From: jx158167 Date: Wed, 7 Mar 2018 20:08:32 -0800 Subject: [PATCH 447/774] [SPARK-23524] Big local shuffle blocks should not be checked for corruption. ## What changes were proposed in this pull request? In current code, all local blocks will be checked for corruption no matter it's big or not. The reasons are as below: Size in FetchResult for local block is set to be 0 (https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala#L327) SPARK-4105 meant to only check the small blocks(size Closes #20685 from jinxing64/SPARK-23524. --- .../storage/ShuffleBlockFetcherIterator.scala | 14 +++--- .../ShuffleBlockFetcherIteratorSuite.scala | 45 +++++++++++++++++++ 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 98b5a735a4529..dd9df74689a13 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -90,7 +90,7 @@ final class ShuffleBlockFetcherIterator( private[this] val startTime = System.currentTimeMillis /** Local blocks to fetch, excluding zero-sized blocks. */ - private[this] val localBlocks = new ArrayBuffer[BlockId]() + private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[BlockId]() /** Remote blocks to fetch, excluding zero-sized blocks. */ private[this] val remoteBlocks = new HashSet[BlockId]() @@ -316,6 +316,7 @@ final class ShuffleBlockFetcherIterator( * track in-memory are the ManagedBuffer references themselves. */ private[this] def fetchLocalBlocks() { + logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}") val iter = localBlocks.iterator while (iter.hasNext) { val blockId = iter.next() @@ -324,7 +325,8 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() - results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false)) + results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, + buf.size(), buf, false)) } catch { case e: Exception => // If we see an exception, stop immediately. @@ -397,7 +399,9 @@ final class ShuffleBlockFetcherIterator( } shuffleMetrics.incRemoteBlocksFetched(1) } - bytesInFlight -= size + if (!localBlocks.contains(blockId)) { + bytesInFlight -= size + } if (isNetworkReqDone) { reqsInFlight -= 1 logDebug("Number of requests in flight " + reqsInFlight) @@ -583,8 +587,8 @@ object ShuffleBlockFetcherIterator { * Result of a fetch from a remote block successfully. * @param blockId block id * @param address BlockManager that the block was fetched from. - * @param size estimated size of the block, used to calculate bytesInFlight. - * Note that this is NOT the exact bytes. + * @param size estimated size of the block. Note that this is NOT the exact bytes. + * Size of remote block is used to calculate bytesInFlight. * @param buf `ManagedBuffer` for the content. * @param isNetworkReqDone Is this the last network request for this host in this fetch request. */ diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 5bfe9905ff17b..692ae3bf597e0 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -352,6 +352,51 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } } + test("big blocks are not checked for corruption") { + val corruptStream = mock(classOf[InputStream]) + when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) + val corruptBuffer = mock(classOf[ManagedBuffer]) + when(corruptBuffer.createInputStream()).thenReturn(corruptStream) + doReturn(10000L).when(corruptBuffer).size() + + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + doReturn(corruptBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0)) + val localBlockLengths = Seq[Tuple2[BlockId, Long]]( + ShuffleBlockId(0, 0, 0) -> corruptBuffer.size() + ) + + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val remoteBlockLengths = Seq[Tuple2[BlockId, Long]]( + ShuffleBlockId(0, 1, 0) -> corruptBuffer.size() + ) + + val transfer = createMockTransfer( + Map(ShuffleBlockId(0, 0, 0) -> corruptBuffer, ShuffleBlockId(0, 1, 0) -> corruptBuffer)) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (localBmId, localBlockLengths), + (remoteBmId, remoteBlockLengths) + ) + + val taskContext = TaskContext.empty() + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress, + (_, in) => new LimitedInputStream(in, 10000), + 2048, + Int.MaxValue, + Int.MaxValue, + Int.MaxValue, + true) + // Blocks should be returned without exceptions. + assert(Set(iterator.next()._1, iterator.next()._1) === + Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) + } + test("retry corrupt blocks (disabled)") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) From fe22f32041572596a22e5f7441fa0bfbd9608648 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 8 Mar 2018 10:50:09 +0100 Subject: [PATCH 448/774] [SPARK-23620] Splitting thread dump lines by using the br tag ## What changes were proposed in this pull request? I propose to replace `'\n'` by the `
    ` tag in generated html of thread dump page. The `
    ` tag will split thread lines in more reliable way. For now it could look like on the screen shot if the html is proxied and `'\n'` is replaced by another whitespace. The changes allow to more easily read and copy stack traces. ## How was this patch tested? I tested it manually by checking the thread dump page and its source. Author: Maxim Gekk Closes #20762 from MaxGekk/br-thread-dump. --- .../org/apache/spark/status/api/v1/api.scala | 24 ++++++++++++++++++- .../ui/exec/ExecutorThreadDumpPage.scala | 2 +- .../scala/org/apache/spark/util/Utils.scala | 6 ++--- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 369e98b683b1a..971d7e90fa7b8 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -19,6 +19,8 @@ package org.apache.spark.status.api.v1 import java.lang.{Long => JLong} import java.util.Date +import scala.xml.{NodeSeq, Text} + import com.fasterxml.jackson.annotation.JsonIgnoreProperties import com.fasterxml.jackson.databind.annotation.JsonDeserialize @@ -317,11 +319,31 @@ class RuntimeInfo private[spark]( val javaHome: String, val scalaVersion: String) +case class StackTrace(elems: Seq[String]) { + override def toString: String = elems.mkString + + def html: NodeSeq = { + val withNewLine = elems.foldLeft(NodeSeq.Empty) { (acc, elem) => + if (acc.isEmpty) { + acc :+ Text(elem) + } else { + acc :+
    :+ Text(elem) + } + } + + withNewLine + } + + def mkString(start: String, sep: String, end: String): String = { + elems.mkString(start, sep, end) + } +} + case class ThreadStackTrace( val threadId: Long, val threadName: String, val threadState: Thread.State, - val stackTrace: String, + val stackTrace: StackTrace, val blockedByThreadId: Option[Long], val blockedByLock: String, val holdingLocks: Seq[String]) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index 7a9aaf29a8b05..9bb026c60565e 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -60,7 +60,7 @@ private[ui] class ExecutorThreadDumpPage( {thread.threadName} {thread.threadState} {blockedBy}{heldLocks} - {thread.stackTrace} + {thread.stackTrace.html} } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 29d26ea2c85df..5caedeb526469 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -61,7 +61,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} -import org.apache.spark.status.api.v1.ThreadStackTrace +import org.apache.spark.status.api.v1.{StackTrace, ThreadStackTrace} /** CallSite represents a place in user code. It can have a short and a long form. */ private[spark] case class CallSite(shortForm: String, longForm: String) @@ -2118,14 +2118,14 @@ private[spark] object Utils extends Logging { private def threadInfoToThreadStackTrace(threadInfo: ThreadInfo): ThreadStackTrace = { val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap - val stackTrace = threadInfo.getStackTrace.map { frame => + val stackTrace = StackTrace(threadInfo.getStackTrace.map { frame => monitors.get(frame) match { case Some(monitor) => monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}" case None => frame.toString } - }.mkString("\n") + }) // use a set to dedup re-entrant locks that are held at multiple places val heldLocks = From 9bb239c8b174d31981dfff63baa38bb8cecfe913 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 8 Mar 2018 20:19:55 +0900 Subject: [PATCH 449/774] [SPARK-23159][PYTHON] Update cloudpickle to v0.4.3 ## What changes were proposed in this pull request? The version of cloudpickle in PySpark was close to version 0.4.0 with some additional backported fixes and some minor additions for Spark related things. This update removes Spark related changes and matches cloudpickle [v0.4.3](https://github.com/cloudpipe/cloudpickle/releases/tag/v0.4.3): Changes by updating to 0.4.3 include: * Fix pickling of named tuples https://github.com/cloudpipe/cloudpickle/pull/113 * Built in type constructors for PyPy compatibility [here](https://github.com/cloudpipe/cloudpickle/commit/d84980ccaafc7982a50d4e04064011f401f17d1b) * Fix memoryview support https://github.com/cloudpipe/cloudpickle/pull/122 * Improved compatibility with other cloudpickle versions https://github.com/cloudpipe/cloudpickle/pull/128 * Several cleanups https://github.com/cloudpipe/cloudpickle/pull/121 and [here](https://github.com/cloudpipe/cloudpickle/commit/c91aaf110441991307f5097f950764079d0f9652) * [MRG] Regression on pickling classes from the __main__ module https://github.com/cloudpipe/cloudpickle/pull/149 * BUG: Handle instance methods of builtin types https://github.com/cloudpipe/cloudpickle/pull/154 * Fix #129 : do not silence RuntimeError in dump() https://github.com/cloudpipe/cloudpickle/pull/153 ## How was this patch tested? Existing pyspark.tests using python 2.7.14, 3.5.2, 3.6.3 Author: Bryan Cutler Closes #20373 from BryanCutler/pyspark-update-cloudpickle-42-SPARK-23159. --- python/pyspark/accumulators.py | 1 - python/pyspark/cloudpickle.py | 320 ++++++++++++++------------------- python/pyspark/serializers.py | 14 +- 3 files changed, 151 insertions(+), 184 deletions(-) diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 6ef8cf53cc747..7def676b89a24 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -94,7 +94,6 @@ else: import socketserver as SocketServer import threading -from pyspark.cloudpickle import CloudPickler from pyspark.serializers import read_int, PickleSerializer diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 40e91a2d0655d..ea845b98b3db2 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -57,7 +57,6 @@ import types import weakref -from pyspark.util import _exception_message if sys.version < '3': from pickle import Pickler @@ -181,6 +180,32 @@ def _builtin_type(name): return getattr(types, name) +def _make__new__factory(type_): + def _factory(): + return type_.__new__ + return _factory + + +# NOTE: These need to be module globals so that they're pickleable as globals. +_get_dict_new = _make__new__factory(dict) +_get_frozenset_new = _make__new__factory(frozenset) +_get_list_new = _make__new__factory(list) +_get_set_new = _make__new__factory(set) +_get_tuple_new = _make__new__factory(tuple) +_get_object_new = _make__new__factory(object) + +# Pre-defined set of builtin_function_or_method instances that can be +# serialized. +_BUILTIN_TYPE_CONSTRUCTORS = { + dict.__new__: _get_dict_new, + frozenset.__new__: _get_frozenset_new, + set.__new__: _get_set_new, + list.__new__: _get_list_new, + tuple.__new__: _get_tuple_new, + object.__new__: _get_object_new, +} + + if sys.version_info < (3, 4): def _walk_global_ops(code): """ @@ -237,28 +262,16 @@ def dump(self, obj): if 'recursion' in e.args[0]: msg = """Could not pickle object as excessively deep recursion required.""" raise pickle.PicklingError(msg) - except pickle.PickleError: - raise - except Exception as e: - emsg = _exception_message(e) - if "'i' format requires" in emsg: - msg = "Object too large to serialize: %s" % emsg else: - msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg) - print_exec(sys.stderr) - raise pickle.PicklingError(msg) - + raise def save_memoryview(self, obj): - """Fallback to save_string""" - Pickler.save_string(self, str(obj)) + self.save(obj.tobytes()) + dispatch[memoryview] = save_memoryview - def save_buffer(self, obj): - """Fallback to save_string""" - Pickler.save_string(self,str(obj)) - if PY3: - dispatch[memoryview] = save_memoryview - else: + if not PY3: + def save_buffer(self, obj): + self.save(str(obj)) dispatch[buffer] = save_buffer def save_unsupported(self, obj): @@ -318,6 +331,24 @@ def save_function(self, obj, name=None): Determines what kind of function obj is (e.g. lambda, defined at interactive prompt, etc) and handles the pickling appropriately. """ + try: + should_special_case = obj in _BUILTIN_TYPE_CONSTRUCTORS + except TypeError: + # Methods of builtin types aren't hashable in python 2. + should_special_case = False + + if should_special_case: + # We keep a special-cased cache of built-in type constructors at + # global scope, because these functions are structured very + # differently in different python versions and implementations (for + # example, they're instances of types.BuiltinFunctionType in + # CPython, but they're ordinary types.FunctionType instances in + # PyPy). + # + # If the function we've received is in that cache, we just + # serialize it as a lookup into the cache. + return self.save_reduce(_BUILTIN_TYPE_CONSTRUCTORS[obj], (), obj=obj) + write = self.write if name is None: @@ -344,7 +375,7 @@ def save_function(self, obj, name=None): return self.save_global(obj, name) # a builtin_function_or_method which comes in as an attribute of some - # object (e.g., object.__new__, itertools.chain.from_iterable) will end + # object (e.g., itertools.chain.from_iterable) will end # up with modname "__main__" and so end up here. But these functions # have no __code__ attribute in CPython, so the handling for # user-defined functions below will fail. @@ -352,16 +383,13 @@ def save_function(self, obj, name=None): # for different python versions. if not hasattr(obj, '__code__'): if PY3: - if sys.version_info < (3, 4): - raise pickle.PicklingError("Can't pickle %r" % obj) - else: - rv = obj.__reduce_ex__(self.proto) + rv = obj.__reduce_ex__(self.proto) else: if hasattr(obj, '__self__'): rv = (getattr, (obj.__self__, name)) else: raise pickle.PicklingError("Can't pickle %r" % obj) - return Pickler.save_reduce(self, obj=obj, *rv) + return self.save_reduce(obj=obj, *rv) # if func is lambda, def'ed at prompt, is in main, or is nested, then # we'll pickle the actual function object rather than simply saving a @@ -420,20 +448,18 @@ def save_dynamic_class(self, obj): from global modules. """ clsdict = dict(obj.__dict__) # copy dict proxy to a dict - if not isinstance(clsdict.get('__dict__', None), property): - # don't extract dict that are properties - clsdict.pop('__dict__', None) - clsdict.pop('__weakref__', None) - - # hack as __new__ is stored differently in the __dict__ - new_override = clsdict.get('__new__', None) - if new_override: - clsdict['__new__'] = obj.__new__ - - # namedtuple is a special case for Spark where we use the _load_namedtuple function - if getattr(obj, '_is_namedtuple_', False): - self.save_reduce(_load_namedtuple, (obj.__name__, obj._fields)) - return + clsdict.pop('__weakref__', None) + + # On PyPy, __doc__ is a readonly attribute, so we need to include it in + # the initial skeleton class. This is safe because we know that the + # doc can't participate in a cycle with the original class. + type_kwargs = {'__doc__': clsdict.pop('__doc__', None)} + + # If type overrides __dict__ as a property, include it in the type kwargs. + # In Python 2, we can't set this attribute after construction. + __dict__ = clsdict.pop('__dict__', None) + if isinstance(__dict__, property): + type_kwargs['__dict__'] = __dict__ save = self.save write = self.write @@ -453,23 +479,12 @@ def save_dynamic_class(self, obj): # Push the rehydration function. save(_rehydrate_skeleton_class) - # Mark the start of the args for the rehydration function. + # Mark the start of the args tuple for the rehydration function. write(pickle.MARK) - # On PyPy, __doc__ is a readonly attribute, so we need to include it in - # the initial skeleton class. This is safe because we know that the - # doc can't participate in a cycle with the original class. - doc_dict = {'__doc__': clsdict.pop('__doc__', None)} - - # Create and memoize an empty class with obj's name and bases. - save(type(obj)) - save(( - obj.__name__, - obj.__bases__, - doc_dict, - )) - write(pickle.REDUCE) - self.memoize(obj) + # Create and memoize an skeleton class with obj's name and bases. + tp = type(obj) + self.save_reduce(tp, (obj.__name__, obj.__bases__, type_kwargs), obj=obj) # Now save the rest of obj's __dict__. Any references to obj # encountered while saving will point to the skeleton class. @@ -522,17 +537,22 @@ def save_function_tuple(self, func): self.memoize(func) # save the rest of the func data needed by _fill_function - save(f_globals) - save(defaults) - save(dct) - save(func.__module__) - save(closure_values) + state = { + 'globals': f_globals, + 'defaults': defaults, + 'dict': dct, + 'module': func.__module__, + 'closure_values': closure_values, + } + if hasattr(func, '__qualname__'): + state['qualname'] = func.__qualname__ + save(state) write(pickle.TUPLE) write(pickle.REDUCE) # applies _fill_function on the tuple _extract_code_globals_cache = ( weakref.WeakKeyDictionary() - if sys.version_info >= (2, 7) and not hasattr(sys, "pypy_version_info") + if not hasattr(sys, "pypy_version_info") else {}) @classmethod @@ -608,37 +628,22 @@ def save_global(self, obj, name=None, pack=struct.pack): The name of this method is somewhat misleading: all types get dispatched here. """ - if obj.__module__ == "__builtin__" or obj.__module__ == "builtins": - if obj in _BUILTIN_TYPE_NAMES: - return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) - - if name is None: - name = obj.__name__ - - modname = getattr(obj, "__module__", None) - if modname is None: - try: - # whichmodule() could fail, see - # https://bitbucket.org/gutworth/six/issues/63/importing-six-breaks-pickling - modname = pickle.whichmodule(obj, name) - except Exception: - modname = '__main__' + if obj.__module__ == "__main__": + return self.save_dynamic_class(obj) - if modname == '__main__': - themodule = None - else: - __import__(modname) - themodule = sys.modules[modname] - self.modules.add(themodule) + try: + return Pickler.save_global(self, obj, name=name) + except Exception: + if obj.__module__ == "__builtin__" or obj.__module__ == "builtins": + if obj in _BUILTIN_TYPE_NAMES: + return self.save_reduce( + _builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) - if hasattr(themodule, name) and getattr(themodule, name) is obj: - return Pickler.save_global(self, obj, name) + typ = type(obj) + if typ is not obj and isinstance(obj, (type, types.ClassType)): + return self.save_dynamic_class(obj) - typ = type(obj) - if typ is not obj and isinstance(obj, (type, types.ClassType)): - self.save_dynamic_class(obj) - else: - raise pickle.PicklingError("Can't pickle %r" % obj) + raise dispatch[type] = save_global dispatch[types.ClassType] = save_global @@ -709,12 +714,7 @@ def save_property(self, obj): dispatch[property] = save_property def save_classmethod(self, obj): - try: - orig_func = obj.__func__ - except AttributeError: # Python 2.6 - orig_func = obj.__get__(None, object) - if isinstance(obj, classmethod): - orig_func = orig_func.__func__ # Unbind + orig_func = obj.__func__ self.save_reduce(type(obj), (orig_func,), obj=obj) dispatch[classmethod] = save_classmethod dispatch[staticmethod] = save_classmethod @@ -754,64 +754,6 @@ def __getattribute__(self, item): if type(operator.attrgetter) is type: dispatch[operator.attrgetter] = save_attrgetter - def save_reduce(self, func, args, state=None, - listitems=None, dictitems=None, obj=None): - # Assert that args is a tuple or None - if not isinstance(args, tuple): - raise pickle.PicklingError("args from reduce() should be a tuple") - - # Assert that func is callable - if not hasattr(func, '__call__'): - raise pickle.PicklingError("func from reduce should be callable") - - save = self.save - write = self.write - - # Protocol 2 special case: if func's name is __newobj__, use NEWOBJ - if self.proto >= 2 and getattr(func, "__name__", "") == "__newobj__": - cls = args[0] - if not hasattr(cls, "__new__"): - raise pickle.PicklingError( - "args[0] from __newobj__ args has no __new__") - if obj is not None and cls is not obj.__class__: - raise pickle.PicklingError( - "args[0] from __newobj__ args has the wrong class") - args = args[1:] - save(cls) - - save(args) - write(pickle.NEWOBJ) - else: - save(func) - save(args) - write(pickle.REDUCE) - - if obj is not None: - self.memoize(obj) - - # More new special cases (that work with older protocols as - # well): when __reduce__ returns a tuple with 4 or 5 items, - # the 4th and 5th item should be iterators that provide list - # items and dict items (as (key, value) tuples), or None. - - if listitems is not None: - self._batch_appends(listitems) - - if dictitems is not None: - self._batch_setitems(dictitems) - - if state is not None: - save(state) - write(pickle.BUILD) - - def save_partial(self, obj): - """Partial objects do not serialize correctly in python2.x -- this fixes the bugs""" - self.save_reduce(_genpartial, (obj.func, obj.args, obj.keywords)) - - if sys.version_info < (2,7): # 2.7 supports partial pickling - dispatch[partial] = save_partial - - def save_file(self, obj): """Save a file""" try: @@ -867,23 +809,21 @@ def save_not_implemented(self, obj): dispatch[type(Ellipsis)] = save_ellipsis dispatch[type(NotImplemented)] = save_not_implemented - # WeakSet was added in 2.7. - if hasattr(weakref, 'WeakSet'): - def save_weakset(self, obj): - self.save_reduce(weakref.WeakSet, (list(obj),)) - - dispatch[weakref.WeakSet] = save_weakset + def save_weakset(self, obj): + self.save_reduce(weakref.WeakSet, (list(obj),)) - """Special functions for Add-on libraries""" - def inject_addons(self): - """Plug in system. Register additional pickling functions if modules already loaded""" - pass + dispatch[weakref.WeakSet] = save_weakset def save_logger(self, obj): self.save_reduce(logging.getLogger, (obj.name,), obj=obj) dispatch[logging.Logger] = save_logger + """Special functions for Add-on libraries""" + def inject_addons(self): + """Plug in system. Register additional pickling functions if modules already loaded""" + pass + # Tornado support @@ -913,11 +853,12 @@ def dump(obj, file, protocol=2): def dumps(obj, protocol=2): file = StringIO() - - cp = CloudPickler(file,protocol) - cp.dump(obj) - - return file.getvalue() + try: + cp = CloudPickler(file,protocol) + cp.dump(obj) + return file.getvalue() + finally: + file.close() # including pickles unloading functions in this namespace load = pickle.load @@ -1019,18 +960,40 @@ def __reduce__(cls): return cls.__name__ -def _fill_function(func, globals, defaults, dict, module, closure_values): - """ Fills in the rest of function data into the skeleton function object - that were created via _make_skel_func(). +def _fill_function(*args): + """Fills in the rest of function data into the skeleton function object + + The skeleton itself is create by _make_skel_func(). """ - func.__globals__.update(globals) - func.__defaults__ = defaults - func.__dict__ = dict - func.__module__ = module + if len(args) == 2: + func = args[0] + state = args[1] + elif len(args) == 5: + # Backwards compat for cloudpickle v0.4.0, after which the `module` + # argument was introduced + func = args[0] + keys = ['globals', 'defaults', 'dict', 'closure_values'] + state = dict(zip(keys, args[1:])) + elif len(args) == 6: + # Backwards compat for cloudpickle v0.4.1, after which the function + # state was passed as a dict to the _fill_function it-self. + func = args[0] + keys = ['globals', 'defaults', 'dict', 'module', 'closure_values'] + state = dict(zip(keys, args[1:])) + else: + raise ValueError('Unexpected _fill_value arguments: %r' % (args,)) + + func.__globals__.update(state['globals']) + func.__defaults__ = state['defaults'] + func.__dict__ = state['dict'] + if 'module' in state: + func.__module__ = state['module'] + if 'qualname' in state: + func.__qualname__ = state['qualname'] cells = func.__closure__ if cells is not None: - for cell, value in zip(cells, closure_values): + for cell, value in zip(cells, state['closure_values']): if value is not _empty_cell_value: cell_set(cell, value) @@ -1087,13 +1050,6 @@ def _find_module(mod_name): file.close() return path, description -def _load_namedtuple(name, fields): - """ - Loads a class generated by namedtuple - """ - from collections import namedtuple - return namedtuple(name, fields) - """Constructors for 3rd party libraries Note: These can never be renamed due to client compatibility issues""" diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 91a7f093cec19..917e258d8a602 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -68,6 +68,7 @@ xrange = range from pyspark import cloudpickle +from pyspark.util import _exception_message __all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"] @@ -565,7 +566,18 @@ def loads(self, obj, encoding=None): class CloudPickleSerializer(PickleSerializer): def dumps(self, obj): - return cloudpickle.dumps(obj, 2) + try: + return cloudpickle.dumps(obj, 2) + except pickle.PickleError: + raise + except Exception as e: + emsg = _exception_message(e) + if "'i' format requires" in emsg: + msg = "Object too large to serialize: %s" % emsg + else: + msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg) + cloudpickle.print_exec(sys.stderr) + raise pickle.PicklingError(msg) class MarshalSerializer(FramedSerializer): From d6632d185e147fcbe6724545488ad80dce20277e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 8 Mar 2018 20:22:07 +0900 Subject: [PATCH 450/774] [SPARK-23380][PYTHON] Adds a conf for Arrow fallback in toPandas/createDataFrame with Pandas DataFrame ## What changes were proposed in this pull request? This PR adds a configuration to control the fallback of Arrow optimization for `toPandas` and `createDataFrame` with Pandas DataFrame. ## How was this patch tested? Manually tested and unit tests added. You can test this by: **`createDataFrame`** ```python spark.conf.set("spark.sql.execution.arrow.enabled", False) pdf = spark.createDataFrame([[{'a': 1}]]).toPandas() spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True) spark.createDataFrame(pdf, "a: map") ``` ```python spark.conf.set("spark.sql.execution.arrow.enabled", False) pdf = spark.createDataFrame([[{'a': 1}]]).toPandas() spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False) spark.createDataFrame(pdf, "a: map") ``` **`toPandas`** ```python spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True) spark.createDataFrame([[{'a': 1}]]).toPandas() ``` ```python spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False) spark.createDataFrame([[{'a': 1}]]).toPandas() ``` Author: hyukjinkwon Closes #20678 from HyukjinKwon/SPARK-23380-conf. --- docs/sql-programming-guide.md | 5 + python/pyspark/sql/dataframe.py | 120 ++++++++++++------ python/pyspark/sql/session.py | 22 +++- python/pyspark/sql/tests.py | 84 ++++++++++-- .../apache/spark/sql/internal/SQLConf.scala | 13 +- 5 files changed, 186 insertions(+), 58 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 01e2076555ee6..451b814ab6c53 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1689,6 +1689,10 @@ using the call `toPandas()` and when creating a Spark DataFrame from a Pandas Da `createDataFrame(pandas_df)`. To use Arrow when executing these calls, users need to first set the Spark configuration 'spark.sql.execution.arrow.enabled' to 'true'. This is disabled by default. +In addition, optimizations enabled by 'spark.sql.execution.arrow.enabled' could fallback automatically +to non-Arrow optimization implementation if an error occurs before the actual computation within Spark. +This can be controlled by 'spark.sql.execution.arrow.fallback.enabled'. +
    {% include_example dataframe_with_arrow python/sql/arrow.py %} @@ -1800,6 +1804,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.3 to 2.4 - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. + - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 9d8e85cde914f..8f90a367e8bf8 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1992,55 +1992,91 @@ def toPandas(self): timezone = None if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": + use_arrow = True try: - from pyspark.sql.types import _check_dataframe_convert_date, \ - _check_dataframe_localize_timestamps, to_arrow_schema + from pyspark.sql.types import to_arrow_schema from pyspark.sql.utils import require_minimum_pyarrow_version + require_minimum_pyarrow_version() - import pyarrow to_arrow_schema(self.schema) - tables = self._collectAsArrow() - if tables: - table = pyarrow.concat_tables(tables) - pdf = table.to_pandas() - pdf = _check_dataframe_convert_date(pdf, self.schema) - return _check_dataframe_localize_timestamps(pdf, timezone) - else: - return pd.DataFrame.from_records([], columns=self.columns) except Exception as e: - msg = ( - "Note: toPandas attempted Arrow optimization because " - "'spark.sql.execution.arrow.enabled' is set to true. Please set it to false " - "to disable this.") - raise RuntimeError("%s\n%s" % (_exception_message(e), msg)) - else: - pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) - dtype = {} + if self.sql_ctx.getConf("spark.sql.execution.arrow.fallback.enabled", "true") \ + .lower() == "true": + msg = ( + "toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "Attempts non-optimization as " + "'spark.sql.execution.arrow.fallback.enabled' is set to " + "true." % _exception_message(e)) + warnings.warn(msg) + use_arrow = False + else: + msg = ( + "toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "For fallback to non-optimization automatically, please set true to " + "'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e)) + raise RuntimeError(msg) + + # Try to use Arrow optimization when the schema is supported and the required version + # of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled. + if use_arrow: + try: + from pyspark.sql.types import _check_dataframe_convert_date, \ + _check_dataframe_localize_timestamps + import pyarrow + + tables = self._collectAsArrow() + if tables: + table = pyarrow.concat_tables(tables) + pdf = table.to_pandas() + pdf = _check_dataframe_convert_date(pdf, self.schema) + return _check_dataframe_localize_timestamps(pdf, timezone) + else: + return pd.DataFrame.from_records([], columns=self.columns) + except Exception as e: + # We might have to allow fallback here as well but multiple Spark jobs can + # be executed. So, simply fail in this case for now. + msg = ( + "toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed unexpectedly:\n %s\n" + "Note that 'spark.sql.execution.arrow.fallback.enabled' does " + "not have an effect in such failure in the middle of " + "computation." % _exception_message(e)) + raise RuntimeError(msg) + + # Below is toPandas without Arrow optimization. + pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) + + dtype = {} + for field in self.schema: + pandas_type = _to_corrected_pandas_type(field.dataType) + # SPARK-21766: if an integer field is nullable and has null values, it can be + # inferred by pandas as float column. Once we convert the column with NaN back + # to integer type e.g., np.int16, we will hit exception. So we use the inferred + # float type, not the corrected type from the schema in this case. + if pandas_type is not None and \ + not(isinstance(field.dataType, IntegralType) and field.nullable and + pdf[field.name].isnull().any()): + dtype[field.name] = pandas_type + + for f, t in dtype.items(): + pdf[f] = pdf[f].astype(t, copy=False) + + if timezone is None: + return pdf + else: + from pyspark.sql.types import _check_series_convert_timestamps_local_tz for field in self.schema: - pandas_type = _to_corrected_pandas_type(field.dataType) - # SPARK-21766: if an integer field is nullable and has null values, it can be - # inferred by pandas as float column. Once we convert the column with NaN back - # to integer type e.g., np.int16, we will hit exception. So we use the inferred - # float type, not the corrected type from the schema in this case. - if pandas_type is not None and \ - not(isinstance(field.dataType, IntegralType) and field.nullable and - pdf[field.name].isnull().any()): - dtype[field.name] = pandas_type - - for f, t in dtype.items(): - pdf[f] = pdf[f].astype(t, copy=False) - - if timezone is None: - return pdf - else: - from pyspark.sql.types import _check_series_convert_timestamps_local_tz - for field in self.schema: - # TODO: handle nested timestamps, such as ArrayType(TimestampType())? - if isinstance(field.dataType, TimestampType): - pdf[field.name] = \ - _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) - return pdf + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if isinstance(field.dataType, TimestampType): + pdf[field.name] = \ + _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) + return pdf def _collectAsArrow(self): """ diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index b3af9b82953f3..215bb3e5c5173 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -666,8 +666,26 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr try: return self._create_from_pandas_with_arrow(data, schema, timezone) except Exception as e: - warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e)) - # Fallback to create DataFrame without arrow if raise some exception + from pyspark.util import _exception_message + + if self.conf.get("spark.sql.execution.arrow.fallback.enabled", "true") \ + .lower() == "true": + msg = ( + "createDataFrame attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "Attempts non-optimization as " + "'spark.sql.execution.arrow.fallback.enabled' is set to " + "true." % _exception_message(e)) + warnings.warn(msg) + else: + msg = ( + "createDataFrame attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "For fallback to non-optimization automatically, please set true to " + "'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e)) + raise RuntimeError(msg) data = self._convert_from_pandas(data, schema, timezone) if isinstance(schema, StructType): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fa3b7203e10ac..a9fe0b425ad3e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -32,7 +32,9 @@ import datetime import array import ctypes +import warnings import py4j +from contextlib import contextmanager try: import xmlrunner @@ -48,12 +50,13 @@ else: import unittest +from pyspark.util import _exception_message + _pandas_requirement_message = None try: from pyspark.sql.utils import require_minimum_pandas_version require_minimum_pandas_version() except ImportError as e: - from pyspark.util import _exception_message # If Pandas version requirement is not satisfied, skip related tests. _pandas_requirement_message = _exception_message(e) @@ -62,7 +65,6 @@ from pyspark.sql.utils import require_minimum_pyarrow_version require_minimum_pyarrow_version() except ImportError as e: - from pyspark.util import _exception_message # If Arrow version requirement is not satisfied, skip related tests. _pyarrow_requirement_message = _exception_message(e) @@ -195,6 +197,28 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() cls.spark.stop() + @contextmanager + def sql_conf(self, pairs): + """ + A convenient context manager to test some configuration specific logic. This sets + `value` to the configuration `key` and then restores it back when it exits. + """ + assert isinstance(pairs, dict), "pairs should be a dictionary." + + keys = pairs.keys() + new_values = pairs.values() + old_values = [self.spark.conf.get(key, None) for key in keys] + for key, new_value in zip(keys, new_values): + self.spark.conf.set(key, new_value) + try: + yield + finally: + for key, old_value in zip(keys, old_values): + if old_value is None: + self.spark.conf.unset(key) + else: + self.spark.conf.set(key, old_value) + def assertPandasEqual(self, expected, result): msg = ("DataFrames are not equal: " + "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + @@ -3458,6 +3482,8 @@ def setUpClass(cls): cls.spark.conf.set("spark.sql.session.timeZone", tz) cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + # Disable fallback by default to easily detect the failures. + cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false") cls.schema = StructType([ StructField("1_str_t", StringType(), True), StructField("2_int_t", IntegerType(), True), @@ -3493,20 +3519,30 @@ def create_pandas_data_frame(self): data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) return pd.DataFrame(data=data_dict) - def test_unsupported_datatype(self): + def test_toPandas_fallback_enabled(self): + import pandas as pd + + with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) + df = self.spark.createDataFrame([({u'a': 1},)], schema=schema) + with QuietTest(self.sc): + with warnings.catch_warnings(record=True) as warns: + pdf = df.toPandas() + # Catch and check the last UserWarning. + user_warns = [ + warn.message for warn in warns if isinstance(warn.message, UserWarning)] + self.assertTrue(len(user_warns) > 0) + self.assertTrue( + "Attempts non-optimization" in _exception_message(user_warns[-1])) + self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) + + def test_toPandas_fallback_disabled(self): schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported type'): df.toPandas() - df = self.spark.createDataFrame([(None,)], schema="a binary") - with QuietTest(self.sc): - with self.assertRaisesRegexp( - Exception, - 'Unsupported type.*\nNote: toPandas attempted Arrow optimization because'): - df.toPandas() - def test_null_conversion(self): df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + self.data) @@ -3625,7 +3661,7 @@ def test_createDataFrame_with_incorrect_schema(self): pdf = self.create_pandas_data_frame() wrong_schema = StructType(list(reversed(self.schema))) with QuietTest(self.sc): - with self.assertRaisesRegexp(TypeError, ".*field.*can.not.accept.*type"): + with self.assertRaisesRegexp(RuntimeError, ".*No cast.*string.*timestamp.*"): self.spark.createDataFrame(pdf, schema=wrong_schema) def test_createDataFrame_with_names(self): @@ -3650,7 +3686,7 @@ def test_createDataFrame_column_name_encoding(self): def test_createDataFrame_with_single_data_type(self): import pandas as pd with QuietTest(self.sc): - with self.assertRaisesRegexp(TypeError, ".*IntegerType.*tuple"): + with self.assertRaisesRegexp(RuntimeError, ".*IntegerType.*not supported.*"): self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int") def test_createDataFrame_does_not_modify_input(self): @@ -3705,6 +3741,30 @@ def test_createDataFrame_with_int_col_names(self): self.assertEqual(pdf_col_names, df.columns) self.assertEqual(pdf_col_names, df_arrow.columns) + def test_createDataFrame_fallback_enabled(self): + import pandas as pd + + with QuietTest(self.sc): + with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): + with warnings.catch_warnings(record=True) as warns: + df = self.spark.createDataFrame( + pd.DataFrame([[{u'a': 1}]]), "a: map") + # Catch and check the last UserWarning. + user_warns = [ + warn.message for warn in warns if isinstance(warn.message, UserWarning)] + self.assertTrue(len(user_warns) > 0) + self.assertTrue( + "Attempts non-optimization" in _exception_message(user_warns[-1])) + self.assertEqual(df.collect(), [Row(a={u'a': 1})]) + + def test_createDataFrame_fallback_disabled(self): + import pandas as pd + + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported type'): + self.spark.createDataFrame( + pd.DataFrame([[{u'a': 1}]]), "a: map") + # Regression test for SPARK-23314 def test_timestamp_dst(self): import pandas as pd diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ce3f94618edeb..3f96112659c11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1058,7 +1058,7 @@ object SQLConf { .intConf .createWithDefault(100) - val ARROW_EXECUTION_ENABLE = + val ARROW_EXECUTION_ENABLED = buildConf("spark.sql.execution.arrow.enabled") .doc("When true, make use of Apache Arrow for columnar data transfers. Currently available " + "for use with pyspark.sql.DataFrame.toPandas, and " + @@ -1068,6 +1068,13 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ARROW_FALLBACK_ENABLED = + buildConf("spark.sql.execution.arrow.fallback.enabled") + .doc("When true, optimizations enabled by 'spark.sql.execution.arrow.enabled' will " + + "fallback automatically to non-optimized implementations if an error occurs.") + .booleanConf + .createWithDefault(true) + val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH = buildConf("spark.sql.execution.arrow.maxRecordsPerBatch") .doc("When using Apache Arrow, limit the maximum number of records that can be written " + @@ -1518,7 +1525,9 @@ class SQLConf extends Serializable with Logging { def rangeExchangeSampleSizePerPartition: Int = getConf(RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION) - def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE) + def arrowEnabled: Boolean = getConf(ARROW_EXECUTION_ENABLED) + + def arrowFallbackEnabled: Boolean = getConf(ARROW_FALLBACK_ENABLED) def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) From 2cb23a8f51a151970c121015fcbad9beeafa8295 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 8 Mar 2018 20:29:07 +0900 Subject: [PATCH 451/774] [SPARK-23011][SQL][PYTHON] Support alternative function form with group aggregate pandas UDF ## What changes were proposed in this pull request? This PR proposes to support an alternative function from with group aggregate pandas UDF. The current form: ``` def foo(pdf): return ... ``` Takes a single arg that is a pandas DataFrame. With this PR, an alternative form is supported: ``` def foo(key, pdf): return ... ``` The alternative form takes two argument - a tuple that presents the grouping key, and a pandas DataFrame represents the data. ## How was this patch tested? GroupbyApplyTests Author: Li Jin Closes #20295 from icexelloss/SPARK-23011-groupby-apply-key. --- python/pyspark/serializers.py | 18 ++- python/pyspark/sql/functions.py | 25 ++++ python/pyspark/sql/tests.py | 121 ++++++++++++++++-- python/pyspark/sql/types.py | 45 ++++++- python/pyspark/sql/udf.py | 19 +-- python/pyspark/util.py | 16 +++ python/pyspark/worker.py | 49 +++++-- .../python/FlatMapGroupsInPandasExec.scala | 56 +++++++- 8 files changed, 294 insertions(+), 55 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 917e258d8a602..ebf549396f463 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -250,6 +250,15 @@ def __init__(self, timezone): super(ArrowStreamPandasSerializer, self).__init__() self._timezone = timezone + def arrow_to_pandas(self, arrow_column): + from pyspark.sql.types import from_arrow_type, \ + _check_series_convert_date, _check_series_localize_timestamps + + s = arrow_column.to_pandas() + s = _check_series_convert_date(s, from_arrow_type(arrow_column.type)) + s = _check_series_localize_timestamps(s, self._timezone) + return s + def dump_stream(self, iterator, stream): """ Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or @@ -272,16 +281,11 @@ def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ - from pyspark.sql.types import from_arrow_schema, _check_dataframe_convert_date, \ - _check_dataframe_localize_timestamps import pyarrow as pa reader = pa.open_stream(stream) - schema = from_arrow_schema(reader.schema) + for batch in reader: - pdf = batch.to_pandas() - pdf = _check_dataframe_convert_date(pdf, schema) - pdf = _check_dataframe_localize_timestamps(pdf, self._timezone) - yield [c for _, c in pdf.iteritems()] + yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()] def __repr__(self): return "ArrowStreamPandasSerializer" diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b9c0c57262c5d..dc1341ac74d3d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2267,6 +2267,31 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 2| 1.1094003924504583| +---+-------------------+ + Alternatively, the user can define a function that takes two arguments. + In this case, the grouping key will be passed as the first argument and the data will + be passed as the second argument. The grouping key will be passed as a tuple of numpy + data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in + as a `pandas.DataFrame` containing all columns from the original Spark DataFrame. + This is useful when the user does not want to hardcode grouping key in the function. + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> import pandas as pd # doctest: +SKIP + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) # doctest: +SKIP + >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP + ... def mean_udf(key, pdf): + ... # key is a tuple of one numpy.int64, which is the value + ... # of 'id' for the current group + ... return pd.DataFrame([key + (pdf.v.mean(),)]) + >>> df.groupby('id').apply(mean_udf).show() # doctest: +SKIP + +---+---+ + | id| v| + +---+---+ + | 1|1.5| + | 2|6.0| + +---+---+ + .. seealso:: :meth:`pyspark.sql.GroupedData.apply` 3. GROUPED_AGG diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a9fe0b425ad3e..480815d27333f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3903,7 +3903,7 @@ def foo(df): return df with self.assertRaisesRegexp(ValueError, 'Invalid function'): @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP) - def foo(k, v): + def foo(k, v, w): return k @@ -4476,20 +4476,45 @@ def test_supported_types(self): from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col df = self.data.withColumn("arr", array(col("id"))) - foo_udf = pandas_udf( + # Different forms of group map pandas UDF, results of these are the same + + output_schema = StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('arr', ArrayType(LongType())), + StructField('v1', DoubleType()), + StructField('v2', LongType())]) + + udf1 = pandas_udf( lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), - StructType( - [StructField('id', LongType()), - StructField('v', IntegerType()), - StructField('arr', ArrayType(LongType())), - StructField('v1', DoubleType()), - StructField('v2', LongType())]), + output_schema, PandasUDFType.GROUPED_MAP ) - result = df.groupby('id').apply(foo_udf).sort('id').toPandas() - expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) - self.assertPandasEqual(expected, result) + udf2 = pandas_udf( + lambda _, pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + output_schema, + PandasUDFType.GROUPED_MAP + ) + + udf3 = pandas_udf( + lambda key, pdf: pdf.assign(id=key[0], v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + output_schema, + PandasUDFType.GROUPED_MAP + ) + + result1 = df.groupby('id').apply(udf1).sort('id').toPandas() + expected1 = df.toPandas().groupby('id').apply(udf1.func).reset_index(drop=True) + + result2 = df.groupby('id').apply(udf2).sort('id').toPandas() + expected2 = expected1 + + result3 = df.groupby('id').apply(udf3).sort('id').toPandas() + expected3 = expected1 + + self.assertPandasEqual(expected1, result1) + self.assertPandasEqual(expected2, result2) + self.assertPandasEqual(expected3, result3) def test_register_grouped_map_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4648,6 +4673,80 @@ def test_timestamp_dst(self): result = df.groupby('time').apply(foo_udf).sort('time') self.assertPandasEqual(df.toPandas(), result.toPandas()) + def test_udf_with_key(self): + from pyspark.sql.functions import pandas_udf, col, PandasUDFType + df = self.data + pdf = df.toPandas() + + def foo1(key, pdf): + import numpy as np + assert type(key) == tuple + assert type(key[0]) == np.int64 + + return pdf.assign(v1=key[0], + v2=pdf.v * key[0], + v3=pdf.v * pdf.id, + v4=pdf.v * pdf.id.mean()) + + def foo2(key, pdf): + import numpy as np + assert type(key) == tuple + assert type(key[0]) == np.int64 + assert type(key[1]) == np.int32 + + return pdf.assign(v1=key[0], + v2=key[1], + v3=pdf.v * key[0], + v4=pdf.v + key[1]) + + def foo3(key, pdf): + assert type(key) == tuple + assert len(key) == 0 + return pdf.assign(v1=pdf.v * pdf.id) + + # v2 is int because numpy.int64 * pd.Series results in pd.Series + # v3 is long because pd.Series * pd.Series results in pd.Series + udf1 = pandas_udf( + foo1, + 'id long, v int, v1 long, v2 int, v3 long, v4 double', + PandasUDFType.GROUPED_MAP) + + udf2 = pandas_udf( + foo2, + 'id long, v int, v1 long, v2 int, v3 int, v4 int', + PandasUDFType.GROUPED_MAP) + + udf3 = pandas_udf( + foo3, + 'id long, v int, v1 long', + PandasUDFType.GROUPED_MAP) + + # Test groupby column + result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas() + expected1 = pdf.groupby('id')\ + .apply(lambda x: udf1.func((x.id.iloc[0],), x))\ + .sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected1, result1) + + # Test groupby expression + result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas() + expected2 = pdf.groupby(pdf.id % 2)\ + .apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\ + .sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected2, result2) + + # Test complex groupby + result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas() + expected3 = pdf.groupby([pdf.id, pdf.v % 2])\ + .apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\ + .sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected3, result3) + + # Test empty groupby + result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas() + expected4 = udf3.func((), pdf) + self.assertPandasEqual(expected4, result4) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index cd857402db8f7..1632862d3f1ba 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1695,6 +1695,19 @@ def from_arrow_schema(arrow_schema): for field in arrow_schema]) +def _check_series_convert_date(series, data_type): + """ + Cast the series to datetime.date if it's a date type, otherwise returns the original series. + + :param series: pandas.Series + :param data_type: a Spark data type for the series + """ + if type(data_type) == DateType: + return series.dt.date + else: + return series + + def _check_dataframe_convert_date(pdf, schema): """ Correct date type value to use datetime.date. @@ -1705,8 +1718,7 @@ def _check_dataframe_convert_date(pdf, schema): :param schema: a Spark schema of the pandas.DataFrame """ for field in schema: - if type(field.dataType) == DateType: - pdf[field.name] = pdf[field.name].dt.date + pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType) return pdf @@ -1725,6 +1737,29 @@ def _get_local_timezone(): return os.environ.get('TZ', 'dateutil/:') +def _check_series_localize_timestamps(s, timezone): + """ + Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone. + + If the input series is not a timestamp series, then the same series is returned. If the input + series is a timestamp series, then a converted series is returned. + + :param s: pandas.Series + :param timezone: the timezone to convert. if None then use local timezone + :return pandas.Series that have been converted to tz-naive + """ + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + + from pandas.api.types import is_datetime64tz_dtype + tz = timezone or _get_local_timezone() + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if is_datetime64tz_dtype(s.dtype): + return s.dt.tz_convert(tz).dt.tz_localize(None) + else: + return s + + def _check_dataframe_localize_timestamps(pdf, timezone): """ Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone @@ -1736,12 +1771,8 @@ def _check_dataframe_localize_timestamps(pdf, timezone): from pyspark.sql.utils import require_minimum_pandas_version require_minimum_pandas_version() - from pandas.api.types import is_datetime64tz_dtype - tz = timezone or _get_local_timezone() for column, series in pdf.iteritems(): - # TODO: handle nested timestamps, such as ArrayType(TimestampType())? - if is_datetime64tz_dtype(series.dtype): - pdf[column] = series.dt.tz_convert(tz).dt.tz_localize(None) + pdf[column] = _check_series_localize_timestamps(series, timezone) return pdf diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index b9b490874f4fb..ce804c18e9b14 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -17,6 +17,8 @@ """ User-defined function related classes and functions """ +import sys +import inspect import functools from pyspark import SparkContext, since @@ -24,6 +26,7 @@ from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \ _parse_datatype_string, to_arrow_type, to_arrow_schema +from pyspark.util import _get_argspec __all__ = ["UDFRegistration"] @@ -41,18 +44,10 @@ def _create_udf(f, returnType, evalType): PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): - import inspect - import sys from pyspark.sql.utils import require_minimum_pyarrow_version - require_minimum_pyarrow_version() - if sys.version_info[0] < 3: - # `getargspec` is deprecated since python3.0 (incompatible with function annotations). - # See SPARK-23569. - argspec = inspect.getargspec(f) - else: - argspec = inspect.getfullargspec(f) + argspec = _get_argspec(f) if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \ argspec.varargs is None: @@ -61,11 +56,11 @@ def _create_udf(f, returnType, evalType): "Instead, create a 1-arg pandas_udf and ignore the arg in your function." ) - if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF and len(argspec.args) != 1: + if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \ + and len(argspec.args) not in (1, 2): raise ValueError( "Invalid function: pandas_udfs with function type GROUPED_MAP " - "must take a single arg that is a pandas DataFrame." - ) + "must take either one argument (data) or two arguments (key, data).") # Set the name of the UserDefinedFunction object to be the name of function f udf_obj = UserDefinedFunction( diff --git a/python/pyspark/util.py b/python/pyspark/util.py index ad4a0bc68ef41..6837b18b7d7a5 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -15,6 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +import sys +import inspect from py4j.protocol import Py4JJavaError __all__ = [] @@ -45,6 +48,19 @@ def _exception_message(excp): return str(excp) +def _get_argspec(f): + """ + Get argspec of a function. Supports both Python 2 and Python 3. + """ + # `getargspec` is deprecated since python3.0 (incompatible with function annotations). + # See SPARK-23569. + if sys.version_info[0] < 3: + argspec = inspect.getargspec(f) + else: + argspec = inspect.getfullargspec(f) + return argspec + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 89a3a92bc66d6..202cac350aafc 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -34,6 +34,7 @@ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type +from pyspark.util import _get_argspec from pyspark import shuffle pickleSer = PickleSerializer() @@ -91,10 +92,16 @@ def verify_result_length(*a): def wrap_grouped_map_pandas_udf(f, return_type): - def wrapped(*series): + def wrapped(key_series, value_series): import pandas as pd + argspec = _get_argspec(f) + + if len(argspec.args) == 1: + result = f(pd.concat(value_series, axis=1)) + elif len(argspec.args) == 2: + key = tuple(s[0] for s in key_series) + result = f(key, pd.concat(value_series, axis=1)) - result = f(pd.concat(series, axis=1)) if not isinstance(result, pd.DataFrame): raise TypeError("Return type of the user-defined function should be " "pandas.DataFrame, but is {}".format(type(result))) @@ -149,18 +156,36 @@ def read_udfs(pickleSer, infile, eval_type): num_udfs = read_int(infile) udfs = {} call_udf = [] - for i in range(num_udfs): + mapper_str = "" + if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + # Create function like this: + # lambda a: f([a[0]], [a[0], a[1]]) + + # We assume there is only one UDF here because grouped map doesn't + # support combining multiple UDFs. + assert num_udfs == 1 + + # See FlatMapGroupsInPandasExec for how arg_offsets are used to + # distinguish between grouping attributes and data attributes arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) - udfs['f%d' % i] = udf - args = ["a[%d]" % o for o in arg_offsets] - call_udf.append("f%d(%s)" % (i, ", ".join(args))) - # Create function like this: - # lambda a: (f0(a0), f1(a1, a2), f2(a3)) - # In the special case of a single UDF this will return a single result rather - # than a tuple of results; this is the format that the JVM side expects. - mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) - mapper = eval(mapper_str, udfs) + udfs['f'] = udf + split_offset = arg_offsets[0] + 1 + arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]] + arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]] + mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1)) + else: + # Create function like this: + # lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3])) + # In the special case of a single UDF this will return a single result rather + # than a tuple of results; this is the format that the JVM side expects. + for i in range(num_udfs): + arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) + udfs['f%d' % i] = udf + args = ["a[%d]" % o for o in arg_offsets] + call_udf.append("f%d(%s)" % (i, ", ".join(args))) + mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) + mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index c798fe5a92c54..513e174c7733e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.python import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} @@ -75,20 +76,63 @@ case class FlatMapGroupsInPandasExec( val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) - val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray) - val schema = StructType(child.schema.drop(groupingAttributes.length)) val sessionLocalTimeZone = conf.sessionLocalTimeZone val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + // Deduplicate the grouping attributes. + // If a grouping attribute also appears in data attributes, then we don't need to send the + // grouping attribute to Python worker. If a grouping attribute is not in data attributes, + // then we need to send this grouping attribute to python worker. + // + // We use argOffsets to distinguish grouping attributes and data attributes as following: + // + // argOffsets[0] is the length of grouping attributes + // argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes + // argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes + + val dataAttributes = child.output.drop(groupingAttributes.length) + val groupingIndicesInData = groupingAttributes.map { attribute => + dataAttributes.indexWhere(attribute.semanticEquals) + } + + val groupingArgOffsets = new ArrayBuffer[Int] + val nonDupGroupingAttributes = new ArrayBuffer[Attribute] + val nonDupGroupingSize = groupingIndicesInData.count(_ == -1) + + // Non duplicate grouping attributes are added to nonDupGroupingAttributes and + // their offsets are 0, 1, 2 ... + // Duplicate grouping attributes are NOT added to nonDupGroupingAttributes and + // their offsets are n + index, where n is the total number of non duplicate grouping + // attributes and index is the index in the data attributes that the grouping attribute + // is a duplicate of. + + groupingAttributes.zip(groupingIndicesInData).foreach { + case (attribute, index) => + if (index == -1) { + groupingArgOffsets += nonDupGroupingAttributes.length + nonDupGroupingAttributes += attribute + } else { + groupingArgOffsets += index + nonDupGroupingSize + } + } + + val dataArgOffsets = nonDupGroupingAttributes.length until + (nonDupGroupingAttributes.length + dataAttributes.length) + + val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets) + + // Attributes after deduplication + val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes + val dedupSchema = StructType.fromAttributes(dedupAttributes) + inputRDD.mapPartitionsInternal { iter => val grouped = if (groupingAttributes.isEmpty) { Iterator(iter) } else { val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) - val dropGrouping = - UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output) + val dedupProj = UnsafeProjection.create(dedupAttributes, child.output) groupedIter.map { - case (_, groupedRowIter) => groupedRowIter.map(dropGrouping) + case (_, groupedRowIter) => groupedRowIter.map(dedupProj) } } @@ -96,7 +140,7 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, schema, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema, sessionLocalTimeZone, pandasRespectSessionTimeZone) .compute(grouped, context.partitionId(), context) From 7013eea11cb32b1e0038dc751c485da5c94a484b Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 8 Mar 2018 20:38:34 +0900 Subject: [PATCH 452/774] [SPARK-23522][PYTHON] always use sys.exit over builtin exit The exit() builtin is only for interactive use. applications should use sys.exit(). ## What changes were proposed in this pull request? All usage of the builtin `exit()` function is replaced by `sys.exit()`. ## How was this patch tested? I ran `python/run-tests`. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Benjamin Peterson Closes #20682 from benjaminp/sys-exit. --- dev/merge_spark_pr.py | 2 +- dev/run-tests.py | 2 +- examples/src/main/python/avro_inputformat.py | 2 +- examples/src/main/python/kmeans.py | 2 +- examples/src/main/python/logistic_regression.py | 2 +- examples/src/main/python/ml/dataframe_example.py | 2 +- examples/src/main/python/mllib/correlations.py | 2 +- examples/src/main/python/mllib/kmeans.py | 2 +- examples/src/main/python/mllib/logistic_regression.py | 2 +- examples/src/main/python/mllib/random_rdd_generation.py | 2 +- examples/src/main/python/mllib/sampled_rdds.py | 4 ++-- .../python/mllib/streaming_linear_regression_example.py | 2 +- examples/src/main/python/pagerank.py | 2 +- examples/src/main/python/parquet_inputformat.py | 2 +- examples/src/main/python/sort.py | 2 +- .../main/python/sql/streaming/structured_kafka_wordcount.py | 2 +- .../python/sql/streaming/structured_network_wordcount.py | 2 +- .../sql/streaming/structured_network_wordcount_windowed.py | 2 +- .../src/main/python/streaming/direct_kafka_wordcount.py | 2 +- examples/src/main/python/streaming/flume_wordcount.py | 2 +- examples/src/main/python/streaming/hdfs_wordcount.py | 2 +- examples/src/main/python/streaming/kafka_wordcount.py | 2 +- examples/src/main/python/streaming/network_wordcount.py | 2 +- .../src/main/python/streaming/network_wordjoinsentiments.py | 2 +- .../main/python/streaming/recoverable_network_wordcount.py | 2 +- examples/src/main/python/streaming/sql_network_wordcount.py | 2 +- .../src/main/python/streaming/stateful_network_wordcount.py | 2 +- examples/src/main/python/wordcount.py | 2 +- python/pyspark/accumulators.py | 2 +- python/pyspark/broadcast.py | 2 +- python/pyspark/conf.py | 2 +- python/pyspark/context.py | 2 +- python/pyspark/daemon.py | 2 +- python/pyspark/find_spark_home.py | 2 +- python/pyspark/heapq3.py | 3 ++- python/pyspark/ml/classification.py | 3 ++- python/pyspark/ml/clustering.py | 4 +++- python/pyspark/ml/evaluation.py | 3 ++- python/pyspark/ml/feature.py | 2 +- python/pyspark/ml/image.py | 4 +++- python/pyspark/ml/linalg/__init__.py | 2 +- python/pyspark/ml/recommendation.py | 4 +++- python/pyspark/ml/regression.py | 3 ++- python/pyspark/ml/stat.py | 4 +++- python/pyspark/ml/tuning.py | 6 ++++-- python/pyspark/mllib/classification.py | 3 ++- python/pyspark/mllib/clustering.py | 2 +- python/pyspark/mllib/evaluation.py | 3 ++- python/pyspark/mllib/feature.py | 2 +- python/pyspark/mllib/fpm.py | 4 +++- python/pyspark/mllib/linalg/__init__.py | 2 +- python/pyspark/mllib/linalg/distributed.py | 2 +- python/pyspark/mllib/random.py | 3 ++- python/pyspark/mllib/recommendation.py | 3 ++- python/pyspark/mllib/regression.py | 6 ++++-- python/pyspark/mllib/stat/_statistics.py | 2 +- python/pyspark/mllib/tree.py | 3 ++- python/pyspark/mllib/util.py | 2 +- python/pyspark/profiler.py | 3 ++- python/pyspark/rdd.py | 2 +- python/pyspark/serializers.py | 2 +- python/pyspark/shuffle.py | 3 ++- python/pyspark/sql/catalog.py | 3 ++- python/pyspark/sql/column.py | 2 +- python/pyspark/sql/conf.py | 4 +++- python/pyspark/sql/context.py | 2 +- python/pyspark/sql/dataframe.py | 2 +- python/pyspark/sql/functions.py | 2 +- python/pyspark/sql/group.py | 4 +++- python/pyspark/sql/readwriter.py | 2 +- python/pyspark/sql/session.py | 2 +- python/pyspark/sql/streaming.py | 2 +- python/pyspark/sql/types.py | 2 +- python/pyspark/sql/udf.py | 3 ++- python/pyspark/sql/window.py | 2 +- python/pyspark/streaming/util.py | 3 ++- python/pyspark/util.py | 4 +++- python/pyspark/worker.py | 6 +++--- python/setup.py | 6 +++--- 79 files changed, 120 insertions(+), 86 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 6b244d8184b2c..5ea205fbed4aa 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -510,7 +510,7 @@ def main(): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) try: main() except: diff --git a/dev/run-tests.py b/dev/run-tests.py index fe75ef4411c8c..164c1e2200aa9 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -621,7 +621,7 @@ def _test(): import doctest failure_count = doctest.testmod()[0] if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py index 6286ba6541fbd..a18722c687f8b 100644 --- a/examples/src/main/python/avro_inputformat.py +++ b/examples/src/main/python/avro_inputformat.py @@ -61,7 +61,7 @@ Assumes you have Avro data stored in . Reader schema can be optionally specified in [reader_schema_file]. """, file=sys.stderr) - exit(-1) + sys.exit(-1) path = sys.argv[1] diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py index 92e0a3ae2ee60..a42d711fc505f 100755 --- a/examples/src/main/python/kmeans.py +++ b/examples/src/main/python/kmeans.py @@ -49,7 +49,7 @@ def closestPoint(p, centers): if len(sys.argv) != 4: print("Usage: kmeans ", file=sys.stderr) - exit(-1) + sys.exit(-1) print("""WARN: This is a naive implementation of KMeans Clustering and is given as an example! Please refer to examples/src/main/python/ml/kmeans_example.py for an diff --git a/examples/src/main/python/logistic_regression.py b/examples/src/main/python/logistic_regression.py index 01c938454b108..bcc4e0f4e8eae 100755 --- a/examples/src/main/python/logistic_regression.py +++ b/examples/src/main/python/logistic_regression.py @@ -48,7 +48,7 @@ def readPointBatch(iterator): if len(sys.argv) != 3: print("Usage: logistic_regression ", file=sys.stderr) - exit(-1) + sys.exit(-1) print("""WARN: This is a naive implementation of Logistic Regression and is given as an example! diff --git a/examples/src/main/python/ml/dataframe_example.py b/examples/src/main/python/ml/dataframe_example.py index 109f901012c9c..d62cf2338a1fe 100644 --- a/examples/src/main/python/ml/dataframe_example.py +++ b/examples/src/main/python/ml/dataframe_example.py @@ -33,7 +33,7 @@ if __name__ == "__main__": if len(sys.argv) > 2: print("Usage: dataframe_example.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) elif len(sys.argv) == 2: input = sys.argv[1] else: diff --git a/examples/src/main/python/mllib/correlations.py b/examples/src/main/python/mllib/correlations.py index 0e13546b88e67..089504fa7064b 100755 --- a/examples/src/main/python/mllib/correlations.py +++ b/examples/src/main/python/mllib/correlations.py @@ -31,7 +31,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: print("Usage: correlations ()", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonCorrelations") if len(sys.argv) == 2: filepath = sys.argv[1] diff --git a/examples/src/main/python/mllib/kmeans.py b/examples/src/main/python/mllib/kmeans.py index 002fc75799648..1bdb3e9b4a2af 100755 --- a/examples/src/main/python/mllib/kmeans.py +++ b/examples/src/main/python/mllib/kmeans.py @@ -36,7 +36,7 @@ def parseVector(line): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: kmeans ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="KMeans") lines = sc.textFile(sys.argv[1]) data = lines.map(parseVector) diff --git a/examples/src/main/python/mllib/logistic_regression.py b/examples/src/main/python/mllib/logistic_regression.py index d4f1d34e2d8cf..87efe17375226 100755 --- a/examples/src/main/python/mllib/logistic_regression.py +++ b/examples/src/main/python/mllib/logistic_regression.py @@ -42,7 +42,7 @@ def parsePoint(line): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: logistic_regression ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonLR") points = sc.textFile(sys.argv[1]).map(parsePoint) iterations = int(sys.argv[2]) diff --git a/examples/src/main/python/mllib/random_rdd_generation.py b/examples/src/main/python/mllib/random_rdd_generation.py index 729bae30b152c..9a429b5f8abdf 100755 --- a/examples/src/main/python/mllib/random_rdd_generation.py +++ b/examples/src/main/python/mllib/random_rdd_generation.py @@ -29,7 +29,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: print("Usage: random_rdd_generation", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonRandomRDDGeneration") diff --git a/examples/src/main/python/mllib/sampled_rdds.py b/examples/src/main/python/mllib/sampled_rdds.py index b7033ab7daeb3..00e7cf4bbcdbf 100755 --- a/examples/src/main/python/mllib/sampled_rdds.py +++ b/examples/src/main/python/mllib/sampled_rdds.py @@ -29,7 +29,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: print("Usage: sampled_rdds ", file=sys.stderr) - exit(-1) + sys.exit(-1) if len(sys.argv) == 2: datapath = sys.argv[1] else: @@ -43,7 +43,7 @@ numExamples = examples.count() if numExamples == 0: print("Error: Data file had no samples to load.", file=sys.stderr) - exit(1) + sys.exit(1) print('Loaded data with %d examples from file: %s' % (numExamples, datapath)) # Example: RDD.sample() and RDD.takeSample() diff --git a/examples/src/main/python/mllib/streaming_linear_regression_example.py b/examples/src/main/python/mllib/streaming_linear_regression_example.py index f600496867c11..714c9a0de7217 100644 --- a/examples/src/main/python/mllib/streaming_linear_regression_example.py +++ b/examples/src/main/python/mllib/streaming_linear_regression_example.py @@ -36,7 +36,7 @@ if len(sys.argv) != 3: print("Usage: streaming_linear_regression_example.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonLogisticRegressionWithLBFGSExample") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py index 0d6c253d397a0..2c19e8700ab16 100755 --- a/examples/src/main/python/pagerank.py +++ b/examples/src/main/python/pagerank.py @@ -47,7 +47,7 @@ def parseNeighbors(urls): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: pagerank ", file=sys.stderr) - exit(-1) + sys.exit(-1) print("WARN: This is a naive implementation of PageRank and is given as an example!\n" + "Please refer to PageRank implementation provided by graphx", diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index a3f86cf8999cf..83041f0040a0c 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -45,7 +45,7 @@ /path/to/examples/parquet_inputformat.py Assumes you have Parquet data stored in . """, file=sys.stderr) - exit(-1) + sys.exit(-1) path = sys.argv[1] diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py index 81898cf6d5ce6..d3cd985d197e3 100755 --- a/examples/src/main/python/sort.py +++ b/examples/src/main/python/sort.py @@ -25,7 +25,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: sort ", file=sys.stderr) - exit(-1) + sys.exit(-1) spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py b/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py index 9e8a552b3b10b..921067891352a 100644 --- a/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py +++ b/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py @@ -49,7 +49,7 @@ print(""" Usage: structured_kafka_wordcount.py """, file=sys.stderr) - exit(-1) + sys.exit(-1) bootstrapServers = sys.argv[1] subscribeType = sys.argv[2] diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount.py b/examples/src/main/python/sql/streaming/structured_network_wordcount.py index c3284c1d01017..9ac392164735b 100644 --- a/examples/src/main/python/sql/streaming/structured_network_wordcount.py +++ b/examples/src/main/python/sql/streaming/structured_network_wordcount.py @@ -38,7 +38,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: structured_network_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) host = sys.argv[1] port = int(sys.argv[2]) diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py index db672551504b5..c4e3bbf44cd5a 100644 --- a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py +++ b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py @@ -53,7 +53,7 @@ msg = ("Usage: structured_network_wordcount_windowed.py " " []") print(msg, file=sys.stderr) - exit(-1) + sys.exit(-1) host = sys.argv[1] port = int(sys.argv[2]) diff --git a/examples/src/main/python/streaming/direct_kafka_wordcount.py b/examples/src/main/python/streaming/direct_kafka_wordcount.py index 425df309011a0..c5c186c11f79a 100644 --- a/examples/src/main/python/streaming/direct_kafka_wordcount.py +++ b/examples/src/main/python/streaming/direct_kafka_wordcount.py @@ -39,7 +39,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: direct_kafka_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingDirectKafkaWordCount") ssc = StreamingContext(sc, 2) diff --git a/examples/src/main/python/streaming/flume_wordcount.py b/examples/src/main/python/streaming/flume_wordcount.py index 5d6e6dc36d6f9..c8ea92b61ca6e 100644 --- a/examples/src/main/python/streaming/flume_wordcount.py +++ b/examples/src/main/python/streaming/flume_wordcount.py @@ -39,7 +39,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: flume_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingFlumeWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py index f815dd26823d1..f9a5c43a8eaa9 100644 --- a/examples/src/main/python/streaming/hdfs_wordcount.py +++ b/examples/src/main/python/streaming/hdfs_wordcount.py @@ -35,7 +35,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: hdfs_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingHDFSWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py index 704f6602e2297..e9ee08b9fd228 100644 --- a/examples/src/main/python/streaming/kafka_wordcount.py +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -39,7 +39,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: kafka_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingKafkaWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py index 9010fafb425e6..f3099d2517cd5 100644 --- a/examples/src/main/python/streaming/network_wordcount.py +++ b/examples/src/main/python/streaming/network_wordcount.py @@ -35,7 +35,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: network_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingNetworkWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/network_wordjoinsentiments.py b/examples/src/main/python/streaming/network_wordjoinsentiments.py index d51a380a5d5f9..2b5434c0c845a 100644 --- a/examples/src/main/python/streaming/network_wordjoinsentiments.py +++ b/examples/src/main/python/streaming/network_wordjoinsentiments.py @@ -47,7 +47,7 @@ def print_happiest_words(rdd): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: network_wordjoinsentiments.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingNetworkWordJoinSentiments") ssc = StreamingContext(sc, 5) diff --git a/examples/src/main/python/streaming/recoverable_network_wordcount.py b/examples/src/main/python/streaming/recoverable_network_wordcount.py index 52b2639cdf55c..60167dc772544 100644 --- a/examples/src/main/python/streaming/recoverable_network_wordcount.py +++ b/examples/src/main/python/streaming/recoverable_network_wordcount.py @@ -101,7 +101,7 @@ def filterFunc(wordCount): if len(sys.argv) != 5: print("Usage: recoverable_network_wordcount.py " " ", file=sys.stderr) - exit(-1) + sys.exit(-1) host, port, checkpoint, output = sys.argv[1:] ssc = StreamingContext.getOrCreate(checkpoint, lambda: createContext(host, int(port), output)) diff --git a/examples/src/main/python/streaming/sql_network_wordcount.py b/examples/src/main/python/streaming/sql_network_wordcount.py index 7f12281c0e3fe..ab3cfc067994d 100644 --- a/examples/src/main/python/streaming/sql_network_wordcount.py +++ b/examples/src/main/python/streaming/sql_network_wordcount.py @@ -48,7 +48,7 @@ def getSparkSessionInstance(sparkConf): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: sql_network_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) host, port = sys.argv[1:] sc = SparkContext(appName="PythonSqlNetworkWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py index d7bb61e729f18..d5d1eba6c5969 100644 --- a/examples/src/main/python/streaming/stateful_network_wordcount.py +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -39,7 +39,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: stateful_network_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount") ssc = StreamingContext(sc, 1) ssc.checkpoint("checkpoint") diff --git a/examples/src/main/python/wordcount.py b/examples/src/main/python/wordcount.py index 3d5e44d5b2df1..a05e24ff3ff95 100755 --- a/examples/src/main/python/wordcount.py +++ b/examples/src/main/python/wordcount.py @@ -26,7 +26,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: wordcount ", file=sys.stderr) - exit(-1) + sys.exit(-1) spark = SparkSession\ .builder\ diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 7def676b89a24..f730d290273fe 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -265,4 +265,4 @@ def _start_update_server(): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 02fc515fb824a..b3dfc99962a35 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -162,4 +162,4 @@ def clear(self): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index 491b3a81972bc..ab429d9ab10de 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -217,7 +217,7 @@ def _test(): import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 24905f1c97b21..7c664966ed74e 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -1035,7 +1035,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 7f06d4288c872..7bed5216eabf3 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -89,7 +89,7 @@ def shutdown(code): signal.signal(SIGTERM, SIG_DFL) # Send SIGHUP to notify workers of shutdown os.kill(0, SIGHUP) - exit(code) + sys.exit(code) def handle_sigterm(*args): shutdown(1) diff --git a/python/pyspark/find_spark_home.py b/python/pyspark/find_spark_home.py index 212a618b767ab..9cf0e8c8d2fe9 100755 --- a/python/pyspark/find_spark_home.py +++ b/python/pyspark/find_spark_home.py @@ -68,7 +68,7 @@ def is_spark_home(path): return next(path for path in paths if is_spark_home(path)) except StopIteration: print("Could not find valid SPARK_HOME while searching {0}".format(paths), file=sys.stderr) - exit(-1) + sys.exit(-1) if __name__ == "__main__": print(_find_spark_home()) diff --git a/python/pyspark/heapq3.py b/python/pyspark/heapq3.py index b27e91a4cc251..6af084adcf373 100644 --- a/python/pyspark/heapq3.py +++ b/python/pyspark/heapq3.py @@ -884,6 +884,7 @@ def nlargest(n, iterable, key=None): if __name__ == "__main__": import doctest + import sys (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 27ad1e80aa0d3..fbbe3d0307c81 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -16,6 +16,7 @@ # import operator +import sys from multiprocessing.pool import ThreadPool from pyspark import since, keyword_only @@ -2043,4 +2044,4 @@ def _to_java(self): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 6448b76a0da88..b3d5fb17f6b81 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + from pyspark import since, keyword_only from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper @@ -1181,4 +1183,4 @@ def getKeepLastCheckpoint(self): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 695d8ab27cc96..8eaf07645a37f 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -15,6 +15,7 @@ # limitations under the License. # +import sys from abc import abstractmethod, ABCMeta from pyspark import since, keyword_only @@ -446,4 +447,4 @@ def getDistanceMeasure(self): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 04b07e6a05481..f2e357f0bede5 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -3717,4 +3717,4 @@ def setSize(self, value): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 45c936645f2a8..96d702f844839 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -24,6 +24,8 @@ :members: """ +import sys + import numpy as np from pyspark import SparkContext from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string @@ -251,7 +253,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index ad1b487676fa7..6a611a2b5b59d 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -1158,7 +1158,7 @@ def _test(): import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index e8bcbe4cd34cb..a8eae9bd268d3 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + from pyspark import since, keyword_only from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel @@ -480,4 +482,4 @@ def recommendForItemSubset(self, dataset, numUsers): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index f0812bd1d4a39..de0a0fa9f3bf8 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -15,6 +15,7 @@ # limitations under the License. # +import sys import warnings from pyspark import since, keyword_only @@ -1812,4 +1813,4 @@ def __repr__(self): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index 079b0833e1c6d..0eeb5e528434a 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + from pyspark import since, SparkContext from pyspark.ml.common import _java2py, _py2java from pyspark.ml.wrapper import _jvm @@ -151,4 +153,4 @@ def corr(dataset, column, method="pearson"): failure_count, test_count = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 6c0cad6cbaaa1..545e24ca05aa5 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -15,9 +15,11 @@ # limitations under the License. # import itertools -import numpy as np +import sys from multiprocessing.pool import ThreadPool +import numpy as np + from pyspark import since, keyword_only from pyspark.ml import Estimator, Model from pyspark.ml.common import _py2java @@ -727,4 +729,4 @@ def _to_java(self): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index cce703d432b5a..bb281981fd56b 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -16,6 +16,7 @@ # from math import exp +import sys import warnings import numpy @@ -761,7 +762,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index bb687a7da6ffd..0cbabab13a896 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -1048,7 +1048,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 2cd1da3fbf9aa..36cb03369b8c0 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -15,6 +15,7 @@ # limitations under the License. # +import sys import warnings from pyspark import since @@ -542,7 +543,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index e5231dc3a27a8..40ecd2e0ff4be 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -819,7 +819,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": sys.path.pop(0) diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index f58ea5dfb0874..de18dad1f675d 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + import numpy from numpy import array from collections import namedtuple @@ -197,7 +199,7 @@ def _test(): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 7b24b3c74a9fa..60d96d8d5ceb8 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -1370,7 +1370,7 @@ def _test(): import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index 4cb802514be52..bba88542167ad 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -1377,7 +1377,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 61213ddf62e8b..a8833cb446923 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -19,6 +19,7 @@ Python package for random data generation. """ +import sys from functools import wraps from pyspark import since @@ -421,7 +422,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 81182881352bb..3d4eae85132bb 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -16,6 +16,7 @@ # import array +import sys from collections import namedtuple from pyspark import SparkContext, since @@ -326,7 +327,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index ea107d400621d..6be45f51862c9 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -15,9 +15,11 @@ # limitations under the License. # +import sys +import warnings + import numpy as np from numpy import array -import warnings from pyspark import RDD, since from pyspark.streaming.dstream import DStream @@ -837,7 +839,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 49b26446dbc32..3c75b132ecad2 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -313,7 +313,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 619fa16d463f5..b05734ce489d9 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -17,6 +17,7 @@ from __future__ import absolute_import +import sys import random from pyspark import SparkContext, RDD, since @@ -654,7 +655,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 97755807ef262..fc7809387b13a 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -521,7 +521,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py index 44d17bd629473..3c7656ab5758c 100644 --- a/python/pyspark/profiler.py +++ b/python/pyspark/profiler.py @@ -19,6 +19,7 @@ import pstats import os import atexit +import sys from pyspark.accumulators import AccumulatorParam @@ -173,4 +174,4 @@ def stats(self): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 93b8974a7e64a..4b44f76747264 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2498,7 +2498,7 @@ def _test(): globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index ebf549396f463..15753f77bd903 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -715,4 +715,4 @@ def write_with_length(obj, stream): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index e974cda9fc3e1..02c773302e9da 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -23,6 +23,7 @@ import itertools import operator import random +import sys import pyspark.heapq3 as heapq from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ @@ -810,4 +811,4 @@ def load_partition(j): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 6aef0f22340be..b0d8357f4feec 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -15,6 +15,7 @@ # limitations under the License. # +import sys import warnings from collections import namedtuple @@ -306,7 +307,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 43b38a2cd477c..e05a7b33c11a7 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -660,7 +660,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index 792c420ca6386..d929834aeeaa5 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + from pyspark import since from pyspark.rdd import ignore_unicode_prefix @@ -80,7 +82,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(pyspark.sql.conf, globs=globs) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index cc1cd1a5842d9..6cb90399dd616 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -543,7 +543,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8f90a367e8bf8..3fc194d8ec1d1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2231,7 +2231,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index dc1341ac74d3d..dff590983b4d9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2404,7 +2404,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index ab646535c864c..35cac406e0965 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + from pyspark import since from pyspark.rdd import ignore_unicode_prefix, PythonEvalType from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal @@ -299,7 +301,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 9d05ac7cb39be..803f561ece67b 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -970,7 +970,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) sc.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 215bb3e5c5173..e82a9750a0014 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -830,7 +830,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index cc622decfd682..e8966c20a8f42 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -930,7 +930,7 @@ def _test(): globs['spark'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 1632862d3f1ba..826aab97e58db 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1890,7 +1890,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index ce804c18e9b14..24dd06c26089c 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -20,6 +20,7 @@ import sys import inspect import functools +import sys from pyspark import SparkContext, since from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix @@ -397,7 +398,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index bb841a9b9ff7c..e667fba099fb9 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -264,7 +264,7 @@ def _test(): SparkContext('local[4]', 'PythonTest') (failure_count, test_count) = doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE) if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index abbbf6eb9394f..df184471993ff 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -18,6 +18,7 @@ import time from datetime import datetime import traceback +import sys from pyspark import SparkContext, RDD @@ -147,4 +148,4 @@ def rddToFileName(prefix, suffix, timestamp): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 6837b18b7d7a5..ed1bdd0e4be83 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -22,6 +22,8 @@ __all__ = [] +import sys + def _exception_message(excp): """Return the message from an exception as either a str or unicode object. Supports both @@ -65,4 +67,4 @@ def _get_argspec(f): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 202cac350aafc..a1a4336b1e8de 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -205,7 +205,7 @@ def main(infile, outfile): boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests - exit(-1) + sys.exit(-1) version = utf8_deserializer.loads(infile) if version != "%d.%d" % sys.version_info[:2]: @@ -279,7 +279,7 @@ def process(): # Write the error to stderr if it happened while serializing print("PySpark worker failed with exception:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - exit(-1) + sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) @@ -297,7 +297,7 @@ def process(): else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) - exit(-1) + sys.exit(-1) if __name__ == '__main__': diff --git a/python/setup.py b/python/setup.py index 6a98401941d8d..794ceceae3008 100644 --- a/python/setup.py +++ b/python/setup.py @@ -26,7 +26,7 @@ if sys.version_info < (2, 7): print("Python versions prior to 2.7 are not supported for pip installed PySpark.", file=sys.stderr) - exit(-1) + sys.exit(-1) try: exec(open('pyspark/version.py').read()) @@ -98,7 +98,7 @@ def _supports_symlinks(): except: print("Temp path for symlink to parent already exists {0}".format(TEMP_PATH), file=sys.stderr) - exit(-1) + sys.exit(-1) # If you are changing the versions here, please also change ./python/pyspark/sql/utils.py and # ./python/run-tests.py. In case of Arrow, you should also check ./pom.xml. @@ -140,7 +140,7 @@ def _supports_symlinks(): if not os.path.isdir(SCRIPTS_TARGET): print(incorrect_invocation_message, file=sys.stderr) - exit(-1) + sys.exit(-1) # Scripts directive requires a list of each script path and does not take wild cards. script_names = os.listdir(SCRIPTS_TARGET) From 92e7ecbbbd6817378abdbd56541a9c13dcea8659 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 8 Mar 2018 14:18:14 +0100 Subject: [PATCH 453/774] [SPARK-23592][SQL] Add interpreted execution to DecodeUsingSerializer ## What changes were proposed in this pull request? The PR adds interpreted execution to DecodeUsingSerializer. ## How was this patch tested? added UT Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Marco Gaido Closes #20760 from mgaido91/SPARK-23592. --- .../catalyst/expressions/objects/objects.scala | 5 +++++ .../expressions/ObjectExpressionsSuite.scala | 15 +++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 7bbc3c732e782..adf9ddf327c96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1242,6 +1242,11 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) extends UnaryExpression with NonSQLExpression with SerializerSupport { + override def nullSafeEval(input: Any): Any = { + val inputBytes = java.nio.ByteBuffer.wrap(input.asInstanceOf[Array[Byte]]) + serializerInstance.deserialize(inputBytes) + } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val serializer = addImmutableSerializerIfNeeded(ctx) // Code to deserialize. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 346b13277c709..ffeec2a38c532 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.reflect.ClassTag + import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.Row @@ -123,4 +125,17 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(encodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) } } + + test("SPARK-23592: DecodeUsingSerializer should support interpreted execution") { + val cls = classOf[java.lang.Integer] + val inputObject = BoundReference(0, ObjectType(classOf[Array[Byte]]), nullable = true) + val conf = new SparkConf() + Seq(true, false).foreach { useKryo => + val serializer = if (useKryo) new KryoSerializer(conf) else new JavaSerializer(conf) + val input = serializer.newInstance().serialize(new Integer(1)).array() + val decodeUsingSerializer = DecodeUsingSerializer(inputObject, ClassTag(cls), useKryo) + checkEvaluation(decodeUsingSerializer, new Integer(1), InternalRow.fromSeq(Seq(input))) + checkEvaluation(decodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) + } + } } From 3be4adf6485ca19cdc5db23394c3f5a660d7dc6f Mon Sep 17 00:00:00 2001 From: lucio <576632108@qq.com> Date: Thu, 8 Mar 2018 08:03:24 -0600 Subject: [PATCH 454/774] [SPARK-22751][ML] Improve ML RandomForest shuffle performance ## What changes were proposed in this pull request? As I mentioned in [SPARK-22751](https://issues.apache.org/jira/browse/SPARK-22751?jql=project%20%3D%20SPARK%20AND%20component%20%3D%20ML%20AND%20text%20~%20randomforest), there is a shuffle performance problem in ML Randomforest when train a RF in high dimensional data. The reason is that, in _org.apache.spark.tree.impl.RandomForest_, the function _findSplitsBySorting_ will actually flatmap a sparse vector into a dense vector, then in groupByKey there will be a huge shuffle write size. To avoid this, we can add a filter in flatmap, to filter out zero value. And in function _findSplitsForContinuousFeature_, we can infer the number of zero value by _metadata_. In addition, if a feature only contains zero value, _continuousSplits_ will not has the key of feature id. So I add a check when using _continuousSplits_. ## How was this patch tested? Ran model locally using spark-submit. Author: lucio <576632108@qq.com> Closes #20472 from lucio-yz/master. --- .../spark/ml/tree/impl/RandomForest.scala | 52 ++++++++++++++----- .../ml/tree/impl/RandomForestSuite.scala | 23 ++++---- 2 files changed, 50 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 8e514f11e78ea..16f32d76b9984 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -892,13 +892,7 @@ private[spark] object RandomForest extends Logging { // Sample the input only if there are continuous features. val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous) val sampledInput = if (continuousFeatures.nonEmpty) { - // Calculate the number of samples for approximate quantile calculation. - val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) - val fraction = if (requiredSamples < metadata.numExamples) { - requiredSamples.toDouble / metadata.numExamples - } else { - 1.0 - } + val fraction = samplesFractionForFindSplits(metadata) logDebug("fraction of data used for calculating quantiles = " + fraction) input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()) } else { @@ -920,8 +914,9 @@ private[spark] object RandomForest extends Logging { val numPartitions = math.min(continuousFeatures.length, input.partitions.length) input - .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx)))) - .groupByKey(numPartitions) + .flatMap { point => + continuousFeatures.map(idx => (idx, point.features(idx))).filter(_._2 != 0.0) + }.groupByKey(numPartitions) .map { case (idx, samples) => val thresholds = findSplitsForContinuousFeature(samples, metadata, idx) val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh)) @@ -933,7 +928,8 @@ private[spark] object RandomForest extends Logging { val numFeatures = metadata.numFeatures val splits: Array[Array[Split]] = Array.tabulate(numFeatures) { case i if metadata.isContinuous(i) => - val split = continuousSplits(i) + // some features may contain only zero, so continuousSplits will not have a record + val split = continuousSplits.getOrElse(i, Array.empty[Split]) metadata.setNumSplits(i, split.length) split @@ -1003,11 +999,22 @@ private[spark] object RandomForest extends Logging { } else { val numSplits = metadata.numSplits(featureIndex) - // get count for each distinct value - val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) { - case ((m, cnt), x) => - (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1) + // get count for each distinct value except zero value + val partNumSamples = featureSamples.size + val partValueCountMap = scala.collection.mutable.Map[Double, Int]() + featureSamples.foreach { x => + partValueCountMap(x) = partValueCountMap.getOrElse(x, 0) + 1 } + + // Calculate the expected number of samples for finding splits + val numSamples = (samplesFractionForFindSplits(metadata) * metadata.numExamples).toInt + // add expected zero value count and get complete statistics + val valueCountMap: Map[Double, Int] = if (numSamples - partNumSamples > 0) { + partValueCountMap.toMap + (0.0 -> (numSamples - partNumSamples)) + } else { + partValueCountMap.toMap + } + // sort distinct values val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray @@ -1149,4 +1156,21 @@ private[spark] object RandomForest extends Logging { 3 * totalBins } } + + /** + * Calculate the subsample fraction for finding splits + * + * @param metadata decision tree metadata + * @return subsample fraction + */ + private def samplesFractionForFindSplits( + metadata: DecisionTreeMetadata): Double = { + // Calculate the number of samples for approximate quantile calculation. + val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) + if (requiredSamples < metadata.numExamples) { + requiredSamples.toDouble / metadata.numExamples + } else { + 1.0 + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 5f0d26eb5c058..743dacf146fe7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -93,12 +93,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { test("find splits for a continuous feature") { // find splits for normal case { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 200000, 0, 0, Map(), Set(), Array(6), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array.fill(200000)(math.random) + val featureSamples = Array.fill(10000)(math.random).filter(_ != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits.length === 5) assert(fakeMetadata.numSplits(0) === 5) @@ -109,7 +109,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // SPARK-16957: Use midpoints for split values. { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 8, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -117,7 +117,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // possibleSplits <= numSplits { - val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble) + val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble).filter(_ != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((0.0 + 1.0) / 2) assert(splits === expectedSplits) @@ -125,7 +125,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // possibleSplits > numSplits { - val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble) + val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble).filter(_ != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2) assert(splits === expectedSplits) @@ -135,7 +135,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits should not return identical splits // when there are not enough split candidates, reduce the number of splits in metadata { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 12, 0, 0, Map(), Set(), Array(5), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -150,7 +150,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most samples close to the minimum { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 18, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -164,12 +164,13 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most samples close to the maximum { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 17, 0, 0, Map(), Set(), Array(2), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) + val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) + .map(_.toDouble).filter(_ != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((1.0 + 2.0) / 2) assert(splits === expectedSplits) @@ -177,12 +178,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits for constant feature { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 3, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(0, 0, 0).map(_.toDouble) + val featureSamples = Array(0, 0, 0).map(_.toDouble).filter(_ != 0.0) val featureSamplesEmpty = Array.empty[Double] val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits === Array.empty[Double]) From ea480990e726aed59750f1cea8d40adba56d991a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 8 Mar 2018 11:09:15 -0800 Subject: [PATCH 455/774] [SPARK-23628][SQL] calculateParamLength should not return 1 + num of epressions ## What changes were proposed in this pull request? There was a bug in `calculateParamLength` which caused it to return always 1 + the number of expressions. This could lead to Exceptions especially with expressions of type long. ## How was this patch tested? added UT + fixed previous UT Author: Marco Gaido Closes #20772 from mgaido91/SPARK-23628. --- .../expressions/codegen/CodeGenerator.scala | 51 ++++++++++--------- .../expressions/CodeGenerationSuite.scala | 6 +++ .../sql/execution/WholeStageCodegenExec.scala | 5 +- .../execution/WholeStageCodegenSuite.scala | 16 +++--- 4 files changed, 43 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 793824b0b0a2f..fe5e63ec0a2bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1063,31 +1063,6 @@ class CodegenContext { "" } } - - /** - * Returns the length of parameters for a Java method descriptor. `this` contributes one unit - * and a parameter of type long or double contributes two units. Besides, for nullable parameter, - * we also need to pass a boolean parameter for the null status. - */ - def calculateParamLength(params: Seq[Expression]): Int = { - def paramLengthForExpr(input: Expression): Int = { - // For a nullable expression, we need to pass in an extra boolean parameter. - (if (input.nullable) 1 else 0) + javaType(input.dataType) match { - case JAVA_LONG | JAVA_DOUBLE => 2 - case _ => 1 - } - } - // Initial value is 1 for `this`. - 1 + params.map(paramLengthForExpr(_)).sum - } - - /** - * In Java, a method descriptor is valid only if it represents method parameters with a total - * length less than a pre-defined constant. - */ - def isValidParamLength(paramLength: Int): Boolean = { - paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH - } } /** @@ -1538,4 +1513,30 @@ object CodeGenerator extends Logging { def defaultValue(dt: DataType, typedNull: Boolean = false): String = defaultValue(javaType(dt), typedNull) + + /** + * Returns the length of parameters for a Java method descriptor. `this` contributes one unit + * and a parameter of type long or double contributes two units. Besides, for nullable parameter, + * we also need to pass a boolean parameter for the null status. + */ + def calculateParamLength(params: Seq[Expression]): Int = { + def paramLengthForExpr(input: Expression): Int = { + val javaParamLength = javaType(input.dataType) match { + case JAVA_LONG | JAVA_DOUBLE => 2 + case _ => 1 + } + // For a nullable expression, we need to pass in an extra boolean parameter. + (if (input.nullable) 1 else 0) + javaParamLength + } + // Initial value is 1 for `this`. + 1 + params.map(paramLengthForExpr).sum + } + + /** + * In Java, a method descriptor is valid only if it represents method parameters with a total + * length less than a pre-defined constant. + */ + def isValidParamLength(paramLength: Int): Boolean = { + paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 1e48c7b8df9da..64c13e8972036 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -436,4 +436,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { ctx.addImmutableStateIfNotExists("String", mutableState2) assert(ctx.inlinedMutableStates.length == 2) } + + test("SPARK-23628: calculateParamLength should compute properly the param length") { + assert(CodeGenerator.calculateParamLength(Seq.range(0, 100).map(Literal(_))) == 101) + assert(CodeGenerator.calculateParamLength( + Seq.range(0, 100).map(x => Literal(x.toLong))) == 201) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index f89e3fb0e536f..6ddaacfee1a40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -174,8 +174,9 @@ trait CodegenSupport extends SparkPlan { // declaration. val confEnabled = SQLConf.get.wholeStageSplitConsumeFuncByOperator val requireAllOutput = output.forall(parent.usedInputs.contains(_)) - val paramLength = ctx.calculateParamLength(output) + (if (row != null) 1 else 0) - val consumeFunc = if (confEnabled && requireAllOutput && ctx.isValidParamLength(paramLength)) { + val paramLength = CodeGenerator.calculateParamLength(output) + (if (row != null) 1 else 0) + val consumeFunc = if (confEnabled && requireAllOutput + && CodeGenerator.isValidParamLength(paramLength)) { constructDoConsumeFunction(ctx, inputVars, row) } else { parent.doConsume(ctx, inputVars, rowVar) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index ef16292a8e75c..0fb9dd2017a09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.expressions.scalalang.typed -import org.apache.spark.sql.functions.{avg, broadcast, col, max} +import org.apache.spark.sql.functions.{avg, broadcast, col, lit, max} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} @@ -249,12 +249,12 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } test("Skip splitting consume function when parameter number exceeds JVM limit") { - import testImplicits._ - - Seq((255, false), (254, true)).foreach { case (columnNum, hasSplit) => + // since every field is nullable we have 2 params for each input column (one for the value + // and one for the isNull variable) + Seq((128, false), (127, true)).foreach { case (columnNum, hasSplit) => withTempPath { dir => val path = dir.getCanonicalPath - spark.range(10).select(Seq.tabulate(columnNum) {i => ('id + i).as(s"c$i")} : _*) + spark.range(10).select(Seq.tabulate(columnNum) {i => lit(i).as(s"c$i")} : _*) .write.mode(SaveMode.Overwrite).parquet(path) withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255", @@ -263,10 +263,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { val df = spark.read.parquet(path).selectExpr(projection: _*) val plan = df.queryExecution.executedPlan - val wholeStageCodeGenExec = plan.find(p => p match { - case wp: WholeStageCodegenExec => true + val wholeStageCodeGenExec = plan.find { + case _: WholeStageCodegenExec => true case _ => false - }) + } assert(wholeStageCodeGenExec.isDefined) val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 assert(code.body.contains("project_doConsume") == hasSplit) From e7bbca88964d95593fa15eb94643ba519801e352 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 8 Mar 2018 22:02:28 +0100 Subject: [PATCH 456/774] [SPARK-23602][SQL] PrintToStderr prints value also in interpreted mode ## What changes were proposed in this pull request? `PrintToStderr` was doing what is it supposed to only when code generation is enabled. The PR adds the same behavior in interpreted mode too. ## How was this patch tested? added UT Author: Marco Gaido Closes #20773 from mgaido91/SPARK-23602. --- .../spark/sql/catalyst/expressions/misc.scala | 7 +++++- .../expressions/MiscExpressionsSuite.scala | 25 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 4b9006ab5b423..38e4fe44b15ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -31,7 +31,12 @@ case class PrintToStderr(child: Expression) extends UnaryExpression { override def dataType: DataType = child.dataType - protected override def nullSafeEval(input: Any): Any = input + protected override def nullSafeEval(input: Any): Any = { + // scalastyle:off println + System.err.println(outputPrefix + input) + // scalastyle:on println + input + } private val outputPrefix = s"Result of ${child.simpleString} is " diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index a21c139fe71d0..c3d08bf68c7bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.io.PrintStream + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ @@ -43,4 +45,27 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Length(Uuid()), 36) assert(evaluateWithoutCodegen(Uuid()) !== evaluateWithoutCodegen(Uuid())) } + + test("PrintToStderr") { + val inputExpr = Literal(1) + val systemErr = System.err + + val (outputEval, outputCodegen) = try { + val errorStream = new java.io.ByteArrayOutputStream() + System.setErr(new PrintStream(errorStream)) + // check without codegen + checkEvaluationWithoutCodegen(PrintToStderr(inputExpr), 1) + val outputEval = errorStream.toString + errorStream.reset() + // check with codegen + checkEvaluationWithGeneratedMutableProjection(PrintToStderr(inputExpr), 1) + val outputCodegen = errorStream.toString + (outputEval, outputCodegen) + } finally { + System.setErr(systemErr) + } + + assert(outputCodegen.contains(s"Result of $inputExpr is 1")) + assert(outputEval.contains(s"Result of $inputExpr is 1")) + } } From d90e77bd0ec19f8ba9198a24ec2ab3db7708eca8 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 8 Mar 2018 14:58:40 -0800 Subject: [PATCH 457/774] [SPARK-23271][SQL] Parquet output contains only _SUCCESS file after writing an empty dataframe ## What changes were proposed in this pull request? Below are the two cases. ``` SQL case 1 scala> List.empty[String].toDF().rdd.partitions.length res18: Int = 1 ``` When we write the above data frame as parquet, we create a parquet file containing just the schema of the data frame. Case 2 ``` SQL scala> val anySchema = StructType(StructField("anyName", StringType, nullable = false) :: Nil) anySchema: org.apache.spark.sql.types.StructType = StructType(StructField(anyName,StringType,false)) scala> spark.read.schema(anySchema).csv("/tmp/empty_folder").rdd.partitions.length res22: Int = 0 ``` For the 2nd case, since number of partitions = 0, we don't call the write task (the task has logic to create the empty metadata only parquet file) The fix is to create a dummy single partition RDD and set up the write task based on it to ensure the metadata-only file. ## How was this patch tested? A new test is added to DataframeReaderWriterSuite. Author: Dilip Biswal Closes #20525 from dilipbiswal/spark-23271. --- docs/sql-programming-guide.md | 1 + .../datasources/FileFormatWriter.scala | 15 ++++++++++++--- .../spark/sql/FileBasedDataSourceSuite.scala | 18 ++++++++++++++++++ .../sql/test/DataFrameReaderWriterSuite.scala | 1 - 4 files changed, 31 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 451b814ab6c53..d2132d2ae7441 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1805,6 +1805,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. + - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 1d80a69bc5a1d..401597f967218 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -190,9 +190,18 @@ object FileFormatWriter extends Logging { global = false, child = plan).execute() } - val ret = new Array[WriteTaskResult](rdd.partitions.length) + + // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single + // partition rdd to make sure we at least set up one write task to write the metadata. + val rddWithNonEmptyPartitions = if (rdd.partitions.length == 0) { + sparkSession.sparkContext.parallelize(Array.empty[InternalRow], 1) + } else { + rdd + } + + val ret = new Array[WriteTaskResult](rddWithNonEmptyPartitions.partitions.length) sparkSession.sparkContext.runJob( - rdd, + rddWithNonEmptyPartitions, (taskContext: TaskContext, iter: Iterator[InternalRow]) => { executeTask( description = description, @@ -202,7 +211,7 @@ object FileFormatWriter extends Logging { committer, iterator = iter) }, - 0 until rdd.partitions.length, + rddWithNonEmptyPartitions.partitions.indices, (index, res: WriteTaskResult) => { committer.onTaskCommit(res.commitMsg) ret(index) = res diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 73e3df3b6202e..bd3071bcf9010 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -89,6 +90,23 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } + Seq("orc", "parquet").foreach { format => + test(s"SPARK-23271 empty RDD when saved should write a metadata only file - $format") { + withTempPath { outputPath => + val df = spark.emptyDataFrame.select(lit(1).as("i")) + df.write.format(format).save(outputPath.toString) + val partFiles = outputPath.listFiles() + .filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_")) + assert(partFiles.length === 1) + + // Now read the file. + val df1 = spark.read.format(format).load(outputPath.toString) + checkAnswer(df1, Seq.empty[Row]) + assert(df1.schema.equals(df.schema.asNullable)) + } + } + } + allFileBasedDataSources.foreach { format => test(s"SPARK-22146 read files containing special characters using $format") { withTempDir { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 8c9bb7d56a35f..a707a88dfa670 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -301,7 +301,6 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be intercept[AnalysisException] { spark.range(10).write.format("csv").mode("overwrite").partitionBy("id").save(path) } - spark.emptyDataFrame.write.format("parquet").mode("overwrite").save(path) } } From 2c3673680e16f88f1d1cd73a3f7445ded5b3daa8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 9 Mar 2018 10:36:38 -0800 Subject: [PATCH 458/774] [SPARK-23630][YARN] Allow user's hadoop conf customizations to take effect. This change restores functionality that was inadvertently removed as part of the fix for SPARK-22372. Also modified an existing unit test to make sure the feature works as intended. Author: Marcelo Vanzin Closes #20776 from vanzin/SPARK-23630. --- .../apache/spark/deploy/SparkHadoopUtil.scala | 11 +++++- .../org/apache/spark/deploy/yarn/Client.scala | 14 ++++---- .../spark/deploy/yarn/YarnClusterSuite.scala | 34 ++++++++++++++----- 3 files changed, 44 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index e14f9845e6db6..177295fb7af0f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -111,7 +111,9 @@ class SparkHadoopUtil extends Logging { * subsystems. */ def newConfiguration(conf: SparkConf): Configuration = { - SparkHadoopUtil.newConfiguration(conf) + val hadoopConf = SparkHadoopUtil.newConfiguration(conf) + hadoopConf.addResource(SparkHadoopUtil.SPARK_HADOOP_CONF_FILE) + hadoopConf } /** @@ -435,6 +437,13 @@ object SparkHadoopUtil { */ private[spark] val UPDATE_INPUT_METRICS_INTERVAL_RECORDS = 1000 + /** + * Name of the file containing the gateway's Hadoop configuration, to be overlayed on top of the + * cluster's Hadoop config. It is up to the Spark code launching the application to create + * this file if it's desired. If the file doesn't exist, it will just be ignored. + */ + private[spark] val SPARK_HADOOP_CONF_FILE = "__spark_hadoop_conf__.xml" + def get: SparkHadoopUtil = instance /** diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 8cd3cd9746a3a..28087dee831d1 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -696,7 +696,13 @@ private[spark] class Client( } } - Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey => + // SPARK-23630: during testing, Spark scripts filter out hadoop conf dirs so that user's + // environments do not interfere with tests. This allows a special env variable during + // tests so that custom conf dirs can be used by unit tests. + val confDirs = Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR") ++ + (if (Utils.isTesting) Seq("SPARK_TEST_HADOOP_CONF_DIR") else Nil) + + confDirs.foreach { envKey => sys.env.get(envKey).foreach { path => val dir = new File(path) if (dir.isDirectory()) { @@ -753,7 +759,7 @@ private[spark] class Client( // Save the YARN configuration into a separate file that will be overlayed on top of the // cluster's Hadoop conf. - confStream.putNextEntry(new ZipEntry(SPARK_HADOOP_CONF_FILE)) + confStream.putNextEntry(new ZipEntry(SparkHadoopUtil.SPARK_HADOOP_CONF_FILE)) hadoopConf.writeXml(confStream) confStream.closeEntry() @@ -1220,10 +1226,6 @@ private object Client extends Logging { // Name of the file in the conf archive containing Spark configuration. val SPARK_CONF_FILE = "__spark_conf__.properties" - // Name of the file containing the gateway's Hadoop configuration, to be overlayed on top of the - // cluster's Hadoop config. - val SPARK_HADOOP_CONF_FILE = "__spark_hadoop_conf__.xml" - // Subdirectory where the user's python files (not archives) will be placed. val LOCALIZED_PYTHON_DIR = "__pyfiles__" diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 5003326b440bf..33d400a5b1b2e 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -114,12 +114,25 @@ class YarnClusterSuite extends BaseYarnClusterSuite { )) } - test("yarn-cluster should respect conf overrides in SparkHadoopUtil (SPARK-16414)") { + test("yarn-cluster should respect conf overrides in SparkHadoopUtil (SPARK-16414, SPARK-23630)") { + // Create a custom hadoop config file, to make sure it's contents are propagated to the driver. + val customConf = Utils.createTempDir() + val coreSite = """ + | + | + | spark.test.key + | testvalue + | + | + |""".stripMargin + Files.write(coreSite, new File(customConf, "core-site.xml"), StandardCharsets.UTF_8) + val result = File.createTempFile("result", null, tempDir) val finalState = runSpark(false, mainClassName(YarnClusterDriverUseSparkHadoopUtilConf.getClass), - appArgs = Seq("key=value", result.getAbsolutePath()), - extraConf = Map("spark.hadoop.key" -> "value")) + appArgs = Seq("key=value", "spark.test.key=testvalue", result.getAbsolutePath()), + extraConf = Map("spark.hadoop.key" -> "value"), + extraEnv = Map("SPARK_TEST_HADOOP_CONF_DIR" -> customConf.getAbsolutePath())) checkResult(finalState, result) } @@ -319,13 +332,13 @@ private object YarnClusterDriverWithFailure extends Logging with Matchers { private object YarnClusterDriverUseSparkHadoopUtilConf extends Logging with Matchers { def main(args: Array[String]): Unit = { - if (args.length != 2) { + if (args.length < 2) { // scalastyle:off println System.err.println( s""" |Invalid command line: ${args.mkString(" ")} | - |Usage: YarnClusterDriverUseSparkHadoopUtilConf [hadoopConfKey=value] [result file] + |Usage: YarnClusterDriverUseSparkHadoopUtilConf [hadoopConfKey=value]+ [result file] """.stripMargin) // scalastyle:on println System.exit(1) @@ -335,11 +348,16 @@ private object YarnClusterDriverUseSparkHadoopUtilConf extends Logging with Matc .set("spark.extraListeners", classOf[SaveExecutorInfo].getName) .setAppName("yarn test using SparkHadoopUtil's conf")) - val kv = args(0).split("=") - val status = new File(args(1)) + val kvs = args.take(args.length - 1).map { kv => + val parsed = kv.split("=") + (parsed(0), parsed(1)) + } + val status = new File(args.last) var result = "failure" try { - SparkHadoopUtil.get.conf.get(kv(0)) should be (kv(1)) + kvs.foreach { case (k, v) => + SparkHadoopUtil.get.conf.get(k) should be (v) + } result = "success" } finally { Files.write(result, status, StandardCharsets.UTF_8) From 2ca9bb083c515511d2bfee271fc3e0269aceb9d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20=C5=9Awitakowski?= Date: Fri, 9 Mar 2018 14:29:31 -0800 Subject: [PATCH 459/774] [SPARK-23173][SQL] Avoid creating corrupt parquet files when loading data from JSON MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? The from_json() function accepts an additional parameter, where the user might specify the schema. The issue is that the specified schema might not be compatible with data. In particular, the JSON data might be missing data for fields declared as non-nullable in the schema. The from_json() function does not verify the data against such errors. When data with missing fields is sent to the parquet encoder, there is no verification either. The end results is a corrupt parquet file. To avoid corruptions, make sure that all fields in the user-specified schema are set to be nullable. Since this changes the behavior of a public function, we need to include it in release notes. The behavior can be reverted by setting `spark.sql.fromJsonForceNullableSchema=false` ## How was this patch tested? Added two new tests. Author: Michał Świtakowski Closes #20694 from mswit-databricks/SPARK-23173. --- .../expressions/jsonExpressions.scala | 22 +++++++++----- .../apache/spark/sql/internal/SQLConf.scala | 8 +++++ .../expressions/JsonExpressionsSuite.scala | 30 ++++++++++++++++++- .../datasources/parquet/ParquetIOSuite.scala | 19 ++++++++++++ 4 files changed, 70 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 18b4fed597447..fdd672c416a03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -515,10 +516,15 @@ case class JsonToStructs( child: Expression, timeZoneId: Option[String] = None) extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { - override def nullable: Boolean = true - def this(schema: DataType, options: Map[String, String], child: Expression) = - this(schema, options, child, None) + val forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA) + + // The JSON input data might be missing certain fields. We force the nullability + // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder + // can generate incorrect files if values are missing in columns declared as non-nullable. + val nullableSchema = if (forceNullableSchema) schema.asNullable else schema + + override def nullable: Boolean = true // Used in `FunctionRegistry` def this(child: Expression, schema: Expression) = @@ -535,22 +541,22 @@ case class JsonToStructs( child = child, timeZoneId = None) - override def checkInputDataTypes(): TypeCheckResult = schema match { + override def checkInputDataTypes(): TypeCheckResult = nullableSchema match { case _: StructType | ArrayType(_: StructType, _) => super.checkInputDataTypes() case _ => TypeCheckResult.TypeCheckFailure( - s"Input schema ${schema.simpleString} must be a struct or an array of structs.") + s"Input schema ${nullableSchema.simpleString} must be a struct or an array of structs.") } @transient - lazy val rowSchema = schema match { + lazy val rowSchema = nullableSchema match { case st: StructType => st case ArrayType(st: StructType, _) => st } // This converts parsed rows to the desired output by the given schema. @transient - lazy val converter = schema match { + lazy val converter = nullableSchema match { case _: StructType => (rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null case ArrayType(_: StructType, _) => @@ -563,7 +569,7 @@ case class JsonToStructs( rowSchema, new JSONOptions(options + ("mode" -> FailFastMode.name), timeZoneId.get)) - override def dataType: DataType = schema + override def dataType: DataType = nullableSchema override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3f96112659c11..11864bd1b1847 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -493,6 +493,14 @@ object SQLConf { .stringConf .createWithDefault("_corrupt_record") + val FROM_JSON_FORCE_NULLABLE_SCHEMA = buildConf("spark.sql.fromJsonForceNullableSchema") + .internal() + .doc("When true, force the output schema of the from_json() function to be nullable " + + "(including all the fields). Otherwise, the schema might not be compatible with" + + "actual data, which leads to curruptions.") + .booleanConf + .createWithDefault(true) + val BROADCAST_TIMEOUT = buildConf("spark.sql.broadcastTimeout") .doc("Timeout in seconds for the broadcast wait time in broadcast joins.") .timeConf(TimeUnit.SECONDS) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index a0bbe02f92354..7812319756eae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -22,11 +22,13 @@ import java.util.Calendar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeTestUtils, DateTimeUtils, GenericArrayData, PermissiveMode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { +class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with PlanTestBase { val json = """ |{"store":{"fruit":[{"weight":8,"type":"apple"},{"weight":9,"type":"pear"}], @@ -680,4 +682,30 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ) } } + + test("from_json missing fields") { + for (forceJsonNullableSchema <- Seq(false, true)) { + withSQLConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA.key -> forceJsonNullableSchema.toString) { + val input = + """{ + | "a": 1, + | "c": "foo" + |} + |""".stripMargin + val jsonSchema = new StructType() + .add("a", LongType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + val output = InternalRow(1L, null, UTF8String.fromString("foo")) + checkEvaluation( + JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId), + output + ) + val schema = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId) + .dataType + val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema + assert(schemaToCompare == schema) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 3af80930ec807..0b3e8ca060d87 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -780,6 +781,24 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { assert(option.compressionCodecClassName == "UNCOMPRESSED") } } + + test("SPARK-23173 Writing a file with data converted from JSON with and incorrect user schema") { + withTempPath { file => + val jsonData = + """{ + | "a": 1, + | "c": "foo" + |} + |""".stripMargin + val jsonSchema = new StructType() + .add("a", LongType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + spark.range(1).select(from_json(lit(jsonData), jsonSchema) as "input") + .write.parquet(file.getAbsolutePath) + checkAnswer(spark.read.parquet(file.getAbsolutePath), Seq(Row(Row(1, null, "foo")))) + } + } } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) From 10b0657b035641ce735055bba2c8459e71bc2400 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Fri, 9 Mar 2018 15:41:19 -0800 Subject: [PATCH 460/774] [SPARK-23624][SQL] Revise doc of method pushFilters in Datasource V2 ## What changes were proposed in this pull request? Revise doc of method pushFilters in SupportsPushDownFilters/SupportsPushDownCatalystFilters In `FileSourceStrategy`, except `partitionKeyFilters`(the references of which is subset of partition keys), all filters needs to be evaluated after scanning. Otherwise, Spark will get wrong result from data sources like Orc/Parquet. This PR is to improve the doc. Author: Wang Gengliang Closes #20769 from gengliangwang/revise_pushdown_doc. --- .../sql/sources/v2/reader/SupportsPushDownCatalystFilters.java | 2 +- .../spark/sql/sources/v2/reader/SupportsPushDownFilters.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index 98224102374aa..290d614805ac7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -34,7 +34,7 @@ public interface SupportsPushDownCatalystFilters extends DataSourceReader { /** - * Pushes down filters, and returns unsupported filters. + * Pushes down filters, and returns filters that need to be evaluated after scanning. */ Expression[] pushCatalystFilters(Expression[] filters); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index f35c711b0387a..1cff024232a44 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -32,7 +32,7 @@ public interface SupportsPushDownFilters extends DataSourceReader { /** - * Pushes down filters, and returns unsupported filters. + * Pushes down filters, and returns filters that need to be evaluated after scanning. */ Filter[] pushFilters(Filter[] filters); From 1a54f48b6744032b16543594651ee6d5e3ad4bda Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 9 Mar 2018 15:54:55 -0800 Subject: [PATCH 461/774] [SPARK-23510][SQL][FOLLOW-UP] Support Hive 2.2 and Hive 2.3 metastore ## What changes were proposed in this pull request? In the PR https://github.com/apache/spark/pull/20671, I forgot to update the doc about this new support. ## How was this patch tested? N/A Author: gatorsmile Closes #20789 from gatorsmile/docUpdate. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d2132d2ae7441..0e092e0e37ccf 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -2229,7 +2229,7 @@ referencing a singleton. Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 2.1.1. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). +(from 0.12.0 to 2.3.2. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses From b6f837c9d3cb0f76f0a52df37e34aea8944f6867 Mon Sep 17 00:00:00 2001 From: DylanGuedes Date: Sat, 10 Mar 2018 19:48:29 +0900 Subject: [PATCH 462/774] [PYTHON] Changes input variable to not conflict with built-in function Signed-off-by: DylanGuedes ## What changes were proposed in this pull request? Changes variable name conflict: [input is a built-in python function](https://stackoverflow.com/questions/20670732/is-input-a-keyword-in-python). ## How was this patch tested? I runned the example and it works fine. Author: DylanGuedes Closes #20775 from DylanGuedes/input_variable. --- examples/src/main/python/ml/dataframe_example.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/src/main/python/ml/dataframe_example.py b/examples/src/main/python/ml/dataframe_example.py index d62cf2338a1fe..cabc3de68f2f4 100644 --- a/examples/src/main/python/ml/dataframe_example.py +++ b/examples/src/main/python/ml/dataframe_example.py @@ -17,7 +17,7 @@ """ An example of how to use DataFrame for ML. Run with:: - bin/spark-submit examples/src/main/python/ml/dataframe_example.py + bin/spark-submit examples/src/main/python/ml/dataframe_example.py """ from __future__ import print_function @@ -35,18 +35,18 @@ print("Usage: dataframe_example.py ", file=sys.stderr) sys.exit(-1) elif len(sys.argv) == 2: - input = sys.argv[1] + input_path = sys.argv[1] else: - input = "data/mllib/sample_libsvm_data.txt" + input_path = "data/mllib/sample_libsvm_data.txt" spark = SparkSession \ .builder \ .appName("DataFrameExample") \ .getOrCreate() - # Load input data - print("Loading LIBSVM file with UDT from " + input + ".") - df = spark.read.format("libsvm").load(input).cache() + # Load an input file + print("Loading LIBSVM file with UDT from " + input_path + ".") + df = spark.read.format("libsvm").load(input_path).cache() print("Schema from LIBSVM:") df.printSchema() print("Loaded training data as a DataFrame with " + From b304e07e0671faf96530f9d8f49c55a83b07fa15 Mon Sep 17 00:00:00 2001 From: Xiayun Sun Date: Mon, 12 Mar 2018 22:13:28 +0900 Subject: [PATCH 463/774] [SPARK-23462][SQL] improve missing field error message in `StructType` ## What changes were proposed in this pull request? The error message ```s"""Field "$name" does not exist."""``` is thrown when looking up an unknown field in StructType. In the error message, we should also contain the information about which columns/fields exist in this struct. ## How was this patch tested? Added new unit tests. Note: I created a new `StructTypeSuite.scala` as I couldn't find an existing suite that's suitable to place these tests. I may be missing something so feel free to propose new locations. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Xiayun Sun Closes #20649 from xysun/SPARK-23462. --- .../apache/spark/sql/types/StructType.scala | 11 +++-- .../spark/sql/types/StructTypeSuite.scala | 40 +++++++++++++++++++ 2 files changed, 48 insertions(+), 3 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index d5011c3cb87e9..362676b252126 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -271,7 +271,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru */ def apply(name: String): StructField = { nameToField.getOrElse(name, - throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) + throw new IllegalArgumentException( + s"""Field "$name" does not exist. + |Available fields: ${fieldNames.mkString(", ")}""".stripMargin)) } /** @@ -284,7 +286,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru val nonExistFields = names -- fieldNamesSet if (nonExistFields.nonEmpty) { throw new IllegalArgumentException( - s"Field ${nonExistFields.mkString(",")} does not exist.") + s"""Nonexistent field(s): ${nonExistFields.mkString(", ")}. + |Available fields: ${fieldNames.mkString(", ")}""".stripMargin) } // Preserve the original order of fields. StructType(fields.filter(f => names.contains(f.name))) @@ -297,7 +300,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru */ def fieldIndex(name: String): Int = { nameToIndex.getOrElse(name, - throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) + throw new IllegalArgumentException( + s"""Field "$name" does not exist. + |Available fields: ${fieldNames.mkString(", ")}""".stripMargin)) } private[sql] def getFieldIndex(name: String): Option[Int] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala new file mode 100644 index 0000000000000..c6ca8bb005429 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.SparkFunSuite + +class StructTypeSuite extends SparkFunSuite { + + val s = StructType.fromDDL("a INT, b STRING") + + test("lookup a single missing field should output existing fields") { + val e = intercept[IllegalArgumentException](s("c")).getMessage + assert(e.contains("Available fields: a, b")) + } + + test("lookup a set of missing fields should output existing fields") { + val e = intercept[IllegalArgumentException](s(Set("a", "c"))).getMessage + assert(e.contains("Available fields: a, b")) + } + + test("lookup fieldIndex for missing field should output existing fields") { + val e = intercept[IllegalArgumentException](s.fieldIndex("c")).getMessage + assert(e.contains("Available fields: a, b")) + } +} From d5b41aea62201cd5b1baad2f68f5fc7eb99c62c5 Mon Sep 17 00:00:00 2001 From: Jooseong Kim Date: Mon, 12 Mar 2018 11:31:34 -0700 Subject: [PATCH 464/774] [SPARK-23618][K8S][BUILD] Initialize BUILD_ARGS in docker-image-tool.sh ## What changes were proposed in this pull request? This change initializes BUILD_ARGS to an empty array when $SPARK_HOME/RELEASE exists. In function build, "local BUILD_ARGS" effectively creates an array of one element where the first and only element is an empty string, so "${BUILD_ARGS[]}" expands to "" and passes an extra argument to docker. Setting BUILD_ARGS to an empty array makes "${BUILD_ARGS[]}" expand to nothing. ## How was this patch tested? Manually tested. $ cat RELEASE Spark 2.3.0 (git revision a0d7949896) built for Hadoop 2.7.3 Build flags: -Phadoop-2.7 -Phive -Phive-thriftserver -Pkafka-0-8 -Pmesos -Pyarn -Pkubernetes -Pflume -Psparkr -DzincPort=3036 $ ./bin/docker-image-tool.sh -m t testing build Sending build context to Docker daemon 256.4MB ... vanzin Author: Jooseong Kim Closes #20791 from jooseong/SPARK-23618. --- bin/docker-image-tool.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index 071406336d1b1..0d0f564bb8b9b 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -57,6 +57,7 @@ function build { else # Not passed as an argument to docker, but used to validate the Spark directory. IMG_PATH="kubernetes/dockerfiles" + BUILD_ARGS=() fi if [ ! -d "$IMG_PATH" ]; then From 567bd31e0ae8b632357baa93e1469b666fb06f3d Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 12 Mar 2018 14:53:15 -0500 Subject: [PATCH 465/774] [SPARK-23412][ML] Add cosine distance to BisectingKMeans ## What changes were proposed in this pull request? The PR adds the option to specify a distance measure in BisectingKMeans. Moreover, it introduces the ability to use the cosine distance measure in it. ## How was this patch tested? added UTs + existing UTs Author: Marco Gaido Closes #20600 from mgaido91/SPARK-23412. --- .../spark/ml/clustering/BisectingKMeans.scala | 16 +- .../apache/spark/ml/clustering/KMeans.scala | 11 +- .../ml/param/shared/SharedParamsCodeGen.scala | 6 +- .../spark/ml/param/shared/sharedParams.scala | 19 ++ .../mllib/clustering/BisectingKMeans.scala | 139 ++++---- .../clustering/BisectingKMeansModel.scala | 115 +++++-- .../mllib/clustering/DistanceMeasure.scala | 303 ++++++++++++++++++ .../spark/mllib/clustering/KMeans.scala | 196 +---------- .../ml/clustering/BisectingKMeansSuite.scala | 44 ++- project/MimaExcludes.scala | 6 + 10 files changed, 557 insertions(+), 298 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 4c20e6563bad1..f7c422dc0faea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -26,7 +26,8 @@ import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} +import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, + BisectingKMeansModel => MLlibBisectingKMeansModel} import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD @@ -38,8 +39,8 @@ import org.apache.spark.sql.types.{IntegerType, StructType} /** * Common params for BisectingKMeans and BisectingKMeansModel */ -private[clustering] trait BisectingKMeansParams extends Params - with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol { +private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter + with HasFeaturesCol with HasSeed with HasPredictionCol with HasDistanceMeasure { /** * The desired number of leaf clusters. Must be > 1. Default: 4. @@ -104,6 +105,10 @@ class BisectingKMeansModel private[ml] ( @Since("2.1.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group expertSetParam */ + @Since("2.4.0") + def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) @@ -248,6 +253,10 @@ class BisectingKMeans @Since("2.0.0") ( @Since("2.0.0") def setMinDivisibleClusterSize(value: Double): this.type = set(minDivisibleClusterSize, value) + /** @group expertSetParam */ + @Since("2.4.0") + def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): BisectingKMeansModel = { transformSchema(dataset.schema, logging = true) @@ -263,6 +272,7 @@ class BisectingKMeans @Since("2.0.0") ( .setMaxIterations($(maxIter)) .setMinDivisibleClusterSize($(minDivisibleClusterSize)) .setSeed($(seed)) + .setDistanceMeasure($(distanceMeasure)) val parentModel = bkm.run(rdd) val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) val summary = new BisectingKMeansSummary( diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index c8145de564cbe..987a4285ebad4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -40,7 +40,7 @@ import org.apache.spark.util.VersionUtils.majorVersion * Common params for KMeans and KMeansModel */ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol - with HasSeed with HasPredictionCol with HasTol { + with HasSeed with HasPredictionCol with HasTol with HasDistanceMeasure { /** * The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than @@ -71,15 +71,6 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") def getInitMode: String = $(initMode) - @Since("2.4.0") - final val distanceMeasure = new Param[String](this, "distanceMeasure", "The distance measure. " + - "Supported options: 'euclidean' and 'cosine'.", - (value: String) => MLlibKMeans.validateDistanceMeasure(value)) - - /** @group expertGetParam */ - @Since("2.4.0") - def getDistanceMeasure: String = $(distanceMeasure) - /** * Param for the number of steps for the k-means|| initialization mode. This is an advanced * setting -- the default of 2 is almost always enough. Must be > 0. Default: 2. diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 6ad44af9ef7eb..b9c3170cc3c28 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -91,7 +91,11 @@ private[shared] object SharedParamsCodeGen { "after fitting. If set to true, then all sub-models will be available. Warning: For " + "large models, collecting all sub-models can cause OOMs on the Spark driver", Some("false"), isExpertParam = true), - ParamDesc[String]("loss", "the loss function to be optimized", finalFields = false) + ParamDesc[String]("loss", "the loss function to be optimized", finalFields = false), + ParamDesc[String]("distanceMeasure", "The distance measure. Supported options: 'euclidean'" + + " and 'cosine'", Some("org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN"), + isValid = "(value: String) => " + + "org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)") ) val code = genSharedParams(params) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index be8b2f273164b..282ea6ebcbf7f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -504,4 +504,23 @@ trait HasLoss extends Params { /** @group getParam */ final def getLoss: String = $(loss) } + +/** + * Trait for shared param distanceMeasure (default: org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN). This trait may be changed or + * removed between minor versions. + */ +@DeveloperApi +trait HasDistanceMeasure extends Params { + + /** + * Param for The distance measure. Supported options: 'euclidean' and 'cosine'. + * @group param + */ + final val distanceMeasure: Param[String] = new Param[String](this, "distanceMeasure", "The distance measure. Supported options: 'euclidean' and 'cosine'", (value: String) => org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)) + + setDefault(distanceMeasure, org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN) + + /** @group getParam */ + final def getDistanceMeasure: String = $(distanceMeasure) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index 2221f4c0edc17..98af487306dcc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -25,7 +25,7 @@ import scala.collection.mutable import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging -import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -57,7 +57,8 @@ class BisectingKMeans private ( private var k: Int, private var maxIterations: Int, private var minDivisibleClusterSize: Double, - private var seed: Long) extends Logging { + private var seed: Long, + private var distanceMeasure: String) extends Logging { import BisectingKMeans._ @@ -65,7 +66,7 @@ class BisectingKMeans private ( * Constructs with the default configuration */ @Since("1.6.0") - def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##) + def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##, DistanceMeasure.EUCLIDEAN) /** * Sets the desired number of leaf clusters (default: 4). @@ -134,6 +135,22 @@ class BisectingKMeans private ( @Since("1.6.0") def getSeed: Long = this.seed + /** + * The distance suite used by the algorithm. + */ + @Since("2.4.0") + def getDistanceMeasure: String = distanceMeasure + + /** + * Set the distance suite used by the algorithm. + */ + @Since("2.4.0") + def setDistanceMeasure(distanceMeasure: String): this.type = { + DistanceMeasure.validateDistanceMeasure(distanceMeasure) + this.distanceMeasure = distanceMeasure + this + } + /** * Runs the bisecting k-means algorithm. * @param input RDD of vectors @@ -147,11 +164,13 @@ class BisectingKMeans private ( } val d = input.map(_.size).first() logInfo(s"Feature dimension: $d.") + + val dMeasure: DistanceMeasure = DistanceMeasure.decodeFromString(this.distanceMeasure) // Compute and cache vector norms for fast distance computation. val norms = input.map(v => Vectors.norm(v, 2.0)).persist(StorageLevel.MEMORY_AND_DISK) val vectors = input.zip(norms).map { case (x, norm) => new VectorWithNorm(x, norm) } var assignments = vectors.map(v => (ROOT_INDEX, v)) - var activeClusters = summarize(d, assignments) + var activeClusters = summarize(d, assignments, dMeasure) val rootSummary = activeClusters(ROOT_INDEX) val n = rootSummary.size logInfo(s"Number of points: $n.") @@ -184,24 +203,25 @@ class BisectingKMeans private ( val divisibleIndices = divisibleClusters.keys.toSet logInfo(s"Dividing ${divisibleIndices.size} clusters on level $level.") var newClusterCenters = divisibleClusters.flatMap { case (index, summary) => - val (left, right) = splitCenter(summary.center, random) + val (left, right) = splitCenter(summary.center, random, dMeasure) Iterator((leftChildIndex(index), left), (rightChildIndex(index), right)) }.map(identity) // workaround for a Scala bug (SI-7005) that produces a not serializable map var newClusters: Map[Long, ClusterSummary] = null var newAssignments: RDD[(Long, VectorWithNorm)] = null for (iter <- 0 until maxIterations) { - newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters) + newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters, + dMeasure) .filter { case (index, _) => divisibleIndices.contains(parentIndex(index)) } - newClusters = summarize(d, newAssignments) + newClusters = summarize(d, newAssignments, dMeasure) newClusterCenters = newClusters.mapValues(_.center).map(identity) } if (preIndices != null) { preIndices.unpersist(false) } preIndices = indices - indices = updateAssignments(assignments, divisibleIndices, newClusterCenters).keys + indices = updateAssignments(assignments, divisibleIndices, newClusterCenters, dMeasure).keys .persist(StorageLevel.MEMORY_AND_DISK) assignments = indices.zip(vectors) inactiveClusters ++= activeClusters @@ -222,8 +242,8 @@ class BisectingKMeans private ( } norms.unpersist(false) val clusters = activeClusters ++ inactiveClusters - val root = buildTree(clusters) - new BisectingKMeansModel(root) + val root = buildTree(clusters, dMeasure) + new BisectingKMeansModel(root, this.distanceMeasure) } /** @@ -266,8 +286,9 @@ private object BisectingKMeans extends Serializable { */ private def summarize( d: Int, - assignments: RDD[(Long, VectorWithNorm)]): Map[Long, ClusterSummary] = { - assignments.aggregateByKey(new ClusterSummaryAggregator(d))( + assignments: RDD[(Long, VectorWithNorm)], + distanceMeasure: DistanceMeasure): Map[Long, ClusterSummary] = { + assignments.aggregateByKey(new ClusterSummaryAggregator(d, distanceMeasure))( seqOp = (agg, v) => agg.add(v), combOp = (agg1, agg2) => agg1.merge(agg2) ).mapValues(_.summary) @@ -278,7 +299,8 @@ private object BisectingKMeans extends Serializable { * Cluster summary aggregator. * @param d feature dimension */ - private class ClusterSummaryAggregator(val d: Int) extends Serializable { + private class ClusterSummaryAggregator(val d: Int, val distanceMeasure: DistanceMeasure) + extends Serializable { private var n: Long = 0L private val sum: Vector = Vectors.zeros(d) private var sumSq: Double = 0.0 @@ -288,7 +310,7 @@ private object BisectingKMeans extends Serializable { n += 1L // TODO: use a numerically stable approach to estimate cost sumSq += v.norm * v.norm - BLAS.axpy(1.0, v.vector, sum) + distanceMeasure.updateClusterSum(v, sum) this } @@ -296,19 +318,15 @@ private object BisectingKMeans extends Serializable { def merge(other: ClusterSummaryAggregator): this.type = { n += other.n sumSq += other.sumSq - BLAS.axpy(1.0, other.sum, sum) + distanceMeasure.updateClusterSum(new VectorWithNorm(other.sum), sum) this } /** Returns the summary. */ def summary: ClusterSummary = { - val mean = sum.copy - if (n > 0L) { - BLAS.scal(1.0 / n, mean) - } - val center = new VectorWithNorm(mean) - val cost = math.max(sumSq - n * center.norm * center.norm, 0.0) - new ClusterSummary(n, center, cost) + val center = distanceMeasure.centroid(sum.copy, n) + val cost = distanceMeasure.clusterCost(center, new VectorWithNorm(sum), n, sumSq) + ClusterSummary(n, center, cost) } } @@ -321,16 +339,13 @@ private object BisectingKMeans extends Serializable { */ private def splitCenter( center: VectorWithNorm, - random: Random): (VectorWithNorm, VectorWithNorm) = { + random: Random, + distanceMeasure: DistanceMeasure): (VectorWithNorm, VectorWithNorm) = { val d = center.vector.size val norm = center.norm val level = 1e-4 * norm val noise = Vectors.dense(Array.fill(d)(random.nextDouble())) - val left = center.vector.copy - BLAS.axpy(-level, noise, left) - val right = center.vector.copy - BLAS.axpy(level, noise, right) - (new VectorWithNorm(left), new VectorWithNorm(right)) + distanceMeasure.symmetricCentroids(level, noise, center.vector) } /** @@ -343,16 +358,20 @@ private object BisectingKMeans extends Serializable { private def updateAssignments( assignments: RDD[(Long, VectorWithNorm)], divisibleIndices: Set[Long], - newClusterCenters: Map[Long, VectorWithNorm]): RDD[(Long, VectorWithNorm)] = { + newClusterCenters: Map[Long, VectorWithNorm], + distanceMeasure: DistanceMeasure): RDD[(Long, VectorWithNorm)] = { assignments.map { case (index, v) => if (divisibleIndices.contains(index)) { val children = Seq(leftChildIndex(index), rightChildIndex(index)) - val newClusterChildren = children.filter(newClusterCenters.contains(_)) + val newClusterChildren = children.filter(newClusterCenters.contains) + val newClusterChildrenCenterToId = + newClusterChildren.map(id => newClusterCenters(id) -> id).toMap + val newClusterChildrenCenters = newClusterChildrenCenterToId.keys.toArray if (newClusterChildren.nonEmpty) { - val selected = newClusterChildren.minBy { child => - EuclideanDistanceMeasure.fastSquaredDistance(newClusterCenters(child), v) - } - (selected, v) + val selected = distanceMeasure.findClosest(newClusterChildrenCenters, v)._1 + val center = newClusterChildrenCenters(selected) + val id = newClusterChildrenCenterToId(center) + (id, v) } else { (index, v) } @@ -367,7 +386,9 @@ private object BisectingKMeans extends Serializable { * @param clusters a map from cluster indices to corresponding cluster summaries * @return the root node of the clustering tree */ - private def buildTree(clusters: Map[Long, ClusterSummary]): ClusteringTreeNode = { + private def buildTree( + clusters: Map[Long, ClusterSummary], + distanceMeasure: DistanceMeasure): ClusteringTreeNode = { var leafIndex = 0 var internalIndex = -1 @@ -385,11 +406,11 @@ private object BisectingKMeans extends Serializable { internalIndex -= 1 val leftIndex = leftChildIndex(rawIndex) val rightIndex = rightChildIndex(rawIndex) - val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_)) - val height = math.sqrt(indexes.map { childIndex => - EuclideanDistanceMeasure.fastSquaredDistance(center, clusters(childIndex).center) - }.max) - val children = indexes.map(buildSubTree(_)).toArray + val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains) + val height = indexes.map { childIndex => + distanceMeasure.distance(center, clusters(childIndex).center) + }.max + val children = indexes.map(buildSubTree).toArray new ClusteringTreeNode(index, size, center, cost, height, children) } else { val index = leafIndex @@ -441,42 +462,45 @@ private[clustering] class ClusteringTreeNode private[clustering] ( def center: Vector = centerWithNorm.vector /** Predicts the leaf cluster node index that the input point belongs to. */ - def predict(point: Vector): Int = { - val (index, _) = predict(new VectorWithNorm(point)) + def predict(point: Vector, distanceMeasure: DistanceMeasure): Int = { + val (index, _) = predict(new VectorWithNorm(point), distanceMeasure) index } /** Returns the full prediction path from root to leaf. */ - def predictPath(point: Vector): Array[ClusteringTreeNode] = { - predictPath(new VectorWithNorm(point)).toArray + def predictPath(point: Vector, distanceMeasure: DistanceMeasure): Array[ClusteringTreeNode] = { + predictPath(new VectorWithNorm(point), distanceMeasure).toArray } /** Returns the full prediction path from root to leaf. */ - private def predictPath(pointWithNorm: VectorWithNorm): List[ClusteringTreeNode] = { + private def predictPath( + pointWithNorm: VectorWithNorm, + distanceMeasure: DistanceMeasure): List[ClusteringTreeNode] = { if (isLeaf) { this :: Nil } else { val selected = children.minBy { child => - EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm) + distanceMeasure.distance(child.centerWithNorm, pointWithNorm) } - selected :: selected.predictPath(pointWithNorm) + selected :: selected.predictPath(pointWithNorm, distanceMeasure) } } /** - * Computes the cost (squared distance to the predicted leaf cluster center) of the input point. + * Computes the cost of the input point. */ - def computeCost(point: Vector): Double = { - val (_, cost) = predict(new VectorWithNorm(point)) + def computeCost(point: Vector, distanceMeasure: DistanceMeasure): Double = { + val (_, cost) = predict(new VectorWithNorm(point), distanceMeasure) cost } /** * Predicts the cluster index and the cost of the input point. */ - private def predict(pointWithNorm: VectorWithNorm): (Int, Double) = { - predict(pointWithNorm, - EuclideanDistanceMeasure.fastSquaredDistance(centerWithNorm, pointWithNorm)) + private def predict( + pointWithNorm: VectorWithNorm, + distanceMeasure: DistanceMeasure): (Int, Double) = { + predict(pointWithNorm, distanceMeasure.cost(centerWithNorm, pointWithNorm), distanceMeasure) } /** @@ -486,14 +510,17 @@ private[clustering] class ClusteringTreeNode private[clustering] ( * @return (predicted leaf cluster index, cost) */ @tailrec - private def predict(pointWithNorm: VectorWithNorm, cost: Double): (Int, Double) = { + private def predict( + pointWithNorm: VectorWithNorm, + cost: Double, + distanceMeasure: DistanceMeasure): (Int, Double) = { if (isLeaf) { (index, cost) } else { val (selectedChild, minCost) = children.map { child => - (child, EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm)) + (child, distanceMeasure.cost(child.centerWithNorm, pointWithNorm)) }.minBy(_._2) - selectedChild.predict(pointWithNorm, minCost) + selectedChild.predict(pointWithNorm, minCost, distanceMeasure) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala index 633bda6aac804..9d115afcea75d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -40,9 +40,16 @@ import org.apache.spark.sql.{Row, SparkSession} */ @Since("1.6.0") class BisectingKMeansModel private[clustering] ( - private[clustering] val root: ClusteringTreeNode + private[clustering] val root: ClusteringTreeNode, + @Since("2.4.0") val distanceMeasure: String ) extends Serializable with Saveable with Logging { + @Since("1.6.0") + def this(root: ClusteringTreeNode) = this(root, DistanceMeasure.EUCLIDEAN) + + private val distanceMeasureInstance: DistanceMeasure = + DistanceMeasure.decodeFromString(distanceMeasure) + /** * Leaf cluster centers. */ @@ -59,7 +66,7 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def predict(point: Vector): Int = { - root.predict(point) + root.predict(point, distanceMeasureInstance) } /** @@ -67,7 +74,7 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def predict(points: RDD[Vector]): RDD[Int] = { - points.map { p => root.predict(p) } + points.map { p => root.predict(p, distanceMeasureInstance) } } /** @@ -82,7 +89,7 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def computeCost(point: Vector): Double = { - root.computeCost(point) + root.computeCost(point, distanceMeasureInstance) } /** @@ -91,7 +98,7 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def computeCost(data: RDD[Vector]): Double = { - data.map(root.computeCost).sum() + data.map(root.computeCost(_, distanceMeasureInstance)).sum() } /** @@ -113,18 +120,19 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { @Since("2.0.0") override def load(sc: SparkContext, path: String): BisectingKMeansModel = { - val (loadedClassName, formatVersion, metadata) = Loader.loadMetadata(sc, path) - implicit val formats = DefaultFormats - val rootId = (metadata \ "rootId").extract[Int] - val classNameV1_0 = SaveLoadV1_0.thisClassName + val (loadedClassName, formatVersion, __) = Loader.loadMetadata(sc, path) (loadedClassName, formatVersion) match { - case (classNameV1_0, "1.0") => - val model = SaveLoadV1_0.load(sc, path, rootId) + case (SaveLoadV1_0.thisClassName, SaveLoadV1_0.thisFormatVersion) => + val model = SaveLoadV1_0.load(sc, path) + model + case (SaveLoadV2_0.thisClassName, SaveLoadV2_0.thisFormatVersion) => + val model = SaveLoadV1_0.load(sc, path) model case _ => throw new Exception( s"BisectingKMeansModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $formatVersion). Supported:\n" + - s" ($classNameV1_0, 1.0)") + s" (${SaveLoadV1_0.thisClassName}, ${SaveLoadV1_0.thisClassName}\n" + + s" (${SaveLoadV2_0.thisClassName}, ${SaveLoadV2_0.thisClassName})") } } @@ -136,8 +144,28 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { r.getDouble(4), r.getDouble(5), r.getSeq[Int](6)) } + private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = { + if (node.children.isEmpty) { + Array(node) + } else { + node.children.flatMap(getNodes) ++ Array(node) + } + } + + private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = { + val root = nodes(rootId) + if (root.children.isEmpty) { + new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), + root.cost, root.height, new Array[ClusteringTreeNode](0)) + } else { + val children = root.children.map(c => buildTree(c, nodes)) + new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), + root.cost, root.height, children.toArray) + } + } + private[clustering] object SaveLoadV1_0 { - private val thisFormatVersion = "1.0" + private[clustering] val thisFormatVersion = "1.0" private[clustering] val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel" @@ -155,34 +183,55 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { spark.createDataFrame(data).write.parquet(Loader.dataPath(path)) } - private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = { - if (node.children.isEmpty) { - Array(node) - } else { - node.children.flatMap(getNodes(_)) ++ Array(node) - } - } - - def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = { + def load(sc: SparkContext, path: String): BisectingKMeansModel = { + implicit val formats: DefaultFormats = DefaultFormats + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val rootId = (metadata \ "rootId").extract[Int] val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val rows = spark.read.parquet(Loader.dataPath(path)) Loader.checkSchema[Data](rows.schema) val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap val rootNode = buildTree(rootId, nodes) - new BisectingKMeansModel(rootNode) + new BisectingKMeansModel(rootNode, DistanceMeasure.EUCLIDEAN) } + } + + private[clustering] object SaveLoadV2_0 { + private[clustering] val thisFormatVersion = "2.0" - private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = { - val root = nodes.get(rootId).get - if (root.children.isEmpty) { - new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), - root.cost, root.height, new Array[ClusteringTreeNode](0)) - } else { - val children = root.children.map(c => buildTree(c, nodes)) - new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), - root.cost, root.height, children.toArray) - } + private[clustering] + val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel" + + def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = { + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) + ~ ("rootId" -> model.root.index) ~ ("distanceMeasure" -> model.distanceMeasure))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val data = getNodes(model.root).map(node => Data(node.index, node.size, + node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height, + node.children.map(_.index))) + spark.createDataFrame(data).write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): BisectingKMeansModel = { + implicit val formats: DefaultFormats = DefaultFormats + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val rootId = (metadata \ "rootId").extract[Int] + val distanceMeasure = (metadata \ "distanceMeasure").extract[String] + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val rows = spark.read.parquet(Loader.dataPath(path)) + Loader.checkSchema[Data](rows.schema) + val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") + val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap + val rootNode = buildTree(rootId, nodes) + new BisectingKMeansModel(rootNode, distanceMeasure) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala new file mode 100644 index 0000000000000..683360efabc76 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.apache.spark.annotation.Since +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal} +import org.apache.spark.mllib.util.MLUtils + +private[spark] abstract class DistanceMeasure extends Serializable { + + /** + * @return the index of the closest center to the given point, as well as the cost. + */ + def findClosest( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): (Int, Double) = { + var bestDistance = Double.PositiveInfinity + var bestIndex = 0 + var i = 0 + centers.foreach { center => + val currentDistance = distance(center, point) + if (currentDistance < bestDistance) { + bestDistance = currentDistance + bestIndex = i + } + i += 1 + } + (bestIndex, bestDistance) + } + + /** + * @return the K-means cost of a given point against the given cluster centers. + */ + def pointCost( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): Double = { + findClosest(centers, point)._2 + } + + /** + * @return whether a center converged or not, given the epsilon parameter. + */ + def isCenterConverged( + oldCenter: VectorWithNorm, + newCenter: VectorWithNorm, + epsilon: Double): Boolean = { + distance(oldCenter, newCenter) <= epsilon + } + + /** + * @return the distance between two points. + */ + def distance( + v1: VectorWithNorm, + v2: VectorWithNorm): Double + + /** + * @return the total cost of the cluster from its aggregated properties + */ + def clusterCost( + centroid: VectorWithNorm, + pointsSum: VectorWithNorm, + numberOfPoints: Long, + pointsSquaredNorm: Double): Double + + /** + * Updates the value of `sum` adding the `point` vector. + * @param point a `VectorWithNorm` to be added to `sum` of a cluster + * @param sum the `sum` for a cluster to be updated + */ + def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { + axpy(1.0, point.vector, sum) + } + + /** + * Returns a centroid for a cluster given its `sum` vector and its `count` of points. + * + * @param sum the `sum` for a cluster + * @param count the number of points in the cluster + * @return the centroid of the cluster + */ + def centroid(sum: Vector, count: Long): VectorWithNorm = { + scal(1.0 / count, sum) + new VectorWithNorm(sum) + } + + /** + * Returns two new centroids symmetric to the specified centroid applying `noise` with the + * with the specified `level`. + * + * @param level the level of `noise` to apply to the given centroid. + * @param noise a noise vector + * @param centroid the parent centroid + * @return a left and right centroid symmetric to `centroid` + */ + def symmetricCentroids( + level: Double, + noise: Vector, + centroid: Vector): (VectorWithNorm, VectorWithNorm) = { + val left = centroid.copy + axpy(-level, noise, left) + val right = centroid.copy + axpy(level, noise, right) + (new VectorWithNorm(left), new VectorWithNorm(right)) + } + + /** + * @return the cost of a point to be assigned to the cluster centroid + */ + def cost( + point: VectorWithNorm, + centroid: VectorWithNorm): Double = distance(point, centroid) +} + +@Since("2.4.0") +object DistanceMeasure { + + @Since("2.4.0") + val EUCLIDEAN = "euclidean" + @Since("2.4.0") + val COSINE = "cosine" + + private[spark] def decodeFromString(distanceMeasure: String): DistanceMeasure = + distanceMeasure match { + case EUCLIDEAN => new EuclideanDistanceMeasure + case COSINE => new CosineDistanceMeasure + case _ => throw new IllegalArgumentException(s"distanceMeasure must be one of: " + + s"$EUCLIDEAN, $COSINE. $distanceMeasure provided.") + } + + private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = { + distanceMeasure match { + case DistanceMeasure.EUCLIDEAN => true + case DistanceMeasure.COSINE => true + case _ => false + } + } +} + +private[spark] class EuclideanDistanceMeasure extends DistanceMeasure { + /** + * @return the index of the closest center to the given point, as well as the squared distance. + */ + override def findClosest( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): (Int, Double) = { + var bestDistance = Double.PositiveInfinity + var bestIndex = 0 + var i = 0 + centers.foreach { center => + // Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary + // distance computation. + var lowerBoundOfSqDist = center.norm - point.norm + lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist + if (lowerBoundOfSqDist < bestDistance) { + val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point) + if (distance < bestDistance) { + bestDistance = distance + bestIndex = i + } + } + i += 1 + } + (bestIndex, bestDistance) + } + + /** + * @return whether a center converged or not, given the epsilon parameter. + */ + override def isCenterConverged( + oldCenter: VectorWithNorm, + newCenter: VectorWithNorm, + epsilon: Double): Boolean = { + EuclideanDistanceMeasure.fastSquaredDistance(newCenter, oldCenter) <= epsilon * epsilon + } + + /** + * @param v1: first vector + * @param v2: second vector + * @return the Euclidean distance between the two input vectors + */ + override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { + Math.sqrt(EuclideanDistanceMeasure.fastSquaredDistance(v1, v2)) + } + + /** + * @return the total cost of the cluster from its aggregated properties + */ + override def clusterCost( + centroid: VectorWithNorm, + pointsSum: VectorWithNorm, + numberOfPoints: Long, + pointsSquaredNorm: Double): Double = { + math.max(pointsSquaredNorm - numberOfPoints * centroid.norm * centroid.norm, 0.0) + } + + /** + * @return the cost of a point to be assigned to the cluster centroid + */ + override def cost( + point: VectorWithNorm, + centroid: VectorWithNorm): Double = { + EuclideanDistanceMeasure.fastSquaredDistance(point, centroid) + } +} + + +private[spark] object EuclideanDistanceMeasure { + /** + * @return the squared Euclidean distance between two vectors computed by + * [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]]. + */ + private[clustering] def fastSquaredDistance( + v1: VectorWithNorm, + v2: VectorWithNorm): Double = { + MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm) + } +} + +private[spark] class CosineDistanceMeasure extends DistanceMeasure { + /** + * @param v1: first vector + * @param v2: second vector + * @return the cosine distance between the two input vectors + */ + override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { + assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.") + 1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm + } + + /** + * Updates the value of `sum` adding the `point` vector. + * @param point a `VectorWithNorm` to be added to `sum` of a cluster + * @param sum the `sum` for a cluster to be updated + */ + override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { + assert(point.norm > 0, "Cosine distance is not defined for zero-length vectors.") + axpy(1.0 / point.norm, point.vector, sum) + } + + /** + * Returns a centroid for a cluster given its `sum` vector and its `count` of points. + * + * @param sum the `sum` for a cluster + * @param count the number of points in the cluster + * @return the centroid of the cluster + */ + override def centroid(sum: Vector, count: Long): VectorWithNorm = { + scal(1.0 / count, sum) + val norm = Vectors.norm(sum, 2) + scal(1.0 / norm, sum) + new VectorWithNorm(sum, 1) + } + + /** + * @return the total cost of the cluster from its aggregated properties + */ + override def clusterCost( + centroid: VectorWithNorm, + pointsSum: VectorWithNorm, + numberOfPoints: Long, + pointsSquaredNorm: Double): Double = { + val costVector = pointsSum.vector.copy + math.max(numberOfPoints - dot(centroid.vector, costVector) / centroid.norm, 0.0) + } + + /** + * Returns two new centroids symmetric to the specified centroid applying `noise` with the + * with the specified `level`. + * + * @param level the level of `noise` to apply to the given centroid. + * @param noise a noise vector + * @param centroid the parent centroid + * @return a left and right centroid symmetric to `centroid` + */ + override def symmetricCentroids( + level: Double, + noise: Vector, + centroid: Vector): (VectorWithNorm, VectorWithNorm) = { + val (left, right) = super.symmetricCentroids(level, noise, centroid) + val leftVector = left.vector + val rightVector = right.vector + scal(1.0 / left.norm, leftVector) + scal(1.0 / right.norm, rightVector) + (new VectorWithNorm(leftVector, 1.0), new VectorWithNorm(rightVector, 1.0)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 3c4ba0bc60c7f..b5b1be3490497 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -25,8 +25,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.clustering.{KMeans => NewKMeans} import org.apache.spark.ml.util.Instrumentation import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal} -import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.linalg.BLAS.axpy import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -204,7 +203,7 @@ class KMeans private ( */ @Since("2.4.0") def setDistanceMeasure(distanceMeasure: String): this.type = { - KMeans.validateDistanceMeasure(distanceMeasure) + DistanceMeasure.validateDistanceMeasure(distanceMeasure) this.distanceMeasure = distanceMeasure this } @@ -582,14 +581,6 @@ object KMeans { case _ => false } } - - private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = { - distanceMeasure match { - case DistanceMeasure.EUCLIDEAN => true - case DistanceMeasure.COSINE => true - case _ => false - } - } } /** @@ -605,186 +596,3 @@ private[clustering] class VectorWithNorm(val vector: Vector, val norm: Double) /** Converts the vector to a dense vector. */ def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm) } - - -private[spark] abstract class DistanceMeasure extends Serializable { - - /** - * @return the index of the closest center to the given point, as well as the cost. - */ - def findClosest( - centers: TraversableOnce[VectorWithNorm], - point: VectorWithNorm): (Int, Double) = { - var bestDistance = Double.PositiveInfinity - var bestIndex = 0 - var i = 0 - centers.foreach { center => - val currentDistance = distance(center, point) - if (currentDistance < bestDistance) { - bestDistance = currentDistance - bestIndex = i - } - i += 1 - } - (bestIndex, bestDistance) - } - - /** - * @return the K-means cost of a given point against the given cluster centers. - */ - def pointCost( - centers: TraversableOnce[VectorWithNorm], - point: VectorWithNorm): Double = { - findClosest(centers, point)._2 - } - - /** - * @return whether a center converged or not, given the epsilon parameter. - */ - def isCenterConverged( - oldCenter: VectorWithNorm, - newCenter: VectorWithNorm, - epsilon: Double): Boolean = { - distance(oldCenter, newCenter) <= epsilon - } - - /** - * @return the cosine distance between two points. - */ - def distance( - v1: VectorWithNorm, - v2: VectorWithNorm): Double - - /** - * Updates the value of `sum` adding the `point` vector. - * @param point a `VectorWithNorm` to be added to `sum` of a cluster - * @param sum the `sum` for a cluster to be updated - */ - def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { - axpy(1.0, point.vector, sum) - } - - /** - * Returns a centroid for a cluster given its `sum` vector and its `count` of points. - * - * @param sum the `sum` for a cluster - * @param count the number of points in the cluster - * @return the centroid of the cluster - */ - def centroid(sum: Vector, count: Long): VectorWithNorm = { - scal(1.0 / count, sum) - new VectorWithNorm(sum) - } -} - -@Since("2.4.0") -object DistanceMeasure { - - @Since("2.4.0") - val EUCLIDEAN = "euclidean" - @Since("2.4.0") - val COSINE = "cosine" - - private[spark] def decodeFromString(distanceMeasure: String): DistanceMeasure = - distanceMeasure match { - case EUCLIDEAN => new EuclideanDistanceMeasure - case COSINE => new CosineDistanceMeasure - case _ => throw new IllegalArgumentException(s"distanceMeasure must be one of: " + - s"$EUCLIDEAN, $COSINE. $distanceMeasure provided.") - } -} - -private[spark] class EuclideanDistanceMeasure extends DistanceMeasure { - /** - * @return the index of the closest center to the given point, as well as the squared distance. - */ - override def findClosest( - centers: TraversableOnce[VectorWithNorm], - point: VectorWithNorm): (Int, Double) = { - var bestDistance = Double.PositiveInfinity - var bestIndex = 0 - var i = 0 - centers.foreach { center => - // Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary - // distance computation. - var lowerBoundOfSqDist = center.norm - point.norm - lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist - if (lowerBoundOfSqDist < bestDistance) { - val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point) - if (distance < bestDistance) { - bestDistance = distance - bestIndex = i - } - } - i += 1 - } - (bestIndex, bestDistance) - } - - /** - * @return whether a center converged or not, given the epsilon parameter. - */ - override def isCenterConverged( - oldCenter: VectorWithNorm, - newCenter: VectorWithNorm, - epsilon: Double): Boolean = { - EuclideanDistanceMeasure.fastSquaredDistance(newCenter, oldCenter) <= epsilon * epsilon - } - - /** - * @param v1: first vector - * @param v2: second vector - * @return the Euclidean distance between the two input vectors - */ - override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { - Math.sqrt(EuclideanDistanceMeasure.fastSquaredDistance(v1, v2)) - } -} - - -private[spark] object EuclideanDistanceMeasure { - /** - * @return the squared Euclidean distance between two vectors computed by - * [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]]. - */ - private[clustering] def fastSquaredDistance( - v1: VectorWithNorm, - v2: VectorWithNorm): Double = { - MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm) - } -} - -private[spark] class CosineDistanceMeasure extends DistanceMeasure { - /** - * @param v1: first vector - * @param v2: second vector - * @return the cosine distance between the two input vectors - */ - override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { - assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.") - 1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm - } - - /** - * Updates the value of `sum` adding the `point` vector. - * @param point a `VectorWithNorm` to be added to `sum` of a cluster - * @param sum the `sum` for a cluster to be updated - */ - override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { - axpy(1.0 / point.norm, point.vector, sum) - } - - /** - * Returns a centroid for a cluster given its `sum` vector and its `count` of points. - * - * @param sum the `sum` for a cluster - * @param count the number of points in the cluster - * @return the centroid of the cluster - */ - override def centroid(sum: Vector, count: Long): VectorWithNorm = { - scal(1.0 / count, sum) - val norm = Vectors.norm(sum, 2) - scal(1.0 / norm, sum) - new VectorWithNorm(sum, 1) - } -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index fa7471fa2d658..02880f96ae6d9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.ml.clustering -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.clustering.DistanceMeasure import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -140,6 +142,46 @@ class BisectingKMeansSuite testEstimatorAndModelReadWrite(bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, BisectingKMeansSuite.allParamSettings, checkModelData) } + + test("BisectingKMeans with cosine distance is not supported for 0-length vectors") { + val model = new BisectingKMeans().setK(2).setDistanceMeasure(DistanceMeasure.COSINE).setSeed(1) + val df = spark.createDataFrame(spark.sparkContext.parallelize(Array( + Vectors.dense(0.0, 0.0), + Vectors.dense(10.0, 10.0), + Vectors.dense(1.0, 0.5) + )).map(v => TestRow(v))) + val e = intercept[SparkException](model.fit(df)) + assert(e.getCause.isInstanceOf[AssertionError]) + assert(e.getCause.getMessage.contains("Cosine distance is not defined")) + } + + test("BisectingKMeans with cosine distance") { + val df = spark.createDataFrame(spark.sparkContext.parallelize(Array( + Vectors.dense(1.0, 1.0), + Vectors.dense(10.0, 10.0), + Vectors.dense(1.0, 0.5), + Vectors.dense(10.0, 4.4), + Vectors.dense(-1.0, 1.0), + Vectors.dense(-100.0, 90.0) + )).map(v => TestRow(v))) + val model = new BisectingKMeans() + .setK(3) + .setDistanceMeasure(DistanceMeasure.COSINE) + .setSeed(1) + .fit(df) + val predictionDf = model.transform(df) + assert(predictionDf.select("prediction").distinct().count() == 3) + val predictionsMap = predictionDf.collect().map(row => + row.getAs[Vector]("features") -> row.getAs[Int]("prediction")).toMap + assert(predictionsMap(Vectors.dense(1.0, 1.0)) == + predictionsMap(Vectors.dense(10.0, 10.0))) + assert(predictionsMap(Vectors.dense(1.0, 0.5)) == + predictionsMap(Vectors.dense(10.0, 4.4))) + assert(predictionsMap(Vectors.dense(-1.0, 1.0)) == + predictionsMap(Vectors.dense(-100.0, 90.0))) + + model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0) + } } object BisectingKMeansSuite { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 381f7b5be1ddf..1b6d1dec69d49 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,12 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-23412][ML] Add cosine distance measure to BisectingKmeans + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.org$apache$spark$ml$param$shared$HasDistanceMeasure$_setter_$distanceMeasure_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.getDistanceMeasure"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.distanceMeasure"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.BisectingKMeansModel#SaveLoadV1_0.load"), + // [SPARK-20659] Remove StorageStatus, or make it private ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOffHeapStorageMemory"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOffHeapStorageMemory"), From 23370554d0f88b82154d4232744b874cc58c7848 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 13 Mar 2018 15:20:09 +0100 Subject: [PATCH 466/774] [SPARK-23656][TEST] Perform assertions in XXH64Suite.testKnownByteArrayInputs() on big endian platform, too ## What changes were proposed in this pull request? This PR enables assertions in `XXH64Suite.testKnownByteArrayInputs()` on big endian platform, too. The current implementation performs them only on little endian platform. This PR increase test coverage of big endian platform. ## How was this patch tested? Updated `XXH64Suite` Tested on big endian platform using JIT compiler or interpreter `-Xint`. Author: Kazuaki Ishizaki Closes #20804 from kiszk/SPARK-23656. --- .../sql/catalyst/expressions/XXH64Suite.java | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java index 711887f02832a..1baee91b3439c 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java @@ -74,9 +74,6 @@ public void testKnownByteArrayInputs() { Assert.assertEquals(0x739840CB819FA723L, XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 1, PRIME)); - // These tests currently fail in a big endian environment because the test data and expected - // answers are generated with little endian the assumptions. We could revisit this when Platform - // becomes endian aware. if (ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN) { Assert.assertEquals(0x9256E58AA397AEF1L, hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 4)); @@ -94,6 +91,23 @@ public void testKnownByteArrayInputs() { hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE)); Assert.assertEquals(0xCAA65939306F1E21L, XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE, PRIME)); + } else { + Assert.assertEquals(0x7F875412350ADDDCL, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 4)); + Assert.assertEquals(0x564D279F524D8516L, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 4, PRIME)); + Assert.assertEquals(0x7D9F07E27E0EB006L, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 8)); + Assert.assertEquals(0x893CEF564CB7858L, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 8, PRIME)); + Assert.assertEquals(0xC6198C4C9CC49E17L, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 14)); + Assert.assertEquals(0x4E21BEF7164D4BBL, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 14, PRIME)); + Assert.assertEquals(0xBCF5FAEDEE1F2B5AL, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE)); + Assert.assertEquals(0x6F680C877A358FE5L, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE, PRIME)); } } From 9ddd1e2ceac8155b30beebb6bbfdcd32296fab2d Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 13 Mar 2018 23:31:08 +0900 Subject: [PATCH 467/774] [MINOR][SQL][TEST] Create table using `dataSourceName` in `HadoopFsRelationTest` ## What changes were proposed in this pull request? This PR fixes a minor issue in `HadoopFsRelationTest`, that you should create table using `dataSourceName` instead of `parquet`. The issue won't affect the correctness, but it will generate wrong error message in case the test fails. ## How was this patch tested? Exsiting tests. Author: Xingbo Jiang Closes #20780 from jiangxb1987/dataSourceName. --- .../apache/spark/sql/sources/HadoopFsRelationTest.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index 80aff446bc24b..53397991e59dc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -335,16 +335,17 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes test("saveAsTable()/load() - non-partitioned table - ErrorIfExists") { withTable("t") { - sql("CREATE TABLE t(i INT) USING parquet") - intercept[AnalysisException] { + sql(s"CREATE TABLE t(i INT) USING $dataSourceName") + val msg = intercept[AnalysisException] { testDF.write.format(dataSourceName).mode(SaveMode.ErrorIfExists).saveAsTable("t") - } + }.getMessage + assert(msg.contains("Table `t` already exists")) } } test("saveAsTable()/load() - non-partitioned table - Ignore") { withTable("t") { - sql("CREATE TABLE t(i INT) USING parquet") + sql(s"CREATE TABLE t(i INT) USING $dataSourceName") testDF.write.format(dataSourceName).mode(SaveMode.Ignore).saveAsTable("t") assert(spark.table("t").collect().isEmpty) } From 918fb9beee6a2fd499b8f18dfe0d460f078f5290 Mon Sep 17 00:00:00 2001 From: zuotingbing Date: Tue, 13 Mar 2018 11:31:32 -0700 Subject: [PATCH 468/774] [SPARK-23547][SQL] Cleanup the .pipeout file when the Hive Session closed ## What changes were proposed in this pull request? ![2018-03-07_121010](https://user-images.githubusercontent.com/24823338/37073232-922e10d2-2200-11e8-8172-6e03aa984b39.png) when the hive session closed, we should also cleanup the .pipeout file. ## How was this patch tested? Added test cases. Author: zuotingbing Closes #20702 from zuotingbing/SPARK-23547. --- .../service/cli/session/HiveSessionImpl.java | 18 +++++++++++ .../HiveThriftServer2Suites.scala | 32 ++++++++++++++++++- 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java index fc818bc69c761..f59cdcd3188e6 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java @@ -641,6 +641,8 @@ public void close() throws HiveSQLException { opHandleSet.clear(); // Cleanup session log directory. cleanupSessionLogDir(); + // Cleanup pipeout file. + cleanupPipeoutFile(); HiveHistory hiveHist = sessionState.getHiveHistory(); if (null != hiveHist) { hiveHist.closeStream(); @@ -665,6 +667,22 @@ public void close() throws HiveSQLException { } } + private void cleanupPipeoutFile() { + String lScratchDir = hiveConf.getVar(ConfVars.LOCALSCRATCHDIR); + String sessionID = hiveConf.getVar(ConfVars.HIVESESSIONID); + + File[] fileAry = new File(lScratchDir).listFiles( + (dir, name) -> name.startsWith(sessionID) && name.endsWith(".pipeout")); + + for (File file : fileAry) { + try { + FileUtils.forceDelete(file); + } catch (Exception e) { + LOG.error("Failed to cleanup pipeout file: " + file, e); + } + } + } + private void cleanupSessionLogDir() { if (isOperationLogEnabled) { try { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index b32c547cefefe..192f33a45e273 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.hive.thriftserver -import java.io.File +import java.io.{File, FilenameFilter} import java.net.URL import java.nio.charset.StandardCharsets import java.sql.{Date, DriverManager, SQLException, Statement} +import java.util.UUID import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -613,6 +614,28 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { bufferSrc.close() } } + + test("SPARK-23547 Cleanup the .pipeout file when the Hive Session closed") { + def pipeoutFileList(sessionID: UUID): Array[File] = { + lScratchDir.listFiles(new FilenameFilter { + override def accept(dir: File, name: String): Boolean = { + name.startsWith(sessionID.toString) && name.endsWith(".pipeout") + } + }) + } + + withCLIServiceClient { client => + val user = System.getProperty("user.name") + val sessionHandle = client.openSession(user, "") + val sessionID = sessionHandle.getSessionId + + assert(pipeoutFileList(sessionID).length == 1) + + client.closeSession(sessionHandle) + + assert(pipeoutFileList(sessionID).length == 0) + } + } } class SingleSessionSuite extends HiveThriftJdbcTest { @@ -807,6 +830,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl private val pidDir: File = Utils.createTempDir(namePrefix = "thriftserver-pid") protected var logPath: File = _ protected var operationLogPath: File = _ + protected var lScratchDir: File = _ private var logTailingProcess: Process = _ private var diagnosisBuffer: ArrayBuffer[String] = ArrayBuffer.empty[String] @@ -844,6 +868,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode | --hiveconf ${ConfVars.HIVE_SERVER2_LOGGING_OPERATION_LOG_LOCATION}=$operationLogPath + | --hiveconf ${ConfVars.LOCALSCRATCHDIR}=$lScratchDir | --hiveconf $portConf=$port | --driver-class-path $driverClassPath | --driver-java-options -Dlog4j.debug @@ -873,6 +898,8 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl metastorePath.delete() operationLogPath = Utils.createTempDir() operationLogPath.delete() + lScratchDir = Utils.createTempDir() + lScratchDir.delete() logPath = null logTailingProcess = null @@ -956,6 +983,9 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl operationLogPath.delete() operationLogPath = null + lScratchDir.delete() + lScratchDir = null + Option(logPath).foreach(_.delete()) logPath = null From 1098933b0ac5cdb18101d3aebefa773c2ce05a50 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 13 Mar 2018 23:04:16 +0100 Subject: [PATCH 469/774] [SPARK-23598][SQL] Make methods in BufferedRowIterator public to avoid runtime error for a large query ## What changes were proposed in this pull request? This PR fixes runtime error regarding a large query when a generated code has split classes. The issue is `append()`, `stopEarly()`, and other methods are not accessible from split classes that are not subclasses of `BufferedRowIterator`. This PR fixes this issue by making them `public`. Before applying the PR, we see the following exception by running the attached program with `CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD=-1`. ``` test("SPARK-23598") { // When set -1 to CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD, an exception is thrown val df_pet_age = Seq((8, "bat"), (15, "mouse"), (5, "horse")).toDF("age", "name") df_pet_age.groupBy("name").avg("age").show() } ``` Exception: ``` 19:40:52.591 WARN org.apache.hadoop.util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable 19:41:32.319 ERROR org.apache.spark.executor.Executor: Exception in task 0.0 in stage 0.0 (TID 0) java.lang.IllegalAccessError: tried to access method org.apache.spark.sql.execution.BufferedRowIterator.shouldStop()Z from class org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1$agg_NestedClass1 at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1$agg_NestedClass1.agg_doAggregateWithKeys$(generated.java:203) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(generated.java:160) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$1.hasNext(WholeStageCodegenExec.scala:616) at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408) at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:125) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53) at org.apache.spark.scheduler.Task.run(Task.scala:109) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ... ``` Generated code (line 195 calles `stopEarly()`). ``` /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage1(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=1 /* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; /* 008 */ private scala.collection.Iterator[] inputs; /* 009 */ private boolean agg_initAgg; /* 010 */ private boolean agg_bufIsNull; /* 011 */ private double agg_bufValue; /* 012 */ private boolean agg_bufIsNull1; /* 013 */ private long agg_bufValue1; /* 014 */ private agg_FastHashMap agg_fastHashMap; /* 015 */ private org.apache.spark.unsafe.KVIterator agg_fastHashMapIter; /* 016 */ private org.apache.spark.unsafe.KVIterator agg_mapIter; /* 017 */ private org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap agg_hashMap; /* 018 */ private org.apache.spark.sql.execution.UnsafeKVExternalSorter agg_sorter; /* 019 */ private scala.collection.Iterator inputadapter_input; /* 020 */ private boolean agg_agg_isNull11; /* 021 */ private boolean agg_agg_isNull25; /* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder[] agg_mutableStateArray1 = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder[2]; /* 023 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] agg_mutableStateArray2 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2]; /* 024 */ private UnsafeRow[] agg_mutableStateArray = new UnsafeRow[2]; /* 025 */ /* 026 */ public GeneratedIteratorForCodegenStage1(Object[] references) { /* 027 */ this.references = references; /* 028 */ } /* 029 */ /* 030 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 031 */ partitionIndex = index; /* 032 */ this.inputs = inputs; /* 033 */ /* 034 */ agg_fastHashMap = new agg_FastHashMap(((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getTaskMemoryManager(), ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getEmptyAggregationBuffer()); /* 035 */ agg_hashMap = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).createHashMap(); /* 036 */ inputadapter_input = inputs[0]; /* 037 */ agg_mutableStateArray[0] = new UnsafeRow(1); /* 038 */ agg_mutableStateArray1[0] = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_mutableStateArray[0], 32); /* 039 */ agg_mutableStateArray2[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_mutableStateArray1[0], 1); /* 040 */ agg_mutableStateArray[1] = new UnsafeRow(3); /* 041 */ agg_mutableStateArray1[1] = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_mutableStateArray[1], 32); /* 042 */ agg_mutableStateArray2[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_mutableStateArray1[1], 3); /* 043 */ /* 044 */ } /* 045 */ /* 046 */ public class agg_FastHashMap { /* 047 */ private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch; /* 048 */ private int[] buckets; /* 049 */ private int capacity = 1 << 16; /* 050 */ private double loadFactor = 0.5; /* 051 */ private int numBuckets = (int) (capacity / loadFactor); /* 052 */ private int maxSteps = 2; /* 053 */ private int numRows = 0; /* 054 */ private org.apache.spark.sql.types.StructType keySchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[1] /* keyName */), org.apache.spark.sql.types.DataTypes.StringType); /* 055 */ private org.apache.spark.sql.types.StructType valueSchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[2] /* keyName */), org.apache.spark.sql.types.DataTypes.DoubleType) /* 056 */ .add(((java.lang.String) references[3] /* keyName */), org.apache.spark.sql.types.DataTypes.LongType); /* 057 */ private Object emptyVBase; /* 058 */ private long emptyVOff; /* 059 */ private int emptyVLen; /* 060 */ private boolean isBatchFull = false; /* 061 */ /* 062 */ public agg_FastHashMap( /* 063 */ org.apache.spark.memory.TaskMemoryManager taskMemoryManager, /* 064 */ InternalRow emptyAggregationBuffer) { /* 065 */ batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch /* 066 */ .allocate(keySchema, valueSchema, taskMemoryManager, capacity); /* 067 */ /* 068 */ final UnsafeProjection valueProjection = UnsafeProjection.create(valueSchema); /* 069 */ final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); /* 070 */ /* 071 */ emptyVBase = emptyBuffer; /* 072 */ emptyVOff = Platform.BYTE_ARRAY_OFFSET; /* 073 */ emptyVLen = emptyBuffer.length; /* 074 */ /* 075 */ buckets = new int[numBuckets]; /* 076 */ java.util.Arrays.fill(buckets, -1); /* 077 */ } /* 078 */ /* 079 */ public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert(UTF8String agg_key) { /* 080 */ long h = hash(agg_key); /* 081 */ int step = 0; /* 082 */ int idx = (int) h & (numBuckets - 1); /* 083 */ while (step < maxSteps) { /* 084 */ // Return bucket index if it's either an empty slot or already contains the key /* 085 */ if (buckets[idx] == -1) { /* 086 */ if (numRows < capacity && !isBatchFull) { /* 087 */ // creating the unsafe for new entry /* 088 */ UnsafeRow agg_result = new UnsafeRow(1); /* 089 */ org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder /* 090 */ = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, /* 091 */ 32); /* 092 */ org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter /* 093 */ = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter( /* 094 */ agg_holder, /* 095 */ 1); /* 096 */ agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed /* 097 */ agg_rowWriter.zeroOutNullBytes(); /* 098 */ agg_rowWriter.write(0, agg_key); /* 099 */ agg_result.setTotalSize(agg_holder.totalSize()); /* 100 */ Object kbase = agg_result.getBaseObject(); /* 101 */ long koff = agg_result.getBaseOffset(); /* 102 */ int klen = agg_result.getSizeInBytes(); /* 103 */ /* 104 */ UnsafeRow vRow /* 105 */ = batch.appendRow(kbase, koff, klen, emptyVBase, emptyVOff, emptyVLen); /* 106 */ if (vRow == null) { /* 107 */ isBatchFull = true; /* 108 */ } else { /* 109 */ buckets[idx] = numRows++; /* 110 */ } /* 111 */ return vRow; /* 112 */ } else { /* 113 */ // No more space /* 114 */ return null; /* 115 */ } /* 116 */ } else if (equals(idx, agg_key)) { /* 117 */ return batch.getValueRow(buckets[idx]); /* 118 */ } /* 119 */ idx = (idx + 1) & (numBuckets - 1); /* 120 */ step++; /* 121 */ } /* 122 */ // Didn't find it /* 123 */ return null; /* 124 */ } /* 125 */ /* 126 */ private boolean equals(int idx, UTF8String agg_key) { /* 127 */ UnsafeRow row = batch.getKeyRow(buckets[idx]); /* 128 */ return (row.getUTF8String(0).equals(agg_key)); /* 129 */ } /* 130 */ /* 131 */ private long hash(UTF8String agg_key) { /* 132 */ long agg_hash = 0; /* 133 */ /* 134 */ int agg_result = 0; /* 135 */ byte[] agg_bytes = agg_key.getBytes(); /* 136 */ for (int i = 0; i < agg_bytes.length; i++) { /* 137 */ int agg_hash1 = agg_bytes[i]; /* 138 */ agg_result = (agg_result ^ (0x9e3779b9)) + agg_hash1 + (agg_result << 6) + (agg_result >>> 2); /* 139 */ } /* 140 */ /* 141 */ agg_hash = (agg_hash ^ (0x9e3779b9)) + agg_result + (agg_hash << 6) + (agg_hash >>> 2); /* 142 */ /* 143 */ return agg_hash; /* 144 */ } /* 145 */ /* 146 */ public org.apache.spark.unsafe.KVIterator rowIterator() { /* 147 */ return batch.rowIterator(); /* 148 */ } /* 149 */ /* 150 */ public void close() { /* 151 */ batch.close(); /* 152 */ } /* 153 */ /* 154 */ } /* 155 */ /* 156 */ protected void processNext() throws java.io.IOException { /* 157 */ if (!agg_initAgg) { /* 158 */ agg_initAgg = true; /* 159 */ long wholestagecodegen_beforeAgg = System.nanoTime(); /* 160 */ agg_nestedClassInstance1.agg_doAggregateWithKeys(); /* 161 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[8] /* aggTime */).add((System.nanoTime() - wholestagecodegen_beforeAgg) / 1000000); /* 162 */ } /* 163 */ /* 164 */ // output the result /* 165 */ /* 166 */ while (agg_fastHashMapIter.next()) { /* 167 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_fastHashMapIter.getKey(); /* 168 */ UnsafeRow agg_aggBuffer = (UnsafeRow) agg_fastHashMapIter.getValue(); /* 169 */ wholestagecodegen_nestedClassInstance.agg_doAggregateWithKeysOutput(agg_aggKey, agg_aggBuffer); /* 170 */ /* 171 */ if (shouldStop()) return; /* 172 */ } /* 173 */ agg_fastHashMap.close(); /* 174 */ /* 175 */ while (agg_mapIter.next()) { /* 176 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey(); /* 177 */ UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue(); /* 178 */ wholestagecodegen_nestedClassInstance.agg_doAggregateWithKeysOutput(agg_aggKey, agg_aggBuffer); /* 179 */ /* 180 */ if (shouldStop()) return; /* 181 */ } /* 182 */ /* 183 */ agg_mapIter.close(); /* 184 */ if (agg_sorter == null) { /* 185 */ agg_hashMap.free(); /* 186 */ } /* 187 */ } /* 188 */ /* 189 */ private wholestagecodegen_NestedClass wholestagecodegen_nestedClassInstance = new wholestagecodegen_NestedClass(); /* 190 */ private agg_NestedClass1 agg_nestedClassInstance1 = new agg_NestedClass1(); /* 191 */ private agg_NestedClass agg_nestedClassInstance = new agg_NestedClass(); /* 192 */ /* 193 */ private class agg_NestedClass1 { /* 194 */ private void agg_doAggregateWithKeys() throws java.io.IOException { /* 195 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 196 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 197 */ int inputadapter_value = inputadapter_row.getInt(0); /* 198 */ boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1); /* 199 */ UTF8String inputadapter_value1 = inputadapter_isNull1 ? /* 200 */ null : (inputadapter_row.getUTF8String(1)); /* 201 */ /* 202 */ agg_nestedClassInstance.agg_doConsume(inputadapter_row, inputadapter_value, inputadapter_value1, inputadapter_isNull1); /* 203 */ if (shouldStop()) return; /* 204 */ } /* 205 */ /* 206 */ agg_fastHashMapIter = agg_fastHashMap.rowIterator(); /* 207 */ agg_mapIter = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).finishAggregate(agg_hashMap, agg_sorter, ((org.apache.spark.sql.execution.metric.SQLMetric) references[4] /* peakMemory */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[5] /* spillSize */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[6] /* avgHashProbe */)); /* 208 */ /* 209 */ } /* 210 */ /* 211 */ } /* 212 */ /* 213 */ private class wholestagecodegen_NestedClass { /* 214 */ private void agg_doAggregateWithKeysOutput(UnsafeRow agg_keyTerm, UnsafeRow agg_bufferTerm) /* 215 */ throws java.io.IOException { /* 216 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[7] /* numOutputRows */).add(1); /* 217 */ /* 218 */ boolean agg_isNull35 = agg_keyTerm.isNullAt(0); /* 219 */ UTF8String agg_value37 = agg_isNull35 ? /* 220 */ null : (agg_keyTerm.getUTF8String(0)); /* 221 */ boolean agg_isNull36 = agg_bufferTerm.isNullAt(0); /* 222 */ double agg_value38 = agg_isNull36 ? /* 223 */ -1.0 : (agg_bufferTerm.getDouble(0)); /* 224 */ boolean agg_isNull37 = agg_bufferTerm.isNullAt(1); /* 225 */ long agg_value39 = agg_isNull37 ? /* 226 */ -1L : (agg_bufferTerm.getLong(1)); /* 227 */ /* 228 */ agg_mutableStateArray1[1].reset(); /* 229 */ /* 230 */ agg_mutableStateArray2[1].zeroOutNullBytes(); /* 231 */ /* 232 */ if (agg_isNull35) { /* 233 */ agg_mutableStateArray2[1].setNullAt(0); /* 234 */ } else { /* 235 */ agg_mutableStateArray2[1].write(0, agg_value37); /* 236 */ } /* 237 */ /* 238 */ if (agg_isNull36) { /* 239 */ agg_mutableStateArray2[1].setNullAt(1); /* 240 */ } else { /* 241 */ agg_mutableStateArray2[1].write(1, agg_value38); /* 242 */ } /* 243 */ /* 244 */ if (agg_isNull37) { /* 245 */ agg_mutableStateArray2[1].setNullAt(2); /* 246 */ } else { /* 247 */ agg_mutableStateArray2[1].write(2, agg_value39); /* 248 */ } /* 249 */ agg_mutableStateArray[1].setTotalSize(agg_mutableStateArray1[1].totalSize()); /* 250 */ append(agg_mutableStateArray[1]); /* 251 */ /* 252 */ } /* 253 */ /* 254 */ } /* 255 */ /* 256 */ private class agg_NestedClass { /* 257 */ private void agg_doConsume(InternalRow inputadapter_row, int agg_expr_0, UTF8String agg_expr_1, boolean agg_exprIsNull_1) throws java.io.IOException { /* 258 */ UnsafeRow agg_unsafeRowAggBuffer = null; /* 259 */ UnsafeRow agg_fastAggBuffer = null; /* 260 */ /* 261 */ if (true) { /* 262 */ if (!agg_exprIsNull_1) { /* 263 */ agg_fastAggBuffer = agg_fastHashMap.findOrInsert( /* 264 */ agg_expr_1); /* 265 */ } /* 266 */ } /* 267 */ // Cannot find the key in fast hash map, try regular hash map. /* 268 */ if (agg_fastAggBuffer == null) { /* 269 */ // generate grouping key /* 270 */ agg_mutableStateArray1[0].reset(); /* 271 */ /* 272 */ agg_mutableStateArray2[0].zeroOutNullBytes(); /* 273 */ /* 274 */ if (agg_exprIsNull_1) { /* 275 */ agg_mutableStateArray2[0].setNullAt(0); /* 276 */ } else { /* 277 */ agg_mutableStateArray2[0].write(0, agg_expr_1); /* 278 */ } /* 279 */ agg_mutableStateArray[0].setTotalSize(agg_mutableStateArray1[0].totalSize()); /* 280 */ int agg_value7 = 42; /* 281 */ /* 282 */ if (!agg_exprIsNull_1) { /* 283 */ agg_value7 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(agg_expr_1.getBaseObject(), agg_expr_1.getBaseOffset(), agg_expr_1.numBytes(), agg_value7); /* 284 */ } /* 285 */ if (true) { /* 286 */ // try to get the buffer from hash map /* 287 */ agg_unsafeRowAggBuffer = /* 288 */ agg_hashMap.getAggregationBufferFromUnsafeRow(agg_mutableStateArray[0], agg_value7); /* 289 */ } /* 290 */ // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based /* 291 */ // aggregation after processing all input rows. /* 292 */ if (agg_unsafeRowAggBuffer == null) { /* 293 */ if (agg_sorter == null) { /* 294 */ agg_sorter = agg_hashMap.destructAndCreateExternalSorter(); /* 295 */ } else { /* 296 */ agg_sorter.merge(agg_hashMap.destructAndCreateExternalSorter()); /* 297 */ } /* 298 */ /* 299 */ // the hash map had be spilled, it should have enough memory now, /* 300 */ // try to allocate buffer again. /* 301 */ agg_unsafeRowAggBuffer = agg_hashMap.getAggregationBufferFromUnsafeRow( /* 302 */ agg_mutableStateArray[0], agg_value7); /* 303 */ if (agg_unsafeRowAggBuffer == null) { /* 304 */ // failed to allocate the first page /* 305 */ throw new OutOfMemoryError("No enough memory for aggregation"); /* 306 */ } /* 307 */ } /* 308 */ /* 309 */ } /* 310 */ /* 311 */ if (agg_fastAggBuffer != null) { /* 312 */ // common sub-expressions /* 313 */ boolean agg_isNull21 = false; /* 314 */ long agg_value23 = -1L; /* 315 */ if (!false) { /* 316 */ agg_value23 = (long) agg_expr_0; /* 317 */ } /* 318 */ // evaluate aggregate function /* 319 */ boolean agg_isNull23 = true; /* 320 */ double agg_value25 = -1.0; /* 321 */ /* 322 */ boolean agg_isNull24 = agg_fastAggBuffer.isNullAt(0); /* 323 */ double agg_value26 = agg_isNull24 ? /* 324 */ -1.0 : (agg_fastAggBuffer.getDouble(0)); /* 325 */ if (!agg_isNull24) { /* 326 */ agg_agg_isNull25 = true; /* 327 */ double agg_value27 = -1.0; /* 328 */ do { /* 329 */ boolean agg_isNull26 = agg_isNull21; /* 330 */ double agg_value28 = -1.0; /* 331 */ if (!agg_isNull21) { /* 332 */ agg_value28 = (double) agg_value23; /* 333 */ } /* 334 */ if (!agg_isNull26) { /* 335 */ agg_agg_isNull25 = false; /* 336 */ agg_value27 = agg_value28; /* 337 */ continue; /* 338 */ } /* 339 */ /* 340 */ boolean agg_isNull27 = false; /* 341 */ double agg_value29 = -1.0; /* 342 */ if (!false) { /* 343 */ agg_value29 = (double) 0; /* 344 */ } /* 345 */ if (!agg_isNull27) { /* 346 */ agg_agg_isNull25 = false; /* 347 */ agg_value27 = agg_value29; /* 348 */ continue; /* 349 */ } /* 350 */ /* 351 */ } while (false); /* 352 */ /* 353 */ agg_isNull23 = false; // resultCode could change nullability. /* 354 */ agg_value25 = agg_value26 + agg_value27; /* 355 */ /* 356 */ } /* 357 */ boolean agg_isNull29 = false; /* 358 */ long agg_value31 = -1L; /* 359 */ if (!false && agg_isNull21) { /* 360 */ boolean agg_isNull31 = agg_fastAggBuffer.isNullAt(1); /* 361 */ long agg_value33 = agg_isNull31 ? /* 362 */ -1L : (agg_fastAggBuffer.getLong(1)); /* 363 */ agg_isNull29 = agg_isNull31; /* 364 */ agg_value31 = agg_value33; /* 365 */ } else { /* 366 */ boolean agg_isNull32 = true; /* 367 */ long agg_value34 = -1L; /* 368 */ /* 369 */ boolean agg_isNull33 = agg_fastAggBuffer.isNullAt(1); /* 370 */ long agg_value35 = agg_isNull33 ? /* 371 */ -1L : (agg_fastAggBuffer.getLong(1)); /* 372 */ if (!agg_isNull33) { /* 373 */ agg_isNull32 = false; // resultCode could change nullability. /* 374 */ agg_value34 = agg_value35 + 1L; /* 375 */ /* 376 */ } /* 377 */ agg_isNull29 = agg_isNull32; /* 378 */ agg_value31 = agg_value34; /* 379 */ } /* 380 */ // update fast row /* 381 */ if (!agg_isNull23) { /* 382 */ agg_fastAggBuffer.setDouble(0, agg_value25); /* 383 */ } else { /* 384 */ agg_fastAggBuffer.setNullAt(0); /* 385 */ } /* 386 */ /* 387 */ if (!agg_isNull29) { /* 388 */ agg_fastAggBuffer.setLong(1, agg_value31); /* 389 */ } else { /* 390 */ agg_fastAggBuffer.setNullAt(1); /* 391 */ } /* 392 */ } else { /* 393 */ // common sub-expressions /* 394 */ boolean agg_isNull7 = false; /* 395 */ long agg_value9 = -1L; /* 396 */ if (!false) { /* 397 */ agg_value9 = (long) agg_expr_0; /* 398 */ } /* 399 */ // evaluate aggregate function /* 400 */ boolean agg_isNull9 = true; /* 401 */ double agg_value11 = -1.0; /* 402 */ /* 403 */ boolean agg_isNull10 = agg_unsafeRowAggBuffer.isNullAt(0); /* 404 */ double agg_value12 = agg_isNull10 ? /* 405 */ -1.0 : (agg_unsafeRowAggBuffer.getDouble(0)); /* 406 */ if (!agg_isNull10) { /* 407 */ agg_agg_isNull11 = true; /* 408 */ double agg_value13 = -1.0; /* 409 */ do { /* 410 */ boolean agg_isNull12 = agg_isNull7; /* 411 */ double agg_value14 = -1.0; /* 412 */ if (!agg_isNull7) { /* 413 */ agg_value14 = (double) agg_value9; /* 414 */ } /* 415 */ if (!agg_isNull12) { /* 416 */ agg_agg_isNull11 = false; /* 417 */ agg_value13 = agg_value14; /* 418 */ continue; /* 419 */ } /* 420 */ /* 421 */ boolean agg_isNull13 = false; /* 422 */ double agg_value15 = -1.0; /* 423 */ if (!false) { /* 424 */ agg_value15 = (double) 0; /* 425 */ } /* 426 */ if (!agg_isNull13) { /* 427 */ agg_agg_isNull11 = false; /* 428 */ agg_value13 = agg_value15; /* 429 */ continue; /* 430 */ } /* 431 */ /* 432 */ } while (false); /* 433 */ /* 434 */ agg_isNull9 = false; // resultCode could change nullability. /* 435 */ agg_value11 = agg_value12 + agg_value13; /* 436 */ /* 437 */ } /* 438 */ boolean agg_isNull15 = false; /* 439 */ long agg_value17 = -1L; /* 440 */ if (!false && agg_isNull7) { /* 441 */ boolean agg_isNull17 = agg_unsafeRowAggBuffer.isNullAt(1); /* 442 */ long agg_value19 = agg_isNull17 ? /* 443 */ -1L : (agg_unsafeRowAggBuffer.getLong(1)); /* 444 */ agg_isNull15 = agg_isNull17; /* 445 */ agg_value17 = agg_value19; /* 446 */ } else { /* 447 */ boolean agg_isNull18 = true; /* 448 */ long agg_value20 = -1L; /* 449 */ /* 450 */ boolean agg_isNull19 = agg_unsafeRowAggBuffer.isNullAt(1); /* 451 */ long agg_value21 = agg_isNull19 ? /* 452 */ -1L : (agg_unsafeRowAggBuffer.getLong(1)); /* 453 */ if (!agg_isNull19) { /* 454 */ agg_isNull18 = false; // resultCode could change nullability. /* 455 */ agg_value20 = agg_value21 + 1L; /* 456 */ /* 457 */ } /* 458 */ agg_isNull15 = agg_isNull18; /* 459 */ agg_value17 = agg_value20; /* 460 */ } /* 461 */ // update unsafe row buffer /* 462 */ if (!agg_isNull9) { /* 463 */ agg_unsafeRowAggBuffer.setDouble(0, agg_value11); /* 464 */ } else { /* 465 */ agg_unsafeRowAggBuffer.setNullAt(0); /* 466 */ } /* 467 */ /* 468 */ if (!agg_isNull15) { /* 469 */ agg_unsafeRowAggBuffer.setLong(1, agg_value17); /* 470 */ } else { /* 471 */ agg_unsafeRowAggBuffer.setNullAt(1); /* 472 */ } /* 473 */ /* 474 */ } /* 475 */ /* 476 */ } /* 477 */ /* 478 */ } /* 479 */ /* 480 */ } ``` ## How was this patch tested? Added UT into `WholeStageCodegenSuite` Author: Kazuaki Ishizaki Closes #20779 from kiszk/SPARK-23598. --- .../spark/sql/execution/BufferedRowIterator.java | 12 ++++++++---- .../spark/sql/execution/WholeStageCodegenSuite.scala | 12 ++++++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java index 730a4ae8d5605..74c9c05992719 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -62,10 +62,14 @@ public long durationMs() { */ public abstract void init(int index, Iterator[] iters); + /* + * Attributes of the following four methods are public. Thus, they can be also accessed from + * methods in inner classes. See SPARK-23598 + */ /** * Append a row to currentRows. */ - protected void append(InternalRow row) { + public void append(InternalRow row) { currentRows.add(row); } @@ -75,7 +79,7 @@ protected void append(InternalRow row) { * If it returns true, the caller should exit the loop that [[InputAdapter]] generates. * This interface is mainly used to limit the number of input rows. */ - protected boolean stopEarly() { + public boolean stopEarly() { return false; } @@ -84,14 +88,14 @@ protected boolean stopEarly() { * * If it returns true, the caller should exit the loop (return from processNext()). */ - protected boolean shouldStop() { + public boolean shouldStop() { return !currentRows.isEmpty(); } /** * Increase the peak execution memory for current task. */ - protected void incPeakExecutionMemory(long size) { + public void incPeakExecutionMemory(long size) { TaskContext.get().taskMetrics().incPeakExecutionMemory(size); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 0fb9dd2017a09..4b40e4ef7571c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -32,6 +32,8 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructType} class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + test("range/filter should be combined") { val df = spark.range(10).filter("id = 1").selectExpr("id + 1") val plan = df.queryExecution.executedPlan @@ -307,4 +309,14 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { // a different query can result in codegen cache miss, that's by design } } + + test("SPARK-23598: Codegen working for lots of aggregation operations without runtime errors") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + var df = Seq((8, "bat"), (15, "mouse"), (5, "horse")).toDF("age", "name") + for (i <- 0 until 70) { + df = df.groupBy("name").agg(avg("age").alias("age")) + } + assert(df.limit(1).collect() === Array(Row("bat", 8.0))) + } + } } From 279b3db8970809104c30941254e57e3d62da5041 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Wed, 14 Mar 2018 18:36:31 -0700 Subject: [PATCH 470/774] [SPARK-22915][MLLIB] Streaming tests for spark.ml.feature, from N to Z MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What changes were proposed in this pull request? Adds structured streaming tests using testTransformer for these suites: - NGramSuite - NormalizerSuite - OneHotEncoderEstimatorSuite - OneHotEncoderSuite - PCASuite - PolynomialExpansionSuite - QuantileDiscretizerSuite - RFormulaSuite - SQLTransformerSuite - StandardScalerSuite - StopWordsRemoverSuite - StringIndexerSuite - TokenizerSuite - RegexTokenizerSuite - VectorAssemblerSuite - VectorIndexerSuite - VectorSizeHintSuite - VectorSlicerSuite - Word2VecSuite # How was this patch tested? They are unit test. Author: “attilapiros” Closes #20686 from attilapiros/SPARK-22915. --- .../apache/spark/ml/feature/NGramSuite.scala | 23 +- .../spark/ml/feature/NormalizerSuite.scala | 57 ++--- .../feature/OneHotEncoderEstimatorSuite.scala | 193 ++++++++--------- .../spark/ml/feature/OneHotEncoderSuite.scala | 124 ++++++----- .../apache/spark/ml/feature/PCASuite.scala | 14 +- .../ml/feature/PolynomialExpansionSuite.scala | 62 +++--- .../ml/feature/QuantileDiscretizerSuite.scala | 198 +++++++++-------- .../spark/ml/feature/RFormulaSuite.scala | 158 +++++++------- .../ml/feature/SQLTransformerSuite.scala | 35 +-- .../ml/feature/StandardScalerSuite.scala | 33 +-- .../ml/feature/StopWordsRemoverSuite.scala | 37 ++-- .../spark/ml/feature/StringIndexerSuite.scala | 204 +++++++++--------- .../spark/ml/feature/TokenizerSuite.scala | 30 +-- .../spark/ml/feature/VectorIndexerSuite.scala | 183 +++++++++------- .../ml/feature/VectorSizeHintSuite.scala | 88 +++++--- .../spark/ml/feature/VectorSlicerSuite.scala | 27 +-- .../spark/ml/feature/Word2VecSuite.scala | 28 +-- .../org/apache/spark/ml/util/MLTest.scala | 33 ++- 18 files changed, 809 insertions(+), 718 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index d4975c0b4e20e..e5956ee9942aa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -19,17 +19,15 @@ package org.apache.spark.ml.feature import scala.beans.BeanInfo -import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{DataFrame, Row} + @BeanInfo case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) -class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class NGramSuite extends MLTest with DefaultReadWriteTest { - import org.apache.spark.ml.feature.NGramSuite._ import testImplicits._ test("default behavior yields bigram features") { @@ -83,16 +81,11 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe .setN(3) testDefaultReadWrite(t) } -} - -object NGramSuite extends SparkFunSuite { - def testNGram(t: NGram, dataset: Dataset[_]): Unit = { - t.transform(dataset) - .select("nGrams", "wantedNGrams") - .collect() - .foreach { case Row(actualNGrams, wantedNGrams) => + def testNGram(t: NGram, dataFrame: DataFrame): Unit = { + testTransformer[(Seq[String], Seq[String])](dataFrame, t, "nGrams", "wantedNGrams") { + case Row(actualNGrams : Seq[String], wantedNGrams: Seq[String]) => assert(actualNGrams === wantedNGrams) - } + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index c75027fb4553d..eff57f1223af4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -17,21 +17,17 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class NormalizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @transient var data: Array[Vector] = _ - @transient var dataFrame: DataFrame = _ - @transient var normalizer: Normalizer = _ @transient var l1Normalized: Array[Vector] = _ @transient var l2Normalized: Array[Vector] = _ @@ -62,49 +58,40 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa Vectors.dense(0.897906166, 0.113419726, 0.42532397), Vectors.sparse(3, Seq()) ) - - dataFrame = data.map(NormalizerSuite.FeatureData).toSeq.toDF() - normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normalized_features") - } - - def collectResult(result: DataFrame): Array[Vector] = { - result.select("normalized_features").collect().map { - case Row(features: Vector) => features - } } - def assertTypeOfVector(lhs: Array[Vector], rhs: Array[Vector]): Unit = { - assert((lhs, rhs).zipped.forall { + def assertTypeOfVector(lhs: Vector, rhs: Vector): Unit = { + assert((lhs, rhs) match { case (v1: DenseVector, v2: DenseVector) => true case (v1: SparseVector, v2: SparseVector) => true case _ => false }, "The vector type should be preserved after normalization.") } - def assertValues(lhs: Array[Vector], rhs: Array[Vector]): Unit = { - assert((lhs, rhs).zipped.forall { (vector1, vector2) => - vector1 ~== vector2 absTol 1E-5 - }, "The vector value is not correct after normalization.") + def assertValues(lhs: Vector, rhs: Vector): Unit = { + assert(lhs ~== rhs absTol 1E-5, "The vector value is not correct after normalization.") } test("Normalization with default parameter") { - val result = collectResult(normalizer.transform(dataFrame)) - - assertTypeOfVector(data, result) + val normalizer = new Normalizer().setInputCol("features").setOutputCol("normalized") + val dataFrame: DataFrame = data.zip(l2Normalized).seq.toDF("features", "expected") - assertValues(result, l2Normalized) + testTransformer[(Vector, Vector)](dataFrame, normalizer, "features", "normalized", "expected") { + case Row(features: Vector, normalized: Vector, expected: Vector) => + assertTypeOfVector(normalized, features) + assertValues(normalized, expected) + } } test("Normalization with setter") { - normalizer.setP(1) + val dataFrame: DataFrame = data.zip(l1Normalized).seq.toDF("features", "expected") + val normalizer = new Normalizer().setInputCol("features").setOutputCol("normalized").setP(1) - val result = collectResult(normalizer.transform(dataFrame)) - - assertTypeOfVector(data, result) - - assertValues(result, l1Normalized) + testTransformer[(Vector, Vector)](dataFrame, normalizer, "features", "normalized", "expected") { + case Row(features: Vector, normalized: Vector, expected: Vector) => + assertTypeOfVector(normalized, features) + assertValues(normalized, expected) + } } test("read/write") { @@ -115,7 +102,3 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa testDefaultReadWrite(t) } } - -private object NormalizerSuite { - case class FeatureData(features: Vector) -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala index 1d3f845586426..d549e13262273 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala @@ -17,18 +17,16 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{Encoder, Row} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ -class OneHotEncoderEstimatorSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class OneHotEncoderEstimatorSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -57,13 +55,10 @@ class OneHotEncoderEstimatorSuite assert(encoder.getDropLast === true) encoder.setDropLast(false) assert(encoder.getDropLast === false) - val model = encoder.fit(df) - val encoded = model.transform(df) - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector)](df, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) } } @@ -87,11 +82,9 @@ class OneHotEncoderEstimatorSuite .setOutputCols(Array("output")) val model = encoder.fit(df) - val encoded = model.transform(df) - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector)](df, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) } } @@ -103,11 +96,12 @@ class OneHotEncoderEstimatorSuite .setInputCols(Array("size")) .setOutputCols(Array("encoded")) val model = encoder.fit(df) - val output = model.transform(df) - val group = AttributeGroup.fromStructField(output.schema("encoded")) - assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") { rows => + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + } } test("input column without ML attribute") { @@ -116,11 +110,12 @@ class OneHotEncoderEstimatorSuite .setInputCols(Array("index")) .setOutputCols(Array("encoded")) val model = encoder.fit(df) - val output = model.transform(df) - val group = AttributeGroup.fromStructField(output.schema("encoded")) - assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) + testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") { rows => + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) + } } test("read/write") { @@ -151,29 +146,30 @@ class OneHotEncoderEstimatorSuite val df = spark.createDataFrame(sc.parallelize(data), schema) - val dfWithTypes = df - .withColumn("shortInput", df("input").cast(ShortType)) - .withColumn("longInput", df("input").cast(LongType)) - .withColumn("intInput", df("input").cast(IntegerType)) - .withColumn("floatInput", df("input").cast(FloatType)) - .withColumn("decimalInput", df("input").cast(DecimalType(10, 0))) - - val cols = Array("input", "shortInput", "longInput", "intInput", - "floatInput", "decimalInput") - for (col <- cols) { - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array(col)) + class NumericTypeWithEncoder[A](val numericType: NumericType) + (implicit val encoder: Encoder[(A, Vector)]) + + val types = Seq( + new NumericTypeWithEncoder[Short](ShortType), + new NumericTypeWithEncoder[Long](LongType), + new NumericTypeWithEncoder[Int](IntegerType), + new NumericTypeWithEncoder[Float](FloatType), + new NumericTypeWithEncoder[Byte](ByteType), + new NumericTypeWithEncoder[Double](DoubleType), + new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder())) + + for (t <- types) { + val dfWithTypes = df.select(col("input").cast(t.numericType), col("expected")) + val estimator = new OneHotEncoderEstimator() + .setInputCols(Array("input")) .setOutputCols(Array("output")) .setDropLast(false) - val model = encoder.fit(dfWithTypes) - val encoded = model.transform(dfWithTypes) - - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) - } + val model = estimator.fit(dfWithTypes) + testTransformer(dfWithTypes, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + }(t.encoder) } } @@ -202,12 +198,16 @@ class OneHotEncoderEstimatorSuite assert(encoder.getDropLast === false) val model = encoder.fit(df) - val encoded = model.transform(df) - encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3)) - }.collect().foreach { case (vec1, vec2, vec3, vec4) => - assert(vec1 === vec2) - assert(vec3 === vec4) + testTransformer[(Double, Vector, Double, Vector)]( + df, + model, + "output1", + "output2", + "expected1", + "expected2") { + case Row(output1: Vector, output2: Vector, expected1: Vector, expected2: Vector) => + assert(output1 === expected1) + assert(output2 === expected2) } } @@ -233,12 +233,16 @@ class OneHotEncoderEstimatorSuite .setOutputCols(Array("output1", "output2")) val model = encoder.fit(df) - val encoded = model.transform(df) - encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3)) - }.collect().foreach { case (vec1, vec2, vec3, vec4) => - assert(vec1 === vec2) - assert(vec3 === vec4) + testTransformer[(Double, Vector, Double, Vector)]( + df, + model, + "output1", + "output2", + "expected1", + "expected2") { + case Row(output1: Vector, output2: Vector, expected1: Vector, expected2: Vector) => + assert(output1 === expected1) + assert(output2 === expected2) } } @@ -253,10 +257,12 @@ class OneHotEncoderEstimatorSuite .setOutputCols(Array("encoded")) val model = encoder.fit(trainingDF) - val err = intercept[SparkException] { - model.transform(testDF).show - } - err.getMessage.contains("Unseen value: 3.0. To handle unseen values") + testTransformerByInterceptingException[(Int, Int)]( + testDF, + model, + expectedMessagePart = "Unseen value: 3.0. To handle unseen values", + firstResultCol = "encoded") + } test("Can't transform on negative input") { @@ -268,10 +274,11 @@ class OneHotEncoderEstimatorSuite .setOutputCols(Array("encoded")) val model = encoder.fit(trainingDF) - val err = intercept[SparkException] { - model.transform(testDF).collect() - } - err.getMessage.contains("Negative value: -1.0. Input can't be negative") + testTransformerByInterceptingException[(Int, Int)]( + testDF, + model, + expectedMessagePart = "Negative value: -1.0. Input can't be negative", + firstResultCol = "encoded") } test("Keep on invalid values: dropLast = false") { @@ -295,11 +302,9 @@ class OneHotEncoderEstimatorSuite .setDropLast(false) val model = encoder.fit(trainingDF) - val encoded = model.transform(testDF) - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector)](testDF, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) } } @@ -324,11 +329,9 @@ class OneHotEncoderEstimatorSuite .setDropLast(true) val model = encoder.fit(trainingDF) - val encoded = model.transform(testDF) - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector)](testDF, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) } } @@ -355,19 +358,15 @@ class OneHotEncoderEstimatorSuite val model = encoder.fit(df) model.setDropLast(false) - val encoded1 = model.transform(df) - encoded1.select("output", "expected1").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector, Vector)](df, model, "output", "expected1") { + case Row(output: Vector, expected1: Vector) => + assert(output === expected1) } model.setDropLast(true) - val encoded2 = model.transform(df) - encoded2.select("output", "expected2").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector, Vector)](df, model, "output", "expected2") { + case Row(output: Vector, expected2: Vector) => + assert(output === expected2) } } @@ -392,13 +391,14 @@ class OneHotEncoderEstimatorSuite val model = encoder.fit(trainingDF) model.setHandleInvalid("error") - val err = intercept[SparkException] { - model.transform(testDF).collect() - } - err.getMessage.contains("Unseen value: 3.0. To handle unseen values") + testTransformerByInterceptingException[(Double, Vector)]( + testDF, + model, + expectedMessagePart = "Unseen value: 3.0. To handle unseen values", + firstResultCol = "output") model.setHandleInvalid("keep") - model.transform(testDF).collect() + testTransformerByGlobalCheckFunc[(Double, Vector)](testDF, model, "output") { _ => } } test("Transforming on mismatched attributes") { @@ -413,9 +413,10 @@ class OneHotEncoderEstimatorSuite val testAttr = NominalAttribute.defaultAttr.withValues("tiny", "small", "medium", "large") val testDF = Seq(0.0, 1.0, 2.0, 3.0).map(Tuple1.apply).toDF("size") .select(col("size").as("size", testAttr.toMetadata())) - val err = intercept[Exception] { - model.transform(testDF).collect() - } - err.getMessage.contains("OneHotEncoderModel expected 2 categorical values") + testTransformerByInterceptingException[(Double)]( + testDF, + model, + expectedMessagePart = "OneHotEncoderModel expected 2 categorical values", + firstResultCol = "encoded") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index c44c6813a94be..41b32b2ffa096 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -17,18 +17,18 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.DataFrame +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{DataFrame, Encoder, Row} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ class OneHotEncoderSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -54,16 +54,19 @@ class OneHotEncoderSuite assert(encoder.getDropLast === true) encoder.setDropLast(false) assert(encoder.getDropLast === false) - val encoded = encoder.transform(transformed) - - val output = encoded.select("id", "labelVec").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1), vec(2)) - }.collect().toSet - // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), - (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0)) - assert(output === expected) + val expected = Seq( + (0, Vectors.sparse(3, Seq((0, 1.0)))), + (1, Vectors.sparse(3, Seq((2, 1.0)))), + (2, Vectors.sparse(3, Seq((1, 1.0)))), + (3, Vectors.sparse(3, Seq((0, 1.0)))), + (4, Vectors.sparse(3, Seq((0, 1.0)))), + (5, Vectors.sparse(3, Seq((1, 1.0))))).toDF("id", "expected") + + val withExpected = transformed.join(expected, "id") + testTransformer[(Int, String, Double, Vector)](withExpected, encoder, "labelVec", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + } } test("OneHotEncoder dropLast = true") { @@ -71,16 +74,19 @@ class OneHotEncoderSuite val encoder = new OneHotEncoder() .setInputCol("labelIndex") .setOutputCol("labelVec") - val encoded = encoder.transform(transformed) - - val output = encoded.select("id", "labelVec").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1)) - }.collect().toSet - // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0), - (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0)) - assert(output === expected) + val expected = Seq( + (0, Vectors.sparse(2, Seq((0, 1.0)))), + (1, Vectors.sparse(2, Seq())), + (2, Vectors.sparse(2, Seq((1, 1.0)))), + (3, Vectors.sparse(2, Seq((0, 1.0)))), + (4, Vectors.sparse(2, Seq((0, 1.0)))), + (5, Vectors.sparse(2, Seq((1, 1.0))))).toDF("id", "expected") + + val withExpected = transformed.join(expected, "id") + testTransformer[(Int, String, Double, Vector)](withExpected, encoder, "labelVec", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + } } test("input column with ML attribute") { @@ -90,20 +96,22 @@ class OneHotEncoderSuite val encoder = new OneHotEncoder() .setInputCol("size") .setOutputCol("encoded") - val output = encoder.transform(df) - val group = AttributeGroup.fromStructField(output.schema("encoded")) - assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + testTransformerByGlobalCheckFunc[(Double)](df, encoder, "encoded") { rows => + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + } } + test("input column without ML attribute") { val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("index") val encoder = new OneHotEncoder() .setInputCol("index") .setOutputCol("encoded") - val output = encoder.transform(df) - val group = AttributeGroup.fromStructField(output.schema("encoded")) + val rows = encoder.transform(df).select("encoded").collect() + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) assert(group.size === 2) assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) @@ -119,29 +127,41 @@ class OneHotEncoderSuite test("OneHotEncoder with varying types") { val df = stringIndexed() - val dfWithTypes = df - .withColumn("shortLabel", df("labelIndex").cast(ShortType)) - .withColumn("longLabel", df("labelIndex").cast(LongType)) - .withColumn("intLabel", df("labelIndex").cast(IntegerType)) - .withColumn("floatLabel", df("labelIndex").cast(FloatType)) - .withColumn("decimalLabel", df("labelIndex").cast(DecimalType(10, 0))) - val cols = Array("labelIndex", "shortLabel", "longLabel", "intLabel", - "floatLabel", "decimalLabel") - for (col <- cols) { + val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") + val expected = Seq( + (0, Vectors.sparse(3, Seq((0, 1.0)))), + (1, Vectors.sparse(3, Seq((2, 1.0)))), + (2, Vectors.sparse(3, Seq((1, 1.0)))), + (3, Vectors.sparse(3, Seq((0, 1.0)))), + (4, Vectors.sparse(3, Seq((0, 1.0)))), + (5, Vectors.sparse(3, Seq((1, 1.0))))).toDF("id", "expected") + + val withExpected = df.join(expected, "id") + + class NumericTypeWithEncoder[A](val numericType: NumericType) + (implicit val encoder: Encoder[(A, Vector)]) + + val types = Seq( + new NumericTypeWithEncoder[Short](ShortType), + new NumericTypeWithEncoder[Long](LongType), + new NumericTypeWithEncoder[Int](IntegerType), + new NumericTypeWithEncoder[Float](FloatType), + new NumericTypeWithEncoder[Byte](ByteType), + new NumericTypeWithEncoder[Double](DoubleType), + new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder())) + + for (t <- types) { + val dfWithTypes = withExpected.select(col("labelIndex") + .cast(t.numericType).as("labelIndex", attr.toMetadata()), col("expected")) val encoder = new OneHotEncoder() - .setInputCol(col) + .setInputCol("labelIndex") .setOutputCol("labelVec") .setDropLast(false) - val encoded = encoder.transform(dfWithTypes) - - val output = encoded.select("id", "labelVec").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1), vec(2)) - }.collect().toSet - // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), - (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0)) - assert(output === expected) + + testTransformer(dfWithTypes, encoder, "labelVec", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + }(t.encoder) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index 3067a52a4df76..531b1d7c4d9f7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.linalg.distributed.RowMatrix -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row -class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class PCASuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -62,10 +60,10 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val pcaModel = pca.fit(df) MLTestingUtils.checkCopyAndUids(pca, pcaModel) - - pcaModel.transform(df).select("pca_features", "expected").collect().foreach { - case Row(x: Vector, y: Vector) => - assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") + testTransformer[(Vector, Vector)](df, pcaModel, "pca_features", "expected") { + case Row(result: Vector, expected: Vector) => + assert(result ~== expected absTol 1e-5, + "Transformed vector is different with expected vector.") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index e4b0ddf98bfad..0be7aa6c83f29 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -17,18 +17,13 @@ package org.apache.spark.ml.feature -import org.scalatest.exceptions.TestFailedException - -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row -class PolynomialExpansionSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class PolynomialExpansionSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -60,6 +55,18 @@ class PolynomialExpansionSuite -1.08, 3.3, 1.98, -3.63, 9.0, 5.4, -9.9, -27.0), Vectors.sparse(19, Array.empty, Array.empty)) + def assertTypeOfVector(lhs: Vector, rhs: Vector): Unit = { + assert((lhs, rhs) match { + case (v1: DenseVector, v2: DenseVector) => true + case (v1: SparseVector, v2: SparseVector) => true + case _ => false + }, "The vector type should be preserved after polynomial expansion.") + } + + def assertValues(lhs: Vector, rhs: Vector): Unit = { + assert(lhs ~== rhs absTol 1e-1, "The vector value is not correct after polynomial expansion.") + } + test("Polynomial expansion with default parameter") { val df = data.zip(twoDegreeExpansion).toSeq.toDF("features", "expected") @@ -67,13 +74,10 @@ class PolynomialExpansionSuite .setInputCol("features") .setOutputCol("polyFeatures") - polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { - case Row(expanded: DenseVector, expected: DenseVector) => - assert(expanded ~== expected absTol 1e-1) - case Row(expanded: SparseVector, expected: SparseVector) => - assert(expanded ~== expected absTol 1e-1) - case _ => - throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + testTransformer[(Vector, Vector)](df, polynomialExpansion, "polyFeatures", "expected") { + case Row(expanded: Vector, expected: Vector) => + assertTypeOfVector(expanded, expected) + assertValues(expanded, expected) } } @@ -85,13 +89,10 @@ class PolynomialExpansionSuite .setOutputCol("polyFeatures") .setDegree(3) - polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { - case Row(expanded: DenseVector, expected: DenseVector) => - assert(expanded ~== expected absTol 1e-1) - case Row(expanded: SparseVector, expected: SparseVector) => - assert(expanded ~== expected absTol 1e-1) - case _ => - throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + testTransformer[(Vector, Vector)](df, polynomialExpansion, "polyFeatures", "expected") { + case Row(expanded: Vector, expected: Vector) => + assertTypeOfVector(expanded, expected) + assertValues(expanded, expected) } } @@ -103,11 +104,9 @@ class PolynomialExpansionSuite .setOutputCol("polyFeatures") .setDegree(1) - polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { + testTransformer[(Vector, Vector)](df, polynomialExpansion, "polyFeatures", "expected") { case Row(expanded: Vector, expected: Vector) => - assert(expanded ~== expected absTol 1e-1) - case _ => - throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + assertValues(expanded, expected) } } @@ -133,12 +132,13 @@ class PolynomialExpansionSuite .setOutputCol("polyFeatures") for (i <- Seq(10, 11)) { - val transformed = t.setDegree(i) - .transform(df) - .select(s"expectedPoly${i}size", "polyFeatures") - .rdd.map { case Row(expected: Int, v: Vector) => expected == v.size } - - assert(transformed.collect.forall(identity)) + testTransformer[(Vector, Int, Int)]( + df, + t.setDegree(i), + s"expectedPoly${i}size", + "polyFeatures") { case Row(size: Int, expected: Vector) => + assert(size === expected.size) + } } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 6c363799dd300..b009038bbd833 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -17,15 +17,11 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql._ -import org.apache.spark.sql.functions.udf -class QuantileDiscretizerSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -40,19 +36,19 @@ class QuantileDiscretizerSuite .setInputCol("input") .setOutputCol("result") .setNumBuckets(numBuckets) - val result = discretizer.fit(df).transform(df) - - val observedNumBuckets = result.select("result").distinct.count - assert(observedNumBuckets === numBuckets, - "Observed number of buckets does not equal expected number of buckets.") + val model = discretizer.fit(df) - val relativeError = discretizer.getRelativeError - val isGoodBucket = udf { - (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize) + testTransformerByGlobalCheckFunc[(Double)](df, model, "result") { rows => + val result = rows.map { r => Tuple1(r.getDouble(0)) }.toDF("result") + val observedNumBuckets = result.select("result").distinct.count + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") + val relativeError = discretizer.getRelativeError + val numGoodBuckets = result.groupBy("result").count + .filter(s"abs(count - ${datasetSize / numBuckets}) <= ${relativeError * datasetSize}").count + assert(numGoodBuckets === numBuckets, + "Bucket sizes are not within expected relative error tolerance.") } - val numGoodBuckets = result.groupBy("result").count.filter(isGoodBucket($"count")).count - assert(numGoodBuckets === numBuckets, - "Bucket sizes are not within expected relative error tolerance.") } test("Test on data with high proportion of duplicated values") { @@ -67,11 +63,14 @@ class QuantileDiscretizerSuite .setInputCol("input") .setOutputCol("result") .setNumBuckets(numBuckets) - val result = discretizer.fit(df).transform(df) - val observedNumBuckets = result.select("result").distinct.count - assert(observedNumBuckets == expectedNumBuckets, - s"Observed number of buckets are not correct." + - s" Expected $expectedNumBuckets but found $observedNumBuckets") + val model = discretizer.fit(df) + testTransformerByGlobalCheckFunc[(Double)](df, model, "result") { rows => + val result = rows.map { r => Tuple1(r.getDouble(0)) }.toDF("result") + val observedNumBuckets = result.select("result").distinct.count + assert(observedNumBuckets == expectedNumBuckets, + s"Observed number of buckets are not correct." + + s" Expected $expectedNumBuckets but found $observedNumBuckets") + } } test("Test transform on data with NaN value") { @@ -90,17 +89,20 @@ class QuantileDiscretizerSuite withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") { val dataFrame: DataFrame = validData.toSeq.toDF("input") - intercept[SparkException] { - discretizer.fit(dataFrame).transform(dataFrame).collect() - } + val model = discretizer.fit(dataFrame) + testTransformerByInterceptingException[(Double)]( + dataFrame, + model, + expectedMessagePart = "Bucketizer encountered NaN value.", + firstResultCol = "result") } List(("keep", expectedKeep), ("skip", expectedSkip)).foreach{ case(u, v) => discretizer.setHandleInvalid(u) val dataFrame: DataFrame = validData.zip(v).toSeq.toDF("input", "expected") - val result = discretizer.fit(dataFrame).transform(dataFrame) - result.select("result", "expected").collect().foreach { + val model = discretizer.fit(dataFrame) + testTransformer[(Double, Double)](dataFrame, model, "result", "expected") { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") @@ -119,14 +121,17 @@ class QuantileDiscretizerSuite .setOutputCol("result") .setNumBuckets(5) - val result = discretizer.fit(trainDF).transform(testDF) - val firstBucketSize = result.filter(result("result") === 0.0).count - val lastBucketSize = result.filter(result("result") === 4.0).count + val model = discretizer.fit(trainDF) + testTransformerByGlobalCheckFunc[(Double)](testDF, model, "result") { rows => + val result = rows.map { r => Tuple1(r.getDouble(0)) }.toDF("result") + val firstBucketSize = result.filter(result("result") === 0.0).count + val lastBucketSize = result.filter(result("result") === 4.0).count - assert(firstBucketSize === 30L, - s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.") - assert(lastBucketSize === 31L, - s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.") + assert(firstBucketSize === 30L, + s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.") + assert(lastBucketSize === 31L, + s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.") + } } test("read/write") { @@ -167,21 +172,24 @@ class QuantileDiscretizerSuite .setInputCols(Array("input1", "input2")) .setOutputCols(Array("result1", "result2")) .setNumBuckets(numBuckets) - val result = discretizer.fit(df).transform(df) - - val relativeError = discretizer.getRelativeError - val isGoodBucket = udf { - (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize) - } - - for (i <- 1 to 2) { - val observedNumBuckets = result.select("result" + i).distinct.count - assert(observedNumBuckets === numBuckets, - "Observed number of buckets does not equal expected number of buckets.") - - val numGoodBuckets = result.groupBy("result" + i).count.filter(isGoodBucket($"count")).count - assert(numGoodBuckets === numBuckets, - "Bucket sizes are not within expected relative error tolerance.") + val model = discretizer.fit(df) + testTransformerByGlobalCheckFunc[(Double, Double)](df, model, "result1", "result2") { rows => + val result = + rows.map { r => Tuple2(r.getDouble(0), r.getDouble(1)) }.toDF("result1", "result2") + val relativeError = discretizer.getRelativeError + for (i <- 1 to 2) { + val observedNumBuckets = result.select("result" + i).distinct.count + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") + + val numGoodBuckets = result + .groupBy("result" + i) + .count + .filter(s"abs(count - ${datasetSize / numBuckets}) <= ${relativeError * datasetSize}") + .count + assert(numGoodBuckets === numBuckets, + "Bucket sizes are not within expected relative error tolerance.") + } } } @@ -198,12 +206,16 @@ class QuantileDiscretizerSuite .setInputCols(Array("input1", "input2")) .setOutputCols(Array("result1", "result2")) .setNumBuckets(numBuckets) - val result = discretizer.fit(df).transform(df) - for (i <- 1 to 2) { - val observedNumBuckets = result.select("result" + i).distinct.count - assert(observedNumBuckets == expectedNumBucket, - s"Observed number of buckets are not correct." + - s" Expected $expectedNumBucket but found ($observedNumBuckets") + val model = discretizer.fit(df) + testTransformerByGlobalCheckFunc[(Double, Double)](df, model, "result1", "result2") { rows => + val result = + rows.map { r => Tuple2(r.getDouble(0), r.getDouble(1)) }.toDF("result1", "result2") + for (i <- 1 to 2) { + val observedNumBuckets = result.select("result" + i).distinct.count + assert(observedNumBuckets == expectedNumBucket, + s"Observed number of buckets are not correct." + + s" Expected $expectedNumBucket but found ($observedNumBuckets") + } } } @@ -226,9 +238,12 @@ class QuantileDiscretizerSuite withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") { val dataFrame: DataFrame = validData1.zip(validData2).toSeq.toDF("input1", "input2") - intercept[SparkException] { - discretizer.fit(dataFrame).transform(dataFrame).collect() - } + val model = discretizer.fit(dataFrame) + testTransformerByInterceptingException[(Double, Double)]( + dataFrame, + model, + expectedMessagePart = "Bucketizer encountered NaN value.", + firstResultCol = "result1") } List(("keep", expectedKeep1, expectedKeep2), ("skip", expectedSkip1, expectedSkip2)).foreach { @@ -237,8 +252,14 @@ class QuantileDiscretizerSuite val dataFrame: DataFrame = validData1.zip(validData2).zip(v).zip(w).map { case (((a, b), c), d) => (a, b, c, d) }.toSeq.toDF("input1", "input2", "expected1", "expected2") - val result = discretizer.fit(dataFrame).transform(dataFrame) - result.select("result1", "expected1", "result2", "expected2").collect().foreach { + val model = discretizer.fit(dataFrame) + testTransformer[(Double, Double, Double, Double)]( + dataFrame, + model, + "result1", + "expected1", + "result2", + "expected2") { case Row(x: Double, y: Double, z: Double, w: Double) => assert(x === y && w === z) } @@ -270,9 +291,16 @@ class QuantileDiscretizerSuite .setOutputCols(Array("result1", "result2", "result3")) .setNumBucketsArray(numBucketsArray) - discretizer.fit(df).transform(df). - select("result1", "expected1", "result2", "expected2", "result3", "expected3") - .collect().foreach { + val model = discretizer.fit(df) + testTransformer[(Double, Double, Double, Double, Double, Double)]( + df, + model, + "result1", + "expected1", + "result2", + "expected2", + "result3", + "expected3") { case Row(r1: Double, e1: Double, r2: Double, e2: Double, r3: Double, e3: Double) => assert(r1 === e1, s"The result value is not correct after bucketing. Expected $e1 but found $r1") @@ -324,20 +352,16 @@ class QuantileDiscretizerSuite .setStages(Array(discretizerForCol1, discretizerForCol2, discretizerForCol3)) .fit(df) - val resultForMultiCols = plForMultiCols.transform(df) - .select("result1", "result2", "result3") - .collect() - - val resultForSingleCol = plForSingleCol.transform(df) - .select("result1", "result2", "result3") - .collect() + val expected = plForSingleCol.transform(df).select("result1", "result2", "result3").collect() - resultForSingleCol.zip(resultForMultiCols).foreach { - case (rowForSingle, rowForMultiCols) => - assert(rowForSingle.getDouble(0) == rowForMultiCols.getDouble(0) && - rowForSingle.getDouble(1) == rowForMultiCols.getDouble(1) && - rowForSingle.getDouble(2) == rowForMultiCols.getDouble(2)) - } + testTransformerByGlobalCheckFunc[(Double, Double, Double)]( + df, + plForMultiCols, + "result1", + "result2", + "result3") { rows => + assert(rows === expected) + } } test("Multiple Columns: Comparing setting numBuckets with setting numBucketsArray " + @@ -364,18 +388,16 @@ class QuantileDiscretizerSuite .setOutputCols(Array("result1", "result2", "result3")) .setNumBucketsArray(Array(10, 10, 10)) - val result1 = discretizerSingleNumBuckets.fit(df).transform(df) - .select("result1", "result2", "result3") - .collect() - val result2 = discretizerNumBucketsArray.fit(df).transform(df) - .select("result1", "result2", "result3") - .collect() - - result1.zip(result2).foreach { - case (row1, row2) => - assert(row1.getDouble(0) == row2.getDouble(0) && - row1.getDouble(1) == row2.getDouble(1) && - row1.getDouble(2) == row2.getDouble(2)) + val model = discretizerSingleNumBuckets.fit(df) + val expected = model.transform(df).select("result1", "result2", "result3").collect() + + testTransformerByGlobalCheckFunc[(Double, Double, Double)]( + df, + discretizerNumBucketsArray.fit(df), + "result1", + "result2", + "result3") { rows => + assert(rows === expected) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index bfe38d32dd77d..27d570f0b68ad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkException import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite @@ -32,10 +31,20 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { def testRFormulaTransform[A: Encoder]( dataframe: DataFrame, formulaModel: RFormulaModel, - expected: DataFrame): Unit = { + expected: DataFrame, + expectedAttributes: AttributeGroup*): Unit = { + val resultSchema = formulaModel.transformSchema(dataframe.schema) + assert(resultSchema.json === expected.schema.json) + assert(resultSchema === expected.schema) val (first +: rest) = expected.schema.fieldNames.toSeq val expectedRows = expected.collect() testTransformerByGlobalCheckFunc[A](dataframe, formulaModel, first, rest: _*) { rows => + assert(rows.head.schema.toString() == resultSchema.toString()) + for (expectedAttributeGroup <- expectedAttributes) { + val attributeGroup = + AttributeGroup.fromStructField(rows.head.schema(expectedAttributeGroup.name)) + assert(attributeGroup === expectedAttributeGroup) + } assert(rows === expectedRows) } } @@ -49,15 +58,10 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2") val model = formula.fit(original) MLTestingUtils.checkCopyAndUids(formula, model) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) val expected = Seq( (0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0), (2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0) ).toDF("id", "v1", "v2", "features", "label") - // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString - assert(result.schema.toString == resultSchema.toString) - assert(resultSchema == expected.schema) testRFormulaTransform[(Int, Double, Double)](original, model, expected) } @@ -73,9 +77,13 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") val original = Seq((0, 1.0), (2, 2.0)).toDF("x", "y") val model = formula.fit(original) + val expected = Seq( + (0, 1.0, Vectors.dense(0.0)), + (2, 2.0, Vectors.dense(2.0)) + ).toDF("x", "y", "features") val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) - assert(resultSchema.toString == model.transform(original).schema.toString) + testRFormulaTransform[(Int, Double)](original, model, expected) } test("label column already exists but forceIndexLabel was set with true") { @@ -93,9 +101,11 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { intercept[IllegalArgumentException] { model.transformSchema(original.schema) } - intercept[IllegalArgumentException] { - model.transform(original) - } + testTransformerByInterceptingException[(Int, Boolean)]( + original, + model, + "Label column already exists and is not of type NumericType.", + "x") } test("allow missing label column for test datasets") { @@ -105,21 +115,22 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) assert(!resultSchema.exists(_.name == "label")) - assert(resultSchema.toString == model.transform(original).schema.toString) + val expected = Seq( + (0, 1.0, Vectors.dense(0.0)), + (2, 2.0, Vectors.dense(2.0)) + ).toDF("x", "_not_y", "features") + testRFormulaTransform[(Int, Double)](original, model, expected) } test("allow empty label") { val original = Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0)).toDF("id", "a", "b") val formula = new RFormula().setFormula("~ a + b") val model = formula.fit(original) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) val expected = Seq( (1, 2.0, 3.0, Vectors.dense(2.0, 3.0)), (4, 5.0, 6.0, Vectors.dense(5.0, 6.0)), (7, 8.0, 9.0, Vectors.dense(8.0, 9.0)) ).toDF("id", "a", "b", "features") - assert(result.schema.toString == resultSchema.toString) testRFormulaTransform[(Int, Double, Double)](original, model, expected) } @@ -128,15 +139,12 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) val expected = Seq( (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0), (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0) ).toDF("id", "a", "b", "features", "label") - assert(result.schema.toString == resultSchema.toString) testRFormulaTransform[(Int, String, Int)](original, model, expected) } @@ -175,9 +183,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { var idx = 0 for (orderType <- StringIndexer.supportedStringOrderType) { val model = formula.setStringIndexerOrderType(orderType).fit(original) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) - assert(result.schema.toString == resultSchema.toString) testRFormulaTransform[(Int, String, Int)](original, model, expected(idx)) idx += 1 } @@ -218,9 +223,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { ).toDF("id", "a", "b", "features", "label") val model = formula.fit(original) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) - assert(result.schema.toString == resultSchema.toString) testRFormulaTransform[(Int, String, Int)](original, model, expected) } @@ -254,19 +256,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val formula1 = new RFormula().setFormula("id ~ a + b + c - 1") .setStringIndexerOrderType(StringIndexer.alphabetDesc) val model1 = formula1.fit(original) - val result1 = model1.transform(original) - val resultSchema1 = model1.transformSchema(original.schema) - // Note the column order is different between R and Spark. - val expected1 = Seq( - (1, "foo", "zq", 4, Vectors.sparse(5, Array(0, 4), Array(1.0, 4.0)), 1.0), - (2, "bar", "zz", 4, Vectors.dense(0.0, 0.0, 1.0, 1.0, 4.0), 2.0), - (3, "bar", "zz", 5, Vectors.dense(0.0, 0.0, 1.0, 1.0, 5.0), 3.0), - (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0) - ).toDF("id", "a", "b", "c", "features", "label") - assert(result1.schema.toString == resultSchema1.toString) - testRFormulaTransform[(Int, String, String, Int)](original, model1, expected1) - - val attrs1 = AttributeGroup.fromStructField(result1.schema("features")) val expectedAttrs1 = new AttributeGroup( "features", Array[Attribute]( @@ -275,14 +264,20 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { new BinaryAttribute(Some("a_bar"), Some(3)), new BinaryAttribute(Some("b_zz"), Some(4)), new NumericAttribute(Some("c"), Some(5)))) - assert(attrs1 === expectedAttrs1) + // Note the column order is different between R and Spark. + val expected1 = Seq( + (1, "foo", "zq", 4, Vectors.sparse(5, Array(0, 4), Array(1.0, 4.0)), 1.0), + (2, "bar", "zz", 4, Vectors.dense(0.0, 0.0, 1.0, 1.0, 4.0), 2.0), + (3, "bar", "zz", 5, Vectors.dense(0.0, 0.0, 1.0, 1.0, 5.0), 3.0), + (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0) + ).toDF("id", "a", "b", "c", "features", "label") + + testRFormulaTransform[(Int, String, String, Int)](original, model1, expected1, expectedAttrs1) // There is no impact for string terms interaction. val formula2 = new RFormula().setFormula("id ~ a:b + c - 1") .setStringIndexerOrderType(StringIndexer.alphabetDesc) val model2 = formula2.fit(original) - val result2 = model2.transform(original) - val resultSchema2 = model2.transformSchema(original.schema) // Note the column order is different between R and Spark. val expected2 = Seq( (1, "foo", "zq", 4, Vectors.sparse(7, Array(1, 6), Array(1.0, 4.0)), 1.0), @@ -290,10 +285,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (3, "bar", "zz", 5, Vectors.sparse(7, Array(4, 6), Array(1.0, 5.0)), 3.0), (4, "baz", "zz", 5, Vectors.sparse(7, Array(2, 6), Array(1.0, 5.0)), 4.0) ).toDF("id", "a", "b", "c", "features", "label") - assert(result2.schema.toString == resultSchema2.toString) - testRFormulaTransform[(Int, String, String, Int)](original, model2, expected2) - - val attrs2 = AttributeGroup.fromStructField(result2.schema("features")) val expectedAttrs2 = new AttributeGroup( "features", Array[Attribute]( @@ -304,7 +295,8 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { new NumericAttribute(Some("a_bar:b_zz"), Some(5)), new NumericAttribute(Some("a_bar:b_zq"), Some(6)), new NumericAttribute(Some("c"), Some(7)))) - assert(attrs2 === expectedAttrs2) + + testRFormulaTransform[(Int, String, String, Int)](original, model2, expected2, expectedAttrs2) } test("index string label") { @@ -313,13 +305,14 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) + val attr = NominalAttribute.defaultAttr val expected = Seq( ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), ("female", "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 0.0), ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0) ).toDF("id", "a", "b", "features", "label") - // assert(result.schema.toString == resultSchema.toString) + .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata())) testRFormulaTransform[(String, String, Int)](original, model, expected) } @@ -329,13 +322,14 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5)) ).toDF("id", "a", "b") val model = formula.fit(original) - val expected = spark.createDataFrame( - Seq( + val attr = NominalAttribute.defaultAttr + val expected = Seq( (1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0), (1.0, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), (0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0), (1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0)) - ).toDF("id", "a", "b", "features", "label") + .toDF("id", "a", "b", "features", "label") + .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata())) testRFormulaTransform[(Double, String, Int)](original, model, expected) } @@ -344,15 +338,20 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) - val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expected = Seq( + (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), + (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), + (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0), + (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0)) + .toDF("id", "a", "b", "features", "label") val expectedAttrs = new AttributeGroup( "features", Array( new BinaryAttribute(Some("a_bar"), Some(1)), new BinaryAttribute(Some("a_foo"), Some(2)), new NumericAttribute(Some("b"), Some(3)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, String, Int)](original, model, expected, expectedAttrs) + } test("vector attribute generation") { @@ -360,14 +359,19 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val original = Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) .toDF("id", "vec") val model = formula.fit(original) - val result = model.transform(original) - val attrs = AttributeGroup.fromStructField(result.schema("features")) + val attrs = new AttributeGroup("vec", 2) + val expected = Seq( + (1, Vectors.dense(0.0, 1.0), Vectors.dense(0.0, 1.0), 1.0), + (2, Vectors.dense(1.0, 2.0), Vectors.dense(1.0, 2.0), 2.0)) + .toDF("id", "vec", "features", "label") + .select($"id", $"vec".as("vec", attrs.toMetadata()), $"features", $"label") val expectedAttrs = new AttributeGroup( "features", Array[Attribute]( new NumericAttribute(Some("vec_0"), Some(1)), new NumericAttribute(Some("vec_1"), Some(2)))) - assert(attrs === expectedAttrs) + + testRFormulaTransform[(Int, Vector)](original, model, expected, expectedAttrs) } test("vector attribute generation with unnamed input attrs") { @@ -381,31 +385,31 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { NumericAttribute.defaultAttr)).toMetadata() val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata)) val model = formula.fit(original) - val result = model.transform(original) - val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expected = Seq( + (1, Vectors.dense(0.0, 1.0), Vectors.dense(0.0, 1.0), 1.0), + (2, Vectors.dense(1.0, 2.0), Vectors.dense(1.0, 2.0), 2.0) + ).toDF("id", "vec2", "features", "label") + .select($"id", $"vec2".as("vec2", metadata), $"features", $"label") val expectedAttrs = new AttributeGroup( "features", Array[Attribute]( new NumericAttribute(Some("vec2_0"), Some(1)), new NumericAttribute(Some("vec2_1"), Some(2)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, Vector)](original, model, expected, expectedAttrs) } test("numeric interaction") { val formula = new RFormula().setFormula("a ~ b:c:d") val original = Seq((1, 2, 4, 2), (2, 3, 4, 1)).toDF("a", "b", "c", "d") val model = formula.fit(original) - val result = model.transform(original) val expected = Seq( (1, 2, 4, 2, Vectors.dense(16.0), 1.0), (2, 3, 4, 1, Vectors.dense(12.0), 2.0) ).toDF("a", "b", "c", "d", "features", "label") - testRFormulaTransform[(Int, Int, Int, Int)](original, model, expected) - val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", Array[Attribute](new NumericAttribute(Some("b:c:d"), Some(1)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, Int, Int, Int)](original, model, expected, expectedAttrs) } test("factor numeric interaction") { @@ -414,7 +418,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) val expected = Seq( (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0), @@ -423,15 +426,13 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0) ).toDF("id", "a", "b", "features", "label") - testRFormulaTransform[(Int, String, Int)](original, model, expected) - val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", Array[Attribute]( new NumericAttribute(Some("a_baz:b"), Some(1)), new NumericAttribute(Some("a_bar:b"), Some(2)), new NumericAttribute(Some("a_foo:b"), Some(3)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, String, Int)](original, model, expected, expectedAttrs) } test("factor factor interaction") { @@ -439,14 +440,12 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val original = Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")).toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) val expected = Seq( (1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0), (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0), (3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0) ).toDF("id", "a", "b", "features", "label") testRFormulaTransform[(Int, String, String)](original, model, expected) - val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", Array[Attribute]( @@ -454,7 +453,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { new NumericAttribute(Some("a_bar:b_zz"), Some(2)), new NumericAttribute(Some("a_foo:b_zq"), Some(3)), new NumericAttribute(Some("a_foo:b_zz"), Some(4)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, String, String)](original, model, expected, expectedAttrs) } test("read/write: RFormula") { @@ -517,9 +516,11 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { // Handle unseen features. val formula1 = new RFormula().setFormula("id ~ a + b") - intercept[SparkException] { - formula1.fit(df1).transform(df2).collect() - } + testTransformerByInterceptingException[(Int, String, String)]( + df2, + formula1.fit(df1), + "Unseen label:", + "features") val model1 = formula1.setHandleInvalid("skip").fit(df1) val model2 = formula1.setHandleInvalid("keep").fit(df1) @@ -538,21 +539,28 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { // Handle unseen labels. val formula2 = new RFormula().setFormula("b ~ a + id") - intercept[SparkException] { - formula2.fit(df1).transform(df2).collect() - } + testTransformerByInterceptingException[(Int, String, String)]( + df2, + formula2.fit(df1), + "Unseen label:", + "label") + val model3 = formula2.setHandleInvalid("skip").fit(df1) val model4 = formula2.setHandleInvalid("keep").fit(df1) + val attr = NominalAttribute.defaultAttr val expected3 = Seq( (1, "foo", "zq", Vectors.dense(0.0, 1.0), 0.0), (2, "bar", "zq", Vectors.dense(1.0, 2.0), 0.0) ).toDF("id", "a", "b", "features", "label") + .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata())) + val expected4 = Seq( (1, "foo", "zq", Vectors.dense(0.0, 1.0, 1.0), 0.0), (2, "bar", "zq", Vectors.dense(1.0, 0.0, 2.0), 0.0), (3, "bar", "zy", Vectors.dense(1.0, 0.0, 3.0), 2.0) ).toDF("id", "a", "b", "features", "label") + .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata())) testRFormulaTransform[(Int, String, String)](df2, model3, expected3) testRFormulaTransform[(Int, String, String)](df2, model4, expected4) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala index 673a146e619f2..cf09418d8e0a2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -17,15 +17,12 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.types.{LongType, StructField, StructType} import org.apache.spark.storage.StorageLevel -class SQLTransformerSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class SQLTransformerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -37,14 +34,22 @@ class SQLTransformerSuite val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2") val sqlTrans = new SQLTransformer().setStatement( "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") - val result = sqlTrans.transform(original) - val resultSchema = sqlTrans.transformSchema(original.schema) - val expected = Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0)) + val expected = Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0)) .toDF("id", "v1", "v2", "v3", "v4") - assert(result.schema.toString == resultSchema.toString) - assert(resultSchema == expected.schema) - assert(result.collect().toSeq == expected.collect().toSeq) - assert(original.sparkSession.catalog.listTables().count() == 0) + val resultSchema = sqlTrans.transformSchema(original.schema) + testTransformerByGlobalCheckFunc[(Int, Double, Double)]( + original, + sqlTrans, + "id", + "v1", + "v2", + "v3", + "v4") { rows => + assert(rows.head.schema.toString == resultSchema.toString) + assert(resultSchema == expected.schema) + assert(rows == expected.collect().toSeq) + assert(original.sparkSession.catalog.listTables().count() == 0) + } } test("read/write") { @@ -63,13 +68,13 @@ class SQLTransformerSuite } test("SPARK-22538: SQLTransformer should not unpersist given dataset") { - val df = spark.range(10) + val df = spark.range(10).toDF() df.cache() df.count() assert(df.storageLevel != StorageLevel.NONE) - new SQLTransformer() + val sqlTrans = new SQLTransformer() .setStatement("SELECT id + 1 AS id1 FROM __THIS__") - .transform(df) + testTransformerByGlobalCheckFunc[Long](df, sqlTrans, "id1") { _ => } assert(df.storageLevel != StorageLevel.NONE) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index 350ba44baa1eb..c5c49d67194e4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -17,16 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class StandardScalerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -60,12 +57,10 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext ) } - def assertResult(df: DataFrame): Unit = { - df.select("standardized_features", "expected").collect().foreach { - case Row(vector1: Vector, vector2: Vector) => - assert(vector1 ~== vector2 absTol 1E-5, - "The vector value is not correct after standardization.") - } + def assertResult: Row => Unit = { + case Row(vector1: Vector, vector2: Vector) => + assert(vector1 ~== vector2 absTol 1E-5, + "The vector value is not correct after standardization.") } test("params") { @@ -83,7 +78,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext val standardScaler0 = standardScalerEst0.fit(df0) MLTestingUtils.checkCopyAndUids(standardScalerEst0, standardScaler0) - assertResult(standardScaler0.transform(df0)) + testTransformer[(Vector, Vector)](df0, standardScaler0, "standardized_features", "expected")( + assertResult) } test("Standardization with setter") { @@ -112,9 +108,12 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext .setWithStd(false) .fit(df3) - assertResult(standardScaler1.transform(df1)) - assertResult(standardScaler2.transform(df2)) - assertResult(standardScaler3.transform(df3)) + testTransformer[(Vector, Vector)](df1, standardScaler1, "standardized_features", "expected")( + assertResult) + testTransformer[(Vector, Vector)](df2, standardScaler2, "standardized_features", "expected")( + assertResult) + testTransformer[(Vector, Vector)](df3, standardScaler3, "standardized_features", "expected")( + assertResult) } test("sparse data and withMean") { @@ -130,7 +129,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext .setWithMean(true) .setWithStd(false) .fit(df) - assertResult(standardScaler.transform(df)) + testTransformer[(Vector, Vector)](df, standardScaler, "standardized_features", "expected")( + assertResult) } test("StandardScaler read/write") { @@ -149,4 +149,5 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext assert(newInstance.std === instance.std) assert(newInstance.mean === instance.mean) } + } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index 5262b146b184e..21259a50916d2 100755 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -17,28 +17,20 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Dataset, Row} - -object StopWordsRemoverSuite extends SparkFunSuite { - def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = { - t.transform(dataset) - .select("filtered", "expected") - .collect() - .foreach { case Row(tokens, wantedTokens) => - assert(tokens === wantedTokens) - } - } -} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{DataFrame, Row} -class StopWordsRemoverSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest { - import StopWordsRemoverSuite._ import testImplicits._ + def testStopWordsRemover(t: StopWordsRemover, dataFrame: DataFrame): Unit = { + testTransformer[(Array[String], Array[String])](dataFrame, t, "filtered", "expected") { + case Row(tokens: Seq[_], wantedTokens: Seq[_]) => + assert(tokens === wantedTokens) + } + } + test("StopWordsRemover default") { val remover = new StopWordsRemover() .setInputCol("raw") @@ -151,9 +143,10 @@ class StopWordsRemoverSuite .setOutputCol(outputCol) val dataSet = Seq((Seq("The", "the", "swift"), Seq("swift"))).toDF("raw", outputCol) - val thrown = intercept[IllegalArgumentException] { - testStopWordsRemover(remover, dataSet) - } - assert(thrown.getMessage == s"requirement failed: Column $outputCol already exists.") + testTransformerByInterceptingException[(Array[String], Array[String])]( + dataSet, + remover, + s"requirement failed: Column $outputCol already exists.", + "expected") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 775a04d3df050..df24367177011 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,17 +17,14 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} -class StringIndexerSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class StringIndexerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -46,19 +43,23 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") val indexerModel = indexer.fit(df) - MLTestingUtils.checkCopyAndUids(indexer, indexerModel) - - val transformed = indexerModel.transform(df) - val attr = Attribute.fromStructField(transformed.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attr.values.get === Array("a", "c", "b")) - val output = transformed.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) - assert(output === expected) + val expected = Seq( + (0, 0.0), + (1, 2.0), + (2, 1.0), + (3, 0.0), + (4, 0.0), + (5, 1.0) + ).toDF("id", "labelIndex") + + testTransformerByGlobalCheckFunc[(Int, String)](df, indexerModel, "id", "labelIndex") { rows => + val attr = Attribute.fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("a", "c", "b")) + assert(rows.seq === expected.collect().toSeq) + } } test("StringIndexerUnseen") { @@ -70,36 +71,38 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") .fit(df) + // Verify we throw by default with unseen values - intercept[SparkException] { - indexer.transform(df2).collect() - } + testTransformerByInterceptingException[(Int, String)]( + df2, + indexer, + "Unseen label:", + "labelIndex") - indexer.setHandleInvalid("skip") // Verify that we skip the c record - val transformedSkip = indexer.transform(df2) - val attrSkip = Attribute.fromStructField(transformedSkip.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrSkip.values.get === Array("b", "a")) - val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet // a -> 1, b -> 0 - val expectedSkip = Set((0, 1.0), (1, 0.0)) - assert(outputSkip === expectedSkip) + indexer.setHandleInvalid("skip") + + val expectedSkip = Seq((0, 1.0), (1, 0.0)).toDF() + testTransformerByGlobalCheckFunc[(Int, String)](df2, indexer, "id", "labelIndex") { rows => + val attrSkip = Attribute.fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrSkip.values.get === Array("b", "a")) + assert(rows.seq === expectedSkip.collect().toSeq) + } indexer.setHandleInvalid("keep") - // Verify that we keep the unseen records - val transformedKeep = indexer.transform(df2) - val attrKeep = Attribute.fromStructField(transformedKeep.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrKeep.values.get === Array("b", "a", "__unknown")) - val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet + // a -> 1, b -> 0, c -> 2, d -> 3 - val expectedKeep = Set((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0)) - assert(outputKeep === expectedKeep) + val expectedKeep = Seq((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0)).toDF() + + // Verify that we keep the unseen records + testTransformerByGlobalCheckFunc[(Int, String)](df2, indexer, "id", "labelIndex") { rows => + val attrKeep = Attribute.fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrKeep.values.get === Array("b", "a", "__unknown")) + assert(rows === expectedKeep.collect().toSeq) + } } test("StringIndexer with a numeric input column") { @@ -109,16 +112,14 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") .fit(df) - val transformed = indexer.transform(df) - val attr = Attribute.fromStructField(transformed.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attr.values.get === Array("100", "300", "200")) - val output = transformed.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet // 100 -> 0, 200 -> 2, 300 -> 1 - val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) - assert(output === expected) + val expected = Seq((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)).toDF() + testTransformerByGlobalCheckFunc[(Int, String)](df, indexer, "id", "labelIndex") { rows => + val attr = Attribute.fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("100", "300", "200")) + assert(rows === expected.collect().toSeq) + } } test("StringIndexer with NULLs") { @@ -133,37 +134,36 @@ class StringIndexerSuite withClue("StringIndexer should throw error when setHandleInvalid=error " + "when given NULL values") { - intercept[SparkException] { - indexer.setHandleInvalid("error") - indexer.fit(df).transform(df2).collect() - } + indexer.setHandleInvalid("error") + testTransformerByInterceptingException[(Int, String)]( + df2, + indexer.fit(df), + "StringIndexer encountered NULL value.", + "labelIndex") } indexer.setHandleInvalid("skip") - val transformedSkip = indexer.fit(df).transform(df2) - val attrSkip = Attribute - .fromStructField(transformedSkip.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrSkip.values.get === Array("b", "a")) - val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet + val modelSkip = indexer.fit(df) // a -> 1, b -> 0 - val expectedSkip = Set((0, 1.0), (1, 0.0)) - assert(outputSkip === expectedSkip) + val expectedSkip = Seq((0, 1.0), (1, 0.0)).toDF() + testTransformerByGlobalCheckFunc[(Int, String)](df2, modelSkip, "id", "labelIndex") { rows => + val attrSkip = + Attribute.fromStructField(rows.head.schema("labelIndex")).asInstanceOf[NominalAttribute] + assert(attrSkip.values.get === Array("b", "a")) + assert(rows === expectedSkip.collect().toSeq) + } indexer.setHandleInvalid("keep") - val transformedKeep = indexer.fit(df).transform(df2) - val attrKeep = Attribute - .fromStructField(transformedKeep.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrKeep.values.get === Array("b", "a", "__unknown")) - val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet // a -> 1, b -> 0, null -> 2 - val expectedKeep = Set((0, 1.0), (1, 0.0), (3, 2.0)) - assert(outputKeep === expectedKeep) + val expectedKeep = Seq((0, 1.0), (1, 0.0), (3, 2.0)).toDF() + val modelKeep = indexer.fit(df) + testTransformerByGlobalCheckFunc[(Int, String)](df2, modelKeep, "id", "labelIndex") { rows => + val attrKeep = Attribute + .fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrKeep.values.get === Array("b", "a", "__unknown")) + assert(rows === expectedKeep.collect().toSeq) + } } test("StringIndexerModel should keep silent if the input column does not exist.") { @@ -171,7 +171,9 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") val df = spark.range(0L, 10L).toDF() - assert(indexerModel.transform(df).collect().toSet === df.collect().toSet) + testTransformerByGlobalCheckFunc[Long](df, indexerModel, "id") { rows => + assert(rows.toSet === df.collect().toSet) + } } test("StringIndexerModel can't overwrite output column") { @@ -188,9 +190,12 @@ class StringIndexerSuite .setOutputCol("indexedInput") .fit(df) - intercept[IllegalArgumentException] { - indexer.setOutputCol("output").transform(df) - } + testTransformerByInterceptingException[(Int, String)]( + df, + indexer.setOutputCol("output"), + "Output column output already exists.", + "labelIndex") + } test("StringIndexer read/write") { @@ -223,7 +228,8 @@ class StringIndexerSuite .setInputCol("index") .setOutputCol("actual") .setLabels(labels) - idxToStr0.transform(df0).select("actual", "expected").collect().foreach { + + testTransformer[(Int, String)](df0, idxToStr0, "actual", "expected") { case Row(actual, expected) => assert(actual === expected) } @@ -234,7 +240,8 @@ class StringIndexerSuite val idxToStr1 = new IndexToString() .setInputCol("indexWithAttr") .setOutputCol("actual") - idxToStr1.transform(df1).select("actual", "expected").collect().foreach { + + testTransformer[(Int, String)](df1, idxToStr1, "actual", "expected") { case Row(actual, expected) => assert(actual === expected) } @@ -252,9 +259,10 @@ class StringIndexerSuite .setInputCol("labelIndex") .setOutputCol("sameLabel") .setLabels(indexer.labels) - idx2str.transform(transformed).select("label", "sameLabel").collect().foreach { - case Row(a: String, b: String) => - assert(a === b) + + testTransformer[(Int, String, Double)](transformed, idx2str, "sameLabel", "label") { + case Row(sameLabel, label) => + assert(sameLabel === label) } } @@ -286,10 +294,11 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") .fit(df) - val transformed = indexer.transform(df) - val attrs = - NominalAttribute.decodeStructField(transformed.schema("labelIndex"), preserveName = true) - assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex") + testTransformerByGlobalCheckFunc[(Int, String)](df, indexer, "labelIndex") { rows => + val attrs = + NominalAttribute.decodeStructField(rows.head.schema("labelIndex"), preserveName = true) + assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex") + } } test("StringIndexer order types") { @@ -299,18 +308,17 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") - val expected = Seq(Set((0, 0.0), (1, 0.0), (2, 2.0), (3, 1.0), (4, 1.0), (5, 0.0)), - Set((0, 2.0), (1, 2.0), (2, 0.0), (3, 1.0), (4, 1.0), (5, 2.0)), - Set((0, 1.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 1.0)), - Set((0, 1.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 1.0))) + val expected = Seq(Seq((0, 0.0), (1, 0.0), (2, 2.0), (3, 1.0), (4, 1.0), (5, 0.0)), + Seq((0, 2.0), (1, 2.0), (2, 0.0), (3, 1.0), (4, 1.0), (5, 2.0)), + Seq((0, 1.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 1.0)), + Seq((0, 1.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 1.0))) var idx = 0 for (orderType <- StringIndexer.supportedStringOrderType) { - val transformed = indexer.setStringOrderType(orderType).fit(df).transform(df) - val output = transformed.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet - assert(output === expected(idx)) + val model = indexer.setStringOrderType(orderType).fit(df) + testTransformerByGlobalCheckFunc[(Int, String)](df, model, "id", "labelIndex") { rows => + assert(rows === expected(idx).toDF().collect().toSeq) + } idx += 1 } } @@ -328,7 +336,11 @@ class StringIndexerSuite .setOutputCol("CITYIndexed") .fit(dfNoBristol) - val dfWithIndex = model.transform(dfNoBristol) - assert(dfWithIndex.filter($"CITYIndexed" === 1.0).count == 1) + testTransformerByGlobalCheckFunc[(String, String, String)]( + dfNoBristol, + model, + "CITYIndexed") { rows => + assert(rows.toList.count(_.getDouble(0) == 1.0) === 1) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index c895659a2d8be..be59b0af2c78e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -19,16 +19,14 @@ package org.apache.spark.ml.feature import scala.beans.BeanInfo -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) -class TokenizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class TokenizerSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new Tokenizer) @@ -42,12 +40,17 @@ class TokenizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } } -class RegexTokenizerSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class RegexTokenizerSuite extends MLTest with DefaultReadWriteTest { - import org.apache.spark.ml.feature.RegexTokenizerSuite._ import testImplicits._ + def testRegexTokenizer(t: RegexTokenizer, dataframe: DataFrame): Unit = { + testTransformer[(String, Seq[String])](dataframe, t, "tokens", "wantedTokens") { + case Row(tokens, wantedTokens) => + assert(tokens === wantedTokens) + } + } + test("params") { ParamsSuite.checkParams(new RegexTokenizer) } @@ -105,14 +108,3 @@ class RegexTokenizerSuite } } -object RegexTokenizerSuite extends SparkFunSuite { - - def testRegexTokenizer(t: RegexTokenizer, dataset: Dataset[_]): Unit = { - t.transform(dataset) - .select("tokens", "wantedTokens") - .collect() - .foreach { case Row(tokens, wantedTokens) => - assert(tokens === wantedTokens) - } - } -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 69a7b75e32eb7..e5675e31bbecf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -19,18 +19,16 @@ package org.apache.spark.ml.feature import scala.beans.{BeanInfo, BeanProperty} -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} -class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest with Logging { +class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging { import testImplicits._ import VectorIndexerSuite.FeatureData @@ -128,18 +126,27 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext MLTestingUtils.checkCopyAndUids(vectorIndexer, model) - model.transform(densePoints1) // should work - model.transform(sparsePoints1) // should work + testTransformer[FeatureData](densePoints1, model, "indexed") { _ => } + testTransformer[FeatureData](sparsePoints1, model, "indexed") { _ => } + // If the data is local Dataset, it throws AssertionError directly. - intercept[AssertionError] { - model.transform(densePoints2).collect() - logInfo("Did not throw error when fit, transform were called on vectors of different lengths") + withClue("Did not throw error when fit, transform were called on " + + "vectors of different lengths") { + testTransformerByInterceptingException[FeatureData]( + densePoints2, + model, + "VectorIndexerModel expected vector of length 3 but found length 4", + "indexed") } // If the data is distributed Dataset, it throws SparkException // which is the wrapper of AssertionError. - intercept[SparkException] { - model.transform(densePoints2.repartition(2)).collect() - logInfo("Did not throw error when fit, transform were called on vectors of different lengths") + withClue("Did not throw error when fit, transform were called " + + "on vectors of different lengths") { + testTransformerByInterceptingException[FeatureData]( + densePoints2.repartition(2), + model, + "VectorIndexerModel expected vector of length 3 but found length 4", + "indexed") } intercept[SparkException] { vectorIndexer.fit(badPoints) @@ -178,46 +185,48 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext val categoryMaps = model.categoryMaps // Chose correct categorical features assert(categoryMaps.keys.toSet === categoricalFeatures) - val transformed = model.transform(data).select("indexed") - val indexedRDD: RDD[Vector] = transformed.rdd.map(_.getAs[Vector](0)) - val featureAttrs = AttributeGroup.fromStructField(transformed.schema("indexed")) - assert(featureAttrs.name === "indexed") - assert(featureAttrs.attributes.get.length === model.numFeatures) - categoricalFeatures.foreach { feature: Int => - val origValueSet = collectedData.map(_(feature)).toSet - val targetValueIndexSet = Range(0, origValueSet.size).toSet - val catMap = categoryMaps(feature) - assert(catMap.keys.toSet === origValueSet) // Correct categories - assert(catMap.values.toSet === targetValueIndexSet) // Correct category indices - if (origValueSet.contains(0.0)) { - assert(catMap(0.0) === 0) // value 0 gets index 0 - } - // Check transformed data - assert(indexedRDD.map(_(feature)).collect().toSet === targetValueIndexSet) - // Check metadata - val featureAttr = featureAttrs(feature) - assert(featureAttr.index.get === feature) - featureAttr match { - case attr: BinaryAttribute => - assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) - case attr: NominalAttribute => - assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) - assert(attr.isOrdinal.get === false) - case _ => - throw new RuntimeException(errMsg + s". Categorical feature $feature failed" + - s" metadata check. Found feature attribute: $featureAttr.") + testTransformerByGlobalCheckFunc[FeatureData](data, model, "indexed") { rows => + val transformed = rows.map { r => Tuple1(r.getAs[Vector](0)) }.toDF("indexed") + val indexedRDD: RDD[Vector] = transformed.rdd.map(_.getAs[Vector](0)) + val featureAttrs = AttributeGroup.fromStructField(rows.head.schema("indexed")) + assert(featureAttrs.name === "indexed") + assert(featureAttrs.attributes.get.length === model.numFeatures) + categoricalFeatures.foreach { feature: Int => + val origValueSet = collectedData.map(_(feature)).toSet + val targetValueIndexSet = Range(0, origValueSet.size).toSet + val catMap = categoryMaps(feature) + assert(catMap.keys.toSet === origValueSet) // Correct categories + assert(catMap.values.toSet === targetValueIndexSet) // Correct category indices + if (origValueSet.contains(0.0)) { + assert(catMap(0.0) === 0) // value 0 gets index 0 + } + // Check transformed data + assert(indexedRDD.map(_(feature)).collect().toSet === targetValueIndexSet) + // Check metadata + val featureAttr = featureAttrs(feature) + assert(featureAttr.index.get === feature) + featureAttr match { + case attr: BinaryAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + case attr: NominalAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + assert(attr.isOrdinal.get === false) + case _ => + throw new RuntimeException(errMsg + s". Categorical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } } - } - // Check numerical feature metadata. - Range(0, model.numFeatures).filter(feature => !categoricalFeatures.contains(feature)) - .foreach { feature: Int => - val featureAttr = featureAttrs(feature) - featureAttr match { - case attr: NumericAttribute => - assert(featureAttr.index.get === feature) - case _ => - throw new RuntimeException(errMsg + s". Numerical feature $feature failed" + - s" metadata check. Found feature attribute: $featureAttr.") + // Check numerical feature metadata. + Range(0, model.numFeatures).filter(feature => !categoricalFeatures.contains(feature)) + .foreach { feature: Int => + val featureAttr = featureAttrs(feature) + featureAttr match { + case attr: NumericAttribute => + assert(featureAttr.index.get === feature) + case _ => + throw new RuntimeException(errMsg + s". Numerical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } } } } catch { @@ -236,25 +245,32 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext (sparsePoints1, sparsePoints1TestInvalid))) { val vectorIndexer = getIndexer.setMaxCategories(4).setHandleInvalid("error") val model = vectorIndexer.fit(points) - intercept[SparkException] { - model.transform(pointsTestInvalid).collect() - } + testTransformerByInterceptingException[FeatureData]( + pointsTestInvalid, + model, + "VectorIndexer encountered invalid value", + "indexed") val vectorIndexer1 = getIndexer.setMaxCategories(4).setHandleInvalid("skip") val model1 = vectorIndexer1.fit(points) - val invalidTransformed1 = model1.transform(pointsTestInvalid).select("indexed") - .collect().map(_(0)) - val transformed1 = model1.transform(points).select("indexed").collect().map(_(0)) - assert(transformed1 === invalidTransformed1) - + val expected = Seq( + Vectors.dense(1.0, 2.0, 0.0), + Vectors.dense(0.0, 1.0, 2.0), + Vectors.dense(0.0, 0.0, 1.0), + Vectors.dense(1.0, 3.0, 2.0)) + testTransformerByGlobalCheckFunc[FeatureData](pointsTestInvalid, model1, "indexed") { rows => + assert(rows.map(_(0)) == expected) + } + testTransformerByGlobalCheckFunc[FeatureData](points, model1, "indexed") { rows => + assert(rows.map(_(0)) == expected) + } val vectorIndexer2 = getIndexer.setMaxCategories(4).setHandleInvalid("keep") val model2 = vectorIndexer2.fit(points) - val invalidTransformed2 = model2.transform(pointsTestInvalid).select("indexed") - .collect().map(_(0)) - assert(invalidTransformed2 === transformed1 ++ Array( - Vectors.dense(2.0, 2.0, 0.0), - Vectors.dense(0.0, 4.0, 2.0), - Vectors.dense(1.0, 3.0, 3.0)) - ) + testTransformerByGlobalCheckFunc[FeatureData](pointsTestInvalid, model2, "indexed") { rows => + assert(rows.map(_(0)) == expected ++ Array( + Vectors.dense(2.0, 2.0, 0.0), + Vectors dense(0.0, 4.0, 2.0), + Vectors.dense(1.0, 3.0, 3.0))) + } } } @@ -263,12 +279,12 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext val points = data.collect().map(_.getAs[Vector](0)) val vectorIndexer = getIndexer.setMaxCategories(maxCategories) val model = vectorIndexer.fit(data) - val indexedPoints = - model.transform(data).select("indexed").rdd.map(_.getAs[Vector](0)).collect() - points.zip(indexedPoints).foreach { - case (orig: SparseVector, indexed: SparseVector) => - assert(orig.indices.length == indexed.indices.length) - case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen + testTransformerByGlobalCheckFunc[FeatureData](data, model, "indexed") { rows => + points.zip(rows.map(_(0))).foreach { + case (orig: SparseVector, indexed: SparseVector) => + assert(orig.indices.length == indexed.indices.length) + case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen + } } } checkSparsity(sparsePoints1, maxCategories = 2) @@ -286,17 +302,18 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext val vectorIndexer = getIndexer.setMaxCategories(2) val model = vectorIndexer.fit(densePoints1WithMeta) // Check that ML metadata are preserved. - val indexedPoints = model.transform(densePoints1WithMeta) - val transAttributes: Array[Attribute] = - AttributeGroup.fromStructField(indexedPoints.schema("indexed")).attributes.get - featureAttributes.zip(transAttributes).foreach { case (orig, trans) => - assert(orig.name === trans.name) - (orig, trans) match { - case (orig: NumericAttribute, trans: NumericAttribute) => - assert(orig.max.nonEmpty && orig.max === trans.max) - case _ => + testTransformerByGlobalCheckFunc[FeatureData](densePoints1WithMeta, model, "indexed") { rows => + val transAttributes: Array[Attribute] = + AttributeGroup.fromStructField(rows.head.schema("indexed")).attributes.get + featureAttributes.zip(transAttributes).foreach { case (orig, trans) => + assert(orig.name === trans.name) + (orig, trans) match { + case (orig: NumericAttribute, trans: NumericAttribute) => + assert(orig.max.nonEmpty && orig.max === trans.max) + case _ => // do nothing // TODO: Once input features marked as categorical are handled correctly, check that here. + } } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala index f6c9a76599fae..d89d10b320d84 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.streaming.StreamTest class VectorSizeHintSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -40,16 +38,23 @@ class VectorSizeHintSuite val data = Seq((Vectors.dense(1, 2), 0)).toDF("vector", "intValue") val noSizeTransformer = new VectorSizeHint().setInputCol("vector") - intercept[NoSuchElementException] (noSizeTransformer.transform(data)) + testTransformerByInterceptingException[(Vector, Int)]( + data, + noSizeTransformer, + "Failed to find a default value for size", + "vector") intercept[NoSuchElementException] (noSizeTransformer.transformSchema(data.schema)) val noInputColTransformer = new VectorSizeHint().setSize(2) - intercept[NoSuchElementException] (noInputColTransformer.transform(data)) + testTransformerByInterceptingException[(Vector, Int)]( + data, + noInputColTransformer, + "Failed to find a default value for inputCol", + "vector") intercept[NoSuchElementException] (noInputColTransformer.transformSchema(data.schema)) } test("Adding size to column of vectors.") { - val size = 3 val vectorColName = "vector" val denseVector = Vectors.dense(1, 2, 3) @@ -66,12 +71,15 @@ class VectorSizeHintSuite .setInputCol(vectorColName) .setSize(size) .setHandleInvalid(handleInvalid) - val withSize = transformer.transform(dataFrame) - assert( - AttributeGroup.fromStructField(withSize.schema(vectorColName)).size == size, - "Transformer did not add expected size data.") - val numRows = withSize.collect().length - assert(numRows === data.length, s"Expecting ${data.length} rows, got $numRows.") + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataFrame, transformer, vectorColName) { + rows => { + assert( + AttributeGroup.fromStructField(rows.head.schema(vectorColName)).size == size, + "Transformer did not add expected size data.") + val numRows = rows.length + assert(numRows === data.length, s"Expecting ${data.length} rows, got $numRows.") + } + } } } @@ -93,14 +101,16 @@ class VectorSizeHintSuite .setInputCol(vectorColName) .setSize(size) .setHandleInvalid(handleInvalid) - val withSize = transformer.transform(dataFrameWithMetadata) - - val newGroup = AttributeGroup.fromStructField(withSize.schema(vectorColName)) - assert(newGroup.size === size, "Column has incorrect size metadata.") - assert( - newGroup.attributes.get === group.attributes.get, - "VectorSizeHint did not preserve attributes.") - withSize.collect + testTransformerByGlobalCheckFunc[(Int, Int, Int, Vector)]( + dataFrameWithMetadata, + transformer, + vectorColName) { rows => + val newGroup = AttributeGroup.fromStructField(rows.head.schema(vectorColName)) + assert(newGroup.size === size, "Column has incorrect size metadata.") + assert( + newGroup.attributes.get === group.attributes.get, + "VectorSizeHint did not preserve attributes.") + } } } @@ -120,7 +130,11 @@ class VectorSizeHintSuite .setInputCol(vectorColName) .setSize(size) .setHandleInvalid(handleInvalid) - intercept[IllegalArgumentException](transformer.transform(dataFrameWithMetadata)) + testTransformerByInterceptingException[(Int, Int, Int, Vector)]( + dataFrameWithMetadata, + transformer, + "Trying to set size of vectors in `vector` to 4 but size already set to 3.", + vectorColName) } } @@ -136,18 +150,36 @@ class VectorSizeHintSuite .setHandleInvalid("error") .setSize(3) - intercept[SparkException](sizeHint.transform(dataWithNull).collect()) - intercept[SparkException](sizeHint.transform(dataWithShort).collect()) + testTransformerByInterceptingException[Tuple1[Vector]]( + dataWithNull, + sizeHint, + "Got null vector in VectorSizeHint", + "vector") + + testTransformerByInterceptingException[Tuple1[Vector]]( + dataWithShort, + sizeHint, + "VectorSizeHint Expecting a vector of size 3 but got 1", + "vector") sizeHint.setHandleInvalid("skip") - assert(sizeHint.transform(dataWithNull).count() === 1) - assert(sizeHint.transform(dataWithShort).count() === 1) + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithNull, sizeHint, "vector") { rows => + assert(rows.length === 1) + } + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithShort, sizeHint, "vector") { rows => + assert(rows.length === 1) + } sizeHint.setHandleInvalid("optimistic") - assert(sizeHint.transform(dataWithNull).count() === 2) - assert(sizeHint.transform(dataWithShort).count() === 2) + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithNull, sizeHint, "vector") { rows => + assert(rows.length === 2) + } + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithShort, sizeHint, "vector") { rows => + assert(rows.length === 2) + } } + test("read/write") { val sizeHint = new VectorSizeHint() .setInputCol("myInputCol") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala index 1746ce53107c4..3d90f9d9ac764 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -17,16 +17,16 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.Row import org.apache.spark.sql.types.{StructField, StructType} -class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class VectorSlicerSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ test("params") { val slicer = new VectorSlicer().setInputCol("feature") @@ -84,12 +84,12 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result") - def validateResults(df: DataFrame): Unit = { - df.select("result", "expected").collect().foreach { case Row(vec1: Vector, vec2: Vector) => + def validateResults(rows: Seq[Row]): Unit = { + rows.foreach { case Row(vec1: Vector, vec2: Vector) => assert(vec1 === vec2) } - val resultMetadata = AttributeGroup.fromStructField(df.schema("result")) - val expectedMetadata = AttributeGroup.fromStructField(df.schema("expected")) + val resultMetadata = AttributeGroup.fromStructField(rows.head.schema("result")) + val expectedMetadata = AttributeGroup.fromStructField(rows.head.schema("expected")) assert(resultMetadata.numAttributes === expectedMetadata.numAttributes) resultMetadata.attributes.get.zip(expectedMetadata.attributes.get).foreach { case (a, b) => assert(a === b) @@ -97,13 +97,16 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De } vectorSlicer.setIndices(Array(1, 4)).setNames(Array.empty) - validateResults(vectorSlicer.transform(df)) + testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer, "result", "expected")( + validateResults) vectorSlicer.setIndices(Array(1)).setNames(Array("f4")) - validateResults(vectorSlicer.transform(df)) + testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer, "result", "expected")( + validateResults) vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4")) - validateResults(vectorSlicer.transform(df)) + testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer, "result", "expected")( + validateResults) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 10682ba176aca..b59c4e7967338 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -17,17 +17,17 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row import org.apache.spark.util.Utils -class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class Word2VecSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ test("params") { ParamsSuite.checkParams(new Word2Vec) @@ -36,10 +36,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } test("Word2Vec") { - - val spark = this.spark - import spark.implicits._ - val sentence = "a b " * 100 + "a c " * 10 val numOfWords = sentence.split(" ").size val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) @@ -70,17 +66,13 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul // These expectations are just magic values, characterizing the current // behavior. The test needs to be updated to be more general, see SPARK-11502 val magicExp = Vectors.dense(0.30153007534417237, -0.6833061711354689, 0.5116530778733167) - model.transform(docDF).select("result", "expected").collect().foreach { + testTransformer[(Seq[String], Vector)](docDF, model, "result", "expected") { case Row(vector1: Vector, vector2: Vector) => assert(vector1 ~== magicExp absTol 1E-5, "Transformed vector is different with expected.") } } test("getVectors") { - - val spark = this.spark - import spark.implicits._ - val sentence = "a b " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) @@ -119,9 +111,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("findSynonyms") { - val spark = this.spark - import spark.implicits._ - val sentence = "a b " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) val docDF = doc.zip(doc).toDF("text", "alsotext") @@ -154,9 +143,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("window size") { - val spark = this.spark - import spark.implicits._ - val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) val docDF = doc.zip(doc).toDF("text", "alsotext") @@ -227,8 +213,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } test("Word2Vec works with input that is non-nullable (NGram)") { - val spark = this.spark - import spark.implicits._ val sentence = "a q s t q s t b b b s t m s t m q " val docDF = sc.parallelize(Seq(sentence, sentence)).map(_.split(" ")).toDF("text") @@ -243,7 +227,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .fit(ngramDF) // Just test that this transformation succeeds - model.transform(ngramDF).collect() + testTransformerByGlobalCheckFunc[(Seq[String], Seq[String])](ngramDF, model, "result") { _ => } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala index 17678aa611a48..795fd0e2ac0e4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala @@ -22,9 +22,10 @@ import java.io.File import org.scalatest.Suite import org.apache.spark.SparkContext -import org.apache.spark.ml.{PipelineModel, Transformer} +import org.apache.spark.ml.Transformer import org.apache.spark.sql.{DataFrame, Encoder, Row} import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions.col import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.test.TestSparkSession import org.apache.spark.util.Utils @@ -62,8 +63,10 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => val columnNames = dataframe.schema.fieldNames val stream = MemoryStream[A] - val streamDF = stream.toDS().toDF(columnNames: _*) - + val columnsWithMetadata = dataframe.schema.map { structField => + col(structField.name).as(structField.name, structField.metadata) + } + val streamDF = stream.toDS().toDF(columnNames: _*).select(columnsWithMetadata: _*) val data = dataframe.as[A].collect() val streamOutput = transformer.transform(streamDF) @@ -108,5 +111,29 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => otherResultCols: _*)(globalCheckFunction) testTransformerOnDF(dataframe, transformer, firstResultCol, otherResultCols: _*)(globalCheckFunction) + } + + def testTransformerByInterceptingException[A : Encoder]( + dataframe: DataFrame, + transformer: Transformer, + expectedMessagePart : String, + firstResultCol: String) { + + def hasExpectedMessage(exception: Throwable): Boolean = + exception.getMessage.contains(expectedMessagePart) || + (exception.getCause != null && exception.getCause.getMessage.contains(expectedMessagePart)) + + withClue(s"""Expected message part "${expectedMessagePart}" is not found in DF test.""") { + val exceptionOnDf = intercept[Throwable] { + testTransformerOnDF(dataframe, transformer, firstResultCol)(_ => Unit) + } + assert(hasExpectedMessage(exceptionOnDf)) + } + withClue(s"""Expected message part "${expectedMessagePart}" is not found in stream test.""") { + val exceptionOnStreamData = intercept[Throwable] { + testTransformerOnStreamData(dataframe, transformer, firstResultCol)(_ => Unit) + } + assert(hasExpectedMessage(exceptionOnStreamData)) + } } } From 4f5bad615b47d743b8932aea1071652293981604 Mon Sep 17 00:00:00 2001 From: smallory Date: Thu, 15 Mar 2018 11:58:54 +0900 Subject: [PATCH 471/774] [SPARK-23642][DOCS] AccumulatorV2 subclass isZero scaladoc fix Added/corrected scaladoc for isZero on the DoubleAccumulator, CollectionAccumulator, and LongAccumulator subclasses of AccumulatorV2, particularly noting where there are requirements in addition to having a value of zero in order to return true. ## What changes were proposed in this pull request? Three scaladoc comments are updated in AccumulatorV2.scala No changes outside of comment blocks were made. ## How was this patch tested? Running "sbt unidoc", fixing style errors found, and reviewing the resulting local scaladoc in firefox. Author: smallory Closes #20790 from smallory/patch-1. --- .../main/scala/org/apache/spark/util/AccumulatorV2.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index f4a736d6d439a..0f84ea9752cf5 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -290,7 +290,8 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] { private var _count = 0L /** - * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * Returns false if this accumulator has had any values added to it or the sum is non-zero. + * * @since 2.0.0 */ override def isZero: Boolean = _sum == 0L && _count == 0 @@ -368,6 +369,9 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { private var _sum = 0.0 private var _count = 0L + /** + * Returns false if this accumulator has had any values added to it or the sum is non-zero. + */ override def isZero: Boolean = _sum == 0.0 && _count == 0 override def copy(): DoubleAccumulator = { @@ -441,6 +445,9 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] { private val _list: java.util.List[T] = Collections.synchronizedList(new ArrayList[T]()) + /** + * Returns false if this accumulator instance has any values in it. + */ override def isZero: Boolean = _list.isEmpty override def copyAndReset(): CollectionAccumulator[T] = new CollectionAccumulator From 7c3e8995f18a1fb57c1f2c1b98a1d47590e28f38 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Thu, 15 Mar 2018 00:04:28 -0700 Subject: [PATCH 472/774] [SPARK-23533][SS] Add support for changing ContinuousDataReader's startOffset ## What changes were proposed in this pull request? As discussion in #20675, we need add a new interface `ContinuousDataReaderFactory` to support the requirements of setting start offset in Continuous Processing. ## How was this patch tested? Existing UT. Author: Yuanjian Li Closes #20689 from xuanyuanking/SPARK-23533. --- .../sql/kafka010/KafkaContinuousReader.scala | 11 +++++- .../reader/ContinuousDataReaderFactory.java | 35 +++++++++++++++++++ .../ContinuousRateStreamSource.scala | 15 +++++++- 3 files changed, 59 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index ecd1170321f3f..6e56b0a72d671 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -164,7 +164,16 @@ case class KafkaContinuousDataReaderFactory( startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends DataReaderFactory[UnsafeRow] { + failOnDataLoss: Boolean) extends ContinuousDataReaderFactory[UnsafeRow] { + + override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[UnsafeRow] = { + val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset] + require(kafkaOffset.topicPartition == topicPartition, + s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}") + new KafkaContinuousDataReader( + topicPartition, kafkaOffset.partitionOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) + } + override def createDataReader(): KafkaContinuousDataReader = { new KafkaContinuousDataReader( topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java new file mode 100644 index 0000000000000..a61697649c43e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset; + +/** + * A mix-in interface for {@link DataReaderFactory}. Continuous data reader factories can + * implement this interface to provide creating {@link DataReader} with particular offset. + */ +@InterfaceStability.Evolving +public interface ContinuousDataReaderFactory extends DataReaderFactory { + /** + * Create a DataReader with particular offset as its startOffset. + * + * @param offset offset want to set as the DataReader's startOffset. + */ + DataReader createDataReaderWithOffset(PartitionOffset offset); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index b63d8d3e20650..20d90069163a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -106,7 +106,20 @@ case class RateStreamContinuousDataReaderFactory( partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends DataReaderFactory[Row] { + extends ContinuousDataReaderFactory[Row] { + + override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[Row] = { + val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset] + require(rateStreamOffset.partition == partitionIndex, + s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}") + new RateStreamContinuousDataReader( + rateStreamOffset.currentValue, + rateStreamOffset.currentTimeMs, + partitionIndex, + increment, + rowsPerSecond) + } + override def createDataReader(): DataReader[Row] = new RateStreamContinuousDataReader( startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) From 56e8f48a43eb51e8582db2461a585b13a771a00a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 15 Mar 2018 10:55:33 -0700 Subject: [PATCH 473/774] [SPARK-23695][PYTHON] Fix the error message for Kinesis streaming tests ## What changes were proposed in this pull request? This PR proposes to fix the error message for Kinesis in PySpark when its jar is missing but explicitly enabled. ```bash ENABLE_KINESIS_TESTS=1 SPARK_TESTING=1 bin/pyspark pyspark.streaming.tests ``` Before: ``` Skipped test_flume_stream (enable by setting environment variable ENABLE_FLUME_TESTS=1Skipped test_kafka_stream (enable by setting environment variable ENABLE_KAFKA_0_8_TESTS=1Traceback (most recent call last): File "/usr/local/Cellar/python/2.7.14_3/Frameworks/Python.framework/Versions/2.7/lib/python2.7/runpy.py", line 174, in _run_module_as_main "__main__", fname, loader, pkg_name) File "/usr/local/Cellar/python/2.7.14_3/Frameworks/Python.framework/Versions/2.7/lib/python2.7/runpy.py", line 72, in _run_code exec code in run_globals File "/.../spark/python/pyspark/streaming/tests.py", line 1572, in % kinesis_asl_assembly_dir) + NameError: name 'kinesis_asl_assembly_dir' is not defined ``` After: ``` Skipped test_flume_stream (enable by setting environment variable ENABLE_FLUME_TESTS=1Skipped test_kafka_stream (enable by setting environment variable ENABLE_KAFKA_0_8_TESTS=1Traceback (most recent call last): File "/usr/local/Cellar/python/2.7.14_3/Frameworks/Python.framework/Versions/2.7/lib/python2.7/runpy.py", line 174, in _run_module_as_main "__main__", fname, loader, pkg_name) File "/usr/local/Cellar/python/2.7.14_3/Frameworks/Python.framework/Versions/2.7/lib/python2.7/runpy.py", line 72, in _run_code exec code in run_globals File "/.../spark/python/pyspark/streaming/tests.py", line 1576, in "You need to build Spark with 'build/sbt -Pkinesis-asl " Exception: Failed to find Spark Streaming Kinesis assembly jar in /.../spark/external/kinesis-asl-assembly. You need to build Spark with 'build/sbt -Pkinesis-asl assembly/package streaming-kinesis-asl-assembly/assembly'or 'build/mvn -Pkinesis-asl package' before running this test. ``` ## How was this patch tested? Manually tested. Author: hyukjinkwon Closes #20834 from HyukjinKwon/minor-variable. --- python/pyspark/streaming/tests.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 71f8101e34c50..7dde7c0928c08 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1503,10 +1503,13 @@ def search_flume_assembly_jar(): return jars[0] -def search_kinesis_asl_assembly_jar(): +def _kinesis_asl_assembly_dir(): SPARK_HOME = os.environ["SPARK_HOME"] - kinesis_asl_assembly_dir = os.path.join(SPARK_HOME, "external/kinesis-asl-assembly") - jars = search_jar(kinesis_asl_assembly_dir, "spark-streaming-kinesis-asl-assembly") + return os.path.join(SPARK_HOME, "external/kinesis-asl-assembly") + + +def search_kinesis_asl_assembly_jar(): + jars = search_jar(_kinesis_asl_assembly_dir(), "spark-streaming-kinesis-asl-assembly") if not jars: return None elif len(jars) > 1: @@ -1569,7 +1572,7 @@ def search_kinesis_asl_assembly_jar(): else: raise Exception( ("Failed to find Spark Streaming Kinesis assembly jar in %s. " - % kinesis_asl_assembly_dir) + + % _kinesis_asl_assembly_dir()) + "You need to build Spark with 'build/sbt -Pkinesis-asl " "assembly/package streaming-kinesis-asl-assembly/assembly'" "or 'build/mvn -Pkinesis-asl package' before running this test.") From 15c3c983008557165cc91713ddaf2dbd6d5a506c Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 15 Mar 2018 19:54:58 +0100 Subject: [PATCH 474/774] [HOT-FIX] Fix SparkOutOfMemoryError: Unable to acquire 262144 bytes of memory, got 224631 ## What changes were proposed in this pull request? https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/88263/testReport https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/88260/testReport https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/88257/testReport https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/88224/testReport These tests all failed: ``` org.apache.spark.memory.SparkOutOfMemoryError: Unable to acquire 262144 bytes of memory, got 224631 at org.apache.spark.memory.MemoryConsumer.throwOom(MemoryConsumer.java:157) at org.apache.spark.memory.MemoryConsumer.allocateArray(MemoryConsumer.java:98) at org.apache.spark.unsafe.map.BytesToBytesMap.allocate(BytesToBytesMap.java:787) at org.apache.spark.unsafe.map.BytesToBytesMap.(BytesToBytesMap.java:204) at org.apache.spark.unsafe.map.BytesToBytesMap.(BytesToBytesMap.java:219) ... ``` This PR ignore this test. ## How was this patch tested? N/A Author: Yuming Wang Closes #20835 from wangyum/SPARK-23598. --- .../org/apache/spark/sql/execution/WholeStageCodegenSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 4b40e4ef7571c..9180a22c260f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -310,7 +310,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-23598: Codegen working for lots of aggregation operations without runtime errors") { + ignore("SPARK-23598: Codegen working for lots of aggregation operations without runtime errors") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { var df = Seq((8, "bat"), (15, "mouse"), (5, "horse")).toDF("age", "name") for (i <- 0 until 70) { From 7618896e855579f111dd92cd76794a5672a087e5 Mon Sep 17 00:00:00 2001 From: Sahil Takiar Date: Thu, 15 Mar 2018 17:04:39 -0700 Subject: [PATCH 475/774] [SPARK-23658][LAUNCHER] InProcessAppHandle uses the wrong class in getLogger ## What changes were proposed in this pull request? Changed `Logger` in `InProcessAppHandle` to use `InProcessAppHandle` instead of `ChildProcAppHandle` Author: Sahil Takiar Closes #20815 from sahilTakiar/master. --- .../main/java/org/apache/spark/launcher/InProcessAppHandle.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index 4b740d3fad20e..15fbca0facef2 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -25,7 +25,7 @@ class InProcessAppHandle extends AbstractAppHandle { private static final String THREAD_NAME_FMT = "spark-app-%d: '%s'"; - private static final Logger LOG = Logger.getLogger(ChildProcAppHandle.class.getName()); + private static final Logger LOG = Logger.getLogger(InProcessAppHandle.class.getName()); private static final AtomicLong THREAD_IDS = new AtomicLong(); // Avoid really long thread names. From 18f8575e0166c6997569358d45bdae2cf45bf624 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 15 Mar 2018 17:12:01 -0700 Subject: [PATCH 476/774] [SPARK-23671][CORE] Fix condition to enable the SHS thread pool. Author: Marcelo Vanzin Closes #20814 from vanzin/SPARK-23671. --- .../org/apache/spark/deploy/history/FsHistoryProvider.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index f9d0b5ee4e23e..ace6d9e00c838 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -173,7 +173,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * Fixed size thread pool to fetch and parse log files. */ private val replayExecutor: ExecutorService = { - if (Utils.isTesting) { + if (!Utils.isTesting) { ThreadUtils.newDaemonFixedThreadPool(NUM_PROCESSING_THREADS, "log-replay-executor") } else { MoreExecutors.sameThreadExecutor() From 3675af7247e841e9a689666dc20891ba55c612b3 Mon Sep 17 00:00:00 2001 From: Ye Zhou Date: Thu, 15 Mar 2018 17:15:53 -0700 Subject: [PATCH 477/774] [SPARK-23608][CORE][WEBUI] Add synchronization in SHS between attachSparkUI and detachSparkUI functions to avoid concurrent modification issue to Jetty Handlers Jetty handlers are dynamically attached/detached while SHS is running. But the attach and detach operations might be taking place at the same time due to the async in load/clear in Guava Cache. ## What changes were proposed in this pull request? Add synchronization between attachSparkUI and detachSparkUI in SHS. ## How was this patch tested? With this patch, the jetty handlers missing issue never happens again in our production cluster SHS. Author: Ye Zhou Closes #20744 from zhouyejoe/SPARK-23608. --- .../apache/spark/deploy/history/HistoryServer.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 0ec4afad0308c..611fa563a7cd9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -150,14 +150,18 @@ class HistoryServer( ui: SparkUI, completed: Boolean) { assert(serverInfo.isDefined, "HistoryServer must be bound before attaching SparkUIs") - ui.getHandlers.foreach(attachHandler) - addFilters(ui.getHandlers, conf) + handlers.synchronized { + ui.getHandlers.foreach(attachHandler) + addFilters(ui.getHandlers, conf) + } } /** Detach a reconstructed UI from this server. Only valid after bind(). */ override def detachSparkUI(appId: String, attemptId: Option[String], ui: SparkUI): Unit = { assert(serverInfo.isDefined, "HistoryServer must be bound before detaching SparkUIs") - ui.getHandlers.foreach(detachHandler) + handlers.synchronized { + ui.getHandlers.foreach(detachHandler) + } provider.onUIDetached(appId, attemptId, ui) } From c2632edebd978716dbfa7874a2fc0a8f5a4a9951 Mon Sep 17 00:00:00 2001 From: myroslavlisniak Date: Thu, 15 Mar 2018 17:20:17 -0700 Subject: [PATCH 478/774] [SPARK-23670][SQL] Fix memory leak on SparkPlanGraphWrapper Clean up SparkPlanGraphWrapper objects from InMemoryStore together with cleaning up SQLExecutionUIData existing unit test was extended to check also SparkPlanGraphWrapper object count vanzin Author: myroslavlisniak Closes #20813 from myroslavlisniak/master. --- .../apache/spark/sql/execution/ui/SQLAppStatusListener.scala | 5 ++++- .../apache/spark/sql/execution/ui/SQLAppStatusStore.scala | 4 ++++ .../spark/sql/execution/ui/SQLAppStatusListenerSuite.scala | 1 + 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 53fb9a0cc21cf..71e9f93c4566e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -334,7 +334,10 @@ class SQLAppStatusListener( val view = kvstore.view(classOf[SQLExecutionUIData]).index("completionTime").first(0L) val toDelete = KVUtils.viewToSeq(view, countToDelete.toInt)(_.completionTime.isDefined) - toDelete.foreach { e => kvstore.delete(e.getClass(), e.executionId) } + toDelete.foreach { e => + kvstore.delete(e.getClass(), e.executionId) + kvstore.delete(classOf[SparkPlanGraphWrapper], e.executionId) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala index 9a76584717f42..241001a857c8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala @@ -54,6 +54,10 @@ class SQLAppStatusStore( store.count(classOf[SQLExecutionUIData]) } + def planGraphCount(): Long = { + store.count(classOf[SparkPlanGraphWrapper]) + } + def executionMetrics(executionId: Long): Map[Long, String] = { def metricsFromStore(): Option[Map[Long, String]] = { val exec = store.read(classOf[SQLExecutionUIData], executionId) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 85face3994fd4..f3f08839c1d3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -611,6 +611,7 @@ class SQLAppStatusListenerMemoryLeakSuite extends SparkFunSuite { sc.listenerBus.waitUntilEmpty(10000) val statusStore = spark.sharedState.statusStore assert(statusStore.executionsCount() <= 50) + assert(statusStore.planGraphCount() <= 50) // No live data should be left behind after all executions end. assert(statusStore.listener.get.noLiveData()) } From ca83526de55f0f8784df58cc8b7c0a7cb0c96e23 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 16 Mar 2018 15:12:26 +0800 Subject: [PATCH 479/774] [SPARK-23644][CORE][UI] Use absolute path for REST call in SHS ## What changes were proposed in this pull request? SHS is using a relative path for the REST API call to get the list of the application is a relative path call. In case of the SHS being consumed through a proxy, it can be an issue if the path doesn't end with a "/". Therefore, we should use an absolute path for the REST call as it is done for all the other resources. ## How was this patch tested? manual tests Before the change: ![screen shot 2018-03-10 at 4 22 02 pm](https://user-images.githubusercontent.com/8821783/37244190-8ccf9d40-2485-11e8-8fa9-345bc81472fc.png) After the change: ![screen shot 2018-03-10 at 4 36 34 pm 1](https://user-images.githubusercontent.com/8821783/37244201-a1922810-2485-11e8-8856-eeab2bf5e180.png) Author: Marco Gaido Closes #20794 from mgaido91/SPARK-23644. --- .../main/resources/org/apache/spark/ui/static/historypage.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index f0b2a5a833a99..abc2ec0fa6531 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -113,7 +113,7 @@ $(document).ready(function() { status: (requestedIncomplete ? "running" : "completed") }; - $.getJSON("api/v1/applications", appParams, function(response,status,jqXHR) { + $.getJSON(uiRoot + "/api/v1/applications", appParams, function(response,status,jqXHR) { var array = []; var hasMultipleAttempts = false; for (i in response) { @@ -151,7 +151,7 @@ $(document).ready(function() { "showCompletedColumns": !requestedIncomplete, } - $.get("static/historypage-template.html", function(template) { + $.get(uiRoot + "/static/historypage-template.html", function(template) { var sibling = historySummary.prev(); historySummary.detach(); var apps = $(Mustache.render($(template).filter("#history-summary-template").html(),data)); From c952000487ee003200221b3c4e25dcb06e359f0a Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 16 Mar 2018 16:22:03 +0800 Subject: [PATCH 480/774] [SPARK-23635][YARN] AM env variable should not overwrite same name env variable set through spark.executorEnv. ## What changes were proposed in this pull request? In the current Spark on YARN code, AM always will copy and overwrite its env variables to executors, so we cannot set different values for executors. To reproduce issue, user could start spark-shell like: ``` ./bin/spark-shell --master yarn-client --conf spark.executorEnv.SPARK_ABC=executor_val --conf spark.yarn.appMasterEnv.SPARK_ABC=am_val ``` Then check executor env variables by ``` sc.parallelize(1 to 1).flatMap \{ i => sys.env.toSeq }.collect.foreach(println) ``` We will always get `am_val` instead of `executor_val`. So we should not let AM to overwrite specifically set executor env variables. ## How was this patch tested? Added UT and tested in local cluster. Author: jerryshao Closes #20799 from jerryshao/SPARK-23635. --- .../spark/deploy/yarn/ExecutorRunnable.scala | 22 +++++++----- .../spark/deploy/yarn/YarnClusterSuite.scala | 36 +++++++++++++++++++ 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 3f4d236571ffd..ab08698035c98 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -220,12 +220,6 @@ private[yarn] class ExecutorRunnable( val env = new HashMap[String, String]() Client.populateClasspath(null, conf, sparkConf, env, sparkConf.get(EXECUTOR_CLASS_PATH)) - sparkConf.getExecutorEnv.foreach { case (key, value) => - // This assumes each executor environment variable set here is a path - // This is kept for backward compatibility and consistency with hadoop - YarnSparkHadoopUtil.addPathToEnvironment(env, key, value) - } - // lookup appropriate http scheme for container log urls val yarnHttpPolicy = conf.get( YarnConfiguration.YARN_HTTP_POLICY_KEY, @@ -233,6 +227,20 @@ private[yarn] class ExecutorRunnable( ) val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" + System.getenv().asScala.filterKeys(_.startsWith("SPARK")) + .foreach { case (k, v) => env(k) = v } + + sparkConf.getExecutorEnv.foreach { case (key, value) => + if (key == Environment.CLASSPATH.name()) { + // If the key of env variable is CLASSPATH, we assume it is a path and append it. + // This is kept for backward compatibility and consistency with hadoop + YarnSparkHadoopUtil.addPathToEnvironment(env, key, value) + } else { + // For other env variables, simply overwrite the value. + env(key) = value + } + } + // Add log urls container.foreach { c => sys.env.get("SPARK_USER").foreach { user => @@ -245,8 +253,6 @@ private[yarn] class ExecutorRunnable( } } - System.getenv().asScala.filterKeys(_.startsWith("SPARK")) - .foreach { case (k, v) => env(k) = v } env } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 33d400a5b1b2e..a129be7c06b53 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -225,6 +225,14 @@ class YarnClusterSuite extends BaseYarnClusterSuite { finalState should be (SparkAppHandle.State.FAILED) } + test("executor env overwrite AM env in client mode") { + testExecutorEnv(true) + } + + test("executor env overwrite AM env in cluster mode") { + testExecutorEnv(false) + } + private def testBasicYarnApp(clientMode: Boolean, conf: Map[String, String] = Map()): Unit = { val result = File.createTempFile("result", null, tempDir) val finalState = runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), @@ -305,6 +313,17 @@ class YarnClusterSuite extends BaseYarnClusterSuite { checkResult(finalState, executorResult, "OVERRIDDEN") } + private def testExecutorEnv(clientMode: Boolean): Unit = { + val result = File.createTempFile("result", null, tempDir) + val finalState = runSpark(clientMode, mainClassName(ExecutorEnvTestApp.getClass), + appArgs = Seq(result.getAbsolutePath), + extraConf = Map( + "spark.yarn.appMasterEnv.TEST_ENV" -> "am_val", + "spark.executorEnv.TEST_ENV" -> "executor_val" + ) + ) + checkResult(finalState, result, "true") + } } private[spark] class SaveExecutorInfo extends SparkListener { @@ -526,3 +545,20 @@ private object SparkContextTimeoutApp { } } + +private object ExecutorEnvTestApp { + + def main(args: Array[String]): Unit = { + val status = args(0) + val sparkConf = new SparkConf() + val sc = new SparkContext(sparkConf) + val executorEnvs = sc.parallelize(Seq(1)).flatMap { _ => sys.env }.collect().toMap + val result = sparkConf.getExecutorEnv.forall { case (k, v) => + executorEnvs.get(k).contains(v) + } + + Files.write(result.toString, new File(status), StandardCharsets.UTF_8) + sc.stop() + } + +} From 5414abca4fec6a68174c34d22d071c20027e959d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 16 Mar 2018 09:36:30 -0700 Subject: [PATCH 481/774] [SPARK-23553][TESTS] Tests should not assume the default value of `spark.sql.sources.default` ## What changes were proposed in this pull request? Currently, some tests have an assumption that `spark.sql.sources.default=parquet`. In fact, that is a correct assumption, but that assumption makes it difficult to test new data source format. This PR aims to - Improve test suites more robust and makes it easy to test new data sources in the future. - Test new native ORC data source with the full existing Apache Spark test coverage. As an example, the PR uses `spark.sql.sources.default=orc` during reviews. The value should be `parquet` when this PR is accepted. ## How was this patch tested? Pass the Jenkins with updated tests. Author: Dongjoon Hyun Closes #20705 from dongjoon-hyun/SPARK-23553. --- python/pyspark/sql/readwriter.py | 4 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 9 +-- .../columnar/InMemoryColumnarQuerySuite.scala | 5 +- .../sql/execution/command/DDLSuite.scala | 11 ++- .../ParquetPartitionDiscoverySuite.scala | 10 +++ .../sql/test/DataFrameReaderWriterSuite.scala | 3 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 72 +++++++++---------- .../PartitionProviderCompatibilitySuite.scala | 6 +- .../hive/PartitionedTablePerfStatsSuite.scala | 2 +- .../sql/hive/execution/HiveDDLSuite.scala | 11 +-- .../sql/hive/execution/SQLQuerySuite.scala | 19 ++--- 11 files changed, 81 insertions(+), 71 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 803f561ece67b..facc16bc53108 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -147,8 +147,8 @@ def load(self, path=None, format=None, schema=None, **options): or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param options: all other string options - >>> df = spark.read.load('python/test_support/sql/parquet_partitioned', opt1=True, - ... opt2=1, opt3='str') + >>> df = spark.read.format("parquet").load('python/test_support/sql/parquet_partitioned', + ... opt1=True, opt2=1, opt3='str') >>> df.dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 8f14575c3325f..640affc10ee58 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2150,7 +2150,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("data source table created in InMemoryCatalog should be able to read/write") { withTable("tbl") { - sql("CREATE TABLE tbl(i INT, j STRING) USING parquet") + val provider = spark.sessionState.conf.defaultDataSourceName + sql(s"CREATE TABLE tbl(i INT, j STRING) USING $provider") checkAnswer(sql("SELECT i, j FROM tbl"), Nil) Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto("tbl") @@ -2474,9 +2475,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-16975: Column-partition path starting '_' should be handled correctly") { withTempDir { dir => - val parquetDir = new File(dir, "parquet").getCanonicalPath - spark.range(10).withColumn("_col", $"id").write.partitionBy("_col").save(parquetDir) - spark.read.parquet(parquetDir) + val dataDir = new File(dir, "data").getCanonicalPath + spark.range(10).withColumn("_col", $"id").write.partitionBy("_col").save(dataDir) + spark.read.load(dataDir) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index dc1766fb9a785..26b63e8e8490f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -487,7 +487,10 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-22673: InMemoryRelation should utilize existing stats of the plan to be cached") { - withSQLConf("spark.sql.cbo.enabled" -> "true") { + // This test case depends on the size of parquet in statistics. + withSQLConf( + SQLConf.CBO_ENABLED.key -> "true", + SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "parquet") { withTempPath { workDir => withTable("table1") { val workDirPath = workDir.getAbsolutePath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 4041176262426..4df8fbfe1c0db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -154,10 +154,15 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo Seq(4 -> "d").toDF("i", "j").write.saveAsTable("t1") val e = intercept[AnalysisException] { - Seq(5 -> "e").toDF("i", "j").write.mode("append").format("json").saveAsTable("t1") + val format = if (spark.sessionState.conf.defaultDataSourceName.equalsIgnoreCase("json")) { + "orc" + } else { + "json" + } + Seq(5 -> "e").toDF("i", "j").write.mode("append").format(format).saveAsTable("t1") } - assert(e.message.contains("The format of the existing table default.t1 is " + - "`ParquetFileFormat`. It doesn't match the specified format `JsonFileFormat`.")) + assert(e.message.contains("The format of the existing table default.t1 is ")) + assert(e.message.contains("It doesn't match the specified format")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index edb3da904d10d..e887c9734a8b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -57,6 +57,16 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val timeZone = TimeZone.getDefault() val timeZoneId = timeZone.getID + protected override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.DEFAULT_DATA_SOURCE_NAME.key, "parquet") + } + + protected override def afterAll(): Unit = { + spark.conf.unset(SQLConf.DEFAULT_DATA_SOURCE_NAME.key) + super.afterAll() + } + test("column type inference") { def check(raw: String, literal: Literal, timeZone: TimeZone = timeZone): Unit = { assert(inferPartitionColumnValue(raw, true, timeZone) === literal) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index a707a88dfa670..14b1feb2adc20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -562,7 +562,8 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be "and a same-name temp view exist") { withTable("same_name") { withTempView("same_name") { - sql("CREATE TABLE same_name(id LONG) USING parquet") + val format = spark.sessionState.conf.defaultDataSourceName + sql(s"CREATE TABLE same_name(id LONG) USING $format") spark.range(10).createTempView("same_name") spark.range(20).write.mode(SaveMode.Append).saveAsTable("same_name") checkAnswer(spark.table("same_name"), spark.range(10).toDF()) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 859099a321bf7..d93215fefb810 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -591,7 +591,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } test("Pre insert nullability check (ArrayType)") { - withTable("arrayInParquet") { + withTable("array") { { val df = (Tuple1(Seq(Int.box(1), null: Integer)) :: Nil).toDF("a") val expectedSchema = @@ -604,9 +604,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(df.schema === expectedSchema) df.write - .format("parquet") .mode(SaveMode.Overwrite) - .saveAsTable("arrayInParquet") + .saveAsTable("array") } { @@ -621,25 +620,24 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(df.schema === expectedSchema) df.write - .format("parquet") .mode(SaveMode.Append) - .insertInto("arrayInParquet") + .insertInto("array") } (Tuple1(Seq(4, 5)) :: Nil).toDF("a") .write .mode(SaveMode.Append) - .saveAsTable("arrayInParquet") // This one internally calls df2.insertInto. + .saveAsTable("array") // This one internally calls df2.insertInto. (Tuple1(Seq(Int.box(6), null: Integer)) :: Nil).toDF("a") .write .mode(SaveMode.Append) - .saveAsTable("arrayInParquet") + .saveAsTable("array") - sparkSession.catalog.refreshTable("arrayInParquet") + sparkSession.catalog.refreshTable("array") checkAnswer( - sql("SELECT a FROM arrayInParquet"), + sql("SELECT a FROM array"), Row(ArrayBuffer(1, null)) :: Row(ArrayBuffer(2, 3)) :: Row(ArrayBuffer(4, 5)) :: @@ -648,7 +646,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } test("Pre insert nullability check (MapType)") { - withTable("mapInParquet") { + withTable("map") { { val df = (Tuple1(Map(1 -> (null: Integer))) :: Nil).toDF("a") val expectedSchema = @@ -661,9 +659,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(df.schema === expectedSchema) df.write - .format("parquet") .mode(SaveMode.Overwrite) - .saveAsTable("mapInParquet") + .saveAsTable("map") } { @@ -678,27 +675,24 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(df.schema === expectedSchema) df.write - .format("parquet") .mode(SaveMode.Append) - .insertInto("mapInParquet") + .insertInto("map") } (Tuple1(Map(4 -> 5)) :: Nil).toDF("a") .write - .format("parquet") .mode(SaveMode.Append) - .saveAsTable("mapInParquet") // This one internally calls df2.insertInto. + .saveAsTable("map") // This one internally calls df2.insertInto. (Tuple1(Map(6 -> null.asInstanceOf[Integer])) :: Nil).toDF("a") .write - .format("parquet") .mode(SaveMode.Append) - .saveAsTable("mapInParquet") + .saveAsTable("map") - sparkSession.catalog.refreshTable("mapInParquet") + sparkSession.catalog.refreshTable("map") checkAnswer( - sql("SELECT a FROM mapInParquet"), + sql("SELECT a FROM map"), Row(Map(1 -> null)) :: Row(Map(2 -> 3)) :: Row(Map(4 -> 5)) :: @@ -852,52 +846,52 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv (from to to).map(i => i -> s"str$i").toDF("c1", "c2") } - withTable("insertParquet") { - createDF(0, 9).write.format("parquet").saveAsTable("insertParquet") + withTable("t") { + createDF(0, 9).write.saveAsTable("t") checkAnswer( - sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + sql("SELECT p.c1, p.c2 FROM t p WHERE p.c1 > 5"), (6 to 9).map(i => Row(i, s"str$i"))) intercept[AnalysisException] { - createDF(10, 19).write.format("parquet").saveAsTable("insertParquet") + createDF(10, 19).write.saveAsTable("t") } - createDF(10, 19).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") + createDF(10, 19).write.mode(SaveMode.Append).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + sql("SELECT p.c1, p.c2 FROM t p WHERE p.c1 > 5"), (6 to 19).map(i => Row(i, s"str$i"))) - createDF(20, 29).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") + createDF(20, 29).write.mode(SaveMode.Append).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 25"), + sql("SELECT p.c1, c2 FROM t p WHERE p.c1 > 5 AND p.c1 < 25"), (6 to 24).map(i => Row(i, s"str$i"))) intercept[AnalysisException] { - createDF(30, 39).write.saveAsTable("insertParquet") + createDF(30, 39).write.saveAsTable("t") } - createDF(30, 39).write.mode(SaveMode.Append).saveAsTable("insertParquet") + createDF(30, 39).write.mode(SaveMode.Append).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 35"), + sql("SELECT p.c1, c2 FROM t p WHERE p.c1 > 5 AND p.c1 < 35"), (6 to 34).map(i => Row(i, s"str$i"))) - createDF(40, 49).write.mode(SaveMode.Append).insertInto("insertParquet") + createDF(40, 49).write.mode(SaveMode.Append).insertInto("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 45"), + sql("SELECT p.c1, c2 FROM t p WHERE p.c1 > 5 AND p.c1 < 45"), (6 to 44).map(i => Row(i, s"str$i"))) - createDF(50, 59).write.mode(SaveMode.Overwrite).saveAsTable("insertParquet") + createDF(50, 59).write.mode(SaveMode.Overwrite).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 51 AND p.c1 < 55"), + sql("SELECT p.c1, c2 FROM t p WHERE p.c1 > 51 AND p.c1 < 55"), (52 to 54).map(i => Row(i, s"str$i"))) - createDF(60, 69).write.mode(SaveMode.Ignore).saveAsTable("insertParquet") + createDF(60, 69).write.mode(SaveMode.Ignore).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p"), + sql("SELECT p.c1, c2 FROM t p"), (50 to 59).map(i => Row(i, s"str$i"))) - createDF(70, 79).write.mode(SaveMode.Overwrite).insertInto("insertParquet") + createDF(70, 79).write.mode(SaveMode.Overwrite).insertInto("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p"), + sql("SELECT p.c1, c2 FROM t p"), (70 to 79).map(i => Row(i, s"str$i"))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala index 9440a17677ebf..80afc9d8f44bc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala @@ -37,11 +37,11 @@ class PartitionProviderCompatibilitySuite spark.range(5).selectExpr("id as fieldOne", "id as partCol").write .partitionBy("partCol") .mode("overwrite") - .parquet(dir.getAbsolutePath) + .save(dir.getAbsolutePath) spark.sql(s""" |create table $tableName (fieldOne long, partCol int) - |using parquet + |using ${spark.sessionState.conf.defaultDataSourceName} |options (path "${dir.toURI}") |partitioned by (partCol)""".stripMargin) } @@ -358,7 +358,7 @@ class PartitionProviderCompatibilitySuite try { spark.sql(s""" |create table test (id long, P1 int, P2 int) - |using parquet + |using ${spark.sessionState.conf.defaultDataSourceName} |options (path "${base.toURI}") |partitioned by (P1, P2)""".stripMargin) spark.sql(s"alter table test add partition (P1=0, P2=0) location '${a.toURI}'") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala index 54d3962a46b4d..1a86c604d5da3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala @@ -417,7 +417,7 @@ class PartitionedTablePerfStatsSuite import spark.implicits._ Seq(1).toDF("a").write.mode("overwrite").save(dir.getAbsolutePath) HiveCatalogMetrics.reset() - spark.read.parquet(dir.getAbsolutePath) + spark.read.load(dir.getAbsolutePath) assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 1) assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 1) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 65be244418670..db76ec9d084cb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1658,8 +1658,8 @@ class HiveDDLSuite Seq(5 -> "e").toDF("i", "j") .write.format("hive").mode("append").saveAsTable("t1") } - assert(e.message.contains("The format of the existing table default.t1 is " + - "`ParquetFileFormat`. It doesn't match the specified format `HiveFileFormat`.")) + assert(e.message.contains("The format of the existing table default.t1 is ")) + assert(e.message.contains("It doesn't match the specified format `HiveFileFormat`.")) } } @@ -1709,11 +1709,12 @@ class HiveDDLSuite spark.sessionState.catalog.getTableMetadata(TableIdentifier(tblName)).schema.map(_.name) } + val provider = spark.sessionState.conf.defaultDataSourceName withTable("t", "t1", "t2", "t3", "t4", "t5", "t6") { - sql("CREATE TABLE t(a int, b int, c int, d int) USING parquet PARTITIONED BY (d, b)") + sql(s"CREATE TABLE t(a int, b int, c int, d int) USING $provider PARTITIONED BY (d, b)") assert(getTableColumns("t") == Seq("a", "c", "d", "b")) - sql("CREATE TABLE t1 USING parquet PARTITIONED BY (d, b) AS SELECT 1 a, 1 b, 1 c, 1 d") + sql(s"CREATE TABLE t1 USING $provider PARTITIONED BY (d, b) AS SELECT 1 a, 1 b, 1 c, 1 d") assert(getTableColumns("t1") == Seq("a", "c", "d", "b")) Seq((1, 1, 1, 1)).toDF("a", "b", "c", "d").write.partitionBy("d", "b").saveAsTable("t2") @@ -1723,7 +1724,7 @@ class HiveDDLSuite val dataPath = new File(new File(path, "d=1"), "b=1").getCanonicalPath Seq(1 -> 1).toDF("a", "c").write.save(dataPath) - sql(s"CREATE TABLE t3 USING parquet LOCATION '${path.toURI}'") + sql(s"CREATE TABLE t3 USING $provider LOCATION '${path.toURI}'") assert(getTableColumns("t3") == Seq("a", "c", "d", "b")) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index baabc4a3bca2c..73f83d593bbfb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -516,24 +516,19 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("CTAS with default fileformat") { val table = "ctas1" val ctas = s"CREATE TABLE IF NOT EXISTS $table SELECT key k, value FROM src" - withSQLConf(SQLConf.CONVERT_CTAS.key -> "true") { - withSQLConf("hive.default.fileformat" -> "textfile") { + Seq("orc", "parquet").foreach { dataSourceFormat => + withSQLConf( + SQLConf.CONVERT_CTAS.key -> "true", + SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> dataSourceFormat, + "hive.default.fileformat" -> "textfile") { withTable(table) { sql(ctas) - // We should use parquet here as that is the default datasource fileformat. The default - // datasource file format is controlled by `spark.sql.sources.default` configuration. + // The default datasource file format is controlled by `spark.sql.sources.default`. // This testcase verifies that setting `hive.default.fileformat` has no impact on // the target table's fileformat in case of CTAS. - assert(sessionState.conf.defaultDataSourceName === "parquet") - checkRelation(tableName = table, isDataSourceTable = true, format = "parquet") + checkRelation(tableName = table, isDataSourceTable = true, format = dataSourceFormat) } } - withSQLConf("spark.sql.sources.default" -> "orc") { - withTable(table) { - sql(ctas) - checkRelation(tableName = table, isDataSourceTable = true, format = "orc") - } - } } } From dffeac3691daa620446ae949c5b147518d128e08 Mon Sep 17 00:00:00 2001 From: Sebastian Arzt Date: Fri, 16 Mar 2018 12:25:58 -0500 Subject: [PATCH 482/774] [SPARK-18371][STREAMING] Spark Streaming backpressure generates batch with large number of records MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Omit rounding of backpressure rate. Effects: - no batch with large number of records is created when rate from PID estimator is one - the number of records per batch and partition is more fine-grained improving backpressure accuracy ## How was this patch tested? This was tested by running: - `mvn test -pl external/kafka-0-8` - `mvn test -pl external/kafka-0-10` - a streaming application which was suffering from the issue JasonMWhite The contribution is my original work and I license the work to the project under the project’s open source license Author: Sebastian Arzt Closes #17774 from arzt/kafka-back-pressure. --- .../kafka010/DirectKafkaInputDStream.scala | 6 +-- .../kafka010/DirectKafkaStreamSuite.scala | 48 +++++++++++++++++ .../kafka/DirectKafkaInputDStream.scala | 6 +-- .../kafka/DirectKafkaStreamSuite.scala | 51 +++++++++++++++++++ 4 files changed, 105 insertions(+), 6 deletions(-) diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 0fa3287f36db8..9cb2448fea0f4 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -138,17 +138,17 @@ private[spark] class DirectKafkaInputDStream[K, V]( lagPerPartition.map { case (tp, lag) => val maxRateLimitPerPartition = ppc.maxRatePerPartition(tp) - val backpressureRate = Math.round(lag / totalLag.toFloat * rate) + val backpressureRate = lag / totalLag.toDouble * rate tp -> (if (maxRateLimitPerPartition > 0) { Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate) } - case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp) } + case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp).toDouble } } if (effectiveRateLimitPerPartition.values.sum > 0) { val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 Some(effectiveRateLimitPerPartition.map { - case (tp, limit) => tp -> (secsPerBatch * limit).toLong + case (tp, limit) => tp -> Math.max((secsPerBatch * limit).toLong, 1L) }) } else { None diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index 453b5e5ab20d3..8524743ee2846 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -617,6 +617,54 @@ class DirectKafkaStreamSuite ssc.stop() } + test("maxMessagesPerPartition with zero offset and rate equal to one") { + val topic = "backpressure" + val kafkaParams = getKafkaParams() + val batchIntervalMilliseconds = 60000 + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.kafka.maxRatePerPartition", "100") + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) + val estimateRate = 1L + val fromOffsets = Map( + new TopicPartition(topic, 0) -> 0L, + new TopicPartition(topic, 1) -> 0L, + new TopicPartition(topic, 2) -> 0L, + new TopicPartition(topic, 3) -> 0L + ) + val kafkaStream = withClue("Error creating direct stream") { + new DirectKafkaInputDStream[String, String]( + ssc, + preferredHosts, + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala), + new DefaultPerPartitionConfig(sparkConf) + ) { + currentOffsets = fromOffsets + override val rateController = Some(new ConstantRateController(id, null, estimateRate)) + } + } + + val offsets = Map[TopicPartition, Long]( + new TopicPartition(topic, 0) -> 0, + new TopicPartition(topic, 1) -> 100L, + new TopicPartition(topic, 2) -> 200L, + new TopicPartition(topic, 3) -> 300L + ) + val result = kafkaStream.maxMessagesPerPartition(offsets) + val expected = Map( + new TopicPartition(topic, 0) -> 1L, + new TopicPartition(topic, 1) -> 10L, + new TopicPartition(topic, 2) -> 20L, + new TopicPartition(topic, 3) -> 30L + ) + assert(result.contains(expected), s"Number of messages per partition must be at least 1") + } + /** Get the generated offset ranges from the DirectKafkaStream */ private def getOffsetRanges[K, V]( kafkaStream: DStream[ConsumerRecord[K, V]]): Seq[(Time, Array[OffsetRange])] = { diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index d52c230eb7849..d6dd0744441e4 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -104,17 +104,17 @@ class DirectKafkaInputDStream[ val totalLag = lagPerPartition.values.sum lagPerPartition.map { case (tp, lag) => - val backpressureRate = Math.round(lag / totalLag.toFloat * rate) + val backpressureRate = lag / totalLag.toDouble * rate tp -> (if (maxRateLimitPerPartition > 0) { Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate) } - case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition } + case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition.toDouble } } if (effectiveRateLimitPerPartition.values.sum > 0) { val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 Some(effectiveRateLimitPerPartition.map { - case (tp, limit) => tp -> (secsPerBatch * limit).toLong + case (tp, limit) => tp -> Math.max((secsPerBatch * limit).toLong, 1L) }) } else { None diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 06ef5bc3f8bd0..3fea6cfd910bf 100644 --- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -456,6 +456,57 @@ class DirectKafkaStreamSuite ssc.stop() } + test("maxMessagesPerPartition with zero offset and rate equal to one") { + val topic = "backpressure" + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "smallest" + ) + + val batchIntervalMilliseconds = 60000 + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.kafka.maxRatePerPartition", "100") + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) + val estimatedRate = 1L + val kafkaStream = withClue("Error creating direct stream") { + val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message) + val fromOffsets = Map( + TopicAndPartition(topic, 0) -> 0L, + TopicAndPartition(topic, 1) -> 0L, + TopicAndPartition(topic, 2) -> 0L, + TopicAndPartition(topic, 3) -> 0L + ) + new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)]( + ssc, kafkaParams, fromOffsets, messageHandler) { + override protected[streaming] val rateController = + Some(new DirectKafkaRateController(id, null) { + override def getLatestRate() = estimatedRate + }) + } + } + + val offsets = Map( + TopicAndPartition(topic, 0) -> 0L, + TopicAndPartition(topic, 1) -> 100L, + TopicAndPartition(topic, 2) -> 200L, + TopicAndPartition(topic, 3) -> 300L + ) + val result = kafkaStream.maxMessagesPerPartition(offsets) + val expected = Map( + TopicAndPartition(topic, 0) -> 1L, + TopicAndPartition(topic, 1) -> 10L, + TopicAndPartition(topic, 2) -> 20L, + TopicAndPartition(topic, 3) -> 30L + ) + assert(result.contains(expected), s"Number of messages per partition must be at least 1") + } + /** Get the generated offset ranges from the DirectKafkaStream */ private def getOffsetRanges[K, V]( kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = { From 88d8de9260edf6e9d5449ff7ef6e35d16051fc9f Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 16 Mar 2018 18:28:16 +0100 Subject: [PATCH 483/774] [SPARK-23581][SQL] Add interpreted unsafe projection ## What changes were proposed in this pull request? We currently can only create unsafe rows using code generation. This is a problem for situations in which code generation fails. There is no fallback, and as a result we cannot execute the query. This PR adds an interpreted version of `UnsafeProjection`. The implementation is modeled after `InterpretedMutableProjection`. It stores the expression results in a `GenericInternalRow`, and it then uses a conversion function to convert the `GenericInternalRow` into an `UnsafeRow`. This PR does not implement the actual code generated to interpreted fallback logic. This will be done in a follow-up. ## How was this patch tested? I am piggybacking on exiting `UnsafeProjection` tests, and I have added an interpreted version for each of these. Author: Herman van Hovell Closes #20750 from hvanhovell/SPARK-23581. --- .../codegen/UnsafeArrayWriter.java | 32 +- .../expressions/codegen/UnsafeRowWriter.java | 30 +- .../expressions/codegen/UnsafeWriter.java | 43 ++ .../sql/catalyst/expressions/Expression.scala | 26 ++ .../InterpretedUnsafeProjection.scala | 366 ++++++++++++++++++ .../MonotonicallyIncreasingID.scala | 4 +- .../sql/catalyst/expressions/Projection.scala | 19 +- .../codegen/GenerateUnsafeProjection.scala | 2 +- .../expressions/randomExpressions.scala | 6 +- .../expressions/ComplexTypeSuite.scala | 2 +- .../expressions/ExpressionEvalHelper.scala | 20 +- .../expressions/ObjectExpressionsSuite.scala | 21 +- .../catalyst/expressions/ScalaUDFSuite.scala | 2 +- .../expressions/UnsafeRowConverterSuite.scala | 56 +-- 14 files changed, 555 insertions(+), 74 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 791e8d80e6cba..82cd1b24607e1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -30,7 +30,7 @@ * A helper class to write data into global row buffer using `UnsafeArrayData` format, * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}. */ -public class UnsafeArrayWriter { +public final class UnsafeArrayWriter extends UnsafeWriter { private BufferHolder holder; @@ -83,7 +83,7 @@ private long getElementOffset(int ordinal, int elementSize) { return startingOffset + headerInBytes + ordinal * elementSize; } - public void setOffsetAndSize(int ordinal, long currentCursor, int size) { + public void setOffsetAndSize(int ordinal, int currentCursor, int size) { assertIndexIsValid(ordinal); final long relativeOffset = currentCursor - startingOffset; final long offsetAndSize = (relativeOffset << 32) | (long)size; @@ -96,49 +96,31 @@ private void setNullBit(int ordinal) { BitSetMethods.set(holder.buffer, startingOffset + 8, ordinal); } - public void setNullBoolean(int ordinal) { - setNullBit(ordinal); - // put zero into the corresponding field when set null - Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), false); - } - - public void setNullByte(int ordinal) { + public void setNull1Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), (byte)0); } - public void setNullShort(int ordinal) { + public void setNull2Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), (short)0); } - public void setNullInt(int ordinal) { + public void setNull4Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), 0); } - public void setNullLong(int ordinal) { + public void setNull8Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), (long)0); } - public void setNullFloat(int ordinal) { - setNullBit(ordinal); - // put zero into the corresponding field when set null - Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), (float)0); - } - - public void setNullDouble(int ordinal) { - setNullBit(ordinal); - // put zero into the corresponding field when set null - Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), (double)0); - } - - public void setNull(int ordinal) { setNullLong(ordinal); } + public void setNull(int ordinal) { setNull8Bytes(ordinal); } public void write(int ordinal, boolean value) { assertIndexIsValid(ordinal); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 5d9515c0725da..2620bbcfb87a2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -38,7 +38,7 @@ * beginning of the global row buffer, we don't need to update `startingOffset` and can just call * `zeroOutNullBytes` before writing new data. */ -public class UnsafeRowWriter { +public final class UnsafeRowWriter extends UnsafeWriter { private final BufferHolder holder; // The offset of the global buffer where we start to write this row. @@ -93,18 +93,38 @@ public void setNullAt(int ordinal) { Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L); } + @Override + public void setNull1Bytes(int ordinal) { + setNullAt(ordinal); + } + + @Override + public void setNull2Bytes(int ordinal) { + setNullAt(ordinal); + } + + @Override + public void setNull4Bytes(int ordinal) { + setNullAt(ordinal); + } + + @Override + public void setNull8Bytes(int ordinal) { + setNullAt(ordinal); + } + public long getFieldOffset(int ordinal) { return startingOffset + nullBitsSize + 8 * ordinal; } - public void setOffsetAndSize(int ordinal, long size) { + public void setOffsetAndSize(int ordinal, int size) { setOffsetAndSize(ordinal, holder.cursor, size); } - public void setOffsetAndSize(int ordinal, long currentCursor, long size) { + public void setOffsetAndSize(int ordinal, int currentCursor, int size) { final long relativeOffset = currentCursor - startingOffset; final long fieldOffset = getFieldOffset(ordinal); - final long offsetAndSize = (relativeOffset << 32) | size; + final long offsetAndSize = (relativeOffset << 32) | (long) size; Platform.putLong(holder.buffer, fieldOffset, offsetAndSize); } @@ -174,7 +194,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) { if (input == null || !input.changePrecision(precision, scale)) { BitSetMethods.set(holder.buffer, startingOffset, ordinal); // keep the offset for future update - setOffsetAndSize(ordinal, 0L); + setOffsetAndSize(ordinal, 0); } else { final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); assert bytes.length <= 16; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java new file mode 100644 index 0000000000000..c94b5c7a367ef --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.codegen; + +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Base class for writing Unsafe* structures. + */ +public abstract class UnsafeWriter { + public abstract void setNull1Bytes(int ordinal); + public abstract void setNull2Bytes(int ordinal); + public abstract void setNull4Bytes(int ordinal); + public abstract void setNull8Bytes(int ordinal); + public abstract void write(int ordinal, boolean value); + public abstract void write(int ordinal, byte value); + public abstract void write(int ordinal, short value); + public abstract void write(int ordinal, int value); + public abstract void write(int ordinal, long value); + public abstract void write(int ordinal, float value); + public abstract void write(int ordinal, double value); + public abstract void write(int ordinal, Decimal input, int precision, int scale); + public abstract void write(int ordinal, UTF8String input); + public abstract void write(int ordinal, byte[] input); + public abstract void write(int ordinal, CalendarInterval input); + public abstract void setOffsetAndSize(int ordinal, int currentCursor, int size); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index ed90b185865a0..d7f9e38915dd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -328,6 +328,32 @@ trait Nondeterministic extends Expression { protected def evalInternal(input: InternalRow): Any } +/** + * An expression that contains mutable state. A stateful expression is always non-deterministic + * because the results it produces during evaluation are not only dependent on the given input + * but also on its internal state. + * + * The state of the expressions is generally not exposed in the parameter list and this makes + * comparing stateful expressions problematic because similar stateful expressions (with the same + * parameter list) but with different internal state will be considered equal. This is especially + * problematic during tree transformations. In order to counter this the `fastEquals` method for + * stateful expressions only returns `true` for the same reference. + * + * A stateful expression should never be evaluated multiple times for a single row. This should + * only be a problem for interpreted execution. This can be prevented by creating fresh copies + * of the stateful expression before execution, these can be made using the `freshCopy` function. + */ +trait Stateful extends Nondeterministic { + /** + * Return a fresh uninitialized copy of the stateful expression. + */ + def freshCopy(): Stateful + + /** + * Only the same reference is considered equal. + */ + override def fastEquals(other: TreeNode[_]): Boolean = this eq other +} /** * A leaf expression, i.e. one without any child expressions. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala new file mode 100644 index 0000000000000..0da5ece7e47fe --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types.{UserDefinedType, _} +import org.apache.spark.unsafe.Platform + +/** + * An interpreted unsafe projection. This class reuses the [[UnsafeRow]] it produces, a consumer + * should copy the row if it is being buffered. This class is not thread safe. + * + * @param expressions that produces the resulting fields. These expressions must be bound + * to a schema. + */ +class InterpretedUnsafeProjection(expressions: Array[Expression]) extends UnsafeProjection { + import InterpretedUnsafeProjection._ + + /** Number of (top level) fields in the resulting row. */ + private[this] val numFields = expressions.length + + /** Array that expression results. */ + private[this] val values = new Array[Any](numFields) + + /** The row representing the expression results. */ + private[this] val intermediate = new GenericInternalRow(values) + + /** The row returned by the projection. */ + private[this] val result = new UnsafeRow(numFields) + + /** The buffer which holds the resulting row's backing data. */ + private[this] val holder = new BufferHolder(result, numFields * 32) + + /** The writer that writes the intermediate result to the result row. */ + private[this] val writer: InternalRow => Unit = { + val rowWriter = new UnsafeRowWriter(holder, numFields) + val baseWriter = generateStructWriter( + holder, + rowWriter, + expressions.map(e => StructField("", e.dataType, e.nullable))) + if (!expressions.exists(_.nullable)) { + // No nullable fields. The top-level null bit mask will always be zeroed out. + baseWriter + } else { + // Zero out the null bit mask before we write the row. + row => { + rowWriter.zeroOutNullBytes() + baseWriter(row) + } + } + } + + override def initialize(partitionIndex: Int): Unit = { + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + }) + } + + override def apply(row: InternalRow): UnsafeRow = { + // Put the expression results in the intermediate row. + var i = 0 + while (i < numFields) { + values(i) = expressions(i).eval(row) + i += 1 + } + + // Write the intermediate row to an unsafe row. + holder.reset() + writer(intermediate) + result.setTotalSize(holder.totalSize()) + result + } +} + +/** + * Helper functions for creating an [[InterpretedUnsafeProjection]]. + */ +object InterpretedUnsafeProjection extends UnsafeProjectionCreator { + + /** + * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. + */ + override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { + // We need to make sure that we do not reuse stateful expressions. + val cleanedExpressions = exprs.map(_.transform { + case s: Stateful => s.freshCopy() + }) + new InterpretedUnsafeProjection(cleanedExpressions.toArray) + } + + /** + * Generate a struct writer function. The generated function writes an [[InternalRow]] to the + * given buffer using the given [[UnsafeRowWriter]]. + */ + private def generateStructWriter( + bufferHolder: BufferHolder, + rowWriter: UnsafeRowWriter, + fields: Array[StructField]): InternalRow => Unit = { + val numFields = fields.length + + // Create field writers. + val fieldWriters = fields.map { field => + generateFieldWriter(bufferHolder, rowWriter, field.dataType, field.nullable) + } + // Create basic writer. + row => { + var i = 0 + while (i < numFields) { + fieldWriters(i).apply(row, i) + i += 1 + } + } + } + + /** + * Generate a writer function for a struct field, array element, map key or map value. The + * generated function writes the element at an index in a [[SpecializedGetters]] object (row + * or array) to the given buffer using the given [[UnsafeWriter]]. + */ + private def generateFieldWriter( + bufferHolder: BufferHolder, + writer: UnsafeWriter, + dt: DataType, + nullable: Boolean): (SpecializedGetters, Int) => Unit = { + + // Create the the basic writer. + val unsafeWriter: (SpecializedGetters, Int) => Unit = dt match { + case BooleanType => + (v, i) => writer.write(i, v.getBoolean(i)) + + case ByteType => + (v, i) => writer.write(i, v.getByte(i)) + + case ShortType => + (v, i) => writer.write(i, v.getShort(i)) + + case IntegerType | DateType => + (v, i) => writer.write(i, v.getInt(i)) + + case LongType | TimestampType => + (v, i) => writer.write(i, v.getLong(i)) + + case FloatType => + (v, i) => writer.write(i, v.getFloat(i)) + + case DoubleType => + (v, i) => writer.write(i, v.getDouble(i)) + + case DecimalType.Fixed(precision, scale) => + (v, i) => writer.write(i, v.getDecimal(i, precision, scale), precision, scale) + + case CalendarIntervalType => + (v, i) => writer.write(i, v.getInterval(i)) + + case BinaryType => + (v, i) => writer.write(i, v.getBinary(i)) + + case StringType => + (v, i) => writer.write(i, v.getUTF8String(i)) + + case StructType(fields) => + val numFields = fields.length + val rowWriter = new UnsafeRowWriter(bufferHolder, numFields) + val structWriter = generateStructWriter(bufferHolder, rowWriter, fields) + (v, i) => { + val tmpCursor = bufferHolder.cursor + v.getStruct(i, fields.length) match { + case row: UnsafeRow => + writeUnsafeData( + bufferHolder, + row.getBaseObject, + row.getBaseOffset, + row.getSizeInBytes) + case row => + // Nested struct. We don't know where this will start because a row can be + // variable length, so we need to update the offsets and zero out the bit mask. + rowWriter.reset() + structWriter.apply(row) + } + writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + } + + case ArrayType(elementType, containsNull) => + val arrayWriter = new UnsafeArrayWriter + val elementSize = getElementSize(elementType) + val elementWriter = generateFieldWriter( + bufferHolder, + arrayWriter, + elementType, + containsNull) + (v, i) => { + val tmpCursor = bufferHolder.cursor + writeArray(bufferHolder, arrayWriter, elementWriter, v.getArray(i), elementSize) + writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + } + + case MapType(keyType, valueType, valueContainsNull) => + val keyArrayWriter = new UnsafeArrayWriter + val keySize = getElementSize(keyType) + val keyWriter = generateFieldWriter( + bufferHolder, + keyArrayWriter, + keyType, + nullable = false) + val valueArrayWriter = new UnsafeArrayWriter + val valueSize = getElementSize(valueType) + val valueWriter = generateFieldWriter( + bufferHolder, + valueArrayWriter, + valueType, + valueContainsNull) + (v, i) => { + val tmpCursor = bufferHolder.cursor + v.getMap(i) match { + case map: UnsafeMapData => + writeUnsafeData( + bufferHolder, + map.getBaseObject, + map.getBaseOffset, + map.getSizeInBytes) + case map => + // preserve 8 bytes to write the key array numBytes later. + bufferHolder.grow(8) + bufferHolder.cursor += 8 + + // Write the keys and write the numBytes of key array into the first 8 bytes. + writeArray(bufferHolder, keyArrayWriter, keyWriter, map.keyArray(), keySize) + Platform.putLong(bufferHolder.buffer, tmpCursor, bufferHolder.cursor - tmpCursor - 8) + + // Write the values. + writeArray(bufferHolder, valueArrayWriter, valueWriter, map.valueArray(), valueSize) + } + writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + } + + case udt: UserDefinedType[_] => + generateFieldWriter(bufferHolder, writer, udt.sqlType, nullable) + + case NullType => + (_, _) => {} + + case _ => + throw new SparkException(s"Unsupported data type $dt") + } + + // Always wrap the writer with a null safe version. + dt match { + case _: UserDefinedType[_] => + // The null wrapper depends on the sql type and not on the UDT. + unsafeWriter + case DecimalType.Fixed(precision, _) if precision > Decimal.MAX_LONG_DIGITS => + // We can't call setNullAt() for DecimalType with precision larger than 18, we call write + // directly. We can use the unwrapped writer directly. + unsafeWriter + case BooleanType | ByteType => + (v, i) => { + if (!v.isNullAt(i)) { + unsafeWriter(v, i) + } else { + writer.setNull1Bytes(i) + } + } + case ShortType => + (v, i) => { + if (!v.isNullAt(i)) { + unsafeWriter(v, i) + } else { + writer.setNull2Bytes(i) + } + } + case IntegerType | DateType | FloatType => + (v, i) => { + if (!v.isNullAt(i)) { + unsafeWriter(v, i) + } else { + writer.setNull4Bytes(i) + } + } + case _ => + (v, i) => { + if (!v.isNullAt(i)) { + unsafeWriter(v, i) + } else { + writer.setNull8Bytes(i) + } + } + } + } + + /** + * Get the number of bytes elements of a data type will occupy in the fixed part of an + * [[UnsafeArrayData]] object. Reference types are stored as an 8 byte combination of an + * offset (upper 4 bytes) and a length (lower 4 bytes), these point to the variable length + * portion of the array object. Primitives take up to 8 bytes, depending on the size of the + * underlying data type. + */ + private def getElementSize(dataType: DataType): Int = dataType match { + case NullType | StringType | BinaryType | CalendarIntervalType | + _: DecimalType | _: StructType | _: ArrayType | _: MapType => 8 + case _ => dataType.defaultSize + } + + /** + * Write an array to the buffer. If the array is already in serialized form (an instance of + * [[UnsafeArrayData]]) then we copy the bytes directly, otherwise we do an element-by-element + * copy. + */ + private def writeArray( + bufferHolder: BufferHolder, + arrayWriter: UnsafeArrayWriter, + elementWriter: (SpecializedGetters, Int) => Unit, + array: ArrayData, + elementSize: Int): Unit = array match { + case unsafe: UnsafeArrayData => + writeUnsafeData( + bufferHolder, + unsafe.getBaseObject, + unsafe.getBaseOffset, + unsafe.getSizeInBytes) + case _ => + val numElements = array.numElements() + arrayWriter.initialize(bufferHolder, numElements, elementSize) + var i = 0 + while (i < numElements) { + elementWriter.apply(array, i) + i += 1 + } + } + + /** + * Write an opaque block of data to the buffer. This is used to copy + * [[UnsafeRow]], [[UnsafeArrayData]] and [[UnsafeMapData]] objects. + */ + private def writeUnsafeData( + bufferHolder: BufferHolder, + baseObject: AnyRef, + baseOffset: Long, + sizeInBytes: Int) : Unit = { + bufferHolder.grow(sizeInBytes) + Platform.copyMemory( + baseObject, + baseOffset, + bufferHolder.buffer, + bufferHolder.cursor, + sizeInBytes) + bufferHolder.cursor += sizeInBytes + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 4523079060896..dd523d312e3b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.types.{DataType, LongType} within each partition. The assumption is that the data frame has less than 1 billion partitions, and each partition has less than 8 billion records. """) -case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterministic { +case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { /** * Record ID within each partition. By being transient, count's value is reset to 0 every time @@ -79,4 +79,6 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis override def prettyName: String = "monotonically_increasing_id" override def sql: String = s"$prettyName()" + + override def freshCopy(): MonotonicallyIncreasingID = MonotonicallyIncreasingID() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 64b94f0a2c103..3cd73682188bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -108,8 +108,7 @@ abstract class UnsafeProjection extends Projection { override def apply(row: InternalRow): UnsafeRow } -object UnsafeProjection { - +trait UnsafeProjectionCreator { /** * Returns an UnsafeProjection for given StructType. * @@ -127,13 +126,13 @@ object UnsafeProjection { } /** - * Returns an UnsafeProjection for given sequence of Expressions (bounded). + * Returns an UnsafeProjection for given sequence of bound Expressions. */ def create(exprs: Seq[Expression]): UnsafeProjection = { val unsafeExprs = exprs.map(_ transform { case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) }) - GenerateUnsafeProjection.generate(unsafeExprs) + createProjection(unsafeExprs) } def create(expr: Expression): UnsafeProjection = create(Seq(expr)) @@ -146,6 +145,18 @@ object UnsafeProjection { create(exprs.map(BindReferences.bindReference(_, inputSchema))) } + /** + * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. + */ + protected def createProjection(exprs: Seq[Expression]): UnsafeProjection +} + +object UnsafeProjection extends UnsafeProjectionCreator { + + override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { + GenerateUnsafeProjection.generate(exprs) + } + /** * Same as other create()'s but allowing enabling/disabling subexpression elimination. * TODO: refactor the plumbing and clean this up. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 22717f5954a45..6682ba55b18b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -247,7 +247,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro for (int $index = 0; $index < $numElements; $index++) { if ($tmpInput.isNullAt($index)) { - $arrayWriter.setNull$primitiveTypeName($index); + $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); } else { $writeElement } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 6c9937dacc70b..f36633867316e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.random.XORShiftRandom * * Since this expression is stateful, it cannot be a case object. */ -abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic { +abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful { /** * Record ID within each partition. By being transient, the Random Number Generator is * reset every time we serialize and deserialize and initialize it. @@ -85,6 +85,8 @@ case class Rand(child: Expression) extends RDG { final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false") } + + override def freshCopy(): Rand = Rand(child) } object Rand { @@ -120,6 +122,8 @@ case class Randn(child: Expression) extends RDG { final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") } + + override def freshCopy(): Randn = Randn(child) } object Randn { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 84190f0bd5f7d..b4138ce366b3a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -180,7 +180,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { null, null) } intercept[RuntimeException] { - checkEvalutionWithUnsafeProjection( + checkEvaluationWithUnsafeProjection( CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), null, null) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 58d0c07622eb9..c6343b1cbf600 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -60,7 +60,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) if (GenerateUnsafeProjection.canSupport(expr.dataType)) { - checkEvalutionWithUnsafeProjection(expr, catalystValue, inputRow) + checkEvaluationWithUnsafeProjection(expr, catalystValue, inputRow) } checkEvaluationWithOptimization(expr, catalystValue, inputRow) } @@ -187,11 +187,20 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { plan(inputRow).get(0, expression.dataType) } - protected def checkEvalutionWithUnsafeProjection( + protected def checkEvaluationWithUnsafeProjection( expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow) + checkEvaluationWithUnsafeProjection(expression, expected, inputRow, UnsafeProjection) + checkEvaluationWithUnsafeProjection(expression, expected, inputRow, InterpretedUnsafeProjection) + } + + protected def checkEvaluationWithUnsafeProjection( + expression: Expression, + expected: Any, + inputRow: InternalRow, + factory: UnsafeProjectionCreator): Unit = { + val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow, factory) val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" if (expected == null) { @@ -203,7 +212,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } else { val lit = InternalRow(expected, expected) val expectedRow = - UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit) + factory.create(Array(expression.dataType, expression.dataType)).apply(lit) if (unsafeRow != expectedRow) { fail("Incorrect evaluation in unsafe mode: " + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") @@ -213,7 +222,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { private def evaluateWithUnsafeProjection( expression: Expression, - inputRow: InternalRow = EmptyRow): InternalRow = { + inputRow: InternalRow = EmptyRow, + factory: UnsafeProjectionCreator = UnsafeProjection): InternalRow = { // SPARK-16489 Explicitly doing code generation twice so code gen will fail if // some expression is reusing variable names across different instances. // This behavior is tested in ExpressionEvalHelperSuite. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index ffeec2a38c532..1f6964dfef598 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -45,16 +45,22 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val structInputRow = InternalRow.fromSeq(Seq(Array((1, 2), (3, 4)))) val structExpected = new GenericArrayData( Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4)))) - checkEvalutionWithUnsafeProjection( - structEncoder.serializer.head, structExpected, structInputRow) + checkEvaluationWithUnsafeProjection( + structEncoder.serializer.head, + structExpected, + structInputRow, + UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed // test UnsafeArray-backed data val arrayEncoder = ExpressionEncoder[Array[Array[Int]]] val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4)))) val arrayExpected = new GenericArrayData( Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4)))) - checkEvalutionWithUnsafeProjection( - arrayEncoder.serializer.head, arrayExpected, arrayInputRow) + checkEvaluationWithUnsafeProjection( + arrayEncoder.serializer.head, + arrayExpected, + arrayInputRow, + UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed // test UnsafeMap-backed data val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]] @@ -67,8 +73,11 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { new ArrayBasedMapData( new GenericArrayData(Array(3, 4)), new GenericArrayData(Array(300, 400))))) - checkEvalutionWithUnsafeProjection( - mapEncoder.serializer.head, mapExpected, mapInputRow) + checkEvaluationWithUnsafeProjection( + mapEncoder.serializer.head, + mapExpected, + mapInputRow, + UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed } test("SPARK-23585: UnwrapOption should support interpreted execution") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index 10e3ffd0dff97..e083ae0089244 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -43,7 +43,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { assert(e1.getMessage.contains("Failed to execute user defined function")) val e2 = intercept[SparkException] { - checkEvalutionWithUnsafeProjection(udf, null) + checkEvaluationWithUnsafeProjection(udf, null) } assert(e2.getMessage.contains("Failed to execute user defined function")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index cf3cbe270753e..c07da122cd7b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.{IntegerType, LongType, _} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String @@ -33,10 +33,18 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { private def roundedSize(size: Int) = ByteArrayMethods.roundNumberOfBytesToNearestWord(size) - test("basic conversion with only primitive types") { - val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) - val converter = UnsafeProjection.create(fieldTypes) + private def testWithFactory( + name: String)( + f: UnsafeProjectionCreator => Unit): Unit = { + test(name) { + f(UnsafeProjection) + f(InterpretedUnsafeProjection) + } + } + testWithFactory("basic conversion with only primitive types") { factory => + val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) + val converter = factory.create(fieldTypes) val row = new SpecificInternalRow(fieldTypes) row.setLong(0, 0) row.setLong(1, 1) @@ -71,9 +79,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow2.getInt(2) === 2) } - test("basic conversion with primitive, string and binary types") { + testWithFactory("basic conversion with primitive, string and binary types") { factory => val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new SpecificInternalRow(fieldTypes) row.setLong(0, 0) @@ -90,9 +98,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getBinary(2) === "World".getBytes(StandardCharsets.UTF_8)) } - test("basic conversion with primitive, string, date and timestamp types") { + testWithFactory("basic conversion with primitive, string, date and timestamp types") { factory => val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new SpecificInternalRow(fieldTypes) row.setLong(0, 0) @@ -119,7 +127,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { (Timestamp.valueOf("2015-06-22 08:10:25")) } - test("null handling") { + testWithFactory("null handling") { factory => val fieldTypes: Array[DataType] = Array( NullType, BooleanType, @@ -135,7 +143,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { DecimalType.SYSTEM_DEFAULT // ArrayType(IntegerType) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val rowWithAllNullColumns: InternalRow = { val r = new SpecificInternalRow(fieldTypes) @@ -240,7 +248,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } - test("NaN canonicalization") { + testWithFactory("NaN canonicalization") { factory => val fieldTypes: Array[DataType] = Array(FloatType, DoubleType) val row1 = new SpecificInternalRow(fieldTypes) @@ -251,17 +259,17 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff)) row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes) } - test("basic conversion with struct type") { + testWithFactory("basic conversion with struct type") { factory => val fieldTypes: Array[DataType] = Array( new StructType().add("i", IntegerType), new StructType().add("nest", new StructType().add("l", LongType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, InternalRow(1)) @@ -317,12 +325,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(map.getSizeInBytes == 8 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes) } - test("basic conversion with array type") { + testWithFactory("basic conversion with array type") { factory => val fieldTypes: Array[DataType] = Array( ArrayType(IntegerType), ArrayType(ArrayType(IntegerType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, createArray(1, 2)) @@ -347,12 +355,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size) } - test("basic conversion with map type") { + testWithFactory("basic conversion with map type") { factory => val fieldTypes: Array[DataType] = Array( MapType(IntegerType, IntegerType), MapType(IntegerType, MapType(IntegerType, IntegerType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val map1 = createMap(1, 2)(3, 4) @@ -393,12 +401,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size) } - test("basic conversion with struct and array") { + testWithFactory("basic conversion with struct and array") { factory => val fieldTypes: Array[DataType] = Array( new StructType().add("arr", ArrayType(IntegerType)), ArrayType(new StructType().add("l", LongType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, InternalRow(createArray(1))) @@ -432,12 +440,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } - test("basic conversion with struct and map") { + testWithFactory("basic conversion with struct and map") { factory => val fieldTypes: Array[DataType] = Array( new StructType().add("map", MapType(IntegerType, IntegerType)), MapType(IntegerType, new StructType().add("l", LongType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, InternalRow(createMap(1)(2))) @@ -478,12 +486,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } - test("basic conversion with array and map") { + testWithFactory("basic conversion with array and map") { factory => val fieldTypes: Array[DataType] = Array( ArrayType(MapType(IntegerType, IntegerType)), MapType(IntegerType, ArrayType(IntegerType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, createArray(createMap(1)(2))) From 9945b0227efcd952c8e835453b2831a8c6d5d607 Mon Sep 17 00:00:00 2001 From: Ricardo Martinelli de Oliveira Date: Fri, 16 Mar 2018 10:37:11 -0700 Subject: [PATCH 484/774] [SPARK-23680] Fix entrypoint.sh to properly support Arbitrary UIDs ## What changes were proposed in this pull request? As described in SPARK-23680, entrypoint.sh returns an error code because of a command pipeline execution where it is expected in case of Openshift environments, where arbitrary UIDs are used to run containers ## How was this patch tested? This patch was manually tested by using docker-image-toll.sh script to generate a Spark driver image and running an example against an OpenShift cluster. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Ricardo Martinelli de Oliveira Closes #20822 from rimolive/rmartine-spark-23680. --- .../kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 3d67b0a702dd4..d0cf284f035ea 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -22,7 +22,10 @@ set -ex # Check whether there is a passwd entry for the container UID myuid=$(id -u) mygid=$(id -g) +# turn off -e for getent because it will return error code in anonymous uid case +set +e uidentry=$(getent passwd $myuid) +set -e # If there is no passwd entry for the container UID, attempt to create one if [ -z "$uidentry" ] ; then From bd201bf61e8e1713deb91b962f670c76c9e3492b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 16 Mar 2018 11:11:07 -0700 Subject: [PATCH 485/774] [SPARK-23623][SS] Avoid concurrent use of cached consumers in CachedKafkaConsumer ## What changes were proposed in this pull request? CacheKafkaConsumer in the project `kafka-0-10-sql` is designed to maintain a pool of KafkaConsumers that can be reused. However, it was built with the assumption there will be only one task using trying to read the same Kafka TopicPartition at the same time. Hence, the cache was keyed by the TopicPartition a consumer is supposed to read. And any cases where this assumption may not be true, we have SparkPlan flag to disable the use of a cache. So it was up to the planner to correctly identify when it was not safe to use the cache and set the flag accordingly. Fundamentally, this is the wrong way to approach the problem. It is HARD for a high-level planner to reason about the low-level execution model, whether there will be multiple tasks in the same query trying to read the same partition. Case in point, 2.3.0 introduced stream-stream joins, and you can build a streaming self-join query on Kafka. It's pretty non-trivial to figure out how this leads to two tasks reading the same partition twice, possibly concurrently. And due to the non-triviality, it is hard to figure this out in the planner and set the flag to avoid the cache / consumer pool. And this can inadvertently lead to ConcurrentModificationException ,or worse, silent reading of incorrect data. Here is a better way to design this. The planner shouldnt have to understand these low-level optimizations. Rather the consumer pool should be smart enough avoid concurrent use of a cached consumer. Currently, it tries to do so but incorrectly (the flag inuse is not checked when returning a cached consumer, see [this](https://github.com/apache/spark/blob/master/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala#L403)). If there is another request for the same partition as a currently in-use consumer, the pool should automatically return a fresh consumer that should be closed when the task is done. Then the planner does not have to have a flag to avoid reuses. This PR is a step towards that goal. It does the following. - There are effectively two kinds of consumer that may be generated - Cached consumer - this should be returned to the pool at task end - Non-cached consumer - this should be closed at task end - A trait called KafkaConsumer is introduced to hide this difference from the users of the consumer so that the client code does not have to reason about whether to stop and release. They simply called `val consumer = KafkaConsumer.acquire` and then `consumer.release()`. - If there is request for a consumer that is in-use, then a new consumer is generated. - If there is a concurrent attempt of the same task, then a new consumer is generated, and the existing cached consumer is marked for close upon release. - In addition, I renamed the classes because CachedKafkaConsumer is a misnomer given that what it returns may or may not be cached. This PR does not remove the planner flag to avoid reuse to make this patch safe enough for merging in branch-2.3. This can be done later in master-only. ## How was this patch tested? A new stress test that verifies it is safe to concurrently get consumers for the same partition from the consumer pool. Author: Tathagata Das Closes #20767 from tdas/SPARK-23623. --- .../sql/kafka010/KafkaContinuousReader.scala | 5 +- ...Consumer.scala => KafkaDataConsumer.scala} | 242 ++++++++++++------ .../sql/kafka010/KafkaMicroBatchReader.scala | 22 +- .../spark/sql/kafka010/KafkaSourceRDD.scala | 23 +- .../kafka010/CachedKafkaConsumerSuite.scala | 34 --- .../sql/kafka010/KafkaDataConsumerSuite.scala | 124 +++++++++ 6 files changed, 295 insertions(+), 155 deletions(-) rename external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/{CachedKafkaConsumer.scala => KafkaDataConsumer.scala} (66%) delete mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 6e56b0a72d671..e7e27876088f3 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -196,8 +196,7 @@ class KafkaContinuousDataReader( kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { - private val consumer = - CachedKafkaConsumer.createUncached(topicPartition.topic, topicPartition.partition, kafkaParams) + private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false) private val converter = new KafkaRecordToUnsafeRowConverter private var nextKafkaOffset = startOffset @@ -245,6 +244,6 @@ class KafkaContinuousDataReader( } override def close(): Unit = { - consumer.close() + consumer.release() } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala similarity index 66% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala index e97881cb0a163..48508d057a540 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala @@ -27,30 +27,73 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging +import org.apache.spark.sql.kafka010.KafkaDataConsumer.AvailableOffsetRange import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.util.UninterruptibleThread +private[kafka010] sealed trait KafkaDataConsumer { + /** + * Get the record for the given offset if available. Otherwise it will either throw error + * (if failOnDataLoss = true), or return the next available offset within [offset, untilOffset), + * or null. + * + * @param offset the offset to fetch. + * @param untilOffset the max offset to fetch. Exclusive. + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + * @param failOnDataLoss When `failOnDataLoss` is `true`, this method will either return record at + * offset if available, or throw exception.when `failOnDataLoss` is `false`, + * this method will either return record at offset if available, or return + * the next earliest available record less than untilOffset, or null. It + * will not throw any exception. + */ + def get( + offset: Long, + untilOffset: Long, + pollTimeoutMs: Long, + failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = { + internalConsumer.get(offset, untilOffset, pollTimeoutMs, failOnDataLoss) + } + + /** + * Return the available offset range of the current partition. It's a pair of the earliest offset + * and the latest offset. + */ + def getAvailableOffsetRange(): AvailableOffsetRange = internalConsumer.getAvailableOffsetRange() + + /** + * Release this consumer from being further used. Depending on its implementation, + * this consumer will be either finalized, or reset for reuse later. + */ + def release(): Unit + + /** Reference to the internal implementation that this wrapper delegates to */ + protected def internalConsumer: InternalKafkaConsumer +} + /** - * Consumer of single topicpartition, intended for cached reuse. - * Underlying consumer is not threadsafe, so neither is this, - * but processing the same topicpartition and group id in multiple threads is usually bad anyway. + * A wrapper around Kafka's KafkaConsumer that throws error when data loss is detected. + * This is not for direct use outside this file. */ -private[kafka010] case class CachedKafkaConsumer private( +private[kafka010] case class InternalKafkaConsumer( topicPartition: TopicPartition, kafkaParams: ju.Map[String, Object]) extends Logging { - import CachedKafkaConsumer._ + import InternalKafkaConsumer._ private val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - private var consumer = createConsumer + @volatile private var consumer = createConsumer /** indicates whether this consumer is in use or not */ - private var inuse = true + @volatile var inUse = true + + /** indicate whether this consumer is going to be stopped in the next release */ + @volatile var markedForClose = false /** Iterator to the already fetch data */ - private var fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] - private var nextOffsetInFetchedData = UNKNOWN_OFFSET + @volatile private var fetchedData = + ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] + @volatile private var nextOffsetInFetchedData = UNKNOWN_OFFSET /** Create a KafkaConsumer to fetch records for `topicPartition` */ private def createConsumer: KafkaConsumer[Array[Byte], Array[Byte]] = { @@ -61,8 +104,6 @@ private[kafka010] case class CachedKafkaConsumer private( c } - case class AvailableOffsetRange(earliest: Long, latest: Long) - private def runUninterruptiblyIfPossible[T](body: => T): T = Thread.currentThread match { case ut: UninterruptibleThread => ut.runUninterruptibly(body) @@ -313,21 +354,51 @@ private[kafka010] case class CachedKafkaConsumer private( } } -private[kafka010] object CachedKafkaConsumer extends Logging { - private val UNKNOWN_OFFSET = -2L +private[kafka010] object KafkaDataConsumer extends Logging { + + case class AvailableOffsetRange(earliest: Long, latest: Long) + + private case class CachedKafkaDataConsumer(internalConsumer: InternalKafkaConsumer) + extends KafkaDataConsumer { + assert(internalConsumer.inUse) // make sure this has been set to true + override def release(): Unit = { KafkaDataConsumer.release(internalConsumer) } + } + + private case class NonCachedKafkaDataConsumer(internalConsumer: InternalKafkaConsumer) + extends KafkaDataConsumer { + override def release(): Unit = { internalConsumer.close() } + } - private case class CacheKey(groupId: String, topicPartition: TopicPartition) + private case class CacheKey(groupId: String, topicPartition: TopicPartition) { + def this(topicPartition: TopicPartition, kafkaParams: ju.Map[String, Object]) = + this(kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String], topicPartition) + } + // This cache has the following important properties. + // - We make a best-effort attempt to maintain the max size of the cache as configured capacity. + // The capacity is not guaranteed to be maintained, especially when there are more active + // tasks simultaneously using consumers than the capacity. private lazy val cache = { val conf = SparkEnv.get.conf val capacity = conf.getInt("spark.sql.kafkaConsumerCache.capacity", 64) - new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer](capacity, 0.75f, true) { + new ju.LinkedHashMap[CacheKey, InternalKafkaConsumer](capacity, 0.75f, true) { override def removeEldestEntry( - entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer]): Boolean = { - if (entry.getValue.inuse == false && this.size > capacity) { - logWarning(s"KafkaConsumer cache hitting max capacity of $capacity, " + - s"removing consumer for ${entry.getKey}") + entry: ju.Map.Entry[CacheKey, InternalKafkaConsumer]): Boolean = { + + // Try to remove the least-used entry if its currently not in use. + // + // If you cannot remove it, then the cache will keep growing. In the worst case, + // the cache will grow to the max number of concurrent tasks that can run in the executor, + // (that is, number of tasks slots) after which it will never reduce. This is unlikely to + // be a serious problem because an executor with more than 64 (default) tasks slots is + // likely running on a beefy machine that can handle a large number of simultaneously + // active consumers. + + if (entry.getValue.inUse == false && this.size > capacity) { + logWarning( + s"KafkaConsumer cache hitting max capacity of $capacity, " + + s"removing consumer for ${entry.getKey}") try { entry.getValue.close() } catch { @@ -342,80 +413,87 @@ private[kafka010] object CachedKafkaConsumer extends Logging { } } - def releaseKafkaConsumer( - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): Unit = { - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - val topicPartition = new TopicPartition(topic, partition) - val key = CacheKey(groupId, topicPartition) - - synchronized { - val consumer = cache.get(key) - if (consumer != null) { - consumer.inuse = false - } else { - logWarning(s"Attempting to release consumer that does not exist") - } - } - } - /** - * Removes (and closes) the Kafka Consumer for the given topic, partition and group id. + * Get a cached consumer for groupId, assigned to topic and partition. + * If matching consumer doesn't already exist, will be created using kafkaParams. + * The returned consumer must be released explicitly using [[KafkaDataConsumer.release()]]. + * + * Note: This method guarantees that the consumer returned is not currently in use by any one + * else. Within this guarantee, this method will make a best effort attempt to re-use consumers by + * caching them and tracking when they are in use. */ - def removeKafkaConsumer( - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): Unit = { - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - val topicPartition = new TopicPartition(topic, partition) - val key = CacheKey(groupId, topicPartition) + def acquire( + topicPartition: TopicPartition, + kafkaParams: ju.Map[String, Object], + useCache: Boolean): KafkaDataConsumer = synchronized { + val key = new CacheKey(topicPartition, kafkaParams) + val existingInternalConsumer = cache.get(key) - synchronized { - val removedConsumer = cache.remove(key) - if (removedConsumer != null) { - removedConsumer.close() + lazy val newInternalConsumer = new InternalKafkaConsumer(topicPartition, kafkaParams) + + if (TaskContext.get != null && TaskContext.get.attemptNumber >= 1) { + // If this is reattempt at running the task, then invalidate cached consumer if any and + // start with a new one. + if (existingInternalConsumer != null) { + // Consumer exists in cache. If its in use, mark it for closing later, or close it now. + if (existingInternalConsumer.inUse) { + existingInternalConsumer.markedForClose = true + } else { + existingInternalConsumer.close() + } } + cache.remove(key) // Invalidate the cache in any case + NonCachedKafkaDataConsumer(newInternalConsumer) + + } else if (!useCache) { + // If planner asks to not reuse consumers, then do not use it, return a new consumer + NonCachedKafkaDataConsumer(newInternalConsumer) + + } else if (existingInternalConsumer == null) { + // If consumer is not already cached, then put a new in the cache and return it + cache.put(key, newInternalConsumer) + newInternalConsumer.inUse = true + CachedKafkaDataConsumer(newInternalConsumer) + + } else if (existingInternalConsumer.inUse) { + // If consumer is already cached but is currently in use, then return a new consumer + NonCachedKafkaDataConsumer(newInternalConsumer) + + } else { + // If consumer is already cached and is currently not in use, then return that consumer + existingInternalConsumer.inUse = true + CachedKafkaDataConsumer(existingInternalConsumer) } } - /** - * Get a cached consumer for groupId, assigned to topic and partition. - * If matching consumer doesn't already exist, will be created using kafkaParams. - */ - def getOrCreate( - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = synchronized { - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - val topicPartition = new TopicPartition(topic, partition) - val key = CacheKey(groupId, topicPartition) - - // If this is reattempt at running the task, then invalidate cache and start with - // a new consumer - if (TaskContext.get != null && TaskContext.get.attemptNumber >= 1) { - removeKafkaConsumer(topic, partition, kafkaParams) - val consumer = new CachedKafkaConsumer(topicPartition, kafkaParams) - consumer.inuse = true - cache.put(key, consumer) - consumer - } else { - if (!cache.containsKey(key)) { - cache.put(key, new CachedKafkaConsumer(topicPartition, kafkaParams)) + private def release(intConsumer: InternalKafkaConsumer): Unit = { + synchronized { + + // Clear the consumer from the cache if this is indeed the consumer present in the cache + val key = new CacheKey(intConsumer.topicPartition, intConsumer.kafkaParams) + val cachedIntConsumer = cache.get(key) + if (intConsumer.eq(cachedIntConsumer)) { + // The released consumer is the same object as the cached one. + if (intConsumer.markedForClose) { + intConsumer.close() + cache.remove(key) + } else { + intConsumer.inUse = false + } + } else { + // The released consumer is either not the same one as in the cache, or not in the cache + // at all. This may happen if the cache was invalidate while this consumer was being used. + // Just close this consumer. + intConsumer.close() + logInfo(s"Released a supposedly cached consumer that was not found in the cache") } - val consumer = cache.get(key) - consumer.inuse = true - consumer } } +} - /** Create an [[CachedKafkaConsumer]] but don't put it into cache. */ - def createUncached( - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = { - new CachedKafkaConsumer(new TopicPartition(topic, partition), kafkaParams) - } +private[kafka010] object InternalKafkaConsumer extends Logging { + + private val UNKNOWN_OFFSET = -2L private def reportDataLoss0( failOnDataLoss: Boolean, diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index 8a5f3a249b11c..2ed49ba3f5495 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -321,17 +321,8 @@ private[kafka010] case class KafkaMicroBatchDataReader( failOnDataLoss: Boolean, reuseKafkaConsumer: Boolean) extends DataReader[UnsafeRow] with Logging { - private val consumer = { - if (!reuseKafkaConsumer) { - // If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. We - // uses `assign` here, hence we don't need to worry about the "group.id" conflicts. - CachedKafkaConsumer.createUncached( - offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) - } else { - CachedKafkaConsumer.getOrCreate( - offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) - } - } + private val consumer = KafkaDataConsumer.acquire( + offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) private val rangeToRead = resolveRange(offsetRange) private val converter = new KafkaRecordToUnsafeRowConverter @@ -360,14 +351,7 @@ private[kafka010] case class KafkaMicroBatchDataReader( } override def close(): Unit = { - if (!reuseKafkaConsumer) { - // Don't forget to close non-reuse KafkaConsumers. You may take down your cluster! - consumer.close() - } else { - // Indicate that we're no longer using this consumer - CachedKafkaConsumer.releaseKafkaConsumer( - offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) - } + consumer.release() } private def resolveRange(range: KafkaOffsetRange): KafkaOffsetRange = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala index 66b3409c0cd04..498e344ea39f4 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -52,7 +52,7 @@ private[kafka010] case class KafkaSourceRDDPartition( * An RDD that reads data from Kafka based on offset ranges across multiple partitions. * Additionally, it allows preferred locations to be set for each topic + partition, so that * the [[KafkaSource]] can ensure the same executor always reads the same topic + partition - * and cached KafkaConsumers (see [[CachedKafkaConsumer]] can be used read data efficiently. + * and cached KafkaConsumers (see [[KafkaDataConsumer]] can be used read data efficiently. * * @param sc the [[SparkContext]] * @param executorKafkaParams Kafka configuration for creating KafkaConsumer on the executors @@ -126,14 +126,9 @@ private[kafka010] class KafkaSourceRDD( val sourcePartition = thePart.asInstanceOf[KafkaSourceRDDPartition] val topic = sourcePartition.offsetRange.topic val kafkaPartition = sourcePartition.offsetRange.partition - val consumer = - if (!reuseKafkaConsumer) { - // If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. As here we - // uses `assign`, we don't need to worry about the "group.id" conflicts. - CachedKafkaConsumer.createUncached(topic, kafkaPartition, executorKafkaParams) - } else { - CachedKafkaConsumer.getOrCreate(topic, kafkaPartition, executorKafkaParams) - } + val consumer = KafkaDataConsumer.acquire( + sourcePartition.offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) + val range = resolveRange(consumer, sourcePartition.offsetRange) assert( range.fromOffset <= range.untilOffset, @@ -167,13 +162,7 @@ private[kafka010] class KafkaSourceRDD( } override protected def close(): Unit = { - if (!reuseKafkaConsumer) { - // Don't forget to close non-reuse KafkaConsumers. You may take down your cluster! - consumer.close() - } else { - // Indicate that we're no longer using this consumer - CachedKafkaConsumer.releaseKafkaConsumer(topic, kafkaPartition, executorKafkaParams) - } + consumer.release() } } // Release consumer, either by removing it or indicating we're no longer using it @@ -184,7 +173,7 @@ private[kafka010] class KafkaSourceRDD( } } - private def resolveRange(consumer: CachedKafkaConsumer, range: KafkaSourceRDDOffsetRange) = { + private def resolveRange(consumer: KafkaDataConsumer, range: KafkaSourceRDDOffsetRange) = { if (range.fromOffset < 0 || range.untilOffset < 0) { // Late bind the offset range val availableOffsetRange = consumer.getAvailableOffsetRange() diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala deleted file mode 100644 index 7aa7dd096c07b..0000000000000 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import org.scalatest.PrivateMethodTester - -import org.apache.spark.sql.test.SharedSQLContext - -class CachedKafkaConsumerSuite extends SharedSQLContext with PrivateMethodTester { - - test("SPARK-19886: Report error cause correctly in reportDataLoss") { - val cause = new Exception("D'oh!") - val reportDataLoss = PrivateMethod[Unit]('reportDataLoss0) - val e = intercept[IllegalStateException] { - CachedKafkaConsumer.invokePrivate(reportDataLoss(true, "message", cause)) - } - assert(e.getCause === cause) - } -} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala new file mode 100644 index 0000000000000..0d0fb9c3ab5af --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.concurrent.{Executors, TimeUnit} + +import scala.collection.JavaConverters._ +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.Duration +import scala.util.Random + +import org.apache.kafka.clients.consumer.ConsumerConfig +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.serialization.ByteArrayDeserializer +import org.scalatest.PrivateMethodTester + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.ThreadUtils + +class KafkaDataConsumerSuite extends SharedSQLContext with PrivateMethodTester { + + protected var testUtils: KafkaTestUtils = _ + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils(Map[String, Object]()) + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } + + test("SPARK-19886: Report error cause correctly in reportDataLoss") { + val cause = new Exception("D'oh!") + val reportDataLoss = PrivateMethod[Unit]('reportDataLoss0) + val e = intercept[IllegalStateException] { + InternalKafkaConsumer.invokePrivate(reportDataLoss(true, "message", cause)) + } + assert(e.getCause === cause) + } + + test("SPARK-23623: concurrent use of KafkaDataConsumer") { + val topic = "topic" + Random.nextInt() + val data = (1 to 1000).map(_.toString) + testUtils.createTopic(topic, 1) + testUtils.sendMessages(topic, data.toArray) + val topicPartition = new TopicPartition(topic, 0) + + import ConsumerConfig._ + val kafkaParams = Map[String, Object]( + GROUP_ID_CONFIG -> "groupId", + BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress, + KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + AUTO_OFFSET_RESET_CONFIG -> "earliest", + ENABLE_AUTO_COMMIT_CONFIG -> "false" + ) + + val numThreads = 100 + val numConsumerUsages = 500 + + @volatile var error: Throwable = null + + def consume(i: Int): Unit = { + val useCache = Random.nextBoolean + val taskContext = if (Random.nextBoolean) { + new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), null, null, null) + } else { + null + } + TaskContext.setTaskContext(taskContext) + val consumer = KafkaDataConsumer.acquire( + topicPartition, kafkaParams.asJava, useCache) + try { + val range = consumer.getAvailableOffsetRange() + val rcvd = range.earliest until range.latest map { offset => + val bytes = consumer.get(offset, Long.MaxValue, 10000, failOnDataLoss = false).value() + new String(bytes) + } + assert(rcvd == data) + } catch { + case e: Throwable => + error = e + throw e + } finally { + consumer.release() + } + } + + val threadpool = Executors.newFixedThreadPool(numThreads) + try { + val futures = (1 to numConsumerUsages).map { i => + threadpool.submit(new Runnable { + override def run(): Unit = { consume(i) } + }) + } + futures.foreach(_.get(1, TimeUnit.MINUTES)) + assert(error == null) + } finally { + threadpool.shutdown() + } + } +} From 8a72734f33f6a0abbd3207b0d661633c8b25d9ad Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 16 Mar 2018 11:42:57 -0700 Subject: [PATCH 486/774] [SPARK-15009][PYTHON][ML] Construct a CountVectorizerModel from a vocabulary list ## What changes were proposed in this pull request? Added a class method to construct CountVectorizerModel from a list of vocabulary strings, equivalent to the Scala version. Introduced a common param base class `_CountVectorizerParams` to allow the Python model to also own the parameters. This now matches the Scala class hierarchy. ## How was this patch tested? Added to CountVectorizer doctests to do a transform on a model constructed from vocab, and unit test to verify params and vocab are constructed correctly. Author: Bryan Cutler Closes #16770 from BryanCutler/pyspark-CountVectorizerModel-vocab_ctor-SPARK-15009. --- python/pyspark/ml/feature.py | 168 +++++++++++++++++++++++------------ python/pyspark/ml/tests.py | 32 ++++++- 2 files changed, 142 insertions(+), 58 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index f2e357f0bede5..a1ceb7f02da8b 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -19,12 +19,12 @@ if sys.version > '3': basestring = str -from pyspark import since, keyword_only +from pyspark import since, keyword_only, SparkContext from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.linalg import _convert_to_vector from pyspark.ml.param.shared import * from pyspark.ml.util import JavaMLReadable, JavaMLWritable -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaTransformer, _jvm from pyspark.ml.common import inherit_doc __all__ = ['Binarizer', @@ -403,8 +403,69 @@ def getSplits(self): return self.getOrDefault(self.splits) +class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol): + """ + Params for :py:attr:`CountVectorizer` and :py:attr:`CountVectorizerModel`. + """ + + minTF = Param( + Params._dummy(), "minTF", "Filter to ignore rare words in" + + " a document. For each document, terms with frequency/count less than the given" + + " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + + " times the term must appear in the document); if this is a double in [0,1), then this " + + "specifies a fraction (out of the document's token count). Note that the parameter is " + + "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0", + typeConverter=TypeConverters.toFloat) + minDF = Param( + Params._dummy(), "minDF", "Specifies the minimum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + + " Default 1.0", typeConverter=TypeConverters.toFloat) + vocabSize = Param( + Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", + typeConverter=TypeConverters.toInt) + binary = Param( + Params._dummy(), "binary", "Binary toggle to control the output vector values." + + " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" + + " for discrete probabilistic models that model binary events rather than integer counts." + + " Default False", typeConverter=TypeConverters.toBoolean) + + def __init__(self, *args): + super(_CountVectorizerParams, self).__init__(*args) + self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) + + @since("1.6.0") + def getMinTF(self): + """ + Gets the value of minTF or its default value. + """ + return self.getOrDefault(self.minTF) + + @since("1.6.0") + def getMinDF(self): + """ + Gets the value of minDF or its default value. + """ + return self.getOrDefault(self.minDF) + + @since("1.6.0") + def getVocabSize(self): + """ + Gets the value of vocabSize or its default value. + """ + return self.getOrDefault(self.vocabSize) + + @since("2.0.0") + def getBinary(self): + """ + Gets the value of binary or its default value. + """ + return self.getOrDefault(self.binary) + + @inherit_doc -class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): +class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, JavaMLWritable): """ Extracts a vocabulary from document collections and generates a :py:attr:`CountVectorizerModel`. @@ -437,33 +498,20 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, >>> loadedModel = CountVectorizerModel.load(modelPath) >>> loadedModel.vocabulary == model.vocabulary True + >>> fromVocabModel = CountVectorizerModel.from_vocabulary(["a", "b", "c"], + ... inputCol="raw", outputCol="vectors") + >>> fromVocabModel.transform(df).show(truncate=False) + +-----+---------------+-------------------------+ + |label|raw |vectors | + +-----+---------------+-------------------------+ + |0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])| + |1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])| + +-----+---------------+-------------------------+ + ... .. versionadded:: 1.6.0 """ - minTF = Param( - Params._dummy(), "minTF", "Filter to ignore rare words in" + - " a document. For each document, terms with frequency/count less than the given" + - " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + - " times the term must appear in the document); if this is a double in [0,1), then this " + - "specifies a fraction (out of the document's token count). Note that the parameter is " + - "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0", - typeConverter=TypeConverters.toFloat) - minDF = Param( - Params._dummy(), "minDF", "Specifies the minimum number of" + - " different documents a term must appear in to be included in the vocabulary." + - " If this is an integer >= 1, this specifies the number of documents the term must" + - " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + - " Default 1.0", typeConverter=TypeConverters.toFloat) - vocabSize = Param( - Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", - typeConverter=TypeConverters.toInt) - binary = Param( - Params._dummy(), "binary", "Binary toggle to control the output vector values." + - " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" + - " for discrete probabilistic models that model binary events rather than integer counts." + - " Default False", typeConverter=TypeConverters.toBoolean) - @keyword_only def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, outputCol=None): @@ -474,7 +522,6 @@ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputC super(CountVectorizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", self.uid) - self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -497,13 +544,6 @@ def setMinTF(self, value): """ return self._set(minTF=value) - @since("1.6.0") - def getMinTF(self): - """ - Gets the value of minTF or its default value. - """ - return self.getOrDefault(self.minTF) - @since("1.6.0") def setMinDF(self, value): """ @@ -511,13 +551,6 @@ def setMinDF(self, value): """ return self._set(minDF=value) - @since("1.6.0") - def getMinDF(self): - """ - Gets the value of minDF or its default value. - """ - return self.getOrDefault(self.minDF) - @since("1.6.0") def setVocabSize(self, value): """ @@ -525,13 +558,6 @@ def setVocabSize(self, value): """ return self._set(vocabSize=value) - @since("1.6.0") - def getVocabSize(self): - """ - Gets the value of vocabSize or its default value. - """ - return self.getOrDefault(self.vocabSize) - @since("2.0.0") def setBinary(self, value): """ @@ -539,24 +565,40 @@ def setBinary(self, value): """ return self._set(binary=value) - @since("2.0.0") - def getBinary(self): - """ - Gets the value of binary or its default value. - """ - return self.getOrDefault(self.binary) - def _create_model(self, java_model): return CountVectorizerModel(java_model) -class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable): +@inherit_doc +class CountVectorizerModel(JavaModel, _CountVectorizerParams, JavaMLReadable, JavaMLWritable): """ Model fitted by :py:class:`CountVectorizer`. .. versionadded:: 1.6.0 """ + @classmethod + @since("2.4.0") + def from_vocabulary(cls, vocabulary, inputCol, outputCol=None, minTF=None, binary=None): + """ + Construct the model directly from a vocabulary list of strings, + requires an active SparkContext. + """ + sc = SparkContext._active_spark_context + java_class = sc._gateway.jvm.java.lang.String + jvocab = CountVectorizerModel._new_java_array(vocabulary, java_class) + model = CountVectorizerModel._create_from_java_class( + "org.apache.spark.ml.feature.CountVectorizerModel", jvocab) + model.setInputCol(inputCol) + if outputCol is not None: + model.setOutputCol(outputCol) + if minTF is not None: + model.setMinTF(minTF) + if binary is not None: + model.setBinary(binary) + model._set(vocabSize=len(vocabulary)) + return model + @property @since("1.6.0") def vocabulary(self): @@ -565,6 +607,20 @@ def vocabulary(self): """ return self._call_java("vocabulary") + @since("2.4.0") + def setMinTF(self, value): + """ + Sets the value of :py:attr:`minTF`. + """ + return self._set(minTF=value) + + @since("2.4.0") + def setBinary(self, value): + """ + Sets the value of :py:attr:`binary`. + """ + return self._set(binary=value) + @inherit_doc class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6dee6938d8916..fd45fd00b270b 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -679,6 +679,34 @@ def test_count_vectorizer_with_binary(self): feature, expected = r self.assertEqual(feature, expected) + def test_count_vectorizer_from_vocab(self): + model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words", + outputCol="features", minTF=2) + self.assertEqual(model.vocabulary, ["a", "b", "c"]) + self.assertEqual(model.getMinTF(), 2) + + dataset = self.spark.createDataFrame([ + (0, "a a a b b c".split(' '), SparseVector(3, {0: 3.0, 1: 2.0}),), + (1, "a a".split(' '), SparseVector(3, {0: 2.0}),), + (2, "a b".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) + + transformed_list = model.transform(dataset).select("features", "expected").collect() + + for r in transformed_list: + feature, expected = r + self.assertEqual(feature, expected) + + # Test an empty vocabulary + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"): + CountVectorizerModel.from_vocabulary([], inputCol="words") + + # Test model with default settings can transform + model_default = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words") + transformed_list = model_default.transform(dataset)\ + .select(model_default.getOrDefault(model_default.outputCol)).collect() + self.assertEqual(len(transformed_list), 3) + def test_rformula_force_index_label(self): df = self.spark.createDataFrame([ (1.0, 1.0, "a"), @@ -2019,8 +2047,8 @@ def test_java_params(self): pyspark.ml.regression] for module in modules: for name, cls in inspect.getmembers(module, inspect.isclass): - if not name.endswith('Model') and issubclass(cls, JavaParams)\ - and not inspect.isabstract(cls): + if not name.endswith('Model') and not name.endswith('Params')\ + and issubclass(cls, JavaParams) and not inspect.isabstract(cls): # NOTE: disable check_params_exist until there is parity with Scala API ParamTests.check_params(self, cls(), check_params_exist=False) From 8a1efe3076f29259151f1fba2ff894487efb6c4e Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Fri, 16 Mar 2018 15:40:21 -0700 Subject: [PATCH 487/774] [SPARK-23683][SQL] FileCommitProtocol.instantiate() hardening ## What changes were proposed in this pull request? With SPARK-20236, `FileCommitProtocol.instantiate()` looks for a three argument constructor, passing in the `dynamicPartitionOverwrite` parameter. If there is no such constructor, it falls back to the classic two-arg one. When `InsertIntoHadoopFsRelationCommand` passes down that `dynamicPartitionOverwrite` flag `to FileCommitProtocol.instantiate(`), it assumes that the instantiated protocol supports the specific requirements of dynamic partition overwrite. It does not notice when this does not hold, and so the output generated may be incorrect. This patch changes `FileCommitProtocol.instantiate()` so when `dynamicPartitionOverwrite == true`, it requires the protocol implementation to have a 3-arg constructor. Classic two arg constructors are supported when it is false. Also it adds some debug level logging for anyone trying to understand what's going on. ## How was this patch tested? Unit tests verify that * classes with only 2-arg constructor cannot be used with dynamic overwrite * classes with only 2-arg constructor can be used without dynamic overwrite * classes with 3 arg constructors can be used with both. * the fallback to any two arg ctor takes place after the attempt to load the 3-arg ctor, * passing in invalid class types fail as expected (regression tests on expected behavior) Author: Steve Loughran Closes #20824 from steveloughran/stevel/SPARK-23683-protocol-instantiate. --- .../internal/io/FileCommitProtocol.scala | 11 +- ...FileCommitProtocolInstantiationSuite.scala | 148 ++++++++++++++++++ 2 files changed, 158 insertions(+), 1 deletion(-) create mode 100644 core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index 6d0059b6a0272..e6e9c9e328853 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -20,6 +20,7 @@ package org.apache.spark.internal.io import org.apache.hadoop.fs._ import org.apache.hadoop.mapreduce._ +import org.apache.spark.internal.Logging import org.apache.spark.util.Utils @@ -132,7 +133,7 @@ abstract class FileCommitProtocol { } -object FileCommitProtocol { +object FileCommitProtocol extends Logging { class TaskCommitMessage(val obj: Any) extends Serializable object EmptyTaskCommitMessage extends TaskCommitMessage(null) @@ -145,15 +146,23 @@ object FileCommitProtocol { jobId: String, outputPath: String, dynamicPartitionOverwrite: Boolean = false): FileCommitProtocol = { + + logDebug(s"Creating committer $className; job $jobId; output=$outputPath;" + + s" dynamic=$dynamicPartitionOverwrite") val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]] // First try the constructor with arguments (jobId: String, outputPath: String, // dynamicPartitionOverwrite: Boolean). // If that doesn't exist, try the one with (jobId: string, outputPath: String). try { val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean]) + logDebug("Using (String, String, Boolean) constructor") ctor.newInstance(jobId, outputPath, dynamicPartitionOverwrite.asInstanceOf[java.lang.Boolean]) } catch { case _: NoSuchMethodException => + logDebug("Falling back to (String, String) constructor") + require(!dynamicPartitionOverwrite, + "Dynamic Partition Overwrite is enabled but" + + s" the committer ${className} does not have the appropriate constructor") val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) ctor.newInstance(jobId, outputPath) } diff --git a/core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala b/core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala new file mode 100644 index 0000000000000..2bd32fc927e21 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import org.apache.spark.SparkFunSuite + +/** + * Unit tests for instantiation of FileCommitProtocol implementations. + */ +class FileCommitProtocolInstantiationSuite extends SparkFunSuite { + + test("Dynamic partitions require appropriate constructor") { + + // you cannot instantiate a two-arg client with dynamic partitions + // enabled. + val ex = intercept[IllegalArgumentException] { + instantiateClassic(true) + } + // check the contents of the message and rethrow if unexpected. + // this preserves the stack trace of the unexpected + // exception. + if (!ex.toString.contains("Dynamic Partition Overwrite")) { + fail(s"Wrong text in caught exception $ex", ex) + } + } + + test("Standard partitions work with classic constructor") { + instantiateClassic(false) + } + + test("Three arg constructors have priority") { + assert(3 == instantiateNew(false).argCount, + "Wrong constructor argument count") + } + + test("Three arg constructors have priority when dynamic") { + assert(3 == instantiateNew(true).argCount, + "Wrong constructor argument count") + } + + test("The protocol must be of the correct class") { + intercept[ClassCastException] { + FileCommitProtocol.instantiate( + classOf[Other].getCanonicalName, + "job", + "path", + false) + } + } + + test("If there is no matching constructor, class hierarchy is irrelevant") { + intercept[NoSuchMethodException] { + FileCommitProtocol.instantiate( + classOf[NoMatchingArgs].getCanonicalName, + "job", + "path", + false) + } + } + + /** + * Create a classic two-arg protocol instance. + * @param dynamic dyanmic partitioning mode + * @return the instance + */ + private def instantiateClassic(dynamic: Boolean): ClassicConstructorCommitProtocol = { + FileCommitProtocol.instantiate( + classOf[ClassicConstructorCommitProtocol].getCanonicalName, + "job", + "path", + dynamic).asInstanceOf[ClassicConstructorCommitProtocol] + } + + /** + * Create a three-arg protocol instance. + * @param dynamic dyanmic partitioning mode + * @return the instance + */ + private def instantiateNew( + dynamic: Boolean): FullConstructorCommitProtocol = { + FileCommitProtocol.instantiate( + classOf[FullConstructorCommitProtocol].getCanonicalName, + "job", + "path", + dynamic).asInstanceOf[FullConstructorCommitProtocol] + } + +} + +/** + * This protocol implementation does not have the new three-arg + * constructor. + */ +private class ClassicConstructorCommitProtocol(arg1: String, arg2: String) + extends HadoopMapReduceCommitProtocol(arg1, arg2) { +} + +/** + * This protocol implementation does have the new three-arg constructor + * alongside the original, and a 4 arg one for completeness. + * The final value of the real constructor is the number of arguments + * used in the 2- and 3- constructor, for test assertions. + */ +private class FullConstructorCommitProtocol( + arg1: String, + arg2: String, + b: Boolean, + val argCount: Int) + extends HadoopMapReduceCommitProtocol(arg1, arg2, b) { + + def this(arg1: String, arg2: String) = { + this(arg1, arg2, false, 2) + } + + def this(arg1: String, arg2: String, b: Boolean) = { + this(arg1, arg2, false, 3) + } +} + +/** + * This has the 2-arity constructor, but isn't the right class. + */ +private class Other(arg1: String, arg2: String) { + +} + +/** + * This has no matching arguments as well as being the wrong class. + */ +private class NoMatchingArgs() { + +} + From 61487b308b0169e3108c2ad31674a0c80b8ac5f3 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 18 Mar 2018 20:24:14 +0900 Subject: [PATCH 488/774] [SPARK-23706][PYTHON] spark.conf.get(value, default=None) should produce None in PySpark ## What changes were proposed in this pull request? Scala: ``` scala> spark.conf.get("hey", null) res1: String = null ``` ``` scala> spark.conf.get("spark.sql.sources.partitionOverwriteMode", null) res2: String = null ``` Python: **Before** ``` >>> spark.conf.get("hey", None) ... py4j.protocol.Py4JJavaError: An error occurred while calling o30.get. : java.util.NoSuchElementException: hey ... ``` ``` >>> spark.conf.get("spark.sql.sources.partitionOverwriteMode", None) u'STATIC' ``` **After** ``` >>> spark.conf.get("hey", None) is None True ``` ``` >>> spark.conf.get("spark.sql.sources.partitionOverwriteMode", None) is None True ``` *Note that this PR preserves the case below: ``` >>> spark.conf.get("spark.sql.sources.partitionOverwriteMode") u'STATIC' ``` ## How was this patch tested? Manually tested and unit tests were added. Author: hyukjinkwon Closes #20841 from HyukjinKwon/spark-conf-get. --- python/pyspark/sql/conf.py | 9 +++++---- python/pyspark/sql/context.py | 8 ++++---- python/pyspark/sql/tests.py | 11 +++++++++++ 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index d929834aeeaa5..b82224b6194ed 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -17,7 +17,7 @@ import sys -from pyspark import since +from pyspark import since, _NoValue from pyspark.rdd import ignore_unicode_prefix @@ -39,15 +39,16 @@ def set(self, key, value): @ignore_unicode_prefix @since(2.0) - def get(self, key, default=None): + def get(self, key, default=_NoValue): """Returns the value of Spark runtime configuration property for the given key, assuming it is set. """ self._checkType(key, "key") - if default is None: + if default is _NoValue: return self._jconf.get(key) else: - self._checkType(default, "default") + if default is not None: + self._checkType(default, "default") return self._jconf.get(key, default) @ignore_unicode_prefix diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 6cb90399dd616..e9ec7ba866761 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -22,7 +22,7 @@ if sys.version >= '3': basestring = unicode = str -from pyspark import since +from pyspark import since, _NoValue from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.session import _monkey_patch_RDD, SparkSession from pyspark.sql.dataframe import DataFrame @@ -124,11 +124,11 @@ def setConf(self, key, value): @ignore_unicode_prefix @since(1.3) - def getConf(self, key, defaultValue=None): + def getConf(self, key, defaultValue=_NoValue): """Returns the value of Spark SQL configuration property for the given key. - If the key is not set and defaultValue is not None, return - defaultValue. If the key is not set and defaultValue is None, return + If the key is not set and defaultValue is set, return + defaultValue. If the key is not set and defaultValue is not set, return the system default value. >>> sqlContext.getConf("spark.sql.shuffle.partitions") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 480815d27333f..a0d547ad620e5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2504,6 +2504,17 @@ def test_conf(self): spark.conf.unset("bogo") self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia") + self.assertEqual(spark.conf.get("hyukjin", None), None) + + # This returns 'STATIC' because it's the default value of + # 'spark.sql.sources.partitionOverwriteMode', and `defaultValue` in + # `spark.conf.get` is unset. + self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode"), "STATIC") + + # This returns None because 'spark.sql.sources.partitionOverwriteMode' is unset, but + # `defaultValue` in `spark.conf.get` is set to None. + self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode", None), None) + def test_current_database(self): spark = self.spark spark.catalog._reset() From 745c8c0901ac522ba92c1356ca74bd0dd7701496 Mon Sep 17 00:00:00 2001 From: zhoukang Date: Mon, 19 Mar 2018 13:31:21 +0800 Subject: [PATCH 489/774] [SPARK-23708][CORE] Correct comment for function addShutDownHook in ShutdownHookManager ## What changes were proposed in this pull request? Minor modification.Comment below is not right. ``` /** * Adds a shutdown hook with the given priority. Hooks with lower priority values run * first. * * param hook The code to run during shutdown. * return A handle that can be used to unregister the shutdown hook. */ def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = { shutdownHooks.add(priority, hook) } ``` ## How was this patch tested? UT Author: zhoukang Closes #20845 from caneGuy/zhoukang/fix-shutdowncomment. --- .../main/scala/org/apache/spark/util/ShutdownHookManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index 4001fac3c3d5a..b702838fa257f 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -143,7 +143,7 @@ private[spark] object ShutdownHookManager extends Logging { } /** - * Adds a shutdown hook with the given priority. Hooks with lower priority values run + * Adds a shutdown hook with the given priority. Hooks with higher priority values run * first. * * @param hook The code to run during shutdown. From 4de638c1976dea74761bbe5c30da808178ee885d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 19 Mar 2018 09:41:43 +0100 Subject: [PATCH 490/774] [SPARK-23599][SQL] Add a UUID generator from Pseudo-Random Numbers ## What changes were proposed in this pull request? This patch adds a UUID generator from Pseudo-Random Numbers. We can use it later to have deterministic `UUID()` expression. ## How was this patch tested? Added unit tests. Author: Liang-Chi Hsieh Closes #20817 from viirya/SPARK-23599. --- .../catalyst/util/RandomUUIDGenerator.scala | 43 ++++++++++++++ .../util/RandomUUIDGeneratorSuite.scala | 57 +++++++++++++++++++ 2 files changed, 100 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGenerator.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGeneratorSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGenerator.scala new file mode 100644 index 0000000000000..4fe07a071c1ca --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGenerator.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.util.UUID + +import org.apache.commons.math3.random.MersenneTwister + +import org.apache.spark.unsafe.types.UTF8String + +/** + * This class is used to generate a UUID from Pseudo-Random Numbers. + * + * For the algorithm, see RFC 4122: A Universally Unique IDentifier (UUID) URN Namespace, + * section 4.4 "Algorithms for Creating a UUID from Truly Random or Pseudo-Random Numbers". + */ +case class RandomUUIDGenerator(randomSeed: Long) { + private val random = new MersenneTwister(randomSeed) + + def getNextUUID(): UUID = { + val mostSigBits = (random.nextLong() & 0xFFFFFFFFFFFF0FFFL) | 0x0000000000004000L + val leastSigBits = (random.nextLong() | 0x8000000000000000L) & 0xBFFFFFFFFFFFFFFFL + + new UUID(mostSigBits, leastSigBits) + } + + def getNextUUIDUTF8String(): UTF8String = UTF8String.fromString(getNextUUID().toString()) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGeneratorSuite.scala new file mode 100644 index 0000000000000..b75739e5a3a65 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGeneratorSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.util.Random + +import org.apache.spark.SparkFunSuite + +class RandomUUIDGeneratorSuite extends SparkFunSuite { + test("RandomUUIDGenerator should generate version 4, variant 2 UUIDs") { + val generator = RandomUUIDGenerator(new Random().nextLong()) + for (_ <- 0 to 100) { + val uuid = generator.getNextUUID() + assert(uuid.version() == 4) + assert(uuid.variant() == 2) + } + } + + test("UUID from RandomUUIDGenerator should be deterministic") { + val r1 = new Random(100) + val generator1 = RandomUUIDGenerator(r1.nextLong()) + val r2 = new Random(100) + val generator2 = RandomUUIDGenerator(r2.nextLong()) + val r3 = new Random(101) + val generator3 = RandomUUIDGenerator(r3.nextLong()) + + for (_ <- 0 to 100) { + val uuid1 = generator1.getNextUUID() + val uuid2 = generator2.getNextUUID() + val uuid3 = generator3.getNextUUID() + assert(uuid1 == uuid2) + assert(uuid1 != uuid3) + } + } + + test("Get UTF8String UUID") { + val generator = RandomUUIDGenerator(new Random().nextLong()) + val utf8StringUUID = generator.getNextUUIDUTF8String() + val uuid = java.util.UUID.fromString(utf8StringUUID.toString) + assert(uuid.version() == 4 && uuid.variant() == 2 && utf8StringUUID.toString == uuid.toString) + } +} From f15906da153f139b698e192ec6f82f078f896f1e Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Mon, 19 Mar 2018 11:29:56 -0700 Subject: [PATCH 491/774] [SPARK-22839][K8S] Remove the use of init-container for downloading remote dependencies ## What changes were proposed in this pull request? Removal of the init-container for downloading remote dependencies. Built off of the work done by vanzin in an attempt to refactor driver/executor configuration elaborated in [this](https://issues.apache.org/jira/browse/SPARK-22839) ticket. ## How was this patch tested? This patch was tested with unit and integration tests. Author: Ilan Filonenko Closes #20669 from ifilonenko/remove-init-container. --- bin/docker-image-tool.sh | 9 +- .../org/apache/spark/deploy/SparkSubmit.scala | 2 - docs/running-on-kubernetes.md | 71 +------- .../spark/examples/SparkRemoteFileTest.scala | 48 ++++++ .../org/apache/spark/deploy/k8s/Config.scala | 73 +------- .../apache/spark/deploy/k8s/Constants.scala | 21 +-- .../deploy/k8s/InitContainerBootstrap.scala | 120 ------------- .../spark/deploy/k8s/KubernetesUtils.scala | 63 +------ .../k8s/PodWithDetachedInitContainer.scala | 31 ---- .../deploy/k8s/SparkPodInitContainer.scala | 116 ------------- .../k8s/submit/DriverConfigOrchestrator.scala | 45 +---- .../submit/KubernetesClientApplication.scala | 84 +++++---- .../steps/BasicDriverConfigurationStep.scala | 32 ++-- .../steps/DependencyResolutionStep.scala | 18 +- .../DriverInitContainerBootstrapStep.scala | 95 ----------- .../DriverKubernetesCredentialsStep.scala | 2 +- .../BasicInitContainerConfigurationStep.scala | 67 -------- .../InitContainerConfigOrchestrator.scala | 79 --------- .../InitContainerConfigurationStep.scala | 25 --- .../InitContainerMountSecretsStep.scala | 36 ---- .../initcontainer/InitContainerSpec.scala | 37 ---- .../cluster/k8s/ExecutorPodFactory.scala | 43 +---- .../k8s/KubernetesClusterManager.scala | 65 +------ .../k8s/SparkPodInitContainerSuite.scala | 86 ---------- .../spark/deploy/k8s/submit/ClientSuite.scala | 82 ++++----- .../DriverConfigOrchestratorSuite.scala | 41 +---- .../BasicDriverConfigurationStepSuite.scala | 8 +- .../steps/DependencyResolutionStepSuite.scala | 32 ++-- ...riverInitContainerBootstrapStepSuite.scala | 160 ------------------ ...cInitContainerConfigurationStepSuite.scala | 95 ----------- ...InitContainerConfigOrchestratorSuite.scala | 80 --------- .../InitContainerMountSecretsStepSuite.scala | 52 ------ .../cluster/k8s/ExecutorPodFactorySuite.scala | 67 +------- .../src/main/dockerfiles/spark/Dockerfile | 1 - .../src/main/dockerfiles/spark/entrypoint.sh | 20 +-- 35 files changed, 241 insertions(+), 1665 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index 0d0f564bb8b9b..f090240065bf1 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -64,9 +64,11 @@ function build { error "Cannot find docker image. This script must be run from a runnable distribution of Apache Spark." fi + local DOCKERFILE=${DOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} + docker build "${BUILD_ARGS[@]}" \ -t $(image_ref spark) \ - -f "$IMG_PATH/spark/Dockerfile" . + -f "$DOCKERFILE" . } function push { @@ -84,6 +86,7 @@ Commands: push Push a pre-built image to a registry. Requires a repository address to be provided. Options: + -f file Dockerfile to build. By default builds the Dockerfile shipped with Spark. -r repo Repository address. -t tag Tag to apply to the built image, or to identify the image to be pushed. -m Use minikube's Docker daemon. @@ -113,10 +116,12 @@ fi REPO= TAG= -while getopts mr:t: option +DOCKERFILE= +while getopts f:mr:t: option do case "${option}" in + f) DOCKERFILE=${OPTARG};; r) REPO=${OPTARG};; t) TAG=${OPTARG};; m) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 1e381965c52ba..329bde08718fe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -320,8 +320,6 @@ object SparkSubmit extends CommandLineUtils with Logging { printErrorAndExit("Python applications are currently not supported for Kubernetes.") case (KUBERNETES, _) if args.isR => printErrorAndExit("R applications are currently not supported for Kubernetes.") - case (KUBERNETES, CLIENT) => - printErrorAndExit("Client mode is currently not supported for Kubernetes.") case (LOCAL, CLUSTER) => printErrorAndExit("Cluster deploy mode is not compatible with master \"local\"") case (_, CLUSTER) if isShell(args.primaryResource) => diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 3c7586e8544ba..975b28de47e20 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -126,29 +126,6 @@ Those dependencies can be added to the classpath by referencing them with `local dependencies in custom-built Docker images in `spark-submit`. Note that using application dependencies from the submission client's local file system is currently not yet supported. - -### Using Remote Dependencies -When there are application dependencies hosted in remote locations like HDFS or HTTP servers, the driver and executor pods -need a Kubernetes [init-container](https://kubernetes.io/docs/concepts/workloads/pods/init-containers/) for downloading -the dependencies so the driver and executor containers can use them locally. - -The init-container handles remote dependencies specified in `spark.jars` (or the `--jars` option of `spark-submit`) and -`spark.files` (or the `--files` option of `spark-submit`). It also handles remotely hosted main application resources, e.g., -the main application jar. The following shows an example of using remote dependencies with the `spark-submit` command: - -```bash -$ bin/spark-submit \ - --master k8s://https://: \ - --deploy-mode cluster \ - --name spark-pi \ - --class org.apache.spark.examples.SparkPi \ - --jars https://path/to/dependency1.jar,https://path/to/dependency2.jar - --files hdfs://host:port/path/to/file1,hdfs://host:port/path/to/file2 - --conf spark.executor.instances=5 \ - --conf spark.kubernetes.container.image= \ - https://path/to/examples.jar -``` - ## Secret Management Kubernetes [Secrets](https://kubernetes.io/docs/concepts/configuration/secret/) can be used to provide credentials for a Spark application to access secured services. To mount a user-specified secret into the driver container, users can use @@ -163,10 +140,6 @@ namespace as that of the driver and executor pods. For example, to mount a secre --conf spark.kubernetes.executor.secrets.spark-secret=/etc/secrets ``` -Note that if an init-container is used, any secret mounted into the driver container will also be mounted into the -init-container of the driver. Similarly, any secret mounted into an executor container will also be mounted into the -init-container of the executor. - ## Introspection and Debugging These are the different ways in which you can investigate a running/completed Spark application, monitor progress, and @@ -604,51 +577,12 @@ specific to Spark on Kubernetes. the Driver process. The user can specify multiple of these to set multiple environment variables. - - spark.kubernetes.mountDependencies.jarsDownloadDir - /var/spark-data/spark-jars - - Location to download jars to in the driver and executors. - This directory must be empty and will be mounted as an empty directory volume on the driver and executor pods. - - - - spark.kubernetes.mountDependencies.filesDownloadDir - /var/spark-data/spark-files - - Location to download jars to in the driver and executors. - This directory must be empty and will be mounted as an empty directory volume on the driver and executor pods. - - - - spark.kubernetes.mountDependencies.timeout - 300s - - Timeout in seconds before aborting the attempt to download and unpack dependencies from remote locations into - the driver and executor pods. - - - - spark.kubernetes.mountDependencies.maxSimultaneousDownloads - 5 - - Maximum number of remote dependencies to download simultaneously in a driver or executor pod. - - - - spark.kubernetes.initContainer.image - (value of spark.kubernetes.container.image) - - Custom container image for the init container of both driver and executors. - - spark.kubernetes.driver.secrets.[SecretName] (none) Add the Kubernetes Secret named SecretName to the driver pod on the path specified in the value. For example, - spark.kubernetes.driver.secrets.spark-secret=/etc/secrets. Note that if an init-container is used, - the secret will also be added to the init-container in the driver pod. + spark.kubernetes.driver.secrets.spark-secret=/etc/secrets. @@ -656,8 +590,7 @@ specific to Spark on Kubernetes. (none) Add the Kubernetes Secret named SecretName to the executor pod on the path specified in the value. For example, - spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. Note that if an init-container is used, - the secret will also be added to the init-container in the executor pod. + spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. \ No newline at end of file diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala b/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala new file mode 100644 index 0000000000000..64076f2deb706 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples + +import java.io.File + +import org.apache.spark.SparkFiles +import org.apache.spark.sql.SparkSession + +/** Usage: SparkRemoteFileTest [file] */ +object SparkRemoteFileTest { + def main(args: Array[String]) { + if (args.length < 1) { + System.err.println("Usage: SparkRemoteFileTest ") + System.exit(1) + } + val spark = SparkSession + .builder() + .appName("SparkRemoteFileTest") + .getOrCreate() + val sc = spark.sparkContext + val rdd = sc.parallelize(Seq(1)).map(_ => { + val localLocation = SparkFiles.get(args(0)) + println(s"${args(0)} is stored at: $localLocation") + new File(localLocation).isFile + }) + val truthCheck = rdd.collect().head + println(s"Mounting of ${args(0)} was $truthCheck") + spark.stop() + } +} +// scalastyle:on println diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 471196ac0e3f6..da34a7e06238a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -79,6 +79,12 @@ private[spark] object Config extends Logging { .stringConf .createOptional + val KUBERNETES_DRIVER_SUBMIT_CHECK = + ConfigBuilder("spark.kubernetes.submitInDriver") + .internal() + .booleanConf + .createOptional + val KUBERNETES_EXECUTOR_LIMIT_CORES = ConfigBuilder("spark.kubernetes.executor.limit.cores") .doc("Specify the hard cpu limit for each executor pod") @@ -135,73 +141,6 @@ private[spark] object Config extends Logging { .checkValue(interval => interval > 0, s"Logging interval must be a positive time value.") .createWithDefaultString("1s") - val JARS_DOWNLOAD_LOCATION = - ConfigBuilder("spark.kubernetes.mountDependencies.jarsDownloadDir") - .doc("Location to download jars to in the driver and executors. When using " + - "spark-submit, this directory must be empty and will be mounted as an empty directory " + - "volume on the driver and executor pod.") - .stringConf - .createWithDefault("/var/spark-data/spark-jars") - - val FILES_DOWNLOAD_LOCATION = - ConfigBuilder("spark.kubernetes.mountDependencies.filesDownloadDir") - .doc("Location to download files to in the driver and executors. When using " + - "spark-submit, this directory must be empty and will be mounted as an empty directory " + - "volume on the driver and executor pods.") - .stringConf - .createWithDefault("/var/spark-data/spark-files") - - val INIT_CONTAINER_IMAGE = - ConfigBuilder("spark.kubernetes.initContainer.image") - .doc("Image for the driver and executor's init-container for downloading dependencies.") - .fallbackConf(CONTAINER_IMAGE) - - val INIT_CONTAINER_MOUNT_TIMEOUT = - ConfigBuilder("spark.kubernetes.mountDependencies.timeout") - .doc("Timeout before aborting the attempt to download and unpack dependencies from remote " + - "locations into the driver and executor pods.") - .timeConf(TimeUnit.SECONDS) - .createWithDefault(300) - - val INIT_CONTAINER_MAX_THREAD_POOL_SIZE = - ConfigBuilder("spark.kubernetes.mountDependencies.maxSimultaneousDownloads") - .doc("Maximum number of remote dependencies to download simultaneously in a driver or " + - "executor pod.") - .intConf - .createWithDefault(5) - - val INIT_CONTAINER_REMOTE_JARS = - ConfigBuilder("spark.kubernetes.initContainer.remoteJars") - .doc("Comma-separated list of jar URIs to download in the init-container. This is " + - "calculated from spark.jars.") - .internal() - .stringConf - .createOptional - - val INIT_CONTAINER_REMOTE_FILES = - ConfigBuilder("spark.kubernetes.initContainer.remoteFiles") - .doc("Comma-separated list of file URIs to download in the init-container. This is " + - "calculated from spark.files.") - .internal() - .stringConf - .createOptional - - val INIT_CONTAINER_CONFIG_MAP_NAME = - ConfigBuilder("spark.kubernetes.initContainer.configMapName") - .doc("Name of the config map to use in the init-container that retrieves submitted files " + - "for the executor.") - .internal() - .stringConf - .createOptional - - val INIT_CONTAINER_CONFIG_MAP_KEY_CONF = - ConfigBuilder("spark.kubernetes.initContainer.configMapKey") - .doc("Key for the entry in the init container config map for submitted files that " + - "corresponds to the properties for this init-container.") - .internal() - .stringConf - .createOptional - val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = "spark.kubernetes.authenticate.submission" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 9411956996843..8da5f24044aad 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -63,22 +63,13 @@ private[spark] object Constants { val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH" val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_" val ENV_CLASSPATH = "SPARK_CLASSPATH" - val ENV_DRIVER_MAIN_CLASS = "SPARK_DRIVER_CLASS" - val ENV_DRIVER_ARGS = "SPARK_DRIVER_ARGS" - val ENV_DRIVER_JAVA_OPTS = "SPARK_DRIVER_JAVA_OPTS" val ENV_DRIVER_BIND_ADDRESS = "SPARK_DRIVER_BIND_ADDRESS" - val ENV_DRIVER_MEMORY = "SPARK_DRIVER_MEMORY" - val ENV_MOUNTED_FILES_DIR = "SPARK_MOUNTED_FILES_DIR" - - // Bootstrapping dependencies with the init-container - val INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME = "download-jars-volume" - val INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME = "download-files-volume" - val INIT_CONTAINER_PROPERTIES_FILE_VOLUME = "spark-init-properties" - val INIT_CONTAINER_PROPERTIES_FILE_DIR = "/etc/spark-init" - val INIT_CONTAINER_PROPERTIES_FILE_NAME = "spark-init.properties" - val INIT_CONTAINER_PROPERTIES_FILE_PATH = - s"$INIT_CONTAINER_PROPERTIES_FILE_DIR/$INIT_CONTAINER_PROPERTIES_FILE_NAME" - val INIT_CONTAINER_SECRET_VOLUME_NAME = "spark-init-secret" + val ENV_SPARK_CONF_DIR = "SPARK_CONF_DIR" + // Spark app configs for containers + val SPARK_CONF_VOLUME = "spark-conf-volume" + val SPARK_CONF_DIR_INTERNAL = "/opt/spark/conf" + val SPARK_CONF_FILE_NAME = "spark.properties" + val SPARK_CONF_PATH = s"$SPARK_CONF_DIR_INTERNAL/$SPARK_CONF_FILE_NAME" // Miscellaneous val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala deleted file mode 100644 index f6a57dfe00171..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, EmptyDirVolumeSource, EnvVarBuilder, PodBuilder, VolumeMount, VolumeMountBuilder} - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ - -/** - * Bootstraps an init-container for downloading remote dependencies. This is separated out from - * the init-container steps API because this component can be used to bootstrap init-containers - * for both the driver and executors. - */ -private[spark] class InitContainerBootstrap( - initContainerImage: String, - imagePullPolicy: String, - jarsDownloadPath: String, - filesDownloadPath: String, - configMapName: String, - configMapKey: String, - sparkRole: String, - sparkConf: SparkConf) { - - /** - * Bootstraps an init-container that downloads dependencies to be used by a main container. - */ - def bootstrapInitContainer( - original: PodWithDetachedInitContainer): PodWithDetachedInitContainer = { - val sharedVolumeMounts = Seq[VolumeMount]( - new VolumeMountBuilder() - .withName(INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME) - .withMountPath(jarsDownloadPath) - .build(), - new VolumeMountBuilder() - .withName(INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME) - .withMountPath(filesDownloadPath) - .build()) - - val customEnvVarKeyPrefix = sparkRole match { - case SPARK_POD_DRIVER_ROLE => KUBERNETES_DRIVER_ENV_KEY - case SPARK_POD_EXECUTOR_ROLE => "spark.executorEnv." - case _ => throw new SparkException(s"$sparkRole is not a valid Spark pod role") - } - val customEnvVars = sparkConf.getAllWithPrefix(customEnvVarKeyPrefix).toSeq.map { - case (key, value) => - new EnvVarBuilder() - .withName(key) - .withValue(value) - .build() - } - - val initContainer = new ContainerBuilder(original.initContainer) - .withName("spark-init") - .withImage(initContainerImage) - .withImagePullPolicy(imagePullPolicy) - .addAllToEnv(customEnvVars.asJava) - .addNewVolumeMount() - .withName(INIT_CONTAINER_PROPERTIES_FILE_VOLUME) - .withMountPath(INIT_CONTAINER_PROPERTIES_FILE_DIR) - .endVolumeMount() - .addToVolumeMounts(sharedVolumeMounts: _*) - .addToArgs("init") - .addToArgs(INIT_CONTAINER_PROPERTIES_FILE_PATH) - .build() - - val podWithBasicVolumes = new PodBuilder(original.pod) - .editSpec() - .addNewVolume() - .withName(INIT_CONTAINER_PROPERTIES_FILE_VOLUME) - .withNewConfigMap() - .withName(configMapName) - .addNewItem() - .withKey(configMapKey) - .withPath(INIT_CONTAINER_PROPERTIES_FILE_NAME) - .endItem() - .endConfigMap() - .endVolume() - .addNewVolume() - .withName(INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME) - .withEmptyDir(new EmptyDirVolumeSource()) - .endVolume() - .addNewVolume() - .withName(INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME) - .withEmptyDir(new EmptyDirVolumeSource()) - .endVolume() - .endSpec() - .build() - - val mainContainer = new ContainerBuilder(original.mainContainer) - .addToVolumeMounts(sharedVolumeMounts: _*) - .addNewEnv() - .withName(ENV_MOUNTED_FILES_DIR) - .withValue(filesDownloadPath) - .endEnv() - .build() - - PodWithDetachedInitContainer( - podWithBasicVolumes, - initContainer, - mainContainer) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 37331d8bbf9b7..5bc070147d3a8 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -16,10 +16,6 @@ */ package org.apache.spark.deploy.k8s -import java.io.File - -import io.fabric8.kubernetes.api.model.{Container, Pod, PodBuilder} - import org.apache.spark.SparkConf import org.apache.spark.util.Utils @@ -43,72 +39,23 @@ private[spark] object KubernetesUtils { opt1.foreach { _ => require(opt2.isEmpty, errMessage) } } - /** - * Append the given init-container to a pod's list of init-containers. - * - * @param originalPodSpec original specification of the pod - * @param initContainer the init-container to add to the pod - * @return the pod with the init-container added to the list of InitContainers - */ - def appendInitContainer(originalPodSpec: Pod, initContainer: Container): Pod = { - new PodBuilder(originalPodSpec) - .editOrNewSpec() - .addToInitContainers(initContainer) - .endSpec() - .build() - } - /** * For the given collection of file URIs, resolves them as follows: - * - File URIs with scheme file:// are resolved to the given download path. * - File URIs with scheme local:// resolve to just the path of the URI. * - Otherwise, the URIs are returned as-is. */ - def resolveFileUris( - fileUris: Iterable[String], - fileDownloadPath: String): Iterable[String] = { - fileUris.map { uri => - resolveFileUri(uri, fileDownloadPath, false) - } - } - - /** - * If any file uri has any scheme other than local:// it is mapped as if the file - * was downloaded to the file download path. Otherwise, it is mapped to the path - * part of the URI. - */ - def resolveFilePaths(fileUris: Iterable[String], fileDownloadPath: String): Iterable[String] = { + def resolveFileUrisAndPath(fileUris: Iterable[String]): Iterable[String] = { fileUris.map { uri => - resolveFileUri(uri, fileDownloadPath, true) - } - } - - /** - * Get from a given collection of file URIs the ones that represent remote files. - */ - def getOnlyRemoteFiles(uris: Iterable[String]): Iterable[String] = { - uris.filter { uri => - val scheme = Utils.resolveURI(uri).getScheme - scheme != "file" && scheme != "local" + resolveFileUri(uri) } } - private def resolveFileUri( - uri: String, - fileDownloadPath: String, - assumesDownloaded: Boolean): String = { + private def resolveFileUri(uri: String): String = { val fileUri = Utils.resolveURI(uri) val fileScheme = Option(fileUri.getScheme).getOrElse("file") fileScheme match { - case "local" => - fileUri.getPath - case _ => - if (assumesDownloaded || fileScheme == "file") { - val fileName = new File(fileUri.getPath).getName - s"$fileDownloadPath/$fileName" - } else { - uri - } + case "local" => fileUri.getPath + case _ => uri } } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala deleted file mode 100644 index 0b79f8b12e806..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import io.fabric8.kubernetes.api.model.{Container, Pod} - -/** - * Represents a pod with a detached init-container (not yet added to the pod). - * - * @param pod the pod - * @param initContainer the init-container in the pod - * @param mainContainer the main container in the pod - */ -private[spark] case class PodWithDetachedInitContainer( - pod: Pod, - initContainer: Container, - mainContainer: Container) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala deleted file mode 100644 index c0f08786b76a1..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import java.io.File -import java.util.concurrent.TimeUnit - -import scala.concurrent.{ExecutionContext, Future} - -import org.apache.spark.{SecurityManager => SparkSecurityManager, SparkConf} -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.internal.Logging -import org.apache.spark.util.{ThreadUtils, Utils} - -/** - * Process that fetches files from a resource staging server and/or arbitrary remote locations. - * - * The init-container can handle fetching files from any of those sources, but not all of the - * sources need to be specified. This allows for composing multiple instances of this container - * with different configurations for different download sources, or using the same container to - * download everything at once. - */ -private[spark] class SparkPodInitContainer( - sparkConf: SparkConf, - fileFetcher: FileFetcher) extends Logging { - - private val maxThreadPoolSize = sparkConf.get(INIT_CONTAINER_MAX_THREAD_POOL_SIZE) - private implicit val downloadExecutor = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("download-executor", maxThreadPoolSize)) - - private val jarsDownloadDir = new File(sparkConf.get(JARS_DOWNLOAD_LOCATION)) - private val filesDownloadDir = new File(sparkConf.get(FILES_DOWNLOAD_LOCATION)) - - private val remoteJars = sparkConf.get(INIT_CONTAINER_REMOTE_JARS) - private val remoteFiles = sparkConf.get(INIT_CONTAINER_REMOTE_FILES) - - private val downloadTimeoutMinutes = sparkConf.get(INIT_CONTAINER_MOUNT_TIMEOUT) - - def run(): Unit = { - logInfo(s"Downloading remote jars: $remoteJars") - downloadFiles( - remoteJars, - jarsDownloadDir, - s"Remote jars download directory specified at $jarsDownloadDir does not exist " + - "or is not a directory.") - - logInfo(s"Downloading remote files: $remoteFiles") - downloadFiles( - remoteFiles, - filesDownloadDir, - s"Remote files download directory specified at $filesDownloadDir does not exist " + - "or is not a directory.") - - downloadExecutor.shutdown() - downloadExecutor.awaitTermination(downloadTimeoutMinutes, TimeUnit.MINUTES) - } - - private def downloadFiles( - filesCommaSeparated: Option[String], - downloadDir: File, - errMessage: String): Unit = { - filesCommaSeparated.foreach { files => - require(downloadDir.isDirectory, errMessage) - Utils.stringToSeq(files).foreach { file => - Future[Unit] { - fileFetcher.fetchFile(file, downloadDir) - } - } - } - } -} - -private class FileFetcher(sparkConf: SparkConf, securityManager: SparkSecurityManager) { - - def fetchFile(uri: String, targetDir: File): Unit = { - Utils.fetchFile( - url = uri, - targetDir = targetDir, - conf = sparkConf, - securityMgr = securityManager, - hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf), - timestamp = System.currentTimeMillis(), - useCache = false) - } -} - -object SparkPodInitContainer extends Logging { - - def main(args: Array[String]): Unit = { - logInfo("Starting init-container to download Spark application dependencies.") - val sparkConf = new SparkConf(true) - if (args.nonEmpty) { - Utils.loadDefaultSparkProperties(sparkConf, args(0)) - } - - val securityManager = new SparkSecurityManager(sparkConf) - val fileFetcher = new FileFetcher(sparkConf, securityManager) - new SparkPodInitContainer(sparkConf, fileFetcher).run() - logInfo("Finished downloading application dependencies.") - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala index ae70904621184..b4d3f04a1bc32 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala @@ -16,16 +16,11 @@ */ package org.apache.spark.deploy.k8s.submit -import java.util.UUID - -import com.google.common.primitives.Longs - import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit.steps._ -import org.apache.spark.deploy.k8s.submit.steps.initcontainer.InitContainerConfigOrchestrator import org.apache.spark.launcher.SparkLauncher import org.apache.spark.util.SystemClock import org.apache.spark.util.Utils @@ -34,13 +29,11 @@ import org.apache.spark.util.Utils * Figures out and returns the complete ordered list of needed DriverConfigurationSteps to * configure the Spark driver pod. The returned steps will be applied one by one in the given * order to produce a final KubernetesDriverSpec that is used in KubernetesClientApplication - * to construct and create the driver pod. It uses the InitContainerConfigOrchestrator to - * configure the driver init-container if one is needed, i.e., when there are remote dependencies - * to localize. + * to construct and create the driver pod. */ private[spark] class DriverConfigOrchestrator( kubernetesAppId: String, - launchTime: Long, + kubernetesResourceNamePrefix: String, mainAppResource: Option[MainAppResource], appName: String, mainClass: String, @@ -50,15 +43,8 @@ private[spark] class DriverConfigOrchestrator( // The resource name prefix is derived from the Spark application name, making it easy to connect // the names of the Kubernetes resources from e.g. kubectl or the Kubernetes dashboard to the // application the user submitted. - private val kubernetesResourceNamePrefix = { - val uuid = UUID.nameUUIDFromBytes(Longs.toByteArray(launchTime)).toString.replaceAll("-", "") - s"$appName-$uuid".toLowerCase.replaceAll("\\.", "-") - } private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) - private val initContainerConfigMapName = s"$kubernetesResourceNamePrefix-init-config" - private val jarsDownloadPath = sparkConf.get(JARS_DOWNLOAD_LOCATION) - private val filesDownloadPath = sparkConf.get(FILES_DOWNLOAD_LOCATION) def getAllConfigurationSteps: Seq[DriverConfigurationStep] = { val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( @@ -126,9 +112,7 @@ private[spark] class DriverConfigOrchestrator( val dependencyResolutionStep = if (sparkJars.nonEmpty || sparkFiles.nonEmpty) { Seq(new DependencyResolutionStep( sparkJars, - sparkFiles, - jarsDownloadPath, - filesDownloadPath)) + sparkFiles)) } else { Nil } @@ -139,33 +123,12 @@ private[spark] class DriverConfigOrchestrator( Nil } - val initContainerBootstrapStep = if (existNonContainerLocalFiles(sparkJars ++ sparkFiles)) { - val orchestrator = new InitContainerConfigOrchestrator( - sparkJars, - sparkFiles, - jarsDownloadPath, - filesDownloadPath, - imagePullPolicy, - initContainerConfigMapName, - INIT_CONTAINER_PROPERTIES_FILE_NAME, - sparkConf) - val bootstrapStep = new DriverInitContainerBootstrapStep( - orchestrator.getAllConfigurationSteps, - initContainerConfigMapName, - INIT_CONTAINER_PROPERTIES_FILE_NAME) - - Seq(bootstrapStep) - } else { - Nil - } - Seq( initialSubmissionStep, serviceBootstrapStep, kubernetesCredentialsStep) ++ dependencyResolutionStep ++ - mountSecretsStep ++ - initContainerBootstrapStep + mountSecretsStep } private def existSubmissionLocalFiles(files: Seq[String]): Boolean = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index 5884348cb3e41..e16d1add600b2 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -16,14 +16,14 @@ */ package org.apache.spark.deploy.k8s.submit +import java.io.StringWriter import java.util.{Collections, UUID} - -import scala.collection.JavaConverters._ -import scala.collection.mutable -import scala.util.control.NonFatal +import java.util.Properties import io.fabric8.kubernetes.api.model._ import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.mutable +import scala.util.control.NonFatal import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkApplication @@ -32,6 +32,7 @@ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory import org.apache.spark.deploy.k8s.submit.steps.DriverConfigurationStep import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.ConfigBuilder import org.apache.spark.util.Utils /** @@ -93,10 +94,8 @@ private[spark] class Client( kubernetesClient: KubernetesClient, waitForAppCompletion: Boolean, appName: String, - watcher: LoggingPodStatusWatcher) extends Logging { - - private val driverJavaOptions = sparkConf.get( - org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS) + watcher: LoggingPodStatusWatcher, + kubernetesResourceNamePrefix: String) extends Logging { /** * Run command that initializes a DriverSpec that will be updated after each @@ -110,33 +109,31 @@ private[spark] class Client( for (nextStep <- submissionSteps) { currentDriverSpec = nextStep.configureDriver(currentDriverSpec) } - - val resolvedDriverJavaOpts = currentDriverSpec - .driverSparkConf - // Remove this as the options are instead extracted and set individually below using - // environment variables with prefix SPARK_JAVA_OPT_. - .remove(org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS) - .getAll - .map { - case (confKey, confValue) => s"-D$confKey=$confValue" - } ++ driverJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) - val driverJavaOptsEnvs: Seq[EnvVar] = resolvedDriverJavaOpts.zipWithIndex.map { - case (option, index) => - new EnvVarBuilder() - .withName(s"$ENV_JAVA_OPT_PREFIX$index") - .withValue(option) - .build() - } - + val configMapName = s"$kubernetesResourceNamePrefix-driver-conf-map" + val configMap = buildConfigMap(configMapName, currentDriverSpec.driverSparkConf) + // The include of the ENV_VAR for "SPARK_CONF_DIR" is to allow for the + // Spark command builder to pickup on the Java Options present in the ConfigMap val resolvedDriverContainer = new ContainerBuilder(currentDriverSpec.driverContainer) - .addAllToEnv(driverJavaOptsEnvs.asJava) + .addNewEnv() + .withName(ENV_SPARK_CONF_DIR) + .withValue(SPARK_CONF_DIR_INTERNAL) + .endEnv() + .addNewVolumeMount() + .withName(SPARK_CONF_VOLUME) + .withMountPath(SPARK_CONF_DIR_INTERNAL) + .endVolumeMount() .build() val resolvedDriverPod = new PodBuilder(currentDriverSpec.driverPod) .editSpec() .addToContainers(resolvedDriverContainer) + .addNewVolume() + .withName(SPARK_CONF_VOLUME) + .withNewConfigMap() + .withName(configMapName) + .endConfigMap() + .endVolume() .endSpec() .build() - Utils.tryWithResource( kubernetesClient .pods() @@ -145,7 +142,8 @@ private[spark] class Client( val createdDriverPod = kubernetesClient.pods().create(resolvedDriverPod) try { if (currentDriverSpec.otherKubernetesResources.nonEmpty) { - val otherKubernetesResources = currentDriverSpec.otherKubernetesResources + val otherKubernetesResources = + currentDriverSpec.otherKubernetesResources ++ Seq(configMap) addDriverOwnerReference(createdDriverPod, otherKubernetesResources) kubernetesClient.resourceList(otherKubernetesResources: _*).createOrReplace() } @@ -180,6 +178,26 @@ private[spark] class Client( originalMetadata.setOwnerReferences(Collections.singletonList(driverPodOwnerReference)) } } + + // Build a Config Map that will house spark conf properties in a single file for spark-submit + private def buildConfigMap(configMapName: String, conf: SparkConf): ConfigMap = { + val properties = new Properties() + conf.getAll.foreach { case (k, v) => + properties.setProperty(k, v) + } + val propertiesWriter = new StringWriter() + properties.store(propertiesWriter, + s"Java properties built from Kubernetes config map with name: $configMapName") + + val namespace = conf.get(KUBERNETES_NAMESPACE) + new ConfigMapBuilder() + .withNewMetadata() + .withName(configMapName) + .withNamespace(namespace) + .endMetadata() + .addToData(SPARK_CONF_FILE_NAME, propertiesWriter.toString) + .build() + } } /** @@ -202,6 +220,9 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val launchTime = System.currentTimeMillis() val waitForAppCompletion = sparkConf.get(WAIT_FOR_APP_COMPLETION) val appName = sparkConf.getOption("spark.app.name").getOrElse("spark") + val kubernetesResourceNamePrefix = { + s"$appName-$launchTime".toLowerCase.replaceAll("\\.", "-") + } // The master URL has been checked for validity already in SparkSubmit. // We just need to get rid of the "k8s://" prefix here. val master = sparkConf.get("spark.master").substring("k8s://".length) @@ -211,7 +232,7 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val orchestrator = new DriverConfigOrchestrator( kubernetesAppId, - launchTime, + kubernetesResourceNamePrefix, clientArguments.mainAppResource, appName, clientArguments.mainClass, @@ -231,7 +252,8 @@ private[spark] class KubernetesClientApplication extends SparkApplication { kubernetesClient, waitForAppCompletion, appName, - watcher) + watcher, + kubernetesResourceNamePrefix) client.run() } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index 164e2e5594778..347c4d2d66826 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -26,6 +26,7 @@ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.KubernetesUtils import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec import org.apache.spark.internal.config.{DRIVER_CLASS_PATH, DRIVER_MEMORY, DRIVER_MEMORY_OVERHEAD} +import org.apache.spark.launcher.SparkLauncher /** * Performs basic configuration for the driver pod. @@ -56,8 +57,6 @@ private[spark] class BasicDriverConfigurationStep( // Memory settings private val driverMemoryMiB = sparkConf.get(DRIVER_MEMORY) - private val driverMemoryString = sparkConf.get( - DRIVER_MEMORY.key, DRIVER_MEMORY.defaultValueString) private val memoryOverheadMiB = sparkConf .get(DRIVER_MEMORY_OVERHEAD) .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) @@ -103,24 +102,12 @@ private[spark] class BasicDriverConfigurationStep( ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) } - val driverContainer = new ContainerBuilder(driverSpec.driverContainer) + val driverContainerWithoutArgs = new ContainerBuilder(driverSpec.driverContainer) .withName(DRIVER_CONTAINER_NAME) .withImage(driverContainerImage) .withImagePullPolicy(imagePullPolicy) .addAllToEnv(driverCustomEnvs.asJava) .addToEnv(driverExtraClasspathEnv.toSeq: _*) - .addNewEnv() - .withName(ENV_DRIVER_MEMORY) - .withValue(driverMemoryString) - .endEnv() - .addNewEnv() - .withName(ENV_DRIVER_MAIN_CLASS) - .withValue(mainClass) - .endEnv() - .addNewEnv() - .withName(ENV_DRIVER_ARGS) - .withValue(appArgs.mkString(" ")) - .endEnv() .addNewEnv() .withName(ENV_DRIVER_BIND_ADDRESS) .withValueFrom(new EnvVarSourceBuilder() @@ -134,7 +121,16 @@ private[spark] class BasicDriverConfigurationStep( .addToLimits(maybeCpuLimitQuantity.toMap.asJava) .endResources() .addToArgs("driver") - .build() + .addToArgs("--properties-file", SPARK_CONF_PATH) + .addToArgs("--class", mainClass) + // The user application jar is merged into the spark.jars list and managed through that + // property, so there is no need to reference it explicitly here. + .addToArgs(SparkLauncher.NO_RESOURCE) + + val driverContainer = appArgs.toList match { + case "" :: Nil | Nil => driverContainerWithoutArgs.build() + case _ => driverContainerWithoutArgs.addToArgs(appArgs: _*).build() + } val baseDriverPod = new PodBuilder(driverSpec.driverPod) .editOrNewMetadata() @@ -152,10 +148,14 @@ private[spark] class BasicDriverConfigurationStep( .setIfMissing(KUBERNETES_DRIVER_POD_NAME, driverPodName) .set("spark.app.id", kubernetesAppId) .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, resourceNamePrefix) + // to set the config variables to allow client-mode spark-submit from driver + .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) driverSpec.copy( driverPod = baseDriverPod, driverSparkConf = resolvedSparkConf, driverContainer = driverContainer) } + } + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala index d4b83235b4e3b..43de329f239ad 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala @@ -30,13 +30,11 @@ import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec */ private[spark] class DependencyResolutionStep( sparkJars: Seq[String], - sparkFiles: Seq[String], - jarsDownloadPath: String, - filesDownloadPath: String) extends DriverConfigurationStep { + sparkFiles: Seq[String]) extends DriverConfigurationStep { override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val resolvedSparkJars = KubernetesUtils.resolveFileUris(sparkJars, jarsDownloadPath) - val resolvedSparkFiles = KubernetesUtils.resolveFileUris(sparkFiles, filesDownloadPath) + val resolvedSparkJars = KubernetesUtils.resolveFileUrisAndPath(sparkJars) + val resolvedSparkFiles = KubernetesUtils.resolveFileUrisAndPath(sparkFiles) val sparkConf = driverSpec.driverSparkConf.clone() if (resolvedSparkJars.nonEmpty) { @@ -45,14 +43,12 @@ private[spark] class DependencyResolutionStep( if (resolvedSparkFiles.nonEmpty) { sparkConf.set("spark.files", resolvedSparkFiles.mkString(",")) } - - val resolvedClasspath = KubernetesUtils.resolveFilePaths(sparkJars, jarsDownloadPath) - val resolvedDriverContainer = if (resolvedClasspath.nonEmpty) { + val resolvedDriverContainer = if (resolvedSparkJars.nonEmpty) { new ContainerBuilder(driverSpec.driverContainer) .addNewEnv() - .withName(ENV_MOUNTED_CLASSPATH) - .withValue(resolvedClasspath.mkString(File.pathSeparator)) - .endEnv() + .withName(ENV_MOUNTED_CLASSPATH) + .withValue(resolvedSparkJars.mkString(File.pathSeparator)) + .endEnv() .build() } else { driverSpec.driverContainer diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala deleted file mode 100644 index 9fb3dafdda540..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.StringWriter -import java.util.Properties - -import io.fabric8.kubernetes.api.model.{ConfigMap, ConfigMapBuilder, ContainerBuilder, HasMetadata} - -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.KubernetesUtils -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.deploy.k8s.submit.steps.initcontainer.{InitContainerConfigurationStep, InitContainerSpec} - -/** - * Configures the driver init-container that localizes remote dependencies into the driver pod. - * It applies the given InitContainerConfigurationSteps in the given order to produce a final - * InitContainerSpec that is then used to configure the driver pod with the init-container attached. - * It also builds a ConfigMap that will be mounted into the init-container. The ConfigMap carries - * configuration properties for the init-container. - */ -private[spark] class DriverInitContainerBootstrapStep( - steps: Seq[InitContainerConfigurationStep], - configMapName: String, - configMapKey: String) - extends DriverConfigurationStep { - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - var initContainerSpec = InitContainerSpec( - properties = Map.empty[String, String], - driverSparkConf = Map.empty[String, String], - initContainer = new ContainerBuilder().build(), - driverContainer = driverSpec.driverContainer, - driverPod = driverSpec.driverPod, - dependentResources = Seq.empty[HasMetadata]) - for (nextStep <- steps) { - initContainerSpec = nextStep.configureInitContainer(initContainerSpec) - } - - val configMap = buildConfigMap( - configMapName, - configMapKey, - initContainerSpec.properties) - val resolvedDriverSparkConf = driverSpec.driverSparkConf - .clone() - .set(INIT_CONTAINER_CONFIG_MAP_NAME, configMapName) - .set(INIT_CONTAINER_CONFIG_MAP_KEY_CONF, configMapKey) - .setAll(initContainerSpec.driverSparkConf) - val resolvedDriverPod = KubernetesUtils.appendInitContainer( - initContainerSpec.driverPod, initContainerSpec.initContainer) - - driverSpec.copy( - driverPod = resolvedDriverPod, - driverContainer = initContainerSpec.driverContainer, - driverSparkConf = resolvedDriverSparkConf, - otherKubernetesResources = - driverSpec.otherKubernetesResources ++ - initContainerSpec.dependentResources ++ - Seq(configMap)) - } - - private def buildConfigMap( - configMapName: String, - configMapKey: String, - config: Map[String, String]): ConfigMap = { - val properties = new Properties() - config.foreach { entry => - properties.setProperty(entry._1, entry._2) - } - val propertiesWriter = new StringWriter() - properties.store(propertiesWriter, - s"Java properties built from Kubernetes config map with name: $configMapName " + - s"and config map key: $configMapKey") - new ConfigMapBuilder() - .withNewMetadata() - .withName(configMapName) - .endMetadata() - .addToData(configMapKey, propertiesWriter.toString) - .build() - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala index ccc18908658f1..2424e63999a82 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala @@ -99,7 +99,7 @@ private[spark] class DriverKubernetesCredentialsStep( }.getOrElse(driverSpec.driverPod) ) - val driverContainerWithMountedSecretVolume = kubernetesCredentialsSecret.map { secret => + val driverContainerWithMountedSecretVolume = kubernetesCredentialsSecret.map { _ => new ContainerBuilder(driverSpec.driverContainer) .addNewVolumeMount() .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala deleted file mode 100644 index 01469853dacc2..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, PodWithDetachedInitContainer} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.KubernetesUtils - -/** - * Performs basic configuration for the driver init-container with most of the work delegated to - * the given InitContainerBootstrap. - */ -private[spark] class BasicInitContainerConfigurationStep( - sparkJars: Seq[String], - sparkFiles: Seq[String], - jarsDownloadPath: String, - filesDownloadPath: String, - bootstrap: InitContainerBootstrap) - extends InitContainerConfigurationStep { - - override def configureInitContainer(spec: InitContainerSpec): InitContainerSpec = { - val remoteJarsToDownload = KubernetesUtils.getOnlyRemoteFiles(sparkJars) - val remoteFilesToDownload = KubernetesUtils.getOnlyRemoteFiles(sparkFiles) - val remoteJarsConf = if (remoteJarsToDownload.nonEmpty) { - Map(INIT_CONTAINER_REMOTE_JARS.key -> remoteJarsToDownload.mkString(",")) - } else { - Map() - } - val remoteFilesConf = if (remoteFilesToDownload.nonEmpty) { - Map(INIT_CONTAINER_REMOTE_FILES.key -> remoteFilesToDownload.mkString(",")) - } else { - Map() - } - - val baseInitContainerConfig = Map( - JARS_DOWNLOAD_LOCATION.key -> jarsDownloadPath, - FILES_DOWNLOAD_LOCATION.key -> filesDownloadPath) ++ - remoteJarsConf ++ - remoteFilesConf - - val bootstrapped = bootstrap.bootstrapInitContainer( - PodWithDetachedInitContainer( - spec.driverPod, - spec.initContainer, - spec.driverContainer)) - - spec.copy( - initContainer = bootstrapped.initContainer, - driverContainer = bootstrapped.mainContainer, - driverPod = bootstrapped.pod, - properties = spec.properties ++ baseInitContainerConfig) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala deleted file mode 100644 index f2c29c7ce1076..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ - -/** - * Figures out and returns the complete ordered list of InitContainerConfigurationSteps required to - * configure the driver init-container. The returned steps will be applied in the given order to - * produce a final InitContainerSpec that is used to construct the driver init-container in - * DriverInitContainerBootstrapStep. This class is only used when an init-container is needed, i.e., - * when there are remote application dependencies to localize. - */ -private[spark] class InitContainerConfigOrchestrator( - sparkJars: Seq[String], - sparkFiles: Seq[String], - jarsDownloadPath: String, - filesDownloadPath: String, - imagePullPolicy: String, - configMapName: String, - configMapKey: String, - sparkConf: SparkConf) { - - private val initContainerImage = sparkConf - .get(INIT_CONTAINER_IMAGE) - .getOrElse(throw new SparkException( - "Must specify the init-container image when there are remote dependencies")) - - def getAllConfigurationSteps: Seq[InitContainerConfigurationStep] = { - val initContainerBootstrap = new InitContainerBootstrap( - initContainerImage, - imagePullPolicy, - jarsDownloadPath, - filesDownloadPath, - configMapName, - configMapKey, - SPARK_POD_DRIVER_ROLE, - sparkConf) - val baseStep = new BasicInitContainerConfigurationStep( - sparkJars, - sparkFiles, - jarsDownloadPath, - filesDownloadPath, - initContainerBootstrap) - - val secretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_DRIVER_SECRETS_PREFIX) - // Mount user-specified driver secrets also into the driver's init-container. The - // init-container may need credentials in the secrets to be able to download remote - // dependencies. The driver's main container and its init-container share the secrets - // because the init-container is sort of an implementation details and this sharing - // avoids introducing a dedicated configuration property just for the init-container. - val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { - Seq(new InitContainerMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) - } else { - Nil - } - - Seq(baseStep) ++ mountSecretsStep - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala deleted file mode 100644 index 0372ad5270951..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -/** - * Represents a step in configuring the driver init-container. - */ -private[spark] trait InitContainerConfigurationStep { - - def configureInitContainer(spec: InitContainerSpec): InitContainerSpec -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala deleted file mode 100644 index 0daa7b95e8aae..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import org.apache.spark.deploy.k8s.MountSecretsBootstrap - -/** - * An init-container configuration step for mounting user-specified secrets onto user-specified - * paths. - * - * @param bootstrap a utility actually handling mounting of the secrets - */ -private[spark] class InitContainerMountSecretsStep( - bootstrap: MountSecretsBootstrap) extends InitContainerConfigurationStep { - - override def configureInitContainer(spec: InitContainerSpec) : InitContainerSpec = { - // Mount the secret volumes given that the volumes have already been added to the driver pod - // when mounting the secrets into the main driver container. - val initContainer = bootstrap.mountSecrets(spec.initContainer) - spec.copy(initContainer = initContainer) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala deleted file mode 100644 index b52c343f0c0ed..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import io.fabric8.kubernetes.api.model.{Container, HasMetadata, Pod} - -/** - * Represents a specification of the init-container for the driver pod. - * - * @param properties properties that should be set on the init-container - * @param driverSparkConf Spark configuration properties that will be carried back to the driver - * @param initContainer the init-container object - * @param driverContainer the driver container object - * @param driverPod the driver pod object - * @param dependentResources resources the init-container depends on to work - */ -private[spark] case class InitContainerSpec( - properties: Map[String, String], - driverSparkConf: Map[String, String], - initContainer: Container, - driverContainer: Container, - driverPod: Pod, - dependentResources: Seq[HasMetadata]) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index 141bd2827e7c5..98cbd5607da00 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model._ import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap, PodWithDetachedInitContainer} +import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} @@ -34,18 +34,10 @@ import org.apache.spark.util.Utils * @param sparkConf Spark configuration * @param mountSecretsBootstrap an optional component for mounting user-specified secrets onto * user-specified paths into the executor container - * @param initContainerBootstrap an optional component for bootstrapping the executor init-container - * if one is needed, i.e., when there are remote dependencies to - * localize - * @param initContainerMountSecretsBootstrap an optional component for mounting user-specified - * secrets onto user-specified paths into the executor - * init-container */ private[spark] class ExecutorPodFactory( sparkConf: SparkConf, - mountSecretsBootstrap: Option[MountSecretsBootstrap], - initContainerBootstrap: Option[InitContainerBootstrap], - initContainerMountSecretsBootstrap: Option[MountSecretsBootstrap]) { + mountSecretsBootstrap: Option[MountSecretsBootstrap]) { private val executorExtraClasspath = sparkConf.get(EXECUTOR_CLASS_PATH) @@ -94,8 +86,6 @@ private[spark] class ExecutorPodFactory( private val executorCores = sparkConf.getDouble("spark.executor.cores", 1) private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) - private val executorJarsDownloadDir = sparkConf.get(JARS_DOWNLOAD_LOCATION) - /** * Configure and construct an executor pod with the given parameters. */ @@ -147,8 +137,9 @@ private[spark] class ExecutorPodFactory( (ENV_EXECUTOR_CORES, math.ceil(executorCores).toInt.toString), (ENV_EXECUTOR_MEMORY, executorMemoryString), (ENV_APPLICATION_ID, applicationId), - (ENV_EXECUTOR_ID, executorId), - (ENV_MOUNTED_CLASSPATH, s"$executorJarsDownloadDir/*")) ++ executorEnvs) + // This is to set the SPARK_CONF_DIR to be /opt/spark/conf + (ENV_SPARK_CONF_DIR, SPARK_CONF_DIR_INTERNAL), + (ENV_EXECUTOR_ID, executorId)) ++ executorEnvs) .map(env => new EnvVarBuilder() .withName(env._1) .withValue(env._2) @@ -221,30 +212,10 @@ private[spark] class ExecutorPodFactory( (bootstrap.addSecretVolumes(executorPod), bootstrap.mountSecrets(containerWithLimitCores)) }.getOrElse((executorPod, containerWithLimitCores)) - val (bootstrappedPod, bootstrappedContainer) = - initContainerBootstrap.map { bootstrap => - val podWithInitContainer = bootstrap.bootstrapInitContainer( - PodWithDetachedInitContainer( - maybeSecretsMountedPod, - new ContainerBuilder().build(), - maybeSecretsMountedContainer)) - - val (pod, mayBeSecretsMountedInitContainer) = - initContainerMountSecretsBootstrap.map { bootstrap => - // Mount the secret volumes given that the volumes have already been added to the - // executor pod when mounting the secrets into the main executor container. - (podWithInitContainer.pod, bootstrap.mountSecrets(podWithInitContainer.initContainer)) - }.getOrElse((podWithInitContainer.pod, podWithInitContainer.initContainer)) - - val bootstrappedPod = KubernetesUtils.appendInitContainer( - pod, mayBeSecretsMountedInitContainer) - - (bootstrappedPod, podWithInitContainer.mainContainer) - }.getOrElse((maybeSecretsMountedPod, maybeSecretsMountedContainer)) - new PodBuilder(bootstrappedPod) + new PodBuilder(maybeSecretsMountedPod) .editSpec() - .addToContainers(bootstrappedContainer) + .addToContainers(maybeSecretsMountedContainer) .endSpec() .build() } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index a942db6ae02db..ff5f6801da2a3 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -21,7 +21,7 @@ import java.io.File import io.fabric8.kubernetes.client.Config import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap, SparkKubernetesClientFactory} +import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging @@ -33,7 +33,9 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit override def canCreate(masterURL: String): Boolean = masterURL.startsWith("k8s") override def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler = { - if (masterURL.startsWith("k8s") && sc.deployMode == "client") { + if (masterURL.startsWith("k8s") && + sc.deployMode == "client" && + !sc.conf.get(KUBERNETES_DRIVER_SUBMIT_CHECK).getOrElse(false)) { throw new SparkException("Client mode is currently not supported for Kubernetes.") } @@ -44,74 +46,23 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit sc: SparkContext, masterURL: String, scheduler: TaskScheduler): SchedulerBackend = { - val sparkConf = sc.getConf - val initContainerConfigMap = sparkConf.get(INIT_CONTAINER_CONFIG_MAP_NAME) - val initContainerConfigMapKey = sparkConf.get(INIT_CONTAINER_CONFIG_MAP_KEY_CONF) - - if (initContainerConfigMap.isEmpty) { - logWarning("The executor's init-container config map is not specified. Executors will " + - "therefore not attempt to fetch remote or submitted dependencies.") - } - - if (initContainerConfigMapKey.isEmpty) { - logWarning("The executor's init-container config map key is not specified. Executors will " + - "therefore not attempt to fetch remote or submitted dependencies.") - } - - // Only set up the bootstrap if they've provided both the config map key and the config map - // name. The config map might not be provided if init-containers aren't being used to - // bootstrap dependencies. - val initContainerBootstrap = for { - configMap <- initContainerConfigMap - configMapKey <- initContainerConfigMapKey - } yield { - val initContainerImage = sparkConf - .get(INIT_CONTAINER_IMAGE) - .getOrElse(throw new SparkException( - "Must specify the init-container image when there are remote dependencies")) - new InitContainerBootstrap( - initContainerImage, - sparkConf.get(CONTAINER_IMAGE_PULL_POLICY), - sparkConf.get(JARS_DOWNLOAD_LOCATION), - sparkConf.get(FILES_DOWNLOAD_LOCATION), - configMap, - configMapKey, - SPARK_POD_EXECUTOR_ROLE, - sparkConf) - } - val executorSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) + sc.conf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) val mountSecretBootstrap = if (executorSecretNamesToMountPaths.nonEmpty) { Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths)) } else { None } - // Mount user-specified executor secrets also into the executor's init-container. The - // init-container may need credentials in the secrets to be able to download remote - // dependencies. The executor's main container and its init-container share the secrets - // because the init-container is sort of an implementation details and this sharing - // avoids introducing a dedicated configuration property just for the init-container. - val initContainerMountSecretsBootstrap = if (initContainerBootstrap.nonEmpty && - executorSecretNamesToMountPaths.nonEmpty) { - Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths)) - } else { - None - } val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( KUBERNETES_MASTER_INTERNAL_URL, - Some(sparkConf.get(KUBERNETES_NAMESPACE)), + Some(sc.conf.get(KUBERNETES_NAMESPACE)), KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX, - sparkConf, + sc.conf, Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - val executorPodFactory = new ExecutorPodFactory( - sparkConf, - mountSecretBootstrap, - initContainerBootstrap, - initContainerMountSecretsBootstrap) + val executorPodFactory = new ExecutorPodFactory(sc.conf, mountSecretBootstrap) val allocatorExecutor = ThreadUtils .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala deleted file mode 100644 index e0f29ecd0fb53..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import java.io.File -import java.util.UUID - -import com.google.common.base.Charsets -import com.google.common.io.Files -import org.mockito.Mockito -import org.scalatest.BeforeAndAfter -import org.scalatest.mockito.MockitoSugar._ - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.util.Utils - -class SparkPodInitContainerSuite extends SparkFunSuite with BeforeAndAfter { - - private val DOWNLOAD_JARS_SECRET_LOCATION = createTempFile("txt") - private val DOWNLOAD_FILES_SECRET_LOCATION = createTempFile("txt") - - private var downloadJarsDir: File = _ - private var downloadFilesDir: File = _ - private var downloadJarsSecretValue: String = _ - private var downloadFilesSecretValue: String = _ - private var fileFetcher: FileFetcher = _ - - override def beforeAll(): Unit = { - downloadJarsSecretValue = Files.toString( - new File(DOWNLOAD_JARS_SECRET_LOCATION), Charsets.UTF_8) - downloadFilesSecretValue = Files.toString( - new File(DOWNLOAD_FILES_SECRET_LOCATION), Charsets.UTF_8) - } - - before { - downloadJarsDir = Utils.createTempDir() - downloadFilesDir = Utils.createTempDir() - fileFetcher = mock[FileFetcher] - } - - after { - downloadJarsDir.delete() - downloadFilesDir.delete() - } - - test("Downloads from remote server should invoke the file fetcher") { - val sparkConf = getSparkConfForRemoteFileDownloads - val initContainerUnderTest = new SparkPodInitContainer(sparkConf, fileFetcher) - initContainerUnderTest.run() - Mockito.verify(fileFetcher).fetchFile("http://localhost:9000/jar1.jar", downloadJarsDir) - Mockito.verify(fileFetcher).fetchFile("hdfs://localhost:9000/jar2.jar", downloadJarsDir) - Mockito.verify(fileFetcher).fetchFile("http://localhost:9000/file.txt", downloadFilesDir) - } - - private def getSparkConfForRemoteFileDownloads: SparkConf = { - new SparkConf(true) - .set(INIT_CONTAINER_REMOTE_JARS, - "http://localhost:9000/jar1.jar,hdfs://localhost:9000/jar2.jar") - .set(INIT_CONTAINER_REMOTE_FILES, - "http://localhost:9000/file.txt") - .set(JARS_DOWNLOAD_LOCATION, downloadJarsDir.getAbsolutePath) - .set(FILES_DOWNLOAD_LOCATION, downloadFilesDir.getAbsolutePath) - } - - private def createTempFile(extension: String): String = { - val dir = Utils.createTempDir() - val file = new File(dir, s"${UUID.randomUUID().toString}.$extension") - Files.write(UUID.randomUUID().toString, file, Charsets.UTF_8) - file.getAbsolutePath - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index bf4ec04893204..6a501592f42a3 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -38,6 +38,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { private val DRIVER_POD_UID = "pod-id" private val DRIVER_POD_API_VERSION = "v1" private val DRIVER_POD_KIND = "pod" + private val KUBERNETES_RESOURCE_PREFIX = "resource-example" private type ResourceList = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ HasMetadata, Boolean] @@ -61,6 +62,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { private val submissionSteps = Seq(FirstTestConfigurationStep, SecondTestConfigurationStep) private var createdPodArgumentCaptor: ArgumentCaptor[Pod] = _ private var createdResourcesArgumentCaptor: ArgumentCaptor[HasMetadata] = _ + private var createdContainerArgumentCaptor: ArgumentCaptor[Container] = _ before { MockitoAnnotations.initMocks(this) @@ -94,7 +96,8 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { kubernetesClient, false, "spark", - loggingPodStatusWatcher) + loggingPodStatusWatcher, + KUBERNETES_RESOURCE_PREFIX) submissionClient.run() val createdPod = createdPodArgumentCaptor.getValue assert(createdPod.getMetadata.getName === FirstTestConfigurationStep.podName) @@ -108,62 +111,52 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { SecondTestConfigurationStep.containerName) } - test("The client should create the secondary Kubernetes resources.") { + test("The client should create Kubernetes resources") { + val EXAMPLE_JAVA_OPTS = "-XX:+HeapDumpOnOutOfMemoryError -XX:+PrintGCDetails" + val EXPECTED_JAVA_OPTS = "-XX\\:+HeapDumpOnOutOfMemoryError -XX\\:+PrintGCDetails" val submissionClient = new Client( submissionSteps, - new SparkConf(false), + new SparkConf(false) + .set(org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS, EXAMPLE_JAVA_OPTS), kubernetesClient, false, "spark", - loggingPodStatusWatcher) + loggingPodStatusWatcher, + KUBERNETES_RESOURCE_PREFIX) submissionClient.run() val createdPod = createdPodArgumentCaptor.getValue val otherCreatedResources = createdResourcesArgumentCaptor.getAllValues - assert(otherCreatedResources.size === 1) - val createdResource = Iterables.getOnlyElement(otherCreatedResources).asInstanceOf[Secret] - assert(createdResource.getMetadata.getName === FirstTestConfigurationStep.secretName) - assert(createdResource.getData.asScala === + assert(otherCreatedResources.size === 2) + val secrets = otherCreatedResources.toArray + .filter(_.isInstanceOf[Secret]).map(_.asInstanceOf[Secret]) + val configMaps = otherCreatedResources.toArray + .filter(_.isInstanceOf[ConfigMap]).map(_.asInstanceOf[ConfigMap]) + assert(secrets.nonEmpty) + val secret = secrets.head + assert(secret.getMetadata.getName === FirstTestConfigurationStep.secretName) + assert(secret.getData.asScala === Map(FirstTestConfigurationStep.secretKey -> FirstTestConfigurationStep.secretData)) - val ownerReference = Iterables.getOnlyElement(createdResource.getMetadata.getOwnerReferences) + val ownerReference = Iterables.getOnlyElement(secret.getMetadata.getOwnerReferences) assert(ownerReference.getName === createdPod.getMetadata.getName) assert(ownerReference.getKind === DRIVER_POD_KIND) assert(ownerReference.getUid === DRIVER_POD_UID) assert(ownerReference.getApiVersion === DRIVER_POD_API_VERSION) - } - - test("The client should attach the driver container with the appropriate JVM options.") { - val sparkConf = new SparkConf(false) - .set("spark.logConf", "true") - .set( - org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS, - "-XX:+HeapDumpOnOutOfMemoryError -XX:+PrintGCDetails") - val submissionClient = new Client( - submissionSteps, - sparkConf, - kubernetesClient, - false, - "spark", - loggingPodStatusWatcher) - submissionClient.run() - val createdPod = createdPodArgumentCaptor.getValue + assert(configMaps.nonEmpty) + val configMap = configMaps.head + assert(configMap.getMetadata.getName === + s"$KUBERNETES_RESOURCE_PREFIX-driver-conf-map") + assert(configMap.getData.containsKey(SPARK_CONF_FILE_NAME)) + assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains(EXPECTED_JAVA_OPTS)) + assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains( + "spark.custom-conf=custom-conf-value")) val driverContainer = Iterables.getOnlyElement(createdPod.getSpec.getContainers) assert(driverContainer.getName === SecondTestConfigurationStep.containerName) - val driverJvmOptsEnvs = driverContainer.getEnv.asScala.filter { env => - env.getName.startsWith(ENV_JAVA_OPT_PREFIX) - }.sortBy(_.getName) - assert(driverJvmOptsEnvs.size === 4) - - val expectedJvmOptsValues = Seq( - "-Dspark.logConf=true", - s"-D${SecondTestConfigurationStep.sparkConfKey}=" + - s"${SecondTestConfigurationStep.sparkConfValue}", - "-XX:+HeapDumpOnOutOfMemoryError", - "-XX:+PrintGCDetails") - driverJvmOptsEnvs.zip(expectedJvmOptsValues).zipWithIndex.foreach { - case ((resolvedEnv, expectedJvmOpt), index) => - assert(resolvedEnv.getName === s"$ENV_JAVA_OPT_PREFIX$index") - assert(resolvedEnv.getValue === expectedJvmOpt) - } + val driverEnv = driverContainer.getEnv.asScala.head + assert(driverEnv.getName === ENV_SPARK_CONF_DIR) + assert(driverEnv.getValue === SPARK_CONF_DIR_INTERNAL) + val driverMount = driverContainer.getVolumeMounts.asScala.head + assert(driverMount.getName === SPARK_CONF_VOLUME) + assert(driverMount.getMountPath === SPARK_CONF_DIR_INTERNAL) } test("Waiting for app completion should stall on the watcher") { @@ -173,7 +166,8 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { kubernetesClient, true, "spark", - loggingPodStatusWatcher) + loggingPodStatusWatcher, + KUBERNETES_RESOURCE_PREFIX) submissionClient.run() verify(loggingPodStatusWatcher).awaitCompletion() } @@ -209,13 +203,11 @@ private object FirstTestConfigurationStep extends DriverConfigurationStep { } private object SecondTestConfigurationStep extends DriverConfigurationStep { - val annotationKey = "second-submit" val annotationValue = "submitted" val sparkConfKey = "spark.custom-conf" val sparkConfValue = "custom-conf-value" val containerName = "driverContainer" - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { val modifiedPod = new PodBuilder(driverSpec.driverPod) .editMetadata() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala index 033d303e946fd..df34d2dbcb5be 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala @@ -25,7 +25,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { private val DRIVER_IMAGE = "driver-image" private val IC_IMAGE = "init-container-image" private val APP_ID = "spark-app-id" - private val LAUNCH_TIME = 975256L + private val KUBERNETES_RESOURCE_PREFIX = "example-prefix" private val APP_NAME = "spark" private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" private val APP_ARGS = Array("arg1", "arg2") @@ -38,7 +38,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") val orchestrator = new DriverConfigOrchestrator( APP_ID, - LAUNCH_TIME, + KUBERNETES_RESOURCE_PREFIX, Some(mainAppResource), APP_NAME, MAIN_CLASS, @@ -49,15 +49,14 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { classOf[BasicDriverConfigurationStep], classOf[DriverServiceBootstrapStep], classOf[DriverKubernetesCredentialsStep], - classOf[DependencyResolutionStep] - ) + classOf[DependencyResolutionStep]) } test("Base submission steps without a main app resource.") { val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE) val orchestrator = new DriverConfigOrchestrator( APP_ID, - LAUNCH_TIME, + KUBERNETES_RESOURCE_PREFIX, Option.empty, APP_NAME, MAIN_CLASS, @@ -67,31 +66,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { orchestrator, classOf[BasicDriverConfigurationStep], classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep] - ) - } - - test("Submission steps with an init-container.") { - val sparkConf = new SparkConf(false) - .set(CONTAINER_IMAGE, DRIVER_IMAGE) - .set(INIT_CONTAINER_IMAGE.key, IC_IMAGE) - .set("spark.jars", "hdfs://localhost:9000/var/apps/jars/jar1.jar") - val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - LAUNCH_TIME, - Some(mainAppResource), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep], - classOf[DependencyResolutionStep], - classOf[DriverInitContainerBootstrapStep]) + classOf[DriverKubernetesCredentialsStep]) } test("Submission steps with driver secrets to mount") { @@ -102,7 +77,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") val orchestrator = new DriverConfigOrchestrator( APP_ID, - LAUNCH_TIME, + KUBERNETES_RESOURCE_PREFIX, Some(mainAppResource), APP_NAME, MAIN_CLASS, @@ -122,7 +97,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { .set(CONTAINER_IMAGE, DRIVER_IMAGE) var orchestrator = new DriverConfigOrchestrator( APP_ID, - LAUNCH_TIME, + KUBERNETES_RESOURCE_PREFIX, Some(JavaMainAppResource("file:///var/apps/jars/main.jar")), APP_NAME, MAIN_CLASS, @@ -135,7 +110,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { sparkConf.set("spark.files", "/path/to/file1,/path/to/file2") orchestrator = new DriverConfigOrchestrator( APP_ID, - LAUNCH_TIME, + KUBERNETES_RESOURCE_PREFIX, Some(JavaMainAppResource("local:///var/apps/jars/main.jar")), APP_NAME, MAIN_CLASS, diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index b136f2c02ffba..ce068531c7673 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -73,16 +73,13 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { assert(preparedDriverSpec.driverContainer.getImage === "spark-driver:latest") assert(preparedDriverSpec.driverContainer.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY) - assert(preparedDriverSpec.driverContainer.getEnv.size === 7) + assert(preparedDriverSpec.driverContainer.getEnv.size === 4) val envs = preparedDriverSpec.driverContainer .getEnv .asScala .map(env => (env.getName, env.getValue)) .toMap assert(envs(ENV_CLASSPATH) === "/opt/spark/spark-examples.jar") - assert(envs(ENV_DRIVER_MEMORY) === "256M") - assert(envs(ENV_DRIVER_MAIN_CLASS) === MAIN_CLASS) - assert(envs(ENV_DRIVER_ARGS) === "arg1 arg2 \"arg 3\"") assert(envs(DRIVER_CUSTOM_ENV_KEY1) === "customDriverEnv1") assert(envs(DRIVER_CUSTOM_ENV_KEY2) === "customDriverEnv2") @@ -112,7 +109,8 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { val expectedSparkConf = Map( KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", "spark.app.id" -> APP_ID, - KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX) + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, + "spark.kubernetes.submitInDriver" -> "true") assert(resolvedSparkConf === expectedSparkConf) } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala index 991b03cafb76c..ca43fc97dc991 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala @@ -29,24 +29,17 @@ import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec class DependencyResolutionStepSuite extends SparkFunSuite { private val SPARK_JARS = Seq( - "hdfs://localhost:9000/apps/jars/jar1.jar", - "file:///home/user/apps/jars/jar2.jar", - "local:///var/apps/jars/jar3.jar") + "apps/jars/jar1.jar", + "local:///var/apps/jars/jar2.jar") private val SPARK_FILES = Seq( - "file:///home/user/apps/files/file1.txt", - "hdfs://localhost:9000/apps/files/file2.txt", - "local:///var/apps/files/file3.txt") - - private val JARS_DOWNLOAD_PATH = "/mnt/spark-data/jars" - private val FILES_DOWNLOAD_PATH = "/mnt/spark-data/files" + "apps/files/file1.txt", + "local:///var/apps/files/file2.txt") test("Added dependencies should be resolved in Spark configuration and environment") { val dependencyResolutionStep = new DependencyResolutionStep( SPARK_JARS, - SPARK_FILES, - JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_PATH) + SPARK_FILES) val driverPod = new PodBuilder().build() val baseDriverSpec = KubernetesDriverSpec( driverPod = driverPod, @@ -58,24 +51,19 @@ class DependencyResolutionStepSuite extends SparkFunSuite { assert(preparedDriverSpec.otherKubernetesResources.isEmpty) val resolvedSparkJars = preparedDriverSpec.driverSparkConf.get("spark.jars").split(",").toSet val expectedResolvedSparkJars = Set( - "hdfs://localhost:9000/apps/jars/jar1.jar", - s"$JARS_DOWNLOAD_PATH/jar2.jar", - "/var/apps/jars/jar3.jar") + "apps/jars/jar1.jar", + "/var/apps/jars/jar2.jar") assert(resolvedSparkJars === expectedResolvedSparkJars) val resolvedSparkFiles = preparedDriverSpec.driverSparkConf.get("spark.files").split(",").toSet val expectedResolvedSparkFiles = Set( - s"$FILES_DOWNLOAD_PATH/file1.txt", - s"hdfs://localhost:9000/apps/files/file2.txt", - s"/var/apps/files/file3.txt") + "apps/files/file1.txt", + "/var/apps/files/file2.txt") assert(resolvedSparkFiles === expectedResolvedSparkFiles) val driverEnv = preparedDriverSpec.driverContainer.getEnv.asScala assert(driverEnv.size === 1) assert(driverEnv.head.getName === ENV_MOUNTED_CLASSPATH) val resolvedDriverClasspath = driverEnv.head.getValue.split(File.pathSeparator).toSet - val expectedResolvedDriverClasspath = Set( - s"$JARS_DOWNLOAD_PATH/jar1.jar", - s"$JARS_DOWNLOAD_PATH/jar2.jar", - "/var/apps/jars/jar3.jar") + val expectedResolvedDriverClasspath = expectedResolvedSparkJars assert(resolvedDriverClasspath === expectedResolvedDriverClasspath) } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala deleted file mode 100644 index 758871e2ba356..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.StringReader -import java.util.Properties - -import scala.collection.JavaConverters._ - -import com.google.common.collect.Maps -import io.fabric8.kubernetes.api.model.{ConfigMap, ContainerBuilder, HasMetadata, PodBuilder, SecretBuilder} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.deploy.k8s.submit.steps.initcontainer.{InitContainerConfigurationStep, InitContainerSpec} -import org.apache.spark.util.Utils - -class DriverInitContainerBootstrapStepSuite extends SparkFunSuite { - - private val CONFIG_MAP_NAME = "spark-init-config-map" - private val CONFIG_MAP_KEY = "spark-init-config-map-key" - - test("The init container bootstrap step should use all of the init container steps") { - val baseDriverSpec = KubernetesDriverSpec( - driverPod = new PodBuilder().build(), - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) - val initContainerSteps = Seq( - FirstTestInitContainerConfigurationStep, - SecondTestInitContainerConfigurationStep) - val bootstrapStep = new DriverInitContainerBootstrapStep( - initContainerSteps, - CONFIG_MAP_NAME, - CONFIG_MAP_KEY) - - val preparedDriverSpec = bootstrapStep.configureDriver(baseDriverSpec) - - assert(preparedDriverSpec.driverPod.getMetadata.getLabels.asScala === - FirstTestInitContainerConfigurationStep.additionalLabels) - val additionalDriverEnv = preparedDriverSpec.driverContainer.getEnv.asScala - assert(additionalDriverEnv.size === 1) - assert(additionalDriverEnv.head.getName === - FirstTestInitContainerConfigurationStep.additionalMainContainerEnvKey) - assert(additionalDriverEnv.head.getValue === - FirstTestInitContainerConfigurationStep.additionalMainContainerEnvValue) - - assert(preparedDriverSpec.otherKubernetesResources.size === 2) - assert(preparedDriverSpec.otherKubernetesResources.contains( - FirstTestInitContainerConfigurationStep.additionalKubernetesResource)) - assert(preparedDriverSpec.otherKubernetesResources.exists { - case configMap: ConfigMap => - val hasMatchingName = configMap.getMetadata.getName == CONFIG_MAP_NAME - val configMapData = configMap.getData.asScala - val hasCorrectNumberOfEntries = configMapData.size == 1 - val initContainerPropertiesRaw = configMapData(CONFIG_MAP_KEY) - val initContainerProperties = new Properties() - Utils.tryWithResource(new StringReader(initContainerPropertiesRaw)) { - initContainerProperties.load(_) - } - val initContainerPropertiesMap = Maps.fromProperties(initContainerProperties).asScala - val expectedInitContainerProperties = Map( - SecondTestInitContainerConfigurationStep.additionalInitContainerPropertyKey -> - SecondTestInitContainerConfigurationStep.additionalInitContainerPropertyValue) - val hasMatchingProperties = initContainerPropertiesMap == expectedInitContainerProperties - hasMatchingName && hasCorrectNumberOfEntries && hasMatchingProperties - - case _ => false - }) - - val initContainers = preparedDriverSpec.driverPod.getSpec.getInitContainers - assert(initContainers.size() === 1) - val initContainerEnv = initContainers.get(0).getEnv.asScala - assert(initContainerEnv.size === 1) - assert(initContainerEnv.head.getName === - SecondTestInitContainerConfigurationStep.additionalInitContainerEnvKey) - assert(initContainerEnv.head.getValue === - SecondTestInitContainerConfigurationStep.additionalInitContainerEnvValue) - - val expectedSparkConf = Map( - INIT_CONTAINER_CONFIG_MAP_NAME.key -> CONFIG_MAP_NAME, - INIT_CONTAINER_CONFIG_MAP_KEY_CONF.key -> CONFIG_MAP_KEY, - SecondTestInitContainerConfigurationStep.additionalDriverSparkConfKey -> - SecondTestInitContainerConfigurationStep.additionalDriverSparkConfValue) - assert(preparedDriverSpec.driverSparkConf.getAll.toMap === expectedSparkConf) - } -} - -private object FirstTestInitContainerConfigurationStep extends InitContainerConfigurationStep { - - val additionalLabels = Map("additionalLabelkey" -> "additionalLabelValue") - val additionalMainContainerEnvKey = "TEST_ENV_MAIN_KEY" - val additionalMainContainerEnvValue = "TEST_ENV_MAIN_VALUE" - val additionalKubernetesResource = new SecretBuilder() - .withNewMetadata() - .withName("test-secret") - .endMetadata() - .addToData("secret-key", "secret-value") - .build() - - override def configureInitContainer(initContainerSpec: InitContainerSpec): InitContainerSpec = { - val driverPod = new PodBuilder(initContainerSpec.driverPod) - .editOrNewMetadata() - .addToLabels(additionalLabels.asJava) - .endMetadata() - .build() - val mainContainer = new ContainerBuilder(initContainerSpec.driverContainer) - .addNewEnv() - .withName(additionalMainContainerEnvKey) - .withValue(additionalMainContainerEnvValue) - .endEnv() - .build() - initContainerSpec.copy( - driverPod = driverPod, - driverContainer = mainContainer, - dependentResources = initContainerSpec.dependentResources ++ - Seq(additionalKubernetesResource)) - } -} - -private object SecondTestInitContainerConfigurationStep extends InitContainerConfigurationStep { - val additionalInitContainerEnvKey = "TEST_ENV_INIT_KEY" - val additionalInitContainerEnvValue = "TEST_ENV_INIT_VALUE" - val additionalInitContainerPropertyKey = "spark.initcontainer.testkey" - val additionalInitContainerPropertyValue = "testvalue" - val additionalDriverSparkConfKey = "spark.driver.testkey" - val additionalDriverSparkConfValue = "spark.driver.testvalue" - - override def configureInitContainer(initContainerSpec: InitContainerSpec): InitContainerSpec = { - val initContainer = new ContainerBuilder(initContainerSpec.initContainer) - .addNewEnv() - .withName(additionalInitContainerEnvKey) - .withValue(additionalInitContainerEnvValue) - .endEnv() - .build() - val initContainerProperties = initContainerSpec.properties ++ - Map(additionalInitContainerPropertyKey -> additionalInitContainerPropertyValue) - val driverSparkConf = initContainerSpec.driverSparkConf ++ - Map(additionalDriverSparkConfKey -> additionalDriverSparkConfValue) - initContainerSpec.copy( - initContainer = initContainer, - properties = initContainerProperties, - driverSparkConf = driverSparkConf) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala deleted file mode 100644 index 4553f9f6b1d45..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model._ -import org.mockito.{Mock, MockitoAnnotations} -import org.mockito.Matchers.any -import org.mockito.Mockito.when -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer -import org.scalatest.BeforeAndAfter - -import org.apache.spark.SparkFunSuite -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, PodWithDetachedInitContainer} -import org.apache.spark.deploy.k8s.Config._ - -class BasicInitContainerConfigurationStepSuite extends SparkFunSuite with BeforeAndAfter { - - private val SPARK_JARS = Seq( - "hdfs://localhost:9000/app/jars/jar1.jar", "file:///app/jars/jar2.jar") - private val SPARK_FILES = Seq( - "hdfs://localhost:9000/app/files/file1.txt", "file:///app/files/file2.txt") - private val JARS_DOWNLOAD_PATH = "/var/data/jars" - private val FILES_DOWNLOAD_PATH = "/var/data/files" - private val POD_LABEL = Map("bootstrap" -> "true") - private val INIT_CONTAINER_NAME = "init-container" - private val DRIVER_CONTAINER_NAME = "driver-container" - - @Mock - private var podAndInitContainerBootstrap : InitContainerBootstrap = _ - - before { - MockitoAnnotations.initMocks(this) - when(podAndInitContainerBootstrap.bootstrapInitContainer( - any[PodWithDetachedInitContainer])).thenAnswer(new Answer[PodWithDetachedInitContainer] { - override def answer(invocation: InvocationOnMock) : PodWithDetachedInitContainer = { - val pod = invocation.getArgumentAt(0, classOf[PodWithDetachedInitContainer]) - pod.copy( - pod = new PodBuilder(pod.pod) - .withNewMetadata() - .addToLabels("bootstrap", "true") - .endMetadata() - .withNewSpec().endSpec() - .build(), - initContainer = new ContainerBuilder() - .withName(INIT_CONTAINER_NAME) - .build(), - mainContainer = new ContainerBuilder() - .withName(DRIVER_CONTAINER_NAME) - .build() - )}}) - } - - test("additionalDriverSparkConf with mix of remote files and jars") { - val baseInitStep = new BasicInitContainerConfigurationStep( - SPARK_JARS, - SPARK_FILES, - JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_PATH, - podAndInitContainerBootstrap) - val expectedDriverSparkConf = Map( - JARS_DOWNLOAD_LOCATION.key -> JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_LOCATION.key -> FILES_DOWNLOAD_PATH, - INIT_CONTAINER_REMOTE_JARS.key -> "hdfs://localhost:9000/app/jars/jar1.jar", - INIT_CONTAINER_REMOTE_FILES.key -> "hdfs://localhost:9000/app/files/file1.txt") - val initContainerSpec = InitContainerSpec( - Map.empty[String, String], - Map.empty[String, String], - new Container(), - new Container(), - new Pod, - Seq.empty[HasMetadata]) - val returnContainerSpec = baseInitStep.configureInitContainer(initContainerSpec) - assert(expectedDriverSparkConf === returnContainerSpec.properties) - assert(returnContainerSpec.initContainer.getName === INIT_CONTAINER_NAME) - assert(returnContainerSpec.driverContainer.getName === DRIVER_CONTAINER_NAME) - assert(returnContainerSpec.driverPod.getMetadata.getLabels.asScala === POD_LABEL) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala deleted file mode 100644 index 09b42e4484d86..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ - -class InitContainerConfigOrchestratorSuite extends SparkFunSuite { - - private val DOCKER_IMAGE = "init-container" - private val SPARK_JARS = Seq( - "hdfs://localhost:9000/app/jars/jar1.jar", "file:///app/jars/jar2.jar") - private val SPARK_FILES = Seq( - "hdfs://localhost:9000/app/files/file1.txt", "file:///app/files/file2.txt") - private val JARS_DOWNLOAD_PATH = "/var/data/jars" - private val FILES_DOWNLOAD_PATH = "/var/data/files" - private val DOCKER_IMAGE_PULL_POLICY: String = "IfNotPresent" - private val CUSTOM_LABEL_KEY = "customLabel" - private val CUSTOM_LABEL_VALUE = "customLabelValue" - private val INIT_CONTAINER_CONFIG_MAP_NAME = "spark-init-config-map" - private val INIT_CONTAINER_CONFIG_MAP_KEY = "spark-init-config-map-key" - private val SECRET_FOO = "foo" - private val SECRET_BAR = "bar" - private val SECRET_MOUNT_PATH = "/etc/secrets/init-container" - - test("including basic configuration step") { - val sparkConf = new SparkConf(true) - .set(CONTAINER_IMAGE, DOCKER_IMAGE) - .set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$CUSTOM_LABEL_KEY", CUSTOM_LABEL_VALUE) - - val orchestrator = new InitContainerConfigOrchestrator( - SPARK_JARS.take(1), - SPARK_FILES, - JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_PATH, - DOCKER_IMAGE_PULL_POLICY, - INIT_CONTAINER_CONFIG_MAP_NAME, - INIT_CONTAINER_CONFIG_MAP_KEY, - sparkConf) - val initSteps = orchestrator.getAllConfigurationSteps - assert(initSteps.lengthCompare(1) == 0) - assert(initSteps.head.isInstanceOf[BasicInitContainerConfigurationStep]) - } - - test("including step to mount user-specified secrets") { - val sparkConf = new SparkConf(false) - .set(CONTAINER_IMAGE, DOCKER_IMAGE) - .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) - .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) - - val orchestrator = new InitContainerConfigOrchestrator( - SPARK_JARS.take(1), - SPARK_FILES, - JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_PATH, - DOCKER_IMAGE_PULL_POLICY, - INIT_CONTAINER_CONFIG_MAP_NAME, - INIT_CONTAINER_CONFIG_MAP_KEY, - sparkConf) - val initSteps = orchestrator.getAllConfigurationSteps - assert(initSteps.length === 2) - assert(initSteps.head.isInstanceOf[BasicInitContainerConfigurationStep]) - assert(initSteps(1).isInstanceOf[InitContainerMountSecretsStep]) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala deleted file mode 100644 index 7ac0bde80dfe6..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder} - -import org.apache.spark.SparkFunSuite -import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} - -class InitContainerMountSecretsStepSuite extends SparkFunSuite { - - private val SECRET_FOO = "foo" - private val SECRET_BAR = "bar" - private val SECRET_MOUNT_PATH = "/etc/secrets/init-container" - - test("mounts all given secrets") { - val baseInitContainerSpec = InitContainerSpec( - Map.empty, - Map.empty, - new ContainerBuilder().build(), - new ContainerBuilder().build(), - new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build(), - Seq.empty) - val secretNamesToMountPaths = Map( - SECRET_FOO -> SECRET_MOUNT_PATH, - SECRET_BAR -> SECRET_MOUNT_PATH) - - val mountSecretsBootstrap = new MountSecretsBootstrap(secretNamesToMountPaths) - val initContainerMountSecretsStep = new InitContainerMountSecretsStep(mountSecretsBootstrap) - val configuredInitContainerSpec = initContainerMountSecretsStep.configureInitContainer( - baseInitContainerSpec) - val initContainerWithSecretsMounted = configuredInitContainerSpec.initContainer - - Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => - assert(SecretVolumeUtils.containerHasVolume( - initContainerWithSecretsMounted, volumeName, SECRET_MOUNT_PATH))) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index a3c615be031d2..7755b93835047 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -19,15 +19,13 @@ package org.apache.spark.scheduler.cluster.k8s import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model._ -import org.mockito.{AdditionalAnswers, MockitoAnnotations} -import org.mockito.Matchers.any -import org.mockito.Mockito._ +import org.mockito.MockitoAnnotations import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, MountSecretsBootstrap, PodWithDetachedInitContainer, SecretVolumeUtils} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.MountSecretsBootstrap class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { @@ -55,10 +53,11 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef .set(KUBERNETES_DRIVER_POD_NAME, driverPodName) .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) .set(CONTAINER_IMAGE, executorImage) + .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) } test("basic executor pod has reasonable defaults") { - val factory = new ExecutorPodFactory(baseConf, None, None, None) + val factory = new ExecutorPodFactory(baseConf, None) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) @@ -89,7 +88,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple") - val factory = new ExecutorPodFactory(conf, None, None, None) + val factory = new ExecutorPodFactory(conf, None) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) @@ -101,7 +100,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") - val factory = new ExecutorPodFactory(conf, None, None, None) + val factory = new ExecutorPodFactory(conf, None) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)]("qux" -> "quux"), driverPod, Map[String, Int]()) @@ -116,11 +115,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef val conf = baseConf.clone() val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1")) - val factory = new ExecutorPodFactory( - conf, - Some(secretsBootstrap), - None, - None) + val factory = new ExecutorPodFactory(conf, Some(secretsBootstrap)) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) @@ -138,50 +133,6 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef checkOwnerReferences(executor, driverPodUid) } - test("init-container bootstrap step adds an init container") { - val conf = baseConf.clone() - val initContainerBootstrap = mock(classOf[InitContainerBootstrap]) - when(initContainerBootstrap.bootstrapInitContainer( - any(classOf[PodWithDetachedInitContainer]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - - val factory = new ExecutorPodFactory( - conf, - None, - Some(initContainerBootstrap), - None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - assert(executor.getSpec.getInitContainers.size() === 1) - checkOwnerReferences(executor, driverPodUid) - } - - test("init-container with secrets mount bootstrap") { - val conf = baseConf.clone() - val initContainerBootstrap = mock(classOf[InitContainerBootstrap]) - when(initContainerBootstrap.bootstrapInitContainer( - any(classOf[PodWithDetachedInitContainer]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1")) - - val factory = new ExecutorPodFactory( - conf, - Some(secretsBootstrap), - Some(initContainerBootstrap), - Some(secretsBootstrap)) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - assert(executor.getSpec.getVolumes.size() === 1) - assert(SecretVolumeUtils.podHasVolume(executor, "secret1-volume")) - assert(SecretVolumeUtils.containerHasVolume( - executor.getSpec.getContainers.get(0), "secret1-volume", "/var/secret1")) - assert(executor.getSpec.getInitContainers.size() === 1) - assert(SecretVolumeUtils.containerHasVolume( - executor.getSpec.getInitContainers.get(0), "secret1-volume", "/var/secret1")) - - checkOwnerReferences(executor, driverPodUid) - } - // There is always exactly one controller reference, and it points to the driver pod. private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { assert(executor.getMetadata.getOwnerReferences.size() === 1) @@ -197,8 +148,8 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef ENV_EXECUTOR_CORES -> "1", ENV_EXECUTOR_MEMORY -> "1g", ENV_APPLICATION_ID -> "dummy", - ENV_EXECUTOR_POD_IP -> null, - ENV_MOUNTED_CLASSPATH -> "/var/spark-data/spark-jars/*") ++ additionalEnvVars + ENV_SPARK_CONF_DIR -> SPARK_CONF_DIR_INTERNAL, + ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars assert(executor.getSpec.getContainers.size() === 1) assert(executor.getSpec.getContainers.get(0).getEnv.size() === defaultEnvs.size) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index 491b7cf692478..9badf8556afc3 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -40,7 +40,6 @@ RUN set -ex && \ COPY ${spark_jars} /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin -COPY conf /opt/spark/conf COPY ${img_path}/spark/entrypoint.sh /opt/ COPY examples /opt/spark/examples COPY data /opt/spark/data diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index d0cf284f035ea..3e166116aa3fd 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -56,14 +56,10 @@ fi case "$SPARK_K8S_CMD" in driver) CMD=( - ${JAVA_HOME}/bin/java - "${SPARK_JAVA_OPTS[@]}" - -cp "$SPARK_CLASSPATH" - -Xms$SPARK_DRIVER_MEMORY - -Xmx$SPARK_DRIVER_MEMORY - -Dspark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS - $SPARK_DRIVER_CLASS - $SPARK_DRIVER_ARGS + "$SPARK_HOME/bin/spark-submit" + --conf "spark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS" + --deploy-mode client + "$@" ) ;; @@ -83,14 +79,6 @@ case "$SPARK_K8S_CMD" in ) ;; - init) - CMD=( - "$SPARK_HOME/bin/spark-class" - "org.apache.spark.deploy.k8s.SparkPodInitContainer" - "$@" - ) - ;; - *) echo "Unknown command: $SPARK_K8S_CMD" 1>&2 exit 1 From 5f4deff19511b6870f056eba5489104b9cac05a9 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Mon, 19 Mar 2018 18:02:04 -0700 Subject: [PATCH 492/774] [SPARK-23660] Fix exception in yarn cluster mode when application ended fast ## What changes were proposed in this pull request? Yarn throws the following exception in cluster mode when the application is really small: ``` 18/03/07 23:34:22 WARN netty.NettyRpcEnv: Ignored failure: java.util.concurrent.RejectedExecutionException: Task java.util.concurrent.ScheduledThreadPoolExecutor$ScheduledFutureTask7c974942 rejected from java.util.concurrent.ScheduledThreadPoolExecutor1eea9d2d[Terminated, pool size = 0, active threads = 0, queued tasks = 0, completed tasks = 0] 18/03/07 23:34:22 ERROR yarn.ApplicationMaster: Uncaught exception: org.apache.spark.SparkException: Exception thrown in awaitResult: at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:205) at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75) at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:92) at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:76) at org.apache.spark.deploy.yarn.YarnAllocator.(YarnAllocator.scala:102) at org.apache.spark.deploy.yarn.YarnRMClient.register(YarnRMClient.scala:77) at org.apache.spark.deploy.yarn.ApplicationMaster.registerAM(ApplicationMaster.scala:450) at org.apache.spark.deploy.yarn.ApplicationMaster.runDriver(ApplicationMaster.scala:493) at org.apache.spark.deploy.yarn.ApplicationMaster.org$apache$spark$deploy$yarn$ApplicationMaster$$runImpl(ApplicationMaster.scala:345) at org.apache.spark.deploy.yarn.ApplicationMaster$$anonfun$run$2.apply$mcV$sp(ApplicationMaster.scala:260) at org.apache.spark.deploy.yarn.ApplicationMaster$$anonfun$run$2.apply(ApplicationMaster.scala:260) at org.apache.spark.deploy.yarn.ApplicationMaster$$anonfun$run$2.apply(ApplicationMaster.scala:260) at org.apache.spark.deploy.yarn.ApplicationMaster$$anon$5.run(ApplicationMaster.scala:810) at java.security.AccessController.doPrivileged(Native Method) at javax.security.auth.Subject.doAs(Subject.java:422) at org.apache.hadoop.security.UserGroupInformation.doAs(UserGroupInformation.java:1920) at org.apache.spark.deploy.yarn.ApplicationMaster.doAsUser(ApplicationMaster.scala:809) at org.apache.spark.deploy.yarn.ApplicationMaster.run(ApplicationMaster.scala:259) at org.apache.spark.deploy.yarn.ApplicationMaster$.main(ApplicationMaster.scala:834) at org.apache.spark.deploy.yarn.ApplicationMaster.main(ApplicationMaster.scala) Caused by: org.apache.spark.rpc.RpcEnvStoppedException: RpcEnv already stopped. at org.apache.spark.rpc.netty.Dispatcher.postMessage(Dispatcher.scala:158) at org.apache.spark.rpc.netty.Dispatcher.postLocalMessage(Dispatcher.scala:135) at org.apache.spark.rpc.netty.NettyRpcEnv.ask(NettyRpcEnv.scala:229) at org.apache.spark.rpc.netty.NettyRpcEndpointRef.ask(NettyRpcEnv.scala:523) at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:91) ... 17 more 18/03/07 23:34:22 INFO yarn.ApplicationMaster: Final app status: FAILED, exitCode: 13, (reason: Uncaught exception: org.apache.spark.SparkException: Exception thrown in awaitResult: ) ``` Example application: ``` object ExampleApp { def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("ExampleApp") val sc = new SparkContext(conf) try { // Do nothing } finally { sc.stop() } } ``` This PR pauses user class thread after `SparkContext` created and keeps it so until application master initialises properly. ## How was this patch tested? Automated: Existing unit tests Manual: Application submitted into small cluster Author: Gabor Somogyi Closes #20807 from gaborgsomogyi/SPARK-23660. --- .../spark/deploy/yarn/ApplicationMaster.scala | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 2f88feb0f1fdf..6e35d23def6f0 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -418,7 +418,19 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } private def sparkContextInitialized(sc: SparkContext) = { - sparkContextPromise.success(sc) + sparkContextPromise.synchronized { + // Notify runDriver function that SparkContext is available + sparkContextPromise.success(sc) + // Pause the user class thread in order to make proper initialization in runDriver function. + sparkContextPromise.wait() + } + } + + private def resumeDriver(): Unit = { + // When initialization in runDriver happened the user class thread has to be resumed. + sparkContextPromise.synchronized { + sparkContextPromise.notify() + } } private def registerAM( @@ -497,6 +509,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends // if the user app did not create a SparkContext. throw new IllegalStateException("User did not initialize spark context!") } + resumeDriver() userClassThread.join() } catch { case e: SparkException if e.getCause().isInstanceOf[TimeoutException] => @@ -506,6 +519,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_SC_NOT_INITED, "Timed out waiting for SparkContext.") + } finally { + resumeDriver() } } From 566321852b2d60641fe86acbc8914b4a7063b58e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 19 Mar 2018 21:25:37 -0700 Subject: [PATCH 493/774] [SPARK-23691][PYTHON] Use sql_conf util in PySpark tests where possible ## What changes were proposed in this pull request? https://github.com/apache/spark/commit/d6632d185e147fcbe6724545488ad80dce20277e added an useful util ```python contextmanager def sql_conf(self, pairs): ... ``` to allow configuration set/unset within a block: ```python with self.sql_conf({"spark.blah.blah.blah", "blah"}) # test codes ``` This PR proposes to use this util where possible in PySpark tests. Note that there look already few places affecting tests without restoring the original value back in unittest classes. ## How was this patch tested? Manually tested via: ``` ./run-tests --modules=pyspark-sql --python-executables=python2 ./run-tests --modules=pyspark-sql --python-executables=python3 ``` Author: hyukjinkwon Closes #20830 from HyukjinKwon/cleanup-sql-conf. --- python/pyspark/sql/tests.py | 130 ++++++++++++++---------------------- 1 file changed, 50 insertions(+), 80 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a0d547ad620e5..39d6c5226f138 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2461,17 +2461,13 @@ def test_join_without_on(self): df1 = self.spark.range(1).toDF("a") df2 = self.spark.range(1).toDF("b") - try: - self.spark.conf.set("spark.sql.crossJoin.enabled", "false") + with self.sql_conf({"spark.sql.crossJoin.enabled": False}): self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect()) - self.spark.conf.set("spark.sql.crossJoin.enabled", "true") + with self.sql_conf({"spark.sql.crossJoin.enabled": True}): actual = df1.join(df2, how="inner").collect() expected = [Row(a=0, b=0)] self.assertEqual(actual, expected) - finally: - # We should unset this. Otherwise, other tests are affected. - self.spark.conf.unset("spark.sql.crossJoin.enabled") # Regression test for invalid join methods when on is None, Spark-14761 def test_invalid_join_method(self): @@ -2943,21 +2939,18 @@ def test_create_dateframe_from_pandas_with_dst(self): self.assertPandasEqual(pdf, df.toPandas()) orig_env_tz = os.environ.get('TZ', None) - orig_session_tz = self.spark.conf.get('spark.sql.session.timeZone') try: tz = 'America/Los_Angeles' os.environ['TZ'] = tz time.tzset() - self.spark.conf.set('spark.sql.session.timeZone', tz) - - df = self.spark.createDataFrame(pdf) - self.assertPandasEqual(pdf, df.toPandas()) + with self.sql_conf({'spark.sql.session.timeZone': tz}): + df = self.spark.createDataFrame(pdf) + self.assertPandasEqual(pdf, df.toPandas()) finally: del os.environ['TZ'] if orig_env_tz is not None: os.environ['TZ'] = orig_env_tz time.tzset() - self.spark.conf.set('spark.sql.session.timeZone', orig_session_tz) class HiveSparkSubmitTests(SparkSubmitTests): @@ -3562,12 +3555,11 @@ def test_null_conversion(self): self.assertTrue(all([c == 1 for c in null_counts])) def _toPandas_arrow_toggle(self, df): - self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") - try: + with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): pdf = df.toPandas() - finally: - self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + pdf_arrow = df.toPandas() + return pdf, pdf_arrow def test_toPandas_arrow_toggle(self): @@ -3579,16 +3571,17 @@ def test_toPandas_arrow_toggle(self): def test_toPandas_respect_session_timezone(self): df = self.spark.createDataFrame(self.data, schema=self.schema) - orig_tz = self.spark.conf.get("spark.sql.session.timeZone") - try: - timezone = "America/New_York" - self.spark.conf.set("spark.sql.session.timeZone", timezone) - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") - try: - pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) - self.assertPandasEqual(pdf_arrow_la, pdf_la) - finally: - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) + self.assertPandasEqual(pdf_arrow_la, pdf_la) + + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df) self.assertPandasEqual(pdf_arrow_ny, pdf_ny) @@ -3601,8 +3594,6 @@ def test_toPandas_respect_session_timezone(self): pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz( pdf_la_corrected[field.name], timezone) self.assertPandasEqual(pdf_ny, pdf_la_corrected) - finally: - self.spark.conf.set("spark.sql.session.timeZone", orig_tz) def test_pandas_round_trip(self): pdf = self.create_pandas_data_frame() @@ -3618,12 +3609,11 @@ def test_filtered_frame(self): self.assertTrue(pdf.empty) def _createDataFrame_toggle(self, pdf, schema=None): - self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") - try: + with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): df_no_arrow = self.spark.createDataFrame(pdf, schema=schema) - finally: - self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + df_arrow = self.spark.createDataFrame(pdf, schema=schema) + return df_no_arrow, df_arrow def test_createDataFrame_toggle(self): @@ -3634,18 +3624,18 @@ def test_createDataFrame_toggle(self): def test_createDataFrame_respect_session_timezone(self): from datetime import timedelta pdf = self.create_pandas_data_frame() - orig_tz = self.spark.conf.get("spark.sql.session.timeZone") - try: - timezone = "America/New_York" - self.spark.conf.set("spark.sql.session.timeZone", timezone) - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") - try: - df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema) - result_la = df_no_arrow_la.collect() - result_arrow_la = df_arrow_la.collect() - self.assertEqual(result_la, result_arrow_la) - finally: - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema) + result_la = df_no_arrow_la.collect() + result_arrow_la = df_arrow_la.collect() + self.assertEqual(result_la, result_arrow_la) + + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema) result_ny = df_no_arrow_ny.collect() result_arrow_ny = df_arrow_ny.collect() @@ -3658,8 +3648,6 @@ def test_createDataFrame_respect_session_timezone(self): for k, v in row.asDict().items()}) for row in result_la] self.assertEqual(result_ny, result_la_corrected) - finally: - self.spark.conf.set("spark.sql.session.timeZone", orig_tz) def test_createDataFrame_with_schema(self): pdf = self.create_pandas_data_frame() @@ -4336,9 +4324,7 @@ def gen_timestamps(id): def test_vectorized_udf_check_config(self): from pyspark.sql.functions import pandas_udf, col import pandas as pd - orig_value = self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch", None) - self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3) - try: + with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}): df = self.spark.range(10, numPartitions=1) @pandas_udf(returnType=LongType()) @@ -4348,11 +4334,6 @@ def check_records_per_batch(x): result = df.select(check_records_per_batch(col("id"))).collect() for (r,) in result: self.assertTrue(r <= 3) - finally: - if orig_value is None: - self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") - else: - self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value) def test_vectorized_udf_timestamps_respect_session_timezone(self): from pyspark.sql.functions import pandas_udf, col @@ -4371,30 +4352,27 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self): internal_value = pandas_udf( lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType()) - orig_tz = self.spark.conf.get("spark.sql.session.timeZone") - try: - timezone = "America/New_York" - self.spark.conf.set("spark.sql.session.timeZone", timezone) - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") - try: - df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ - .withColumn("internal_value", internal_value(col("timestamp"))) - result_la = df_la.select(col("idx"), col("internal_value")).collect() - # Correct result_la by adjusting 3 hours difference between Los Angeles and New York - diff = 3 * 60 * 60 * 1000 * 1000 * 1000 - result_la_corrected = \ - df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect() - finally: - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ + .withColumn("internal_value", internal_value(col("timestamp"))) + result_la = df_la.select(col("idx"), col("internal_value")).collect() + # Correct result_la by adjusting 3 hours difference between Los Angeles and New York + diff = 3 * 60 * 60 * 1000 * 1000 * 1000 + result_la_corrected = \ + df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect() + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ .withColumn("internal_value", internal_value(col("timestamp"))) result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect() self.assertNotEqual(result_ny, result_la) self.assertEqual(result_ny, result_la_corrected) - finally: - self.spark.conf.set("spark.sql.session.timeZone", orig_tz) def test_nondeterministic_vectorized_udf(self): # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations @@ -5170,9 +5148,7 @@ def test_complex_expressions(self): def test_retain_group_columns(self): from pyspark.sql.functions import sum, lit, col - orig_value = self.spark.conf.get("spark.sql.retainGroupColumns", None) - self.spark.conf.set("spark.sql.retainGroupColumns", False) - try: + with self.sql_conf({"spark.sql.retainGroupColumns": False}): df = self.data sum_udf = self.pandas_agg_sum_udf @@ -5180,12 +5156,6 @@ def test_retain_group_columns(self): expected1 = df.groupby(df.id).agg(sum(df.v)) self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - finally: - if orig_value is None: - self.spark.conf.unset("spark.sql.retainGroupColumns") - else: - self.spark.conf.set("spark.sql.retainGroupColumns", orig_value) - def test_invalid_args(self): from pyspark.sql.functions import mean From 5e7bc2acef4a1e11d0d8056ef5c12cd5c8f220da Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 20 Mar 2018 10:34:56 -0700 Subject: [PATCH 494/774] [SPARK-23649][SQL] Skipping chars disallowed in UTF-8 ## What changes were proposed in this pull request? The mapping of UTF-8 char's first byte to char's size doesn't cover whole range 0-255. It is defined only for 0-253: https://github.com/apache/spark/blob/master/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java#L60-L65 https://github.com/apache/spark/blob/master/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java#L190 If the first byte of a char is 253-255, IndexOutOfBoundsException is thrown. Besides of that values for 244-252 are not correct according to recent unicode standard for UTF-8: http://www.unicode.org/versions/Unicode10.0.0/UnicodeStandard-10.0.pdf As a consequence of the exception above, the length of input string in UTF-8 encoding cannot be calculated if the string contains chars started from 253 code. It is visible on user's side as for example crashing of schema inferring of csv file which contains such chars but the file can be read if the schema is specified explicitly or if the mode set to multiline. The proposed changes build correct mapping of first byte of UTF-8 char to its size (now it covers all cases) and skip disallowed chars (counts it as one octet). ## How was this patch tested? Added a test and a file with a char which is disallowed in UTF-8 - 0xFF. Author: Maxim Gekk Closes #20796 from MaxGekk/skip-wrong-utf8-chars. --- .../apache/spark/unsafe/types/UTF8String.java | 48 +++++++++++++++---- .../spark/unsafe/types/UTF8StringSuite.java | 23 ++++++++- 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index b0d0c44823e68..5d468aed42337 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -57,12 +57,43 @@ public final class UTF8String implements Comparable, Externalizable, public Object getBaseObject() { return base; } public long getBaseOffset() { return offset; } - private static int[] bytesOfCodePointInUTF8 = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4, - 5, 5, 5, 5, - 6, 6}; + /** + * A char in UTF-8 encoding can take 1-4 bytes depending on the first byte which + * indicates the size of the char. See Unicode standard in page 126, Table 3-6: + * http://www.unicode.org/versions/Unicode10.0.0/UnicodeStandard-10.0.pdf + * + * Binary Hex Comments + * 0xxxxxxx 0x00..0x7F Only byte of a 1-byte character encoding + * 10xxxxxx 0x80..0xBF Continuation bytes (1-3 continuation bytes) + * 110xxxxx 0xC0..0xDF First byte of a 2-byte character encoding + * 1110xxxx 0xE0..0xEF First byte of a 3-byte character encoding + * 11110xxx 0xF0..0xF7 First byte of a 4-byte character encoding + * + * As a consequence of the well-formedness conditions specified in + * Table 3-7 (page 126), the following byte values are disallowed in UTF-8: + * C0–C1, F5–FF. + */ + private static byte[] bytesOfCodePointInUTF8 = { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x00..0x0F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x10..0x1F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x20..0x2F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x30..0x3F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x40..0x4F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x50..0x5F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x60..0x6F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x70..0x7F + // Continuation bytes cannot appear as the first byte + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x80..0x8F + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x90..0x9F + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0xA0..0xAF + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0xB0..0xBF + 0, 0, // 0xC0..0xC1 - disallowed in UTF-8 + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // 0xC2..0xCF + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // 0xD0..0xDF + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, // 0xE0..0xEF + 4, 4, 4, 4, 4, // 0xF0..0xF4 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 // 0xF5..0xFF - disallowed in UTF-8 + }; private static final boolean IS_LITTLE_ENDIAN = ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN; @@ -187,8 +218,9 @@ public void writeTo(OutputStream out) throws IOException { * @param b The first byte of a code point */ private static int numBytesForFirstByte(final byte b) { - final int offset = (b & 0xFF) - 192; - return (offset >= 0) ? bytesOfCodePointInUTF8[offset] : 1; + final int offset = b & 0xFF; + byte numBytes = bytesOfCodePointInUTF8[offset]; + return (numBytes == 0) ? 1: numBytes; // Skip the first byte disallowed in UTF-8 } /** diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 9b303fa5bc6c5..7c34d419574ef 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -58,8 +58,12 @@ private static void checkBasic(String str, int len) { @Test public void basicTest() { checkBasic("", 0); - checkBasic("hello", 5); + checkBasic("¡", 1); // 2 bytes char + checkBasic("ку", 2); // 2 * 2 bytes chars + checkBasic("hello", 5); // 5 * 1 byte chars checkBasic("大 千 世 界", 7); + checkBasic("︽﹋%", 3); // 3 * 3 bytes chars + checkBasic("\uD83E\uDD19", 1); // 4 bytes char } @Test @@ -791,4 +795,21 @@ public void trimRightWithTrimString() { assertEquals(fromString("头"), fromString("头a???/").trimRight(fromString("数?/*&^%a"))); assertEquals(fromString("头"), fromString("头数b数数 [").trimRight(fromString(" []数b"))); } + + @Test + public void skipWrongFirstByte() { + int[] wrongFirstBytes = { + 0x80, 0x9F, 0xBF, // Skip Continuation bytes + 0xC0, 0xC2, // 0xC0..0xC1 - disallowed in UTF-8 + // 0xF5..0xFF - disallowed in UTF-8 + 0xF5, 0xF6, 0xF7, 0xF8, 0xF9, + 0xFA, 0xFB, 0xFC, 0xFD, 0xFE, 0xFF + }; + byte[] c = new byte[1]; + + for (int i = 0; i < wrongFirstBytes.length; ++i) { + c[0] = (byte)wrongFirstBytes[i]; + assertEquals(fromBytes(c).numChars(), 1); + } + } } From 7f5e8aa2606b0ee0297ceb6f4603bd368e3b0291 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 20 Mar 2018 11:14:34 -0700 Subject: [PATCH 495/774] [SPARK-21898][ML] Feature parity for KolmogorovSmirnovTest in MLlib ## What changes were proposed in this pull request? Feature parity for KolmogorovSmirnovTest in MLlib. Implement `DataFrame` interface for `KolmogorovSmirnovTest` in `mllib.stat`. ## How was this patch tested? Test suite added. Author: WeichenXu Author: jkbradley Closes #19108 from WeichenXu123/ml-ks-test. --- .../spark/ml/stat/KolmogorovSmirnovTest.scala | 113 ++++++++++++++ .../stat/JavaKolmogorovSmirnovTestSuite.java | 84 +++++++++++ .../ml/stat/KolmogorovSmirnovTestSuite.scala | 140 ++++++++++++++++++ 3 files changed, 337 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala create mode 100644 mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala new file mode 100644 index 0000000000000..8d80e7768cb6e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import scala.annotation.varargs + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.function.Function +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.mllib.stat.{Statistics => OldStatistics} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.col + +/** + * :: Experimental :: + * + * Conduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a + * continuous distribution. By comparing the largest difference between the empirical cumulative + * distribution of the sample data and the theoretical distribution we can provide a test for the + * the null hypothesis that the sample data comes from that theoretical distribution. + * For more information on KS Test: + * @see + * Kolmogorov-Smirnov test (Wikipedia) + */ +@Experimental +@Since("2.4.0") +object KolmogorovSmirnovTest { + + /** Used to construct output schema of test */ + private case class KolmogorovSmirnovTestResult( + pValue: Double, + statistic: Double) + + private def getSampleRDD(dataset: DataFrame, sampleCol: String): RDD[Double] = { + SchemaUtils.checkNumericType(dataset.schema, sampleCol) + import dataset.sparkSession.implicits._ + dataset.select(col(sampleCol).cast("double")).as[Double].rdd + } + + /** + * Conduct the two-sided Kolmogorov-Smirnov (KS) test for data sampled from a + * continuous distribution. By comparing the largest difference between the empirical cumulative + * distribution of the sample data and the theoretical distribution we can provide a test for the + * the null hypothesis that the sample data comes from that theoretical distribution. + * + * @param dataset a `DataFrame` containing the sample of data to test + * @param sampleCol Name of sample column in dataset, of any numerical type + * @param cdf a `Double => Double` function to calculate the theoretical CDF at a given value + * @return DataFrame containing the test result for the input sampled data. + * This DataFrame will contain a single Row with the following fields: + * - `pValue: Double` + * - `statistic: Double` + */ + @Since("2.4.0") + def test(dataset: DataFrame, sampleCol: String, cdf: Double => Double): DataFrame = { + val spark = dataset.sparkSession + + val rdd = getSampleRDD(dataset, sampleCol) + val testResult = OldStatistics.kolmogorovSmirnovTest(rdd, cdf) + spark.createDataFrame(Seq(KolmogorovSmirnovTestResult( + testResult.pValue, testResult.statistic))) + } + + /** + * Java-friendly version of `test(dataset: DataFrame, sampleCol: String, cdf: Double => Double)` + */ + @Since("2.4.0") + def test(dataset: DataFrame, sampleCol: String, + cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = { + test(dataset, sampleCol, (x: Double) => cdf.call(x)) + } + + /** + * Convenience function to conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability + * distribution equality. Currently supports the normal distribution, taking as parameters + * the mean and standard deviation. + * + * @param dataset a `DataFrame` containing the sample of data to test + * @param sampleCol Name of sample column in dataset, of any numerical type + * @param distName a `String` name for a theoretical distribution, currently only support "norm". + * @param params `Double*` specifying the parameters to be used for the theoretical distribution + * @return DataFrame containing the test result for the input sampled data. + * This DataFrame will contain a single Row with the following fields: + * - `pValue: Double` + * - `statistic: Double` + */ + @Since("2.4.0") + @varargs + def test(dataset: DataFrame, sampleCol: String, distName: String, params: Double*): DataFrame = { + val spark = dataset.sparkSession + + val rdd = getSampleRDD(dataset, sampleCol) + val testResult = OldStatistics.kolmogorovSmirnovTest(rdd, distName, params: _*) + spark.createDataFrame(Seq(KolmogorovSmirnovTestResult( + testResult.pValue, testResult.statistic))) + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java b/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java new file mode 100644 index 0000000000000..021272dd5a40c --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.junit.Test; + +import org.apache.spark.SharedSparkSession; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; + + +public class JavaKolmogorovSmirnovTestSuite extends SharedSparkSession { + + private transient Dataset dataset; + + @Override + public void setUp() throws IOException { + super.setUp(); + List points = Arrays.asList(0.1, 1.1, 10.1, -1.1); + + dataset = spark.createDataset(points, Encoders.DOUBLE()).toDF("sample"); + } + + @Test + public void testKSTestCDF() { + // Create theoretical distributions + NormalDistribution stdNormalDist = new NormalDistribution(0, 1); + + // set seeds + Long seed = 10L; + stdNormalDist.reseedRandomGenerator(seed); + Function stdNormalCDF = (x) -> stdNormalDist.cumulativeProbability(x); + + double pThreshold = 0.05; + + // Comparing a standard normal sample to a standard normal distribution + Row results = KolmogorovSmirnovTest + .test(dataset, "sample", stdNormalCDF).head(); + double pValue1 = results.getDouble(0); + // Cannot reject null hypothesis + assert(pValue1 > pThreshold); + } + + @Test + public void testKSTestNamedDistribution() { + double pThreshold = 0.05; + + // Comparing a standard normal sample to a standard normal distribution + Row results = KolmogorovSmirnovTest + .test(dataset, "sample", "norm", 0.0, 1.0).head(); + double pValue1 = results.getDouble(0); + // Cannot reject null hypothesis + assert(pValue1 > pThreshold); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala new file mode 100644 index 0000000000000..1312de3a1b522 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import org.apache.commons.math3.distribution.{ExponentialDistribution, NormalDistribution, + RealDistribution, UniformRealDistribution} +import org.apache.commons.math3.stat.inference.{KolmogorovSmirnovTest => Math3KSTest} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row + +class KolmogorovSmirnovTestSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + + def apacheCommonMath3EquivalenceTest( + sampleDist: RealDistribution, + theoreticalDist: RealDistribution, + theoreticalDistByName: (String, Array[Double]), + rejectNullHypothesis: Boolean): Unit = { + + // set seeds + val seed = 10L + sampleDist.reseedRandomGenerator(seed) + if (theoreticalDist != null) { + theoreticalDist.reseedRandomGenerator(seed) + } + + // Sample data from the distributions and parallelize it + val n = 100000 + val sampledArray = sampleDist.sample(n) + val sampledDF = sc.parallelize(sampledArray, 10).toDF("sample") + + // Use a apache math commons local KS test to verify calculations + val ksTest = new Math3KSTest() + val pThreshold = 0.05 + + // Comparing a standard normal sample to a standard normal distribution + val Row(pValue1: Double, statistic1: Double) = + if (theoreticalDist != null) { + val cdf = (x: Double) => theoreticalDist.cumulativeProbability(x) + KolmogorovSmirnovTest.test(sampledDF, "sample", cdf).head() + } else { + KolmogorovSmirnovTest.test(sampledDF, "sample", + theoreticalDistByName._1, + theoreticalDistByName._2: _* + ).head() + } + val theoreticalDistMath3 = if (theoreticalDist == null) { + assert(theoreticalDistByName._1 == "norm") + val params = theoreticalDistByName._2 + new NormalDistribution(params(0), params(1)) + } else { + theoreticalDist + } + val referenceStat1 = ksTest.kolmogorovSmirnovStatistic(theoreticalDistMath3, sampledArray) + val referencePVal1 = 1 - ksTest.cdf(referenceStat1, n) + // Verify vs apache math commons ks test + assert(statistic1 ~== referenceStat1 relTol 1e-4) + assert(pValue1 ~== referencePVal1 relTol 1e-4) + + if (rejectNullHypothesis) { + assert(pValue1 < pThreshold) + } else { + assert(pValue1 > pThreshold) + } + } + + test("1 sample Kolmogorov-Smirnov test: apache commons math3 implementation equivalence") { + // Create theoretical distributions + val stdNormalDist = new NormalDistribution(0.0, 1.0) + val expDist = new ExponentialDistribution(0.6) + val uniformDist = new UniformRealDistribution(0.0, 1.0) + val expDist2 = new ExponentialDistribution(0.2) + val stdNormByName = Tuple2("norm", Array(0.0, 1.0)) + + apacheCommonMath3EquivalenceTest(stdNormalDist, null, stdNormByName, false) + apacheCommonMath3EquivalenceTest(expDist, null, stdNormByName, true) + apacheCommonMath3EquivalenceTest(uniformDist, null, stdNormByName, true) + apacheCommonMath3EquivalenceTest(expDist, expDist2, null, true) + } + + test("1 sample Kolmogorov-Smirnov test: R implementation equivalence") { + /* + Comparing results with R's implementation of Kolmogorov-Smirnov for 1 sample + > sessionInfo() + R version 3.2.0 (2015-04-16) + Platform: x86_64-apple-darwin13.4.0 (64-bit) + > set.seed(20) + > v <- rnorm(20) + > v + [1] 1.16268529 -0.58592447 1.78546500 -1.33259371 -0.44656677 0.56960612 + [7] -2.88971761 -0.86901834 -0.46170268 -0.55554091 -0.02013537 -0.15038222 + [13] -0.62812676 1.32322085 -1.52135057 -0.43742787 0.97057758 0.02822264 + [19] -0.08578219 0.38921440 + > ks.test(v, pnorm, alternative = "two.sided") + + One-sample Kolmogorov-Smirnov test + + data: v + D = 0.18874, p-value = 0.4223 + alternative hypothesis: two-sided + */ + + val rKSStat = 0.18874 + val rKSPVal = 0.4223 + val rData = sc.parallelize( + Array( + 1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501, + -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555, + -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063, + -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691, + 0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942 + ) + ).toDF("sample") + val Row(pValue: Double, statistic: Double) = KolmogorovSmirnovTest + .test(rData, "sample", "norm", 0, 1).head() + assert(statistic ~== rKSStat relTol 1e-4) + assert(pValue ~== rKSPVal relTol 1e-4) + } +} From 2c4b9962fdf8c1beb66126ca41628c72eb6c2383 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 20 Mar 2018 11:46:51 -0700 Subject: [PATCH 496/774] [SPARK-23574][SQL] Report SinglePartition in DataSourceV2ScanExec when there's exactly 1 data reader factory. ## What changes were proposed in this pull request? Report SinglePartition in DataSourceV2ScanExec when there's exactly 1 data reader factory. Note that this means reader factories end up being constructed as partitioning is checked; let me know if you think that could be a problem. ## How was this patch tested? existing unit tests Author: Jose Torres Author: Jose Torres Closes #20726 from jose-torres/SPARK-23574. --- .../v2/reader/SupportsReportPartitioning.java | 3 ++ .../datasources/v2/DataSourceRDD.scala | 4 +-- .../datasources/v2/DataSourceV2ScanExec.scala | 29 ++++++++++++++----- .../ContinuousDataSourceRDDIter.scala | 4 +-- .../sql/sources/v2/DataSourceV2Suite.scala | 20 ++++++++++++- .../sql/streaming/StreamingQuerySuite.scala | 4 +-- 6 files changed, 50 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index 5405a916951b8..607628746e873 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -23,6 +23,9 @@ /** * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report data partitioning and try to avoid shuffle at Spark side. + * + * Note that, when the reader creates exactly one {@link DataReaderFactory}, Spark may avoid + * adding a shuffle even if the reader does not implement this interface. */ @InterfaceStability.Evolving public interface SupportsReportPartitioning extends DataSourceReader { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 5ed0ba71e94c7..f85971be394b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -29,11 +29,11 @@ class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: Da class DataSourceRDD[T: ClassTag]( sc: SparkContext, - @transient private val readerFactories: java.util.List[DataReaderFactory[T]]) + @transient private val readerFactories: Seq[DataReaderFactory[T]]) extends RDD[T](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readerFactories.asScala.zipWithIndex.map { + readerFactories.zipWithIndex.map { case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index cb691ba297076..3a5e7bf89e142 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -25,12 +25,14 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch /** * Physical plan node for scanning data from a data source. @@ -56,6 +58,15 @@ case class DataSourceV2ScanExec( } override def outputPartitioning: physical.Partitioning = reader match { + case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchReaderFactories.size == 1 => + SinglePartition + + case r: SupportsScanColumnarBatch if !r.enableBatchRead() && readerFactories.size == 1 => + SinglePartition + + case r if !r.isInstanceOf[SupportsScanColumnarBatch] && readerFactories.size == 1 => + SinglePartition + case s: SupportsReportPartitioning => new DataSourcePartitioning( s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name))) @@ -63,29 +74,33 @@ case class DataSourceV2ScanExec( case _ => super.outputPartitioning } - private lazy val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories() + private lazy val readerFactories: Seq[DataReaderFactory[UnsafeRow]] = reader match { + case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories().asScala case _ => reader.createDataReaderFactories().asScala.map { new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow] - }.asJava + } } - private lazy val inputRDD: RDD[InternalRow] = reader match { + private lazy val batchReaderFactories: Seq[DataReaderFactory[ColumnarBatch]] = reader match { case r: SupportsScanColumnarBatch if r.enableBatchRead() => assert(!reader.isInstanceOf[ContinuousReader], "continuous stream reader does not support columnar read yet.") - new DataSourceRDD(sparkContext, r.createBatchDataReaderFactories()) - .asInstanceOf[RDD[InternalRow]] + r.createBatchDataReaderFactories().asScala + } + private lazy val inputRDD: RDD[InternalRow] = reader match { case _: ContinuousReader => EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) - .askSync[Unit](SetReaderPartitions(readerFactories.size())) + .askSync[Unit](SetReaderPartitions(readerFactories.size)) new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories) .asInstanceOf[RDD[InternalRow]] + case r: SupportsScanColumnarBatch if r.enableBatchRead() => + new DataSourceRDD(sparkContext, batchReaderFactories).asInstanceOf[RDD[InternalRow]] + case _ => new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index cf02c0dda25d7..06754f01657d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -35,14 +35,14 @@ import org.apache.spark.util.ThreadUtils class ContinuousDataSourceRDD( sc: SparkContext, sqlContext: SQLContext, - @transient private val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]]) + @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]]) extends RDD[UnsafeRow](sc, Nil) { private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs override protected def getPartitions: Array[Partition] = { - readerFactories.asScala.zipWithIndex.map { + readerFactories.zipWithIndex.map { case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) }.toArray } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 1157a350461d8..e0a53272cd222 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec} -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.functions._ import org.apache.spark.sql.sources.{Filter, GreaterThan} @@ -191,6 +191,11 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } + test("SPARK-23574: no shuffle exchange with single partition") { + val df = spark.read.format(classOf[SimpleSinglePartitionSource].getName).load().agg(count("*")) + assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.isEmpty) + } + test("simple writable data source") { // TODO: java implementation. Seq(classOf[SimpleWritableDataSource]).foreach { cls => @@ -336,6 +341,19 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } +class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceReader { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5)) + } + } + + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader +} + class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 3f9aa0d1fa5be..ebc9a87b23f84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -326,9 +326,9 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.durationMs.get("setOffsetRange") === 50) assert(progress.durationMs.get("getEndOffset") === 100) - assert(progress.durationMs.get("queryPlanning") === 0) + assert(progress.durationMs.get("queryPlanning") === 200) assert(progress.durationMs.get("walCommit") === 0) - assert(progress.durationMs.get("addBatch") === 350) + assert(progress.durationMs.get("addBatch") === 150) assert(progress.durationMs.get("triggerExecution") === 500) assert(progress.sources.length === 1) From 477d6bd7265e255fd43e53edda02019b32f29bb2 Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Tue, 20 Mar 2018 13:27:50 -0700 Subject: [PATCH 497/774] [SPARK-23500][SQL] Fix complex type simplification rules to apply to entire plan ## What changes were proposed in this pull request? Complex type simplification optimizer rules were not applied to the entire plan, just the expressions reachable from the root node. This patch fixes the rules to transform the entire plan. ## How was this patch tested? New unit test + ran sql / core tests. Author: Henry Robinson Author: Henry Robinson Closes #20687 from henryr/spark-25000. --- .../sql/catalyst/optimizer/ComplexTypes.scala | 61 ++++++++----------- .../sql/catalyst/optimizer/Optimizer.scala | 4 +- .../optimizer/complexTypesSuite.scala | 55 +++++++++++++++-- 3 files changed, 76 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index be0009ec8c760..db7d6d3254bd2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -18,39 +18,39 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule /** -* push down operations into [[CreateNamedStructLike]]. -*/ -object SimplifyCreateStructOps extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = { - plan.transformExpressionsUp { - // push down field extraction + * Simplify redundant [[CreateNamedStructLike]], [[CreateArray]] and [[CreateMap]] expressions. + */ +object SimplifyExtractValueOps extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // One place where this optimization is invalid is an aggregation where the select + // list expression is a function of a grouping expression: + // + // SELECT struct(a,b).a FROM tbl GROUP BY struct(a,b) + // + // cannot be simplified to SELECT a FROM tbl GROUP BY struct(a,b). So just skip this + // optimization for Aggregates (although this misses some cases where the optimization + // can be made). + case a: Aggregate => a + case p => p.transformExpressionsUp { + // Remove redundant field extraction. case GetStructField(createNamedStructLike: CreateNamedStructLike, ordinal, _) => createNamedStructLike.valExprs(ordinal) - } - } -} -/** -* push down operations into [[CreateArray]]. -*/ -object SimplifyCreateArrayOps extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = { - plan.transformExpressionsUp { - // push down field selection (array of structs) - case GetArrayStructFields(CreateArray(elems), field, ordinal, numFields, containsNull) => - // instead f selecting the field on the entire array, - // select it from each member of the array. - // pushing down the operation this way open other optimizations opportunities - // (i.e. struct(...,x,...).x) + // Remove redundant array indexing. + case GetArrayStructFields(CreateArray(elems), field, ordinal, _, _) => + // Instead of selecting the field on the entire array, select it from each member + // of the array. Pushing down the operation this way may open other optimizations + // opportunities (i.e. struct(...,x,...).x) CreateArray(elems.map(GetStructField(_, ordinal, Some(field.name)))) - // push down item selection. + + // Remove redundant map lookup. case ga @ GetArrayItem(CreateArray(elems), IntegerLiteral(idx)) => - // instead of creating the array and then selecting one row, - // remove array creation altgether. + // Instead of creating the array and then selecting one row, remove array creation + // altogether. if (idx >= 0 && idx < elems.size) { // valid index elems(idx) @@ -58,18 +58,7 @@ object SimplifyCreateArrayOps extends Rule[LogicalPlan] { // out of bounds, mimic the runtime behavior and return null Literal(null, ga.dataType) } - } - } -} - -/** -* push down operations into [[CreateMap]]. -*/ -object SimplifyCreateMapOps extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = { - plan.transformExpressionsUp { case GetMapValue(CreateMap(elems), key) => CaseKeyWhen(key, elems) } } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 91208479be03b..2829d1d81eb1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -85,9 +85,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) EliminateSerialization, RemoveRedundantAliases, RemoveRedundantProject, - SimplifyCreateStructOps, - SimplifyCreateArrayOps, - SimplifyCreateMapOps, + SimplifyExtractValueOps, CombineConcats) ++ extendedOperatorOptimizationRules diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index de544ac314789..e44a6692ad8e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -44,14 +44,13 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { BooleanSimplification, SimplifyConditionals, SimplifyBinaryComparison, - SimplifyCreateStructOps, - SimplifyCreateArrayOps, - SimplifyCreateMapOps) :: Nil + SimplifyExtractValueOps) :: Nil } val idAtt = ('id).long.notNull + val nullableIdAtt = ('nullable_id).long - lazy val relation = LocalRelation(idAtt ) + lazy val relation = LocalRelation(idAtt, nullableIdAtt) test("explicit get from namedStruct") { val query = relation @@ -321,7 +320,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select( CaseWhen(Seq( (EqualTo(2L, 'id), ('id + 1L)), - // these two are possible matches, we can't tell untill runtime + // these two are possible matches, we can't tell until runtime (EqualTo(2L, ('id + 1L)), ('id + 2L)), (EqualTo(2L, 'id + 2L), Literal.create(null, LongType)), // this is a definite match (two constants), @@ -331,4 +330,50 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .analyze comparePlans(Optimizer execute rel, expected) } + + test("SPARK-23500: Simplify complex ops that aren't at the plan root") { + val structRel = relation + .select(GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None) as "foo") + .groupBy($"foo")("1").analyze + val structExpected = relation + .select('nullable_id as "foo") + .groupBy($"foo")("1").analyze + comparePlans(Optimizer execute structRel, structExpected) + + // These tests must use nullable attributes from the base relation for the following reason: + // in the 'original' plans below, the Aggregate node produced by groupBy() has a + // nullable AttributeReference to a1, because both array indexing and map lookup are + // nullable expressions. After optimization, the same attribute is now non-nullable, + // but the AttributeReference is not updated to reflect this. In the 'expected' plans, + // the grouping expressions have the same nullability as the original attribute in the + // relation. If that attribute is non-nullable, the tests will fail as the plans will + // compare differently, so for these tests we must use a nullable attribute. See + // SPARK-23634. + val arrayRel = relation + .select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1") + .groupBy($"a1")("1").analyze + val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1").analyze + comparePlans(Optimizer execute arrayRel, arrayExpected) + + val mapRel = relation + .select(GetMapValue(CreateMap(Seq("id", 'nullable_id)), "id") as "m1") + .groupBy($"m1")("1").analyze + val mapExpected = relation + .select('nullable_id as "m1") + .groupBy($"m1")("1").analyze + comparePlans(Optimizer execute mapRel, mapExpected) + } + + test("SPARK-23500: Ensure that aggregation expressions are not simplified") { + // Make sure that aggregation exprs are correctly ignored. Maps can't be used in + // grouping exprs so aren't tested here. + val structAggRel = relation.groupBy( + CreateNamedStruct(Seq("att1", 'nullable_id)))( + GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)).analyze + comparePlans(Optimizer execute structAggRel, structAggRel) + + val arrayAggRel = relation.groupBy( + CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)).analyze + comparePlans(Optimizer execute arrayAggRel, arrayAggRel) + } } From 983e8d9d64b6b1304c43ea6e5dffdc1078138ef9 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 20 Mar 2018 23:17:49 -0700 Subject: [PATCH 498/774] [SPARK-23666][SQL] Do not display exprIds of Alias in user-facing info. ## What changes were proposed in this pull request? To drop `exprId`s for `Alias` in user-facing info., this pr added an entry for `Alias` in `NonSQLExpression.sql` ## How was this patch tested? Added tests in `UDFSuite`. Author: Takeshi Yamamuro Closes #20827 from maropu/SPARK-23666. --- docs/sql-programming-guide.md | 1 + .../sql/catalyst/expressions/Expression.scala | 1 + .../scala/org/apache/spark/sql/UDFSuite.scala | 132 ++++++++++-------- 3 files changed, 78 insertions(+), 56 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 0e092e0e37ccf..5b47fd77f2cbc 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1806,6 +1806,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. + - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index d7f9e38915dd5..38caf67d465d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -288,6 +288,7 @@ trait NonSQLExpression extends Expression { final override def sql: String = { transform { case a: Attribute => new PrettyAttribute(a) + case a: Alias => PrettyAttribute(a.sql, a.dataType) }.toString } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index af6a10b425b9f..21afdc7e2a33f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -144,73 +144,81 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDF in a WHERE") { - spark.udf.register("oneArgFilter", (n: Int) => { n > 80 }) + withTempView("integerData") { + spark.udf.register("oneArgFilter", (n: Int) => { n > 80 }) - val df = sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))).toDF() - df.createOrReplaceTempView("integerData") + val df = sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDF() + df.createOrReplaceTempView("integerData") - val result = - sql("SELECT * FROM integerData WHERE oneArgFilter(key)") - assert(result.count() === 20) + val result = + sql("SELECT * FROM integerData WHERE oneArgFilter(key)") + assert(result.count() === 20) + } } test("UDF in a HAVING") { - spark.udf.register("havingFilter", (n: Long) => { n > 5 }) - - val df = Seq(("red", 1), ("red", 2), ("blue", 10), - ("green", 100), ("green", 200)).toDF("g", "v") - df.createOrReplaceTempView("groupData") - - val result = - sql( - """ - | SELECT g, SUM(v) as s - | FROM groupData - | GROUP BY g - | HAVING havingFilter(s) - """.stripMargin) - - assert(result.count() === 2) + withTempView("groupData") { + spark.udf.register("havingFilter", (n: Long) => { n > 5 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.createOrReplaceTempView("groupData") + + val result = + sql( + """ + | SELECT g, SUM(v) as s + | FROM groupData + | GROUP BY g + | HAVING havingFilter(s) + """.stripMargin) + + assert(result.count() === 2) + } } test("UDF in a GROUP BY") { - spark.udf.register("groupFunction", (n: Int) => { n > 10 }) - - val df = Seq(("red", 1), ("red", 2), ("blue", 10), - ("green", 100), ("green", 200)).toDF("g", "v") - df.createOrReplaceTempView("groupData") - - val result = - sql( - """ - | SELECT SUM(v) - | FROM groupData - | GROUP BY groupFunction(v) - """.stripMargin) - assert(result.count() === 2) + withTempView("groupData") { + spark.udf.register("groupFunction", (n: Int) => { n > 10 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.createOrReplaceTempView("groupData") + + val result = + sql( + """ + | SELECT SUM(v) + | FROM groupData + | GROUP BY groupFunction(v) + """.stripMargin) + assert(result.count() === 2) + } } test("UDFs everywhere") { - spark.udf.register("groupFunction", (n: Int) => { n > 10 }) - spark.udf.register("havingFilter", (n: Long) => { n > 2000 }) - spark.udf.register("whereFilter", (n: Int) => { n < 150 }) - spark.udf.register("timesHundred", (n: Long) => { n * 100 }) - - val df = Seq(("red", 1), ("red", 2), ("blue", 10), - ("green", 100), ("green", 200)).toDF("g", "v") - df.createOrReplaceTempView("groupData") - - val result = - sql( - """ - | SELECT timesHundred(SUM(v)) as v100 - | FROM groupData - | WHERE whereFilter(v) - | GROUP BY groupFunction(v) - | HAVING havingFilter(v100) - """.stripMargin) - assert(result.count() === 1) + withTempView("groupData") { + spark.udf.register("groupFunction", (n: Int) => { n > 10 }) + spark.udf.register("havingFilter", (n: Long) => { n > 2000 }) + spark.udf.register("whereFilter", (n: Int) => { n < 150 }) + spark.udf.register("timesHundred", (n: Long) => { n * 100 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.createOrReplaceTempView("groupData") + + val result = + sql( + """ + | SELECT timesHundred(SUM(v)) as v100 + | FROM groupData + | WHERE whereFilter(v) + | GROUP BY groupFunction(v) + | HAVING havingFilter(v100) + """.stripMargin) + assert(result.count() === 1) + } } test("struct UDF") { @@ -304,4 +312,16 @@ class UDFSuite extends QueryTest with SharedSQLContext { assert(explainStr(spark.range(1).select(udf1(udf2(functions.lit(1))))) .contains(s"UDF:$udf1Name(UDF:$udf2Name(1))")) } + + test("SPARK-23666 Do not display exprId in argument names") { + withTempView("x") { + Seq(((1, 2), 3)).toDF("a", "b").createOrReplaceTempView("x") + spark.udf.register("f", (a: Int) => a) + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + spark.sql("SELECT f(a._1) FROM x").show + } + assert(outputStream.toString.contains("UDF:f(a._1 AS `_1`)")) + } + } } From 500b21c3d6247015e550be7e144e9b4b26fe28be Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 21 Mar 2018 10:19:02 -0500 Subject: [PATCH 499/774] [SPARK-23568][ML] Use metadata numAttributes if available in Silhouette ## What changes were proposed in this pull request? Silhouette need to know the number of features. This was taken using `first` and checking the size of the vector. Despite this works fine, if the number of attributes is present in metadata, we can avoid to trigger a job for this and use the metadata value. This can help improving performances of course. ## How was this patch tested? existing UTs + added UT Author: Marco Gaido Closes #20719 from mgaido91/SPARK-23568. --- .../ml/evaluation/ClusteringEvaluator.scala | 22 ++++++++++++++---- .../evaluation/ClusteringEvaluatorSuite.scala | 23 ++++++++++++++++++- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala index 8d4ae562b3d2b..4353c46781e9d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} @@ -170,6 +171,15 @@ private[evaluation] abstract class Silhouette { def overallScore(df: DataFrame, scoreColumn: Column): Double = { df.select(avg(scoreColumn)).collect()(0).getDouble(0) } + + protected def getNumberOfFeatures(dataFrame: DataFrame, columnName: String): Int = { + val group = AttributeGroup.fromStructField(dataFrame.schema(columnName)) + if (group.size < 0) { + dataFrame.select(col(columnName)).first().getAs[Vector](0).size + } else { + group.size + } + } } /** @@ -360,7 +370,7 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette { df: DataFrame, predictionCol: String, featuresCol: String): Map[Double, ClusterStats] = { - val numFeatures = df.select(col(featuresCol)).first().getAs[Vector](0).size + val numFeatures = getNumberOfFeatures(df, featuresCol) val clustersStatsRDD = df.select( col(predictionCol).cast(DoubleType), col(featuresCol), col("squaredNorm")) .rdd @@ -552,8 +562,11 @@ private[evaluation] object CosineSilhouette extends Silhouette { * @return A [[scala.collection.immutable.Map]] which associates each cluster id to a * its statistics (ie. the precomputed values `N` and `$\Omega_{\Gamma}$`). */ - def computeClusterStats(df: DataFrame, predictionCol: String): Map[Double, (Vector, Long)] = { - val numFeatures = df.select(col(normalizedFeaturesColName)).first().getAs[Vector](0).size + def computeClusterStats( + df: DataFrame, + featuresCol: String, + predictionCol: String): Map[Double, (Vector, Long)] = { + val numFeatures = getNumberOfFeatures(df, featuresCol) val clustersStatsRDD = df.select( col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName)) .rdd @@ -626,7 +639,8 @@ private[evaluation] object CosineSilhouette extends Silhouette { normalizeFeatureUDF(col(featuresCol))) // compute aggregate values for clusters needed by the algorithm - val clustersStatsMap = computeClusterStats(dfWithNormalizedFeatures, predictionCol) + val clustersStatsMap = computeClusterStats(dfWithNormalizedFeatures, featuresCol, + predictionCol) // Silhouette is reasonable only when the number of clusters is greater then 1 assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.") diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala index 3bf34770f5687..2c175ff68e0b8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.ml.evaluation -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.TestingUtils._ @@ -100,4 +102,23 @@ class ClusteringEvaluatorSuite } } + test("SPARK-23568: we should use metadata to determine features number") { + val attributesNum = irisDataset.select("features").rdd.first().getAs[Vector](0).size + val attrGroup = new AttributeGroup("features", attributesNum) + val df = irisDataset.select($"features".as("features", attrGroup.toMetadata()), $"label") + require(AttributeGroup.fromStructField(df.schema("features")) + .numAttributes.isDefined, "numAttributes metadata should be defined") + val evaluator = new ClusteringEvaluator() + .setFeaturesCol("features") + .setPredictionCol("label") + + // with the proper metadata we compute correctly the result + assert(evaluator.evaluate(df) ~== 0.6564679231 relTol 1e-5) + + val wrongAttrGroup = new AttributeGroup("features", attributesNum + 1) + val dfWrong = irisDataset.select($"features".as("features", wrongAttrGroup.toMetadata()), + $"label") + // with wrong metadata the evaluator throws an Exception + intercept[SparkException](evaluator.evaluate(dfWrong)) + } } From bf09f2f71276d3b3a84a8f89109bd785a066c3e6 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 21 Mar 2018 09:39:14 -0700 Subject: [PATCH 500/774] [SPARK-10884][ML] Support prediction on single instance for regression and classification related models ## What changes were proposed in this pull request? Support prediction on single instance for regression and classification related models (i.e., PredictionModel, ClassificationModel and their sub classes). Add corresponding test cases. ## How was this patch tested? Test cases added. Author: WeichenXu Closes #19381 from WeichenXu123/single_prediction. --- .../scala/org/apache/spark/ml/Predictor.scala | 5 ++-- .../spark/ml/classification/Classifier.scala | 6 ++--- .../DecisionTreeClassifier.scala | 2 +- .../ml/classification/GBTClassifier.scala | 2 +- .../spark/ml/classification/LinearSVC.scala | 2 +- .../classification/LogisticRegression.scala | 2 +- .../MultilayerPerceptronClassifier.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../spark/ml/regression/GBTRegressor.scala | 2 +- .../GeneralizedLinearRegression.scala | 2 +- .../ml/regression/LinearRegression.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 2 +- .../DecisionTreeClassifierSuite.scala | 17 ++++++++++++- .../classification/GBTClassifierSuite.scala | 9 +++++++ .../ml/classification/LinearSVCSuite.scala | 6 +++++ .../LogisticRegressionSuite.scala | 9 +++++++ .../MultilayerPerceptronClassifierSuite.scala | 12 ++++++++++ .../ml/classification/NaiveBayesSuite.scala | 22 +++++++++++++++++ .../RandomForestClassifierSuite.scala | 16 +++++++++++++ .../DecisionTreeRegressorSuite.scala | 15 ++++++++++++ .../ml/regression/GBTRegressorSuite.scala | 8 +++++++ .../GeneralizedLinearRegressionSuite.scala | 8 +++++++ .../ml/regression/LinearRegressionSuite.scala | 7 ++++++ .../RandomForestRegressorSuite.scala | 24 +++++++++++++++---- .../org/apache/spark/ml/util/MLTest.scala | 15 ++++++++++-- 25 files changed, 176 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 08b0cb9b8f6a5..d8f3dfa874439 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -219,7 +219,8 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, /** * Predict label for the given features. - * This internal method is used to implement `transform()` and output [[predictionCol]]. + * This method is used to implement `transform()` and output [[predictionCol]]. */ - protected def predict(features: FeaturesType): Double + @Since("2.4.0") + def predict(features: FeaturesType): Double } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 9d1d5aa1e0cff..7e5790ab70ee9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkException -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, VectorUDT} @@ -192,12 +192,12 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur /** * Predict label for the given features. - * This internal method is used to implement `transform()` and output [[predictionCol]]. + * This method is used to implement `transform()` and output [[predictionCol]]. * * This default implementation for classification predicts the index of the maximum value * from `predictRaw()`. */ - override protected def predict(features: FeaturesType): Double = { + override def predict(features: FeaturesType): Double = { raw2prediction(predictRaw(features)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 9f60f0896ec52..65cce697d8202 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -181,7 +181,7 @@ class DecisionTreeClassificationModel private[ml] ( private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses) - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { rootNode.predictImpl(features).prediction } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index f11bc1d8fe415..cd44489f618b2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -267,7 +267,7 @@ class GBTClassificationModel private[ml]( dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { // If thresholds defined, use predictRaw to get probabilities, otherwise use optimization if (isDefined(thresholds)) { super.predict(features) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index ce400f4f1faf7..8f950cd28c3aa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -316,7 +316,7 @@ class LinearSVCModel private[classification] ( BLAS.dot(features, coefficients) + intercept } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { if (margin(features) > $(threshold)) 1.0 else 0.0 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index fa191604218db..3ae4db3f3f965 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1090,7 +1090,7 @@ class LogisticRegressionModel private[spark] ( * Predict label for the given feature vector. * The behavior of this can be adjusted using `thresholds`. */ - override protected def predict(features: Vector): Double = if (isMultinomial) { + override def predict(features: Vector): Double = if (isMultinomial) { super.predict(features) } else { // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index fd4c98f22132f..af2e4699924e5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -322,7 +322,7 @@ class MultilayerPerceptronClassificationModel private[ml] ( * Predict label for the given features. * This internal method is used to implement `transform()` and output [[predictionCol]]. */ - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { LabelConverter.decodeLabel(mlpModel.predict(features)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 0291a57487c47..ad154fcd010cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -178,7 +178,7 @@ class DecisionTreeRegressionModel private[ml] ( private[ml] def this(rootNode: Node, numFeatures: Int) = this(Identifiable.randomUID("dtr"), rootNode, numFeatures) - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { rootNode.predictImpl(features).prediction } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index f41d15b62dddd..6569ff2a5bfc1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -230,7 +230,7 @@ class GBTRegressionModel private[ml]( dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 917a4d238d467..9f1f2405c428e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1010,7 +1010,7 @@ class GeneralizedLinearRegressionModel private[ml] ( private lazy val familyAndLink = FamilyAndLink(this) - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { predict(features, 0.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 6d3fe7a6c748c..92510154d500e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -699,7 +699,7 @@ class LinearRegressionModel private[ml] ( } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { dot(features, coefficients) + intercept } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 200b234b79978..2d594460c2475 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -199,7 +199,7 @@ class RandomForestRegressionModel private[ml] ( dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 // Predict average of tree predictions. // Ignore the weights since all are 1.0 for now. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index eeb0324187c5b..2930f4900d50e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode} +import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -264,6 +264,21 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { Vector, DecisionTreeClassificationModel](this, newTree, newData) } + test("prediction on single instance") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3) + val numClasses = 3 + + val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val newTree = dt.fit(newData) + + testPredictionModelSinglePrediction(newTree, newData) + } + test("training with 1-category categorical feature") { val data = sc.parallelize(Seq( LabeledPoint(0, Vectors.dense(0, 2, 3)), diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 092b4a01d5b0d..57796069f6052 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -197,6 +197,15 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { Vector, GBTClassificationModel](this, gbtModel, validationDataset) } + test("prediction on single instance") { + + val gbt = new GBTClassifier().setSeed(123) + val trainingDataset = trainData.toDF("label", "features") + val gbtModel = gbt.fit(trainingDataset) + + testPredictionModelSinglePrediction(gbtModel, trainingDataset) + } + test("GBT parameter stepSize should be in interval (0, 1]") { withClue("GBT parameter stepSize should be in interval (0, 1]") { intercept[IllegalArgumentException] { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index a93825b8a812d..c05c896df5cb1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -201,6 +201,12 @@ class LinearSVCSuite extends MLTest with DefaultReadWriteTest { dataset.as[LabeledPoint], estimator, modelEquals, 42L) } + test("prediction on single instance") { + val trainer = new LinearSVC() + val model = trainer.fit(smallBinaryDataset) + testPredictionModelSinglePrediction(model, smallBinaryDataset) + } + test("linearSVC comparison with R e1071 and scikit-learn") { val trainer1 = new LinearSVC() .setRegParam(0.00002) // set regParam = 2.0 / datasize / c diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 9987cbf6ba116..36b7e51f93d01 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -499,6 +499,15 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { Vector, LogisticRegressionModel](this, model, smallBinaryDataset) } + test("prediction on single instance") { + val blor = new LogisticRegression().setFamily("binomial") + val blorModel = blor.fit(smallBinaryDataset) + testPredictionModelSinglePrediction(blorModel, smallBinaryDataset) + val mlor = new LogisticRegression().setFamily("multinomial") + val mlorModel = mlor.fit(smallMultinomialDataset) + testPredictionModelSinglePrediction(mlorModel, smallMultinomialDataset) + } + test("coefficients and intercept methods") { val mlr = new LogisticRegression().setMaxIter(1).setFamily("multinomial") val mlrModel = mlr.fit(smallMultinomialDataset) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index daa58a56896d7..6b5fe6e49ffea 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -76,6 +76,18 @@ class MultilayerPerceptronClassifierSuite extends MLTest with DefaultReadWriteTe } } + test("prediction on single instance") { + val layers = Array[Int](2, 5, 2) + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(1) + .setSeed(123L) + .setMaxIter(100) + .setSolver("l-bfgs") + val model = trainer.fit(dataset) + testPredictionModelSinglePrediction(model, dataset) + } + test("Predicted class probabilities: calibration on toy dataset") { val layers = Array[Int](4, 5, 2) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 49115c8a4db30..5f9ab98a2c3ce 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -167,6 +167,28 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest { Vector, NaiveBayesModel](this, model, testDataset) } + test("prediction on single instance") { + val nPoints = 1000 + val piArray = Array(0.5, 0.1, 0.4).map(math.log) + val thetaArray = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) + val pi = Vectors.dense(piArray) + val theta = new DenseMatrix(3, 4, thetaArray.flatten, true) + + val trainDataset = + generateNaiveBayesInput(piArray, thetaArray, nPoints, seed, "multinomial").toDF() + val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial") + val model = nb.fit(trainDataset) + + val validationDataset = + generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF() + + testPredictionModelSinglePrediction(model, validationDataset) + } + test("Naive Bayes with weighted samples") { val numClasses = 3 def modelEquals(m1: NaiveBayesModel, m2: NaiveBayesModel): Unit = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 02a9d5c2a18c0..ba4a9cf082785 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -155,6 +155,22 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { Vector, RandomForestClassificationModel](this, model, df) } + test("prediction on single instance") { + val rdd = orderedLabeledPoints5_20 + val rf = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(3) + .setNumTrees(3) + .setSeed(123) + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + + val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val model = rf.fit(df) + + testPredictionModelSinglePrediction(model, df) + } + test("Fitting without numClasses in metadata") { val df: DataFrame = TreeTests.featureImportanceData(sc).toDF() val rf = new RandomForestClassifier().setMaxDepth(1).setNumTrees(1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 68a1218c23ece..29a438396516b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -136,6 +136,21 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { assert(importances.toArray.forall(_ >= 0.0)) } + test("prediction on single instance") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(3) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0) + + val model = dt.fit(df) + testPredictionModelSinglePrediction(model, df) + } + test("should support all NumericType labels and not support other types") { val dt = new DecisionTreeRegressor().setMaxDepth(1) MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor]( diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 11c593b521e65..fad11d078250f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -99,6 +99,14 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { } } + test("prediction on single instance") { + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setMaxIter(2) + val model = gbt.fit(trainData.toDF()) + testPredictionModelSinglePrediction(model, validationData.toDF) + } + test("Checkpointing") { val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index ef2ff94a5e213..d5bcbb221783e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -211,6 +211,14 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest assert(model.getLink === "identity") } + test("prediction on single instance") { + val glr = new GeneralizedLinearRegression + val model = glr.setFamily("gaussian").setLink("identity") + .fit(datasetGaussianIdentity) + + testPredictionModelSinglePrediction(model, datasetGaussianIdentity) + } + test("generalized linear regression: gaussian family against glm") { /* R code: diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index d42cb1714478f..9b19f63eba1bd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -636,6 +636,13 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { } } + test("prediction on single instance") { + val trainer = new LinearRegression + val model = trainer.fit(datasetWithDenseFeature) + + testPredictionModelSinglePrediction(model, datasetWithDenseFeature) + } + test("linear regression model with constant label") { /* R code: diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 8b8e8a655f47b..e83c49f932973 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -19,22 +19,22 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} /** * Test suite for [[RandomForestRegressor]]. */ -class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest{ +class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{ import RandomForestRegressorSuite.compareAPIs + import testImplicits._ private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _ @@ -74,6 +74,20 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex regressionTestWithContinuousFeatures(rf) } + test("prediction on single instance") { + val rf = new RandomForestRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(10) + .setNumTrees(1) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + + val df = orderedLabeledPoints50_1000.toDF() + val model = rf.fit(df) + testPredictionModelSinglePrediction(model, df) + } + test("Feature importance with toy data") { val rf = new RandomForestRegressor() .setImpurity("variance") diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala index 795fd0e2ac0e4..76d41f9b23715 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala @@ -22,8 +22,9 @@ import java.io.File import org.scalatest.Suite import org.apache.spark.SparkContext -import org.apache.spark.ml.Transformer -import org.apache.spark.sql.{DataFrame, Encoder, Row} +import org.apache.spark.ml.{PredictionModel, Transformer} +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions.col import org.apache.spark.sql.streaming.StreamTest @@ -136,4 +137,14 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => assert(hasExpectedMessage(exceptionOnStreamData)) } } + + def testPredictionModelSinglePrediction(model: PredictionModel[Vector, _], + dataset: Dataset[_]): Unit = { + + model.transform(dataset).select(model.getFeaturesCol, model.getPredictionCol) + .collect().foreach { + case Row(features: Vector, prediction: Double) => + assert(prediction === model.predict(features)) + } + } } From 8d79113b812a91073d2c24a3a9ad94cc3b90b24a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 21 Mar 2018 09:46:47 -0700 Subject: [PATCH 501/774] [SPARK-23577][SQL] Supports custom line separator for text datasource ## What changes were proposed in this pull request? This PR proposes to add `lineSep` option for a configurable line separator in text datasource. It supports this option by using `LineRecordReader`'s functionality with passing it to the constructor. ## How was this patch tested? Manual tests and unit tests were added. Author: hyukjinkwon Closes #20727 from HyukjinKwon/linesep-text. --- python/pyspark/sql/readwriter.py | 14 ++++--- python/pyspark/sql/streaming.py | 8 +++- python/pyspark/sql/tests.py | 24 ++++++++++- .../apache/spark/sql/DataFrameReader.scala | 30 ++++++++------ .../apache/spark/sql/DataFrameWriter.scala | 2 + .../datasources/HadoopFileLinesReader.scala | 23 ++++++++++- .../datasources/text/TextFileFormat.scala | 16 ++++---- .../datasources/text/TextOptions.scala | 12 ++++++ .../sql/streaming/DataStreamReader.scala | 12 +++++- .../datasources/text/TextSuite.scala | 40 +++++++++++++++++++ 10 files changed, 147 insertions(+), 34 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index facc16bc53108..e5288636c596e 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -304,16 +304,18 @@ def parquet(self, *paths): @ignore_unicode_prefix @since(1.6) - def text(self, paths, wholetext=False): + def text(self, paths, wholetext=False, lineSep=None): """ Loads text files and returns a :class:`DataFrame` whose schema starts with a string column named "value", and followed by partitioned columns if there are any. - Each line in the text file is a new row in the resulting DataFrame. + By default, each line in the text file is a new row in the resulting DataFrame. :param paths: string, or list of strings, for input path(s). :param wholetext: if true, read each file from input path(s) as a single row. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. >>> df = spark.read.text('python/test_support/sql/text-test.txt') >>> df.collect() @@ -322,7 +324,7 @@ def text(self, paths, wholetext=False): >>> df.collect() [Row(value=u'hello\\nthis')] """ - self._set_opts(wholetext=wholetext) + self._set_opts(wholetext=wholetext, lineSep=lineSep) if isinstance(paths, basestring): paths = [paths] return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(paths))) @@ -804,18 +806,20 @@ def parquet(self, path, mode=None, partitionBy=None, compression=None): self._jwrite.parquet(path) @since(1.6) - def text(self, path, compression=None): + def text(self, path, compression=None, lineSep=None): """Saves the content of the DataFrame in a text file at the specified path. :param path: the path in any Hadoop supported file system :param compression: compression codec to use when saving to file. This can be one of the known case-insensitive shorten names (none, bzip2, gzip, lz4, snappy and deflate). + :param lineSep: defines the line separator that should be used for writing. If None is + set, it uses the default value, ``\\n``. The DataFrame must have only one column that is of string type. Each row becomes a new line in the output file. """ - self._set_opts(compression=compression) + self._set_opts(compression=compression, lineSep=lineSep) self._jwrite.text(path) @since(2.0) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index e8966c20a8f42..07f9ac1b5aa9e 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -531,17 +531,20 @@ def parquet(self, path): @ignore_unicode_prefix @since(2.0) - def text(self, path): + def text(self, path, wholetext=False, lineSep=None): """ Loads a text file stream and returns a :class:`DataFrame` whose schema starts with a string column named "value", and followed by partitioned columns if there are any. - Each line in the text file is a new row in the resulting DataFrame. + By default, each line in the text file is a new row in the resulting DataFrame. .. note:: Evolving. :param paths: string, or list of strings, for input path(s). + :param wholetext: if true, read each file from input path(s) as a single row. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. >>> text_sdf = spark.readStream.text(tempfile.mkdtemp()) >>> text_sdf.isStreaming @@ -549,6 +552,7 @@ def text(self, path): >>> "value" in str(text_sdf.schema) True """ + self._set_opts(wholetext=wholetext, lineSep=lineSep) if isinstance(path, basestring): return self._df(self._jreader.text(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 39d6c5226f138..967cc83166f3f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -648,7 +648,29 @@ def test_non_existed_udaf(self): self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf")) - def test_multiLine_json(self): + def test_linesep_text(self): + df = self.spark.read.text("python/test_support/sql/ages_newlines.csv", lineSep=",") + expected = [Row(value=u'Joe'), Row(value=u'20'), Row(value=u'"Hi'), + Row(value=u'\nI am Jeo"\nTom'), Row(value=u'30'), + Row(value=u'"My name is Tom"\nHyukjin'), Row(value=u'25'), + Row(value=u'"I am Hyukjin\n\nI love Spark!"\n')] + self.assertEqual(df.collect(), expected) + + tpath = tempfile.mkdtemp() + shutil.rmtree(tpath) + try: + df.write.text(tpath, lineSep="!") + expected = [Row(value=u'Joe!20!"Hi!'), Row(value=u'I am Jeo"'), + Row(value=u'Tom!30!"My name is Tom"'), + Row(value=u'Hyukjin!25!"I am Hyukjin'), + Row(value=u''), Row(value=u'I love Spark!"'), + Row(value=u'!')] + readback = self.spark.read.text(tpath) + self.assertEqual(readback.collect(), expected) + finally: + shutil.rmtree(tpath) + + def test_multiline_json(self): people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", multiLine=True) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 0139913aaa4e2..1a5e47508c070 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -647,14 +647,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * Loads text files and returns a `DataFrame` whose schema starts with a string column named * "value", and followed by partitioned columns if there are any. * - * You can set the following text-specific option(s) for reading text files: - *
      - *
    • `wholetext` ( default `false`): If true, read a file as a single row and not split by "\n". - *
    • - *
    - * By default, each line in the text files is a new row in the resulting DataFrame. - * - * Usage example: + * By default, each line in the text files is a new row in the resulting DataFrame. For example: * {{{ * // Scala: * spark.read.text("/path/to/spark/README.md") @@ -663,6 +656,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * spark.read().text("/path/to/spark/README.md") * }}} * + * You can set the following text-specific option(s) for reading text files: + *
      + *
    • `wholetext` (default `false`): If true, read a file as a single row and not split by "\n". + *
    • + *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
    • + *
    + * * @param paths input paths * @since 1.6.0 */ @@ -686,11 +687,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * If the directory structure of the text files contains partitioning information, those are * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. * - * You can set the following textFile-specific option(s) for reading text files: - *
      - *
    • `wholetext` ( default `false`): If true, read a file as a single row and not split by "\n". - *
    • - *
    * By default, each line in the text files is a new row in the resulting DataFrame. For example: * {{{ * // Scala: @@ -700,6 +696,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * spark.read().textFile("/path/to/spark/README.md") * }}} * + * You can set the following textFile-specific option(s) for reading text files: + *
      + *
    • `wholetext` (default `false`): If true, read a file as a single row and not split by "\n". + *
    • + *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
    • + *
    + * * @param paths input path * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index ed7a9100cc7f1..bb93889dc55e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -587,6 +587,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
  • + *
  • `lineSep` (default `\n`): defines the line separator that should + * be used for writing.
  • * * * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala index 83cf26c63a175..00a78f7343c59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala @@ -30,9 +30,22 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl /** * An adaptor from a [[PartitionedFile]] to an [[Iterator]] of [[Text]], which are all of the lines * in that file. + * + * @param file A part (i.e. "block") of a single file that should be read line by line. + * @param lineSeparator A line separator that should be used for each line. If the value is `None`, + * it covers `\r`, `\r\n` and `\n`. + * @param conf Hadoop configuration + * + * @note The behavior when `lineSeparator` is `None` (covering `\r`, `\r\n` and `\n`) is defined + * by [[LineRecordReader]], not within Spark. */ class HadoopFileLinesReader( - file: PartitionedFile, conf: Configuration) extends Iterator[Text] with Closeable { + file: PartitionedFile, + lineSeparator: Option[Array[Byte]], + conf: Configuration) extends Iterator[Text] with Closeable { + + def this(file: PartitionedFile, conf: Configuration) = this(file, None, conf) + private val iterator = { val fileSplit = new FileSplit( new Path(new URI(file.filePath)), @@ -42,7 +55,13 @@ class HadoopFileLinesReader( Array.empty) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) - val reader = new LineRecordReader() + + val reader = lineSeparator match { + case Some(sep) => new LineRecordReader(sep) + // If the line separator is `None`, it covers `\r`, `\r\n` and `\n`. + case _ => new LineRecordReader() + } + reader.initialize(fileSplit, hadoopAttemptContext) new RecordReaderIterator(reader) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index c661e9bd3b94c..9647f09867643 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -17,11 +17,8 @@ package org.apache.spark.sql.execution.datasources.text -import java.io.Closeable - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.TaskContext @@ -89,7 +86,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new TextOutputWriter(path, dataSchema, context) + new TextOutputWriter(path, dataSchema, textOptions.lineSeparatorInWrite, context) } override def getFileExtension(context: TaskAttemptContext): String = { @@ -113,18 +110,18 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - readToUnsafeMem(broadcastedHadoopConf, requiredSchema, textOptions.wholeText) + readToUnsafeMem(broadcastedHadoopConf, requiredSchema, textOptions) } private def readToUnsafeMem( conf: Broadcast[SerializableConfiguration], requiredSchema: StructType, - wholeTextMode: Boolean): (PartitionedFile) => Iterator[UnsafeRow] = { + textOptions: TextOptions): (PartitionedFile) => Iterator[UnsafeRow] = { (file: PartitionedFile) => { val confValue = conf.value.value - val reader = if (!wholeTextMode) { - new HadoopFileLinesReader(file, confValue) + val reader = if (!textOptions.wholeText) { + new HadoopFileLinesReader(file, textOptions.lineSeparatorInRead, confValue) } else { new HadoopFileWholeTextReader(file, confValue) } @@ -152,6 +149,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { class TextOutputWriter( path: String, dataSchema: StructType, + lineSeparator: Array[Byte], context: TaskAttemptContext) extends OutputWriter { @@ -162,7 +160,7 @@ class TextOutputWriter( val utf8string = row.getUTF8String(0) utf8string.writeTo(writer) } - writer.write('\n') + writer.write(lineSeparator) } override def close(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala index 2a661561ab51e..18698df9fd8e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.text +import java.nio.charset.StandardCharsets + import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs} /** @@ -39,9 +41,19 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti */ val wholeText = parameters.getOrElse(WHOLETEXT, "false").toBoolean + private val lineSeparator: Option[String] = parameters.get(LINE_SEPARATOR).map { sep => + require(sep.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.") + sep + } + // Note that the option 'lineSep' uses a different default value in read and write. + val lineSeparatorInRead: Option[Array[Byte]] = + lineSeparator.map(_.getBytes(StandardCharsets.UTF_8)) + val lineSeparatorInWrite: Array[Byte] = + lineSeparatorInRead.getOrElse("\n".getBytes(StandardCharsets.UTF_8)) } private[text] object TextOptions { val COMPRESSION = "compression" val WHOLETEXT = "wholetext" + val LINE_SEPARATOR = "lineSep" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index c393dcdfdd7e5..9b17406a816b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -387,7 +387,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * Loads text files and returns a `DataFrame` whose schema starts with a string column named * "value", and followed by partitioned columns if there are any. * - * Each line in the text files is a new row in the resulting DataFrame. For example: + * By default, each line in the text files is a new row in the resulting DataFrame. For example: * {{{ * // Scala: * spark.readStream.text("/path/to/directory/") @@ -400,6 +400,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
      *
    • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be * considered in every trigger.
    • + *
    • `wholetext` (default `false`): If true, read a file as a single row and not split by "\n". + *
    • + *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
    • *
    * * @since 2.0.0 @@ -413,7 +417,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * If the directory structure of the text files contains partitioning information, those are * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. * - * Each line in the text file is a new element in the resulting Dataset. For example: + * By default, each line in the text file is a new element in the resulting Dataset. For example: * {{{ * // Scala: * spark.readStream.textFile("/path/to/spark/README.md") @@ -426,6 +430,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
      *
    • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be * considered in every trigger.
    • + *
    • `wholetext` (default `false`): If true, read a file as a single row and not split by "\n". + *
    • + *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
    • *
    * * @param path input path diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 33287044f279e..e8a5299d6ba9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -18,10 +18,13 @@ package org.apache.spark.sql.execution.datasources.text import java.io.File +import java.nio.charset.StandardCharsets +import java.nio.file.Files import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec +import org.apache.spark.TestUtils import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -172,6 +175,43 @@ class TextSuite extends QueryTest with SharedSQLContext { } } + def testLineSeparator(lineSep: String): Unit = { + test(s"SPARK-23577: Support line separator - lineSep: '$lineSep'") { + // Read + val values = Seq("a", "b", "\nc") + val data = values.mkString(lineSep) + val dataWithTrailingLineSep = s"$data$lineSep" + Seq(data, dataWithTrailingLineSep).foreach { lines => + withTempPath { path => + Files.write(path.toPath, lines.getBytes(StandardCharsets.UTF_8)) + val df = spark.read.option("lineSep", lineSep).text(path.getAbsolutePath) + checkAnswer(df, Seq("a", "b", "\nc").toDF()) + } + } + + // Write + withTempPath { path => + values.toDF().coalesce(1) + .write.option("lineSep", lineSep).text(path.getAbsolutePath) + val partFile = TestUtils.recursiveList(path).filter(f => f.getName.startsWith("part-")).head + val readBack = new String(Files.readAllBytes(partFile.toPath), StandardCharsets.UTF_8) + assert(readBack === s"a${lineSep}b${lineSep}\nc${lineSep}") + } + + // Roundtrip + withTempPath { path => + val df = values.toDF() + df.write.option("lineSep", lineSep).text(path.getAbsolutePath) + val readBack = spark.read.option("lineSep", lineSep).text(path.getAbsolutePath) + checkAnswer(df, readBack) + } + } + } + + Seq("|", "^", "::", "!!!@3", 0x1E.toChar.toString).foreach { lineSep => + testLineSeparator(lineSep) + } + private def testFile: String = { Thread.currentThread().getContextClassLoader.getResource("test-data/text-suite.txt").toString } From 98d0ea3f6091730285293321a50148f69e94c9cd Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 21 Mar 2018 09:52:28 -0700 Subject: [PATCH 502/774] [SPARK-23264][SQL] Fix scala.MatchError in literals.sql.out ## What changes were proposed in this pull request? To fix `scala.MatchError` in `literals.sql.out`, this pr added an entry for `CalendarIntervalType` in `QueryExecution.toHiveStructString`. ## How was this patch tested? Existing tests and added tests in `literals.sql` Author: Takeshi Yamamuro Closes #20872 from maropu/FixIntervalTests. --- .../spark/sql/execution/QueryExecution.scala | 2 ++ .../resources/sql-tests/inputs/literals.sql | 3 +++ .../sql-tests/results/literals.sql.out | 20 ++++++++++++------- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 7cae24bf5976c..15379a0663f7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -155,6 +155,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { case (null, _) => "null" case (s: String, StringType) => "\"" + s + "\"" case (decimal, DecimalType()) => decimal.toString + case (interval, CalendarIntervalType) => interval.toString case (other, tpe) if primitiveTypes contains tpe => other.toString } @@ -178,6 +179,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone)) case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8) case (decimal: java.math.BigDecimal, DecimalType()) => formatDecimal(decimal) + case (interval, CalendarIntervalType) => interval.toString case (other, tpe) if primitiveTypes.contains(tpe) => other.toString } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/literals.sql b/sql/core/src/test/resources/sql-tests/inputs/literals.sql index 37b4b7606d12b..a743cf1ec2cde 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/literals.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/literals.sql @@ -105,3 +105,6 @@ select X'XuZ'; -- Hive literal_double test. SELECT 3.14, -3.14, 3.14e8, 3.14e-8, -3.14e8, -3.14e-8, 3.14e+8, 3.14E8, 3.14E-8; + +-- map + interval test +select map(1, interval 1 day, 2, interval 3 week); diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index 95d4413148f64..b8c91dc8b59a4 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 43 +-- Number of queries: 44 -- !query 0 @@ -323,19 +323,17 @@ select timestamp '2016-33-11 20:54:00.000' -- !query 34 select interval 13.123456789 seconds, interval -13.123456789 second -- !query 34 schema -struct<> +struct -- !query 34 output -scala.MatchError -(interval 13 seconds 123 milliseconds 456 microseconds,CalendarIntervalType) (of class scala.Tuple2) +interval 13 seconds 123 milliseconds 456 microseconds interval -12 seconds -876 milliseconds -544 microseconds -- !query 35 select interval 1 year 2 month 3 week 4 day 5 hour 6 minute 7 seconds 8 millisecond, 9 microsecond -- !query 35 schema -struct<> +struct -- !query 35 output -scala.MatchError -(interval 1 years 2 months 3 weeks 4 days 5 hours 6 minutes 7 seconds 8 milliseconds,CalendarIntervalType) (of class scala.Tuple2) +interval 1 years 2 months 3 weeks 4 days 5 hours 6 minutes 7 seconds 8 milliseconds 9 -- !query 36 @@ -416,3 +414,11 @@ SELECT 3.14, -3.14, 3.14e8, 3.14e-8, -3.14e8, -3.14e-8, 3.14e+8, 3.14E8, 3.14E-8 struct<3.14:decimal(3,2),-3.14:decimal(3,2),3.14E+8:decimal(3,-6),3.14E-8:decimal(10,10),-3.14E+8:decimal(3,-6),-3.14E-8:decimal(10,10),3.14E+8:decimal(3,-6),3.14E+8:decimal(3,-6),3.14E-8:decimal(10,10)> -- !query 42 output 3.14 -3.14 314000000 0.0000000314 -314000000 -0.0000000314 314000000 314000000 0.0000000314 + + +-- !query 43 +select map(1, interval 1 day, 2, interval 3 week) +-- !query 43 schema +struct> +-- !query 43 output +{1:interval 1 days,2:interval 3 weeks} From 918c7e99afdcea05c36626e230636c4f8aabf82c Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Wed, 21 Mar 2018 10:06:26 -0700 Subject: [PATCH 503/774] [SPARK-23288][SS] Fix output metrics with parquet sink ## What changes were proposed in this pull request? Output metrics were not filled when parquet sink used. This PR fixes this problem by passing a `BasicWriteJobStatsTracker` in `FileStreamSink`. ## How was this patch tested? Additional unit test added. Author: Gabor Somogyi Closes #20745 from gaborgsomogyi/SPARK-23288. --- .../command/DataWritingCommand.scala | 11 +--- .../datasources/BasicWriteStatsTracker.scala | 25 +++++++-- .../execution/streaming/FileStreamSink.scala | 10 +++- .../sql/streaming/FileStreamSinkSuite.scala | 52 +++++++++++++++++++ 4 files changed, 82 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index e56f8105fc9a7..e11dbd201004d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.command import org.apache.hadoop.conf.Configuration -import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} @@ -45,15 +44,7 @@ trait DataWritingCommand extends Command { // Output columns of the analyzed input query plan def outputColumns: Seq[Attribute] - lazy val metrics: Map[String, SQLMetric] = { - val sparkContext = SparkContext.getActive.get - Map( - "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files"), - "numOutputBytes" -> SQLMetrics.createMetric(sparkContext, "bytes of written output"), - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numParts" -> SQLMetrics.createMetric(sparkContext, "number of dynamic part") - ) - } + lazy val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics def basicWriteJobStatsTracker(hadoopConf: Configuration): BasicWriteJobStatsTracker = { val serializableHadoopConf = new SerializableConfiguration(hadoopConf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index 9dbbe9946ee99..69c03d862391e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -153,12 +153,29 @@ class BasicWriteJobStatsTracker( totalNumOutput += summary.numRows } - metrics("numFiles").add(numFiles) - metrics("numOutputBytes").add(totalNumBytes) - metrics("numOutputRows").add(totalNumOutput) - metrics("numParts").add(numPartitions) + metrics(BasicWriteJobStatsTracker.NUM_FILES_KEY).add(numFiles) + metrics(BasicWriteJobStatsTracker.NUM_OUTPUT_BYTES_KEY).add(totalNumBytes) + metrics(BasicWriteJobStatsTracker.NUM_OUTPUT_ROWS_KEY).add(totalNumOutput) + metrics(BasicWriteJobStatsTracker.NUM_PARTS_KEY).add(numPartitions) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toList) } } + +object BasicWriteJobStatsTracker { + private val NUM_FILES_KEY = "numFiles" + private val NUM_OUTPUT_BYTES_KEY = "numOutputBytes" + private val NUM_OUTPUT_ROWS_KEY = "numOutputRows" + private val NUM_PARTS_KEY = "numParts" + + def metrics: Map[String, SQLMetric] = { + val sparkContext = SparkContext.getActive.get + Map( + NUM_FILES_KEY -> SQLMetrics.createMetric(sparkContext, "number of written files"), + NUM_OUTPUT_BYTES_KEY -> SQLMetrics.createMetric(sparkContext, "bytes of written output"), + NUM_OUTPUT_ROWS_KEY -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + NUM_PARTS_KEY -> SQLMetrics.createMetric(sparkContext, "number of dynamic part") + ) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 87a17cebdc10c..b3d12f67b5d63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -26,7 +26,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.datasources.{FileFormat, FileFormatWriter} +import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, FileFormat, FileFormatWriter} +import org.apache.spark.util.SerializableConfiguration object FileStreamSink extends Logging { // The name of the subdirectory that is used to store metadata about which files are valid. @@ -97,6 +98,11 @@ class FileStreamSink( new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, logPath.toUri.toString) private val hadoopConf = sparkSession.sessionState.newHadoopConf() + private def basicWriteJobStatsTracker: BasicWriteJobStatsTracker = { + val serializableHadoopConf = new SerializableConfiguration(hadoopConf) + new BasicWriteJobStatsTracker(serializableHadoopConf, BasicWriteJobStatsTracker.metrics) + } + override def addBatch(batchId: Long, data: DataFrame): Unit = { if (batchId <= fileLog.getLatest().map(_._1).getOrElse(-1L)) { logInfo(s"Skipping already committed batch $batchId") @@ -131,7 +137,7 @@ class FileStreamSink( hadoopConf = hadoopConf, partitionColumns = partitionColumns, bucketSpec = None, - statsTrackers = Nil, + statsTrackers = Seq(basicWriteJobStatsTracker), options = options) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 31e5527d7366a..cf41d7e0e4fe1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -21,6 +21,7 @@ import java.util.Locale import org.apache.hadoop.fs.Path +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ @@ -405,4 +406,55 @@ class FileStreamSinkSuite extends StreamTest { } } } + + test("SPARK-23288 writing and checking output metrics") { + Seq("parquet", "orc", "text", "json").foreach { format => + val inputData = MemoryStream[String] + val df = inputData.toDF() + + withTempDir { outputDir => + withTempDir { checkpointDir => + + var query: StreamingQuery = null + + var numTasks = 0 + var recordsWritten: Long = 0L + var bytesWritten: Long = 0L + try { + spark.sparkContext.addSparkListener(new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + val outputMetrics = taskEnd.taskMetrics.outputMetrics + recordsWritten += outputMetrics.recordsWritten + bytesWritten += outputMetrics.bytesWritten + numTasks += 1 + } + }) + + query = + df.writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .format(format) + .start(outputDir.getCanonicalPath) + + inputData.addData("1", "2", "3") + inputData.addData("4", "5") + + failAfter(streamingTimeout) { + query.processAllAvailable() + } + spark.sparkContext.listenerBus.waitUntilEmpty(streamingTimeout.toMillis) + + assert(numTasks > 0) + assert(recordsWritten === 5) + // This is heavily file type/version specific but should be filled + assert(bytesWritten > 0) + } finally { + if (query != null) { + query.stop() + } + } + } + } + } + } } From 2b89e4aa2e8bd8b88f6e5eb60d95c1a58e5c4ace Mon Sep 17 00:00:00 2001 From: akonopko Date: Wed, 21 Mar 2018 14:40:21 -0500 Subject: [PATCH 504/774] [SPARK-18580][DSTREAM][KAFKA] Add spark.streaming.backpressure.initialRate to direct Kafka streams ## What changes were proposed in this pull request? Add `spark.streaming.backpressure.initialRate` to direct Kafka Streams for Kafka 0.8 and 0.10 This is required in order to be able to use backpressure with huge lags, which cannot be processed at once. Without this parameter `DirectKafkaInputDStream` with backpressure enabled would try to get all the possible data from Kafka before adjusting consumption rate ## How was this patch tested? - Tests added to `org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala` and `org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala` - Manual tests on YARN cluster Author: akonopko Author: Alexander Konopko Closes #19431 from akonopko/SPARK-18580-initialrate. --- .../kafka010/DirectKafkaInputDStream.scala | 8 ++- .../kafka010/DirectKafkaStreamSuite.scala | 51 +++++++++++++++- .../kafka/DirectKafkaInputDStream.scala | 9 ++- .../kafka/DirectKafkaStreamSuite.scala | 59 ++++++++++++++++++- 4 files changed, 120 insertions(+), 7 deletions(-) diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 9cb2448fea0f4..215b7cab703fb 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -56,6 +56,9 @@ private[spark] class DirectKafkaInputDStream[K, V]( ppc: PerPartitionConfig ) extends InputDStream[ConsumerRecord[K, V]](_ssc) with Logging with CanCommitOffsets { + private val initialRate = context.sparkContext.getConf.getLong( + "spark.streaming.backpressure.initialRate", 0) + val executorKafkaParams = { val ekp = new ju.HashMap[String, Object](consumerStrategy.executorKafkaParams) KafkaUtils.fixKafkaParams(ekp) @@ -126,7 +129,10 @@ private[spark] class DirectKafkaInputDStream[K, V]( protected[streaming] def maxMessagesPerPartition( offsets: Map[TopicPartition, Long]): Option[Map[TopicPartition, Long]] = { - val estimatedRateLimit = rateController.map(_.getLatestRate()) + val estimatedRateLimit = rateController.map { x => { + val lr = x.getLatestRate() + if (lr > 0) lr else initialRate + }} // calculate a per-partition rate limit based on current lag val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match { diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index 8524743ee2846..35e4678f2e3c8 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.kafka010 import java.io.File import java.lang.{ Long => JLong } -import java.util.{ Arrays, HashMap => JHashMap, Map => JMap } +import java.util.{ Arrays, HashMap => JHashMap, Map => JMap, UUID } import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicLong @@ -34,7 +34,7 @@ import org.apache.kafka.common.serialization.StringDeserializer import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} @@ -617,6 +617,53 @@ class DirectKafkaStreamSuite ssc.stop() } + test("backpressure.initialRate should honor maxRatePerPartition") { + backpressureTest(maxRatePerPartition = 1000, initialRate = 500, maxMessagesPerPartition = 250) + } + + test("use backpressure.initialRate with backpressure") { + backpressureTest(maxRatePerPartition = 300, initialRate = 1000, maxMessagesPerPartition = 150) + } + + private def backpressureTest( + maxRatePerPartition: Int, + initialRate: Int, + maxMessagesPerPartition: Int) = { + + val topic = UUID.randomUUID().toString + val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest") + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.backpressure.enabled", "true") + .set("spark.streaming.backpressure.initialRate", initialRate.toString) + .set("spark.streaming.kafka.maxRatePerPartition", maxRatePerPartition.toString) + + val messages = Map("foo" -> 5000) + kafkaTestUtils.sendMessages(topic, messages) + + ssc = new StreamingContext(sparkConf, Milliseconds(500)) + + val kafkaStream = withClue("Error creating direct stream") { + new DirectKafkaInputDStream[String, String]( + ssc, + preferredHosts, + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala), + new DefaultPerPartitionConfig(sparkConf) + ) + } + kafkaStream.start() + + val input = Map(new TopicPartition(topic, 0) -> 1000L) + + assert(kafkaStream.maxMessagesPerPartition(input).get == + Map(new TopicPartition(topic, 0) -> maxMessagesPerPartition)) // we run for half a second + + kafkaStream.stop() + } + test("maxMessagesPerPartition with zero offset and rate equal to one") { val topic = "backpressure" val kafkaParams = getKafkaParams() diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index d6dd0744441e4..9297c39d170c4 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -91,9 +91,16 @@ class DirectKafkaInputDStream[ private val maxRateLimitPerPartition: Long = context.sparkContext.getConf.getLong( "spark.streaming.kafka.maxRatePerPartition", 0) + private val initialRate = context.sparkContext.getConf.getLong( + "spark.streaming.backpressure.initialRate", 0) + protected[streaming] def maxMessagesPerPartition( offsets: Map[TopicAndPartition, Long]): Option[Map[TopicAndPartition, Long]] = { - val estimatedRateLimit = rateController.map(_.getLatestRate()) + + val estimatedRateLimit = rateController.map { x => { + val lr = x.getLatestRate() + if (lr > 0) lr else initialRate + }} // calculate a per-partition rate limit based on current lag val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match { diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 3fea6cfd910bf..ecca38784e777 100644 --- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.kafka import java.io.File -import java.util.Arrays +import java.util.{ Arrays, UUID } import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicLong @@ -32,12 +32,11 @@ import kafka.serializer.StringDecoder import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset import org.apache.spark.streaming.scheduler._ import org.apache.spark.streaming.scheduler.rate.RateEstimator import org.apache.spark.util.Utils @@ -456,6 +455,60 @@ class DirectKafkaStreamSuite ssc.stop() } + test("use backpressure.initialRate with backpressure") { + backpressureTest(maxRatePerPartition = 1000, initialRate = 500, maxMessagesPerPartition = 250) + } + + test("backpressure.initialRate should honor maxRatePerPartition") { + backpressureTest(maxRatePerPartition = 300, initialRate = 1000, maxMessagesPerPartition = 150) + } + + private def backpressureTest( + maxRatePerPartition: Int, + initialRate: Int, + maxMessagesPerPartition: Int) = { + + val topic = UUID.randomUUID().toString + val topicPartitions = Set(TopicAndPartition(topic, 0)) + kafkaTestUtils.createTopic(topic, 1) + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "smallest" + ) + + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.backpressure.enabled", "true") + .set("spark.streaming.backpressure.initialRate", initialRate.toString) + .set("spark.streaming.kafka.maxRatePerPartition", maxRatePerPartition.toString) + + val messages = Map("foo" -> 5000) + kafkaTestUtils.sendMessages(topic, messages) + + ssc = new StreamingContext(sparkConf, Milliseconds(500)) + + val kafkaStream = withClue("Error creating direct stream") { + val kc = new KafkaCluster(kafkaParams) + val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message) + val m = kc.getEarliestLeaderOffsets(topicPartitions) + .fold(e => Map.empty[TopicAndPartition, Long], m => m.mapValues(lo => lo.offset)) + + new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)]( + ssc, kafkaParams, m, messageHandler) + } + kafkaStream.start() + + val input = Map(new TopicAndPartition(topic, 0) -> 1000L) + + assert(kafkaStream.maxMessagesPerPartition(input).get == + Map(new TopicAndPartition(topic, 0) -> maxMessagesPerPartition)) + + kafkaStream.stop() + } + test("maxMessagesPerPartition with zero offset and rate equal to one") { val topic = "backpressure" val kafkaParams = Map( From a091ee676b8707819e94d92693956237310a6145 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 21 Mar 2018 13:52:03 -0700 Subject: [PATCH 505/774] [MINOR] Fix Java lint from new JavaKolmogorovSmirnovTestSuite ## What changes were proposed in this pull request? Fix lint-java from https://github.com/apache/spark/pull/19108 addition of JavaKolmogorovSmirnovTestSuite Author: Joseph K. Bradley Closes #20875 from jkbradley/kstest-lint-fix. --- .../spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java | 7 ------- 1 file changed, 7 deletions(-) diff --git a/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java b/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java index 021272dd5a40c..830f668fe07b8 100644 --- a/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java @@ -18,18 +18,11 @@ package org.apache.spark.ml.stat; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.apache.commons.math3.distribution.NormalDistribution; -import org.apache.spark.ml.linalg.VectorUDT; -import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; -import org.apache.spark.sql.types.DoubleType; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; import org.junit.Test; import org.apache.spark.SharedSparkSession; From 0604beaff2baa2d0fed86c0c87fd2a16a1838b5f Mon Sep 17 00:00:00 2001 From: Mihaly Toth Date: Wed, 21 Mar 2018 17:05:39 -0700 Subject: [PATCH 506/774] [SPARK-23729][CORE] Respect URI fragment when resolving globs Firstly, glob resolution will not result in swallowing the remote name part (that is preceded by the `#` sign) in case of `--files` or `--archives` options Moreover in the special case of multiple resolutions when the remote naming does not make sense and error is returned. Enhanced current test and wrote additional test for the error case Author: Mihaly Toth Closes #20853 from misutoth/glob-with-remote-name. --- .../apache/spark/deploy/DependencyUtils.scala | 34 +++++++++++---- .../org/apache/spark/deploy/SparkSubmit.scala | 13 ++++++ .../spark/deploy/SparkSubmitSuite.scala | 41 +++++++++++++++---- 3 files changed, 72 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala index ecc82d7ac8001..ab319c860ee69 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -18,12 +18,13 @@ package org.apache.spark.deploy import java.io.File +import java.net.URI import org.apache.commons.lang3.StringUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.util.{MutableURLClassLoader, Utils} private[deploy] object DependencyUtils { @@ -137,16 +138,31 @@ private[deploy] object DependencyUtils { def resolveGlobPaths(paths: String, hadoopConf: Configuration): String = { require(paths != null, "paths cannot be null.") Utils.stringToSeq(paths).flatMap { path => - val uri = Utils.resolveURI(path) - uri.getScheme match { - case "local" | "http" | "https" | "ftp" => Array(path) - case _ => - val fs = FileSystem.get(uri, hadoopConf) - Option(fs.globStatus(new Path(uri))).map { status => - status.filter(_.isFile).map(_.getPath.toUri.toString) - }.getOrElse(Array(path)) + val (base, fragment) = splitOnFragment(path) + (resolveGlobPath(base, hadoopConf), fragment) match { + case (resolved, Some(_)) if resolved.length > 1 => throw new SparkException( + s"${base.toString} resolves ambiguously to multiple files: ${resolved.mkString(",")}") + case (resolved, Some(namedAs)) => resolved.map(_ + "#" + namedAs) + case (resolved, _) => resolved } }.mkString(",") } + private def splitOnFragment(path: String): (URI, Option[String]) = { + val uri = Utils.resolveURI(path) + val withoutFragment = new URI(uri.getScheme, uri.getSchemeSpecificPart, null) + (withoutFragment, Option(uri.getFragment)) + } + + private def resolveGlobPath(uri: URI, hadoopConf: Configuration): Array[String] = { + uri.getScheme match { + case "local" | "http" | "https" | "ftp" => Array(uri.toString) + case _ => + val fs = FileSystem.get(uri, hadoopConf) + Option(fs.globStatus(new Path(uri))).map { status => + status.filter(_.isFile).map(_.getPath.toUri.toString) + }.getOrElse(Array(uri.toString)) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 329bde08718fe..3965f17f4b56e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -245,6 +245,19 @@ object SparkSubmit extends CommandLineUtils with Logging { args: SparkSubmitArguments, conf: Option[HadoopConfiguration] = None) : (Seq[String], Seq[String], SparkConf, String) = { + try { + doPrepareSubmitEnvironment(args, conf) + } catch { + case e: SparkException => + printErrorAndExit(e.getMessage) + throw e + } + } + + private def doPrepareSubmitEnvironment( + args: SparkSubmitArguments, + conf: Option[HadoopConfiguration] = None) + : (Seq[String], Seq[String], SparkConf, String) = { // Return values val childArgs = new ArrayBuffer[String]() val childClasspath = new ArrayBuffer[String]() diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index d265643a80b4e..2d0c192db4915 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy import java.io._ import java.net.URI import java.nio.charset.StandardCharsets -import java.nio.file.Files +import java.nio.file.{Files, Paths} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -606,10 +606,13 @@ class SparkSubmitSuite } test("resolves command line argument paths correctly") { - val jars = "/jar1,/jar2" // --jars - val files = "local:/file1,file2" // --files - val archives = "file:/archive1,archive2" // --archives - val pyFiles = "py-file1,py-file2" // --py-files + val dir = Utils.createTempDir() + val archive = Paths.get(dir.toPath.toString, "single.zip") + Files.createFile(archive) + val jars = "/jar1,/jar2" + val files = "local:/file1,file2" + val archives = s"file:/archive1,${dir.toPath.toAbsolutePath.toString}/*.zip#archive3" + val pyFiles = "py-file1,py-file2" // Test jars and files val clArgs = Seq( @@ -636,9 +639,10 @@ class SparkSubmitSuite val appArgs2 = new SparkSubmitArguments(clArgs2) val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2) appArgs2.files should be (Utils.resolveURIs(files)) - appArgs2.archives should be (Utils.resolveURIs(archives)) + appArgs2.archives should fullyMatch regex ("file:/archive1,file:.*#archive3") conf2.get("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) - conf2.get("spark.yarn.dist.archives") should be (Utils.resolveURIs(archives)) + conf2.get("spark.yarn.dist.archives") should fullyMatch regex + ("file:/archive1,file:.*#archive3") // Test python files val clArgs3 = Seq( @@ -657,6 +661,29 @@ class SparkSubmitSuite conf3.get(PYSPARK_PYTHON.key) should be ("python3.5") } + test("ambiguous archive mapping results in error message") { + val dir = Utils.createTempDir() + val archive1 = Paths.get(dir.toPath.toString, "first.zip") + val archive2 = Paths.get(dir.toPath.toString, "second.zip") + Files.createFile(archive1) + Files.createFile(archive2) + val jars = "/jar1,/jar2" + val files = "local:/file1,file2" + val archives = s"file:/archive1,${dir.toPath.toAbsolutePath.toString}/*.zip#archive3" + val pyFiles = "py-file1,py-file2" + + // Test files and archives (Yarn) + val clArgs2 = Seq( + "--master", "yarn", + "--class", "org.SomeClass", + "--files", files, + "--archives", archives, + "thejar.jar" + ) + + testPrematureExit(clArgs2.toArray, "resolves ambiguously to multiple files") + } + test("resolves config paths correctly") { val jars = "/jar1,/jar2" // spark.jars val files = "local:/file1,file2" // spark.files / spark.yarn.dist.files From 95e51ff849a4c46cae463636b1ee393042469e7b Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Wed, 21 Mar 2018 21:21:36 -0700 Subject: [PATCH 507/774] [SPARK-23760][SQL] CodegenContext.withSubExprEliminationExprs should save/restore CSE state correctly ## What changes were proposed in this pull request? Fixed `CodegenContext.withSubExprEliminationExprs()` so that it saves/restores CSE state correctly. ## How was this patch tested? Added new unit test to verify that the old CSE state is indeed saved and restored around the `withSubExprEliminationExprs()` call. Manually verified that this test fails without this patch. Author: Kris Mok Closes #20870 from rednaxelafx/codegen-subexpr-fix. --- .../expressions/codegen/CodeGenerator.scala | 16 +++---- .../expressions/CodeGenerationSuite.scala | 44 +++++++++++++++++++ 2 files changed, 51 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index fe5e63ec0a2bb..84b1e3fbda876 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -402,7 +402,7 @@ class CodegenContext { val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions // Foreach expression that is participating in subexpression elimination, the state to use. - val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] + var subExprEliminationExprs = Map.empty[Expression, SubExprEliminationState] // The collection of sub-expression result resetting methods that need to be called on each row. val subexprFunctions = mutable.ArrayBuffer.empty[String] @@ -921,14 +921,12 @@ class CodegenContext { newSubExprEliminationExprs: Map[Expression, SubExprEliminationState])( f: => Seq[ExprCode]): Seq[ExprCode] = { val oldsubExprEliminationExprs = subExprEliminationExprs - subExprEliminationExprs.clear - newSubExprEliminationExprs.foreach(subExprEliminationExprs += _) + subExprEliminationExprs = newSubExprEliminationExprs val genCodes = f // Restore previous subExprEliminationExprs - subExprEliminationExprs.clear - oldsubExprEliminationExprs.foreach(subExprEliminationExprs += _) + subExprEliminationExprs = oldsubExprEliminationExprs genCodes } @@ -942,7 +940,7 @@ class CodegenContext { def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = { // Create a clear EquivalentExpressions and SubExprEliminationState mapping val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions - val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] + val localSubExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] // Add each expression tree and compute the common subexpressions. expressions.foreach(equivalentExpressions.addExprTree) @@ -955,10 +953,10 @@ class CodegenContext { // Generate the code for this expression tree. val eval = expr.genCode(this) val state = SubExprEliminationState(eval.isNull, eval.value) - e.foreach(subExprEliminationExprs.put(_, state)) + e.foreach(localSubExprEliminationExprs.put(_, state)) eval.code.trim } - SubExprCodes(codes, subExprEliminationExprs.toMap) + SubExprCodes(codes, localSubExprEliminationExprs.toMap) } /** @@ -1006,7 +1004,7 @@ class CodegenContext { subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" val state = SubExprEliminationState(isNull, value) - e.foreach(subExprEliminationExprs.put(_, state)) + subExprEliminationExprs ++= e.map(_ -> state).toMap } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 64c13e8972036..398b6767654fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -442,4 +442,48 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(CodeGenerator.calculateParamLength( Seq.range(0, 100).map(x => Literal(x.toLong))) == 201) } + + test("SPARK-23760: CodegenContext.withSubExprEliminationExprs should save/restore correctly") { + + val ref = BoundReference(0, IntegerType, true) + val add1 = Add(ref, ref) + val add2 = Add(add1, add1) + + // raw testing of basic functionality + { + val ctx = new CodegenContext + val e = ref.genCode(ctx) + // before + ctx.subExprEliminationExprs += ref -> SubExprEliminationState(e.isNull, e.value) + assert(ctx.subExprEliminationExprs.contains(ref)) + // call withSubExprEliminationExprs + ctx.withSubExprEliminationExprs(Map(add1 -> SubExprEliminationState("dummy", "dummy"))) { + assert(ctx.subExprEliminationExprs.contains(add1)) + assert(!ctx.subExprEliminationExprs.contains(ref)) + Seq.empty + } + // after + assert(ctx.subExprEliminationExprs.nonEmpty) + assert(ctx.subExprEliminationExprs.contains(ref)) + assert(!ctx.subExprEliminationExprs.contains(add1)) + } + + // emulate an actual codegen workload + { + val ctx = new CodegenContext + // before + ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE + assert(ctx.subExprEliminationExprs.contains(add1)) + // call withSubExprEliminationExprs + ctx.withSubExprEliminationExprs(Map(ref -> SubExprEliminationState("dummy", "dummy"))) { + assert(ctx.subExprEliminationExprs.contains(ref)) + assert(!ctx.subExprEliminationExprs.contains(add1)) + Seq.empty + } + // after + assert(ctx.subExprEliminationExprs.nonEmpty) + assert(ctx.subExprEliminationExprs.contains(add1)) + assert(!ctx.subExprEliminationExprs.contains(ref)) + } + } } From 5c9eaa6b585e9febd782da8eb6490b24d0d39ff3 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 21 Mar 2018 21:49:02 -0700 Subject: [PATCH 508/774] [SPARK-23372][SQL] Writing empty struct in parquet fails during execution. It should fail earlier in the processing. ## What changes were proposed in this pull request? Currently we allow writing data frames with empty schema into a file based datasource for certain file formats such as JSON, ORC etc. For formats such as Parquet and Text, we raise error at different times of execution. For text format, we return error from the driver early on in processing where as for format such as parquet, the error is raised from executor. **Example** spark.emptyDataFrame.write.format("parquet").mode("overwrite").save(path) **Results in** ``` SQL org.apache.parquet.schema.InvalidSchemaException: Cannot write a schema with an empty group: message spark_schema { } at org.apache.parquet.schema.TypeUtil$1.visit(TypeUtil.java:27) at org.apache.parquet.schema.TypeUtil$1.visit(TypeUtil.java:37) at org.apache.parquet.schema.MessageType.accept(MessageType.java:58) at org.apache.parquet.schema.TypeUtil.checkValidWriteSchema(TypeUtil.java:23) at org.apache.parquet.hadoop.ParquetFileWriter.(ParquetFileWriter.java:225) at org.apache.parquet.hadoop.ParquetOutputFormat.getRecordWriter(ParquetOutputFormat.java:342) at org.apache.parquet.hadoop.ParquetOutputFormat.getRecordWriter(ParquetOutputFormat.java:302) at org.apache.spark.sql.execution.datasources.parquet.ParquetOutputWriter.(ParquetOutputWriter.scala:37) at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$$anon$1.newInstance(ParquetFileFormat.scala:151) at org.apache.spark.sql.execution.datasources.FileFormatWriter$SingleDirectoryWriteTask.newOutputWriter(FileFormatWriter.scala:376) at org.apache.spark.sql.execution.datasources.FileFormatWriter$SingleDirectoryWriteTask.execute(FileFormatWriter.scala:387) at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$org$apache$spark$sql$execution$datasources$FileFormatWriter$$executeTask$3.apply(FileFormatWriter.scala:278) at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$org$apache$spark$sql$execution$datasources$FileFormatWriter$$executeTask$3.apply(FileFormatWriter.scala:276) at org.apache.spark.util.Utils$.tryWithSafeFinallyAndFailureCallbacks(Utils.scala:1411) at org.apache.spark.sql.execution.datasources.FileFormatWriter$.org$apache$spark$sql$execution$datasources$FileFormatWriter$$executeTask(FileFormatWriter.scala:281) at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply(FileFormatWriter.scala:206) at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply(FileFormatWriter.scala:205) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87) at org.apache.spark.scheduler.Task.run(Task.scala:109) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread. ``` In this PR, we unify the error processing and raise error on attempt to write empty schema based dataframes into file based datasource (orc, parquet, text , csv, json etc) early on in the processing. ## How was this patch tested? Unit tests added in FileBasedDatasourceSuite. Author: Dilip Biswal Closes #20579 from dilipbiswal/spark-23372. --- docs/sql-programming-guide.md | 1 + .../execution/datasources/DataSource.scala | 26 ++++++++++++++++- .../spark/sql/FileBasedDataSourceSuite.scala | 28 +++++++++++++++++++ 3 files changed, 54 insertions(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 5b47fd77f2cbc..421e2eaf62bfb 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1807,6 +1807,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. + - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 35fcff69b14d8..31fa89b4570a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{CalendarIntervalType, StructType} +import org.apache.spark.sql.types.{CalendarIntervalType, StructField, StructType} import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.Utils @@ -546,6 +546,7 @@ case class DataSource( case dataSource: CreatableRelationProvider => SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode) case format: FileFormat => + DataSource.validateSchema(data.schema) planForWritingFileFormat(format, mode, data) case _ => sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") @@ -719,4 +720,27 @@ object DataSource extends Logging { } globPath } + + /** + * Called before writing into a FileFormat based data source to make sure the + * supplied schema is not empty. + * @param schema + */ + private def validateSchema(schema: StructType): Unit = { + def hasEmptySchema(schema: StructType): Boolean = { + schema.size == 0 || schema.find { + case StructField(_, b: StructType, _, _) => hasEmptySchema(b) + case _ => false + }.isDefined + } + + + if (hasEmptySchema(schema)) { + throw new AnalysisException( + s""" + |Datasource does not support writing empty or nested empty schemas. + |Please make sure the data schema has at least one or more column(s). + """.stripMargin) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index bd3071bcf9010..06303099f5310 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with BeforeAndAfterAll { @@ -107,6 +108,33 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } + allFileBasedDataSources.foreach { format => + test(s"SPARK-23372 error while writing empty schema files using $format") { + withTempPath { outputPath => + val errMsg = intercept[AnalysisException] { + spark.emptyDataFrame.write.format(format).save(outputPath.toString) + } + assert(errMsg.getMessage.contains( + "Datasource does not support writing empty or nested empty schemas")) + } + + // Nested empty schema + withTempPath { outputPath => + val schema = StructType(Seq( + StructField("a", IntegerType), + StructField("b", StructType(Nil)), + StructField("c", IntegerType) + )) + val df = spark.createDataFrame(sparkContext.emptyRDD[Row], schema) + val errMsg = intercept[AnalysisException] { + df.write.format(format).save(outputPath.toString) + } + assert(errMsg.getMessage.contains( + "Datasource does not support writing empty or nested empty schemas")) + } + } + } + allFileBasedDataSources.foreach { format => test(s"SPARK-22146 read files containing special characters using $format") { withTempDir { dir => From 4d37008c78d7d6b8f8a649b375ecc090700eca4f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 22 Mar 2018 19:57:32 +0100 Subject: [PATCH 509/774] [SPARK-23599][SQL] Use RandomUUIDGenerator in Uuid expression ## What changes were proposed in this pull request? As stated in Jira, there are problems with current `Uuid` expression which uses `java.util.UUID.randomUUID` for UUID generation. This patch uses the newly added `RandomUUIDGenerator` for UUID generation. So we can make `Uuid` deterministic between retries. ## How was this patch tested? Added unit tests. Author: Liang-Chi Hsieh Closes #20861 from viirya/SPARK-23599-2. --- .../sql/catalyst/analysis/Analyzer.scala | 16 ++++ .../spark/sql/catalyst/expressions/misc.scala | 26 +++++-- .../ResolvedUuidExpressionsSuite.scala | 73 +++++++++++++++++++ .../expressions/ExpressionEvalHelper.scala | 5 +- .../expressions/MiscExpressionsSuite.scala | 19 ++++- .../org/apache/spark/sql/DataFrameSuite.scala | 6 ++ 6 files changed, 136 insertions(+), 9 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7848f88bda1c9..e821e96522f7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable.ArrayBuffer +import scala.util.Random import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ @@ -177,6 +178,7 @@ class Analyzer( TimeWindowing :: ResolveInlineTables(conf) :: ResolveTimeZone(conf) :: + ResolvedUuidExpressions :: TypeCoercion.typeCoercionRules(conf) ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), @@ -1994,6 +1996,20 @@ class Analyzer( } } + /** + * Set the seed for random number generation in Uuid expressions. + */ + object ResolvedUuidExpressions extends Rule[LogicalPlan] { + private lazy val random = new Random() + + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + case p if p.resolved => p + case p => p transformExpressionsUp { + case Uuid(None) => Uuid(Some(random.nextLong())) + } + } + } + /** * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the * null check. When user defines a UDF with primitive parameters, there is no way to tell if the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 38e4fe44b15ab..ec93620038cff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,6 +21,7 @@ import java.util.UUID import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -122,18 +123,33 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { 46707d92-02f4-4817-8116-a4c3b23e6266 """) // scalastyle:on line.size.limit -case class Uuid() extends LeafExpression { +case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Nondeterministic { - override lazy val deterministic: Boolean = false + def this() = this(None) + + override lazy val resolved: Boolean = randomSeed.isDefined override def nullable: Boolean = false override def dataType: DataType = StringType - override def eval(input: InternalRow): Any = UTF8String.fromString(UUID.randomUUID().toString) + @transient private[this] var randomGenerator: RandomUUIDGenerator = _ + + override protected def initializeInternal(partitionIndex: Int): Unit = + randomGenerator = RandomUUIDGenerator(randomSeed.get + partitionIndex) + + override protected def evalInternal(input: InternalRow): Any = + randomGenerator.getNextUUIDUTF8String() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.copy(code = s"final UTF8String ${ev.value} = " + - s"UTF8String.fromString(java.util.UUID.randomUUID().toString());", isNull = "false") + val randomGen = ctx.freshName("randomGen") + ctx.addMutableState("org.apache.spark.sql.catalyst.util.RandomUUIDGenerator", randomGen, + forceInline = true, + useFreshName = false) + ctx.addPartitionInitializationStatement(s"$randomGen = " + + "new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" + + s"${randomSeed.get}L + partitionIndex);") + ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", + isNull = "false") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala new file mode 100644 index 0000000000000..fe57c199b8744 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} + +/** + * Test suite for resolving Uuid expressions. + */ +class ResolvedUuidExpressionsSuite extends AnalysisTest { + + private lazy val a = 'a.int + private lazy val r = LocalRelation(a) + private lazy val uuid1 = Uuid().as('_uuid1) + private lazy val uuid2 = Uuid().as('_uuid2) + private lazy val uuid3 = Uuid().as('_uuid3) + private lazy val uuid1Ref = uuid1.toAttribute + + private val analyzer = getAnalyzer(caseSensitive = true) + + private def getUuidExpressions(plan: LogicalPlan): Seq[Uuid] = { + plan.flatMap { + case p => + p.expressions.flatMap(_.collect { + case u: Uuid => u + }) + } + } + + test("analyzed plan sets random seed for Uuid expression") { + val plan = r.select(a, uuid1) + val resolvedPlan = analyzer.executeAndCheck(plan) + getUuidExpressions(resolvedPlan).foreach { u => + assert(u.resolved) + assert(u.randomSeed.isDefined) + } + } + + test("Uuid expressions should have different random seeds") { + val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3) + val resolvedPlan = analyzer.executeAndCheck(plan) + assert(getUuidExpressions(resolvedPlan).map(_.randomSeed.get).distinct.length == 3) + } + + test("Different analyzed plans should have different random seeds in Uuids") { + val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3) + val resolvedPlan1 = analyzer.executeAndCheck(plan) + val resolvedPlan2 = analyzer.executeAndCheck(plan) + val uuids1 = getUuidExpressions(resolvedPlan1) + val uuids2 = getUuidExpressions(resolvedPlan2) + assert(uuids1.distinct.length == 3) + assert(uuids2.distinct.length == 3) + assert(uuids1.intersect(uuids2).length == 0) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index c6343b1cbf600..3828f172a15cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -176,7 +176,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } - private def evaluateWithGeneratedMutableProjection( + protected def evaluateWithGeneratedMutableProjection( expression: Expression, inputRow: InternalRow = EmptyRow): Any = { val plan = generateProject( @@ -220,7 +220,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } - private def evaluateWithUnsafeProjection( + protected def evaluateWithUnsafeProjection( expression: Expression, inputRow: InternalRow = EmptyRow, factory: UnsafeProjectionCreator = UnsafeProjection): InternalRow = { @@ -233,6 +233,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { Alias(expression, s"Optimized($expression)2")() :: Nil), expression) + plan.initialize(0) plan(inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index c3d08bf68c7bb..3383d421f5616 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import java.io.PrintStream +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ @@ -42,8 +44,21 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("uuid") { - checkEvaluation(Length(Uuid()), 36) - assert(evaluateWithoutCodegen(Uuid()) !== evaluateWithoutCodegen(Uuid())) + checkEvaluation(Length(Uuid(Some(0))), 36) + val r = new Random() + val seed1 = Some(r.nextLong()) + assert(evaluateWithoutCodegen(Uuid(seed1)) === evaluateWithoutCodegen(Uuid(seed1))) + assert(evaluateWithGeneratedMutableProjection(Uuid(seed1)) === + evaluateWithGeneratedMutableProjection(Uuid(seed1))) + assert(evaluateWithUnsafeProjection(Uuid(seed1)) === + evaluateWithUnsafeProjection(Uuid(seed1))) + + val seed2 = Some(r.nextLong()) + assert(evaluateWithoutCodegen(Uuid(seed1)) !== evaluateWithoutCodegen(Uuid(seed2))) + assert(evaluateWithGeneratedMutableProjection(Uuid(seed1)) !== + evaluateWithGeneratedMutableProjection(Uuid(seed2))) + assert(evaluateWithUnsafeProjection(Uuid(seed1)) !== + evaluateWithUnsafeProjection(Uuid(seed2))) } test("PrintToStderr") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8b66f77b2f923..f7b3393f65cb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.Uuid import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -2264,4 +2265,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(0, 10) :: Nil) assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) } + + test("Uuid expressions should produce same results at retries in the same DataFrame") { + val df = spark.range(1).select($"id", new Column(Uuid())) + checkAnswer(df, df.collect()) + } } From a649fcf32a7e610da2a2b4e3d94f5d1372c825d6 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 22 Mar 2018 21:20:41 -0700 Subject: [PATCH 510/774] [MINOR][PYTHON] Remove unused codes in schema parsing logics of PySpark ## What changes were proposed in this pull request? This PR proposes to remove out unused codes, `_ignore_brackets_split` and `_BRACKETS`. `_ignore_brackets_split` was introduced in https://github.com/apache/spark/commit/d57daf1f7732a7ac54a91fe112deeda0a254f9ef to refactor and support `toDF("...")`; however, https://github.com/apache/spark/commit/ebc124d4c44d4c84f7868f390f778c0ff5cd66cb replaced the logics here. Seems `_ignore_brackets_split` is not referred anymore. `_BRACKETS` was introduced in https://github.com/apache/spark/commit/880eabec37c69ce4e9594d7babfac291b0f93f50; however, all other usages were removed out in https://github.com/apache/spark/commit/648a8626b82d27d84db3e48bccfd73d020828586. This is rather a followup for https://github.com/apache/spark/commit/ebc124d4c44d4c84f7868f390f778c0ff5cd66cb which I missed in that PR. ## How was this patch tested? Manually tested. Existing tests should cover this. I also double checked by `grep` in the whole repo. Author: hyukjinkwon Closes #20878 from HyukjinKwon/minor-remove-unused. --- python/pyspark/sql/types.py | 35 ----------------------------------- 1 file changed, 35 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 826aab97e58db..5d5919e451b46 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -752,41 +752,6 @@ def __eq__(self, other): _FIXED_DECIMAL = re.compile("decimal\\(\\s*(\\d+)\\s*,\\s*(\\d+)\\s*\\)") -_BRACKETS = {'(': ')', '[': ']', '{': '}'} - - -def _ignore_brackets_split(s, separator): - """ - Splits the given string by given separator, but ignore separators inside brackets pairs, e.g. - given "a,b" and separator ",", it will return ["a", "b"], but given "a, d", it will return - ["a", "d"]. - """ - parts = [] - buf = "" - level = 0 - for c in s: - if c in _BRACKETS.keys(): - level += 1 - buf += c - elif c in _BRACKETS.values(): - if level == 0: - raise ValueError("Brackets are not correctly paired: %s" % s) - level -= 1 - buf += c - elif c == separator and level > 0: - buf += c - elif c == separator: - parts.append(buf) - buf = "" - else: - buf += c - - if len(buf) == 0: - raise ValueError("The %s cannot be the last char: %s" % (separator, s)) - parts.append(buf) - return parts - - def _parse_datatype_string(s): """ Parses the given data type string to a :class:`DataType`. The data type string format equals From b2edc30db1dcc6102687d20c158a2700965fdf51 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 22 Mar 2018 21:23:25 -0700 Subject: [PATCH 511/774] [SPARK-23614][SQL] Fix incorrect reuse exchange when caching is used ## What changes were proposed in this pull request? We should provide customized canonicalize plan for `InMemoryRelation` and `InMemoryTableScanExec`. Otherwise, we can wrongly treat two different cached plans as same result. It causes wrongly reused exchange then. For a test query like this: ```scala val cached = spark.createDataset(Seq(TestDataUnion(1, 2, 3), TestDataUnion(4, 5, 6))).cache() val group1 = cached.groupBy("x").agg(min(col("y")) as "value") val group2 = cached.groupBy("x").agg(min(col("z")) as "value") group1.union(group2) ``` Canonicalized plans before: First exchange: ``` Exchange hashpartitioning(none#0, 5) +- *(1) HashAggregate(keys=[none#0], functions=[partial_min(none#1)], output=[none#0, none#4]) +- *(1) InMemoryTableScan [none#0, none#1] +- InMemoryRelation [x#4253, y#4254, z#4255], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas) +- LocalTableScan [x#4253, y#4254, z#4255] ``` Second exchange: ``` Exchange hashpartitioning(none#0, 5) +- *(3) HashAggregate(keys=[none#0], functions=[partial_min(none#1)], output=[none#0, none#4]) +- *(3) InMemoryTableScan [none#0, none#1] +- InMemoryRelation [x#4253, y#4254, z#4255], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas) +- LocalTableScan [x#4253, y#4254, z#4255] ``` You can find that they have the canonicalized plans are the same, although we use different columns in two `InMemoryTableScan`s. Canonicalized plan after: First exchange: ``` Exchange hashpartitioning(none#0, 5) +- *(1) HashAggregate(keys=[none#0], functions=[partial_min(none#1)], output=[none#0, none#4]) +- *(1) InMemoryTableScan [none#0, none#1] +- InMemoryRelation [none#0, none#1, none#2], true, 10000, StorageLevel(memory, 1 replicas) +- LocalTableScan [none#0, none#1, none#2] ``` Second exchange: ``` Exchange hashpartitioning(none#0, 5) +- *(3) HashAggregate(keys=[none#0], functions=[partial_min(none#1)], output=[none#0, none#4]) +- *(3) InMemoryTableScan [none#0, none#2] +- InMemoryRelation [none#0, none#1, none#2], true, 10000, StorageLevel(memory, 1 replicas) +- LocalTableScan [none#0, none#1, none#2] ``` ## How was this patch tested? Added unit test. Author: Liang-Chi Hsieh Closes #20831 from viirya/SPARK-23614. --- .../execution/columnar/InMemoryRelation.scala | 10 ++++++++++ .../columnar/InMemoryTableScanExec.scala | 19 +++++++++++++------ .../org/apache/spark/sql/DatasetSuite.scala | 9 +++++++++ .../spark/sql/execution/ExchangeSuite.scala | 7 +++++++ 4 files changed, 39 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 22e16913d4da9..2579046e30708 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics} import org.apache.spark.sql.execution.SparkPlan @@ -68,6 +69,15 @@ case class InMemoryRelation( override protected def innerChildren: Seq[SparkPlan] = Seq(child) + override def doCanonicalize(): logical.LogicalPlan = + copy(output = output.map(QueryPlan.normalizeExprId(_, child.output)), + storageLevel = StorageLevel.NONE, + child = child.canonicalized, + tableName = None)( + _cachedColumnBuffers, + sizeInBytesStats, + statsOfPlanToCache) + override def producedAttributes: AttributeSet = outputSet @transient val partitionStatistics = new PartitionStatistics(output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index a93e8a1ad954d..e73e1378d52e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} -import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -38,6 +38,11 @@ case class InMemoryTableScanExec( override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren + override def doCanonicalize(): SparkPlan = + copy(attributes = attributes.map(QueryPlan.normalizeExprId(_, relation.output)), + predicates = predicates.map(QueryPlan.normalizeExprId(_, relation.output)), + relation = relation.canonicalized.asInstanceOf[InMemoryRelation]) + override def vectorTypes: Option[Seq[String]] = Option(Seq.fill(attributes.length)( if (!conf.offHeapColumnVectorEnabled) { @@ -169,11 +174,13 @@ case class InMemoryTableScanExec( override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) - private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) + // Keeps relation's partition statistics because we don't serialize relation. + private val stats = relation.partitionStatistics + private def statsFor(a: Attribute) = stats.forAttribute(a) // Returned filter predicate should return false iff it is impossible for the input expression // to evaluate to `true' based on statistics collected about this partition batch. - @transient val buildFilter: PartialFunction[Expression, Expression] = { + @transient lazy val buildFilter: PartialFunction[Expression, Expression] = { case And(lhs: Expression, rhs: Expression) if buildFilter.isDefinedAt(lhs) || buildFilter.isDefinedAt(rhs) => (buildFilter.lift(lhs) ++ buildFilter.lift(rhs)).reduce(_ && _) @@ -213,14 +220,14 @@ case class InMemoryTableScanExec( l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) } - val partitionFilters: Seq[Expression] = { + lazy val partitionFilters: Seq[Expression] = { predicates.flatMap { p => val filter = buildFilter.lift(p) val boundFilter = filter.map( BindReferences.bindReference( _, - relation.partitionStatistics.schema, + stats.schema, allowFailures = true)) boundFilter.foreach(_ => @@ -243,7 +250,7 @@ case class InMemoryTableScanExec( private def filteredCachedBatches(): RDD[CachedBatch] = { // Using these variables here to avoid serialization of entire objects (if referenced directly) // within the map Partitions closure. - val schema = relation.partitionStatistics.schema + val schema = stats.schema val schemaIndex = schema.zipWithIndex val buffers = relation.cachedColumnBuffers diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 49c59cf695dc1..9b745befcb611 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1446,8 +1446,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val data = Seq(("a", null)) checkDataset(data.toDS(), data: _*) } + + test("SPARK-23614: Union produces incorrect results when caching is used") { + val cached = spark.createDataset(Seq(TestDataUnion(1, 2, 3), TestDataUnion(4, 5, 6))).cache() + val group1 = cached.groupBy("x").agg(min(col("y")) as "value") + val group2 = cached.groupBy("x").agg(min(col("z")) as "value") + checkAnswer(group1.union(group2), Row(4, 5) :: Row(1, 2) :: Row(4, 6) :: Row(1, 3) :: Nil) + } } +case class TestDataUnion(x: Int, y: Int, z: Int) + case class SingleData(id: Int) case class DoubleData(id: Int, val1: String) case class TripleData(id: Int, val1: String, val2: Long) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 697d7e6520713..bde2de5b39fd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -125,4 +125,11 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assertConsistency(spark.range(10000).map(i => Random.nextInt(1000).toLong)) } } + + test("SPARK-23614: Fix incorrect reuse exchange when caching is used") { + val cached = spark.createDataset(Seq((1, 2, 3), (4, 5, 6))).cache() + val projection1 = cached.select("_1", "_2").queryExecution.executedPlan + val projection2 = cached.select("_1", "_3").queryExecution.executedPlan + assert(!projection1.sameResult(projection2)) + } } From 5fa438471110afbf4e2174df449ac79e292501f8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 23 Mar 2018 13:59:21 +0800 Subject: [PATCH 512/774] [SPARK-23361][YARN] Allow AM to restart after initial tokens expire. Currently, the Spark AM relies on the initial set of tokens created by the submission client to be able to talk to HDFS and other services that require delegation tokens. This means that after those tokens expire, a new AM will fail to start (e.g. when there is an application failure and re-attempts are enabled). This PR makes it so that the first thing the AM does when the user provides a principal and keytab is to create new delegation tokens for use. This makes sure that the AM can be started irrespective of how old the original token set is. It also allows all of the token management to be done by the AM - there is no need for the submission client to set configuration values to tell the AM when to renew tokens. Note that even though in this case the AM will not be using the delegation tokens created by the submission client, those tokens still need to be provided to YARN, since they are used to do log aggregation. To be able to re-use the code in the AMCredentialRenewal for the above purposes, I refactored that class a bit so that it can fetch tokens into a pre-defined UGI, insted of always logging in. Another issue with re-attempts is that, after the fix that allows the AM to restart correctly, new executors would get confused about when to update credentials, because the credential updater used the update time initially set up by the submission code. This could make the executor fail to update credentials in time, since that value would be very out of date in the situation described in the bug. To fix that, I changed the YARN code to use the new RPC-based mechanism for distributing tokens to executors. This allowed the old credential updater code to be removed, and a lot of code in the renewer to be simplified. I also made two currently hardcoded values (the renewal time ratio, and the retry wait) configurable; while this probably never needs to be set by anyone in a production environment, it helps with testing; that's also why they're not documented. Tested on real cluster with a specially crafted application to test this functionality: checked proper access to HDFS, Hive and HBase in cluster mode with token renewal on and AM restarts. Tested things still work in client mode too. Author: Marcelo Vanzin Closes #20657 from vanzin/SPARK-23361. --- .../scala/org/apache/spark/SparkConf.scala | 12 +- .../apache/spark/deploy/SparkHadoopUtil.scala | 32 +- .../CoarseGrainedExecutorBackend.scala | 12 - .../spark/internal/config/package.scala | 12 + .../MesosHadoopDelegationTokenManager.scala | 11 +- .../spark/deploy/yarn/ApplicationMaster.scala | 117 +++---- .../org/apache/spark/deploy/yarn/Client.scala | 102 ++---- .../deploy/yarn/YarnSparkHadoopUtil.scala | 20 -- .../org/apache/spark/deploy/yarn/config.scala | 25 -- .../yarn/security/AMCredentialRenewer.scala | 291 +++++++----------- .../yarn/security/CredentialUpdater.scala | 131 -------- .../YARNHadoopDelegationTokenManager.scala | 9 +- .../cluster/YarnClientSchedulerBackend.scala | 9 +- .../cluster/YarnSchedulerBackend.scala | 10 +- ...ARNHadoopDelegationTokenManagerSuite.scala | 7 +- .../apache/spark/streaming/Checkpoint.scala | 3 - 16 files changed, 238 insertions(+), 565 deletions(-) delete mode 100644 resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index f53b2bed74c6e..129956e9f9ffa 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -603,13 +603,15 @@ private[spark] object SparkConf extends Logging { "Please use spark.kryoserializer.buffer instead. The default value for " + "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + "are no longer accepted. To specify the equivalent now, one may use '64k'."), - DeprecatedConfig("spark.rpc", "2.0", "Not used any more."), + DeprecatedConfig("spark.rpc", "2.0", "Not used anymore."), DeprecatedConfig("spark.scheduler.executorTaskBlacklistTime", "2.1.0", "Please use the new blacklisting options, spark.blacklist.*"), - DeprecatedConfig("spark.yarn.am.port", "2.0.0", "Not used any more"), - DeprecatedConfig("spark.executor.port", "2.0.0", "Not used any more"), + DeprecatedConfig("spark.yarn.am.port", "2.0.0", "Not used anymore"), + DeprecatedConfig("spark.executor.port", "2.0.0", "Not used anymore"), DeprecatedConfig("spark.shuffle.service.index.cache.entries", "2.3.0", - "Not used any more. Please use spark.shuffle.service.index.cache.size") + "Not used anymore. Please use spark.shuffle.service.index.cache.size"), + DeprecatedConfig("spark.yarn.credentials.file.retention.count", "2.4.0", "Not used anymore."), + DeprecatedConfig("spark.yarn.credentials.file.retention.days", "2.4.0", "Not used anymore.") ) Map(configs.map { cfg => (cfg.key -> cfg) } : _*) @@ -748,7 +750,7 @@ private[spark] object SparkConf extends Logging { } if (key.startsWith("spark.akka") || key.startsWith("spark.ssl.akka")) { logWarning( - s"The configuration key $key is not supported any more " + + s"The configuration key $key is not supported anymore " + s"because Spark doesn't use Akka since 2.0") } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 177295fb7af0f..8353e64a619cf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -40,6 +40,7 @@ import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdenti import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.util.Utils /** @@ -146,7 +147,8 @@ class SparkHadoopUtil extends Logging { private[spark] def addDelegationTokens(tokens: Array[Byte], sparkConf: SparkConf) { UserGroupInformation.setConfiguration(newConfiguration(sparkConf)) val creds = deserialize(tokens) - logInfo(s"Adding/updating delegation tokens ${dumpTokens(creds)}") + logInfo("Updating delegation tokens for current user.") + logDebug(s"Adding/updating delegation tokens ${dumpTokens(creds)}") addCurrentUserCredentials(creds) } @@ -321,19 +323,6 @@ class SparkHadoopUtil extends Logging { } } - /** - * Return a fresh Hadoop configuration, bypassing the HDFS cache mechanism. - * This is to prevent the DFSClient from using an old cached token to connect to the NameNode. - */ - private[spark] def getConfBypassingFSCache( - hadoopConf: Configuration, - scheme: String): Configuration = { - val newConf = new Configuration(hadoopConf) - val confKey = s"fs.${scheme}.impl.disable.cache" - newConf.setBoolean(confKey, true) - newConf - } - /** * Dump the credentials' tokens to string values. * @@ -447,16 +436,17 @@ object SparkHadoopUtil { def get: SparkHadoopUtil = instance /** - * Given an expiration date (e.g. for Hadoop Delegation Tokens) return a the date - * when a given fraction of the duration until the expiration date has passed. - * Formula: current time + (fraction * (time until expiration)) + * Given an expiration date for the current set of credentials, calculate the time when new + * credentials should be created. + * * @param expirationDate Drop-dead expiration date - * @param fraction fraction of the time until expiration return - * @return Date when the fraction of the time until expiration has passed + * @param conf Spark configuration + * @return Timestamp when new credentials should be created. */ - private[spark] def getDateOfNextUpdate(expirationDate: Long, fraction: Double): Long = { + private[spark] def nextCredentialRenewalTime(expirationDate: Long, conf: SparkConf): Long = { val ct = System.currentTimeMillis - (ct + (fraction * (expirationDate - ct))).toLong + val ratio = conf.get(CREDENTIALS_RENEWAL_INTERVAL_RATIO) + (ct + (ratio * (expirationDate - ct))).toLong } /** diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 9b62e4b1b7150..48d3630abd1f9 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -213,13 +213,6 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { driverConf.set(key, value) } } - if (driverConf.contains("spark.yarn.credentials.file")) { - logInfo("Will periodically update credentials from: " + - driverConf.get("spark.yarn.credentials.file")) - Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") - .getMethod("startCredentialUpdater", classOf[SparkConf]) - .invoke(null, driverConf) - } cfg.hadoopDelegationCreds.foreach { tokens => SparkHadoopUtil.get.addDelegationTokens(tokens, driverConf) @@ -234,11 +227,6 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } env.rpcEnv.awaitTermination() - if (driverConf.contains("spark.yarn.credentials.file")) { - Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") - .getMethod("stopCredentialUpdater") - .invoke(null) - } } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index a313ad0554a3a..407545aa4a47a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -525,4 +525,16 @@ package object config { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("1g") + private[spark] val CREDENTIALS_RENEWAL_INTERVAL_RATIO = + ConfigBuilder("spark.security.credentials.renewalRatio") + .doc("Ratio of the credential's expiration time when Spark should fetch new credentials.") + .doubleConf + .createWithDefault(0.75d) + + private[spark] val CREDENTIALS_RENEWAL_RETRY_WAIT = + ConfigBuilder("spark.security.credentials.retryWait") + .doc("How long to wait before retrying to fetch new credentials after a failure.") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("1h") + } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala index 7165bfae18a5e..a1bf4f0c048fe 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala @@ -29,6 +29,7 @@ import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.{config, Logging} import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.UpdateDelegationTokens +import org.apache.spark.ui.UIUtils import org.apache.spark.util.ThreadUtils @@ -63,7 +64,7 @@ private[spark] class MesosHadoopDelegationTokenManager( val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) val rt = tokenManager.obtainDelegationTokens(hadoopConf, creds) logInfo(s"Initialized tokens: ${SparkHadoopUtil.get.dumpTokens(creds)}") - (SparkHadoopUtil.get.serialize(creds), SparkHadoopUtil.getDateOfNextUpdate(rt, 0.75)) + (SparkHadoopUtil.get.serialize(creds), SparkHadoopUtil.nextCredentialRenewalTime(rt, conf)) } catch { case e: Exception => logError(s"Failed to fetch Hadoop delegation tokens $e") @@ -104,8 +105,10 @@ private[spark] class MesosHadoopDelegationTokenManager( } catch { case e: Exception => // Log the error and try to write new tokens back in an hour - logWarning("Couldn't broadcast tokens, trying again in an hour", e) - credentialRenewerThread.schedule(this, 1, TimeUnit.HOURS) + val delay = TimeUnit.SECONDS.toMillis(conf.get(config.CREDENTIALS_RENEWAL_RETRY_WAIT)) + logWarning( + s"Couldn't broadcast tokens, trying again in ${UIUtils.formatDuration(delay)}", e) + credentialRenewerThread.schedule(this, delay, TimeUnit.MILLISECONDS) return } scheduleRenewal(this) @@ -135,7 +138,7 @@ private[spark] class MesosHadoopDelegationTokenManager( "related configurations in the target services.") currTime } else { - SparkHadoopUtil.getDateOfNextUpdate(nextRenewalTime, 0.75) + SparkHadoopUtil.nextCredentialRenewalTime(nextRenewalTime, conf) } logInfo(s"Time of next renewal is in ${timeOfNextRenewal - System.currentTimeMillis()} ms") diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 6e35d23def6f0..d04989e138f83 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -29,7 +29,6 @@ import scala.concurrent.duration.Duration import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ @@ -41,7 +40,7 @@ import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.{AMCredentialRenewer, YARNHadoopDelegationTokenManager} +import org.apache.spark.deploy.yarn.security.AMCredentialRenewer import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.rpc._ @@ -79,42 +78,43 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends private val yarnConf = new YarnConfiguration(SparkHadoopUtil.newConfiguration(sparkConf)) - private val ugi = { - val original = UserGroupInformation.getCurrentUser() - - // If a principal and keytab were provided, log in to kerberos, and set up a thread to - // renew the kerberos ticket when needed. Because the UGI API does not expose the TTL - // of the TGT, use a configuration to define how often to check that a relogin is necessary. - // checkTGTAndReloginFromKeytab() is a no-op if the relogin is not yet needed. - val principal = sparkConf.get(PRINCIPAL).orNull - val keytab = sparkConf.get(KEYTAB).orNull - if (principal != null && keytab != null) { - UserGroupInformation.loginUserFromKeytab(principal, keytab) - - val renewer = new Thread() { - override def run(): Unit = Utils.tryLogNonFatalError { - while (true) { - TimeUnit.SECONDS.sleep(sparkConf.get(KERBEROS_RELOGIN_PERIOD)) - UserGroupInformation.getCurrentUser().checkTGTAndReloginFromKeytab() - } - } + private val userClassLoader = { + val classpath = Client.getUserClasspath(sparkConf) + val urls = classpath.map { entry => + new URL("file:" + new File(entry.getPath()).getAbsolutePath()) + } + + if (isClusterMode) { + if (Client.isUserClassPathFirst(sparkConf, isDriver = true)) { + new ChildFirstURLClassLoader(urls, Utils.getContextOrSparkClassLoader) + } else { + new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) } - renewer.setName("am-kerberos-renewer") - renewer.setDaemon(true) - renewer.start() - - // Transfer the original user's tokens to the new user, since that's needed to connect to - // YARN. It also copies over any delegation tokens that might have been created by the - // client, which will then be transferred over when starting executors (until new ones - // are created by the periodic task). - val newUser = UserGroupInformation.getCurrentUser() - SparkHadoopUtil.get.transferCredentials(original, newUser) - newUser } else { - SparkHadoopUtil.get.createSparkUser() + new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) } } + private val credentialRenewer: Option[AMCredentialRenewer] = sparkConf.get(KEYTAB).map { _ => + new AMCredentialRenewer(sparkConf, yarnConf) + } + + private val ugi = credentialRenewer match { + case Some(cr) => + // Set the context class loader so that the token renewer has access to jars distributed + // by the user. + val currentLoader = Thread.currentThread().getContextClassLoader() + Thread.currentThread().setContextClassLoader(userClassLoader) + try { + cr.start() + } finally { + Thread.currentThread().setContextClassLoader(currentLoader) + } + + case _ => + SparkHadoopUtil.get.createSparkUser() + } + private val client = doAsUser { new YarnRMClient() } // Default to twice the number of executors (twice the maximum number of executors if dynamic @@ -148,23 +148,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends // A flag to check whether user has initialized spark context @volatile private var registered = false - private val userClassLoader = { - val classpath = Client.getUserClasspath(sparkConf) - val urls = classpath.map { entry => - new URL("file:" + new File(entry.getPath()).getAbsolutePath()) - } - - if (isClusterMode) { - if (Client.isUserClassPathFirst(sparkConf, isDriver = true)) { - new ChildFirstURLClassLoader(urls, Utils.getContextOrSparkClassLoader) - } else { - new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) - } - } else { - new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) - } - } - // Lock for controlling the allocator (heartbeat) thread. private val allocatorLock = new Object() @@ -189,8 +172,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends // In cluster mode, used to tell the AM when the user's SparkContext has been initialized. private val sparkContextPromise = Promise[SparkContext]() - private var credentialRenewer: AMCredentialRenewer = _ - // Load the list of localized files set by the client. This is used when launching executors, // and is loaded here so that these configs don't pollute the Web UI's environment page in // cluster mode. @@ -316,31 +297,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } } - // If the credentials file config is present, we must periodically renew tokens. So create - // a new AMDelegationTokenRenewer - if (sparkConf.contains(CREDENTIALS_FILE_PATH)) { - // Start a short-lived thread for AMCredentialRenewer, the only purpose is to set the - // classloader so that main jar and secondary jars could be used by AMCredentialRenewer. - val credentialRenewerThread = new Thread { - setName("AMCredentialRenewerStarter") - setContextClassLoader(userClassLoader) - - override def run(): Unit = { - val credentialManager = new YARNHadoopDelegationTokenManager( - sparkConf, - yarnConf, - conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) - - val credentialRenewer = - new AMCredentialRenewer(sparkConf, yarnConf, credentialManager) - credentialRenewer.scheduleLoginFromKeytab() - } - } - - credentialRenewerThread.start() - credentialRenewerThread.join() - } - if (isClusterMode) { runDriver() } else { @@ -409,9 +365,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends logDebug("shutting down user thread") userClassThread.interrupt() } - if (!inShutdown && credentialRenewer != null) { - credentialRenewer.stop() - credentialRenewer = null + if (!inShutdown) { + credentialRenewer.foreach(_.stop()) } } } @@ -468,6 +423,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends securityMgr, localResources) + credentialRenewer.foreach(_.setDriverRef(driverRef)) + // Initialize the AM endpoint *after* the allocator has been initialized. This ensures // that when the driver sends an initial executor request (e.g. after an AM restart), // the allocator is ready to service requests. diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 28087dee831d1..5763c3dbc5a8a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -93,11 +93,21 @@ private[spark] class Client( private val distCacheMgr = new ClientDistributedCacheManager() - private var loginFromKeytab = false - private var principal: String = null - private var keytab: String = null - private var credentials: Credentials = null - private var amKeytabFileName: String = null + private val principal = sparkConf.get(PRINCIPAL).orNull + private val keytab = sparkConf.get(KEYTAB).orNull + private val loginFromKeytab = principal != null + private val amKeytabFileName: String = { + require((principal == null) == (keytab == null), + "Both principal and keytab must be defined, or neither.") + if (loginFromKeytab) { + logInfo(s"Kerberos credentials: principal = $principal, keytab = $keytab") + // Generate a file name that can be used for the keytab file, that does not conflict + // with any user file. + new File(keytab).getName() + "-" + UUID.randomUUID().toString + } else { + null + } + } private val launcherBackend = new LauncherBackend() { override protected def conf: SparkConf = sparkConf @@ -120,11 +130,6 @@ private[spark] class Client( private val appStagingBaseDir = sparkConf.get(STAGING_DIR).map { new Path(_) } .getOrElse(FileSystem.get(hadoopConf).getHomeDirectory()) - private val credentialManager = new YARNHadoopDelegationTokenManager( - sparkConf, - hadoopConf, - conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) - def reportLauncherState(state: SparkAppHandle.State): Unit = { launcherBackend.setState(state) } @@ -145,9 +150,6 @@ private[spark] class Client( var appId: ApplicationId = null try { launcherBackend.connect() - // Setup the credentials before doing anything else, - // so we have don't have issues at any point. - setupCredentials() yarnClient.init(hadoopConf) yarnClient.start() @@ -288,8 +290,26 @@ private[spark] class Client( appContext } - /** Set up security tokens for launching our ApplicationMaster container. */ + /** + * Set up security tokens for launching our ApplicationMaster container. + * + * This method will obtain delegation tokens from all the registered providers, and set them in + * the AM's launch context. + */ private def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = { + val credentials = UserGroupInformation.getCurrentUser().getCredentials() + val credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf) + credentialManager.obtainDelegationTokens(hadoopConf, credentials) + + // When using a proxy user, copy the delegation tokens to the user's credentials. Avoid + // that for regular users, since in those case the user already has access to the TGT, + // and adding delegation tokens could lead to expired or cancelled tokens being used + // later, as reported in SPARK-15754. + val currentUser = UserGroupInformation.getCurrentUser() + if (SparkHadoopUtil.get.isProxyUser(currentUser)) { + currentUser.addCredentials(credentials) + } + val dob = new DataOutputBuffer credentials.writeTokenStorageToStream(dob) amContainer.setTokens(ByteBuffer.wrap(dob.getData)) @@ -384,36 +404,6 @@ private[spark] class Client( // and add them as local resources to the application master. val fs = destDir.getFileSystem(hadoopConf) - // Merge credentials obtained from registered providers - val nearestTimeOfNextRenewal = credentialManager.obtainDelegationTokens(hadoopConf, credentials) - - if (credentials != null) { - // Add credentials to current user's UGI, so that following operations don't need to use the - // Kerberos tgt to get delegations again in the client side. - val currentUser = UserGroupInformation.getCurrentUser() - if (SparkHadoopUtil.get.isProxyUser(currentUser)) { - currentUser.addCredentials(credentials) - } - logDebug(SparkHadoopUtil.get.dumpTokens(credentials).mkString("\n")) - } - - // If we use principal and keytab to login, also credentials can be renewed some time - // after current time, we should pass the next renewal and updating time to credential - // renewer and updater. - if (loginFromKeytab && nearestTimeOfNextRenewal > System.currentTimeMillis() && - nearestTimeOfNextRenewal != Long.MaxValue) { - - // Valid renewal time is 75% of next renewal time, and the valid update time will be - // slightly later then renewal time (80% of next renewal time). This is to make sure - // credentials are renewed and updated before expired. - val currTime = System.currentTimeMillis() - val renewalTime = (nearestTimeOfNextRenewal - currTime) * 0.75 + currTime - val updateTime = (nearestTimeOfNextRenewal - currTime) * 0.8 + currTime - - sparkConf.set(CREDENTIALS_RENEWAL_TIME, renewalTime.toLong) - sparkConf.set(CREDENTIALS_UPDATE_TIME, updateTime.toLong) - } - // Used to keep track of URIs added to the distributed cache. If the same URI is added // multiple times, YARN will fail to launch containers for the app with an internal // error. @@ -793,11 +783,6 @@ private[spark] class Client( populateClasspath(args, hadoopConf, sparkConf, env, sparkConf.get(DRIVER_CLASS_PATH)) env("SPARK_YARN_STAGING_DIR") = stagingDirPath.toString env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() - if (loginFromKeytab) { - val credentialsFile = "credentials-" + UUID.randomUUID().toString - sparkConf.set(CREDENTIALS_FILE_PATH, new Path(stagingDirPath, credentialsFile).toString) - logInfo(s"Credentials file set to: $credentialsFile") - } // Pick up any environment variables for the AM provided through spark.yarn.appMasterEnv.* val amEnvPrefix = "spark.yarn.appMasterEnv." @@ -1014,25 +999,6 @@ private[spark] class Client( amContainer } - def setupCredentials(): Unit = { - loginFromKeytab = sparkConf.contains(PRINCIPAL.key) - if (loginFromKeytab) { - principal = sparkConf.get(PRINCIPAL).get - keytab = sparkConf.get(KEYTAB).orNull - - require(keytab != null, "Keytab must be specified when principal is specified.") - logInfo("Attempting to login to the Kerberos" + - s" using principal: $principal and keytab: $keytab") - val f = new File(keytab) - // Generate a file name that can be used for the keytab file, that does not conflict - // with any user file. - amKeytabFileName = f.getName + "-" + UUID.randomUUID().toString - sparkConf.set(PRINCIPAL.key, principal) - } - // Defensive copy of the credentials - credentials = new Credentials(UserGroupInformation.getCurrentUser.getCredentials) - } - /** * Report the state of an application until it has exited, either successfully or * due to some failure, then return a pair of the yarn application state (FINISHED, FAILED, diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index f406fabd61860..8eda6cb1277c5 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -30,7 +30,6 @@ import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.CredentialUpdater import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager import org.apache.spark.internal.config._ import org.apache.spark.launcher.YarnCommandBuilderUtils @@ -38,8 +37,6 @@ import org.apache.spark.util.Utils object YarnSparkHadoopUtil { - private var credentialUpdater: CredentialUpdater = _ - // Additional memory overhead // 10% was arrived at experimentally. In the interest of minimizing memory waste while covering // the common cases. Memory overhead tends to grow with container size. @@ -206,21 +203,4 @@ object YarnSparkHadoopUtil { filesystemsToAccess + stagingFS } - def startCredentialUpdater(sparkConf: SparkConf): Unit = { - val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) - val credentialManager = new YARNHadoopDelegationTokenManager( - sparkConf, - hadoopConf, - conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) - credentialUpdater = new CredentialUpdater(sparkConf, hadoopConf, credentialManager) - credentialUpdater.start() - } - - def stopCredentialUpdater(): Unit = { - if (credentialUpdater != null) { - credentialUpdater.stop() - credentialUpdater = null - } - } - } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 3ba3ae5ab4401..1a99b3bd57672 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -231,16 +231,6 @@ package object config { /* Security configuration. */ - private[spark] val CREDENTIAL_FILE_MAX_COUNT = - ConfigBuilder("spark.yarn.credentials.file.retention.count") - .intConf - .createWithDefault(5) - - private[spark] val CREDENTIALS_FILE_MAX_RETENTION = - ConfigBuilder("spark.yarn.credentials.file.retention.days") - .intConf - .createWithDefault(5) - private[spark] val NAMENODES_TO_ACCESS = ConfigBuilder("spark.yarn.access.namenodes") .doc("Extra NameNode URLs for which to request delegation tokens. The NameNode that hosts " + "fs.defaultFS does not need to be listed here.") @@ -271,11 +261,6 @@ package object config { /* Private configs. */ - private[spark] val CREDENTIALS_FILE_PATH = ConfigBuilder("spark.yarn.credentials.file") - .internal() - .stringConf - .createWithDefault(null) - // Internal config to propagate the location of the user's jar to the driver/executors private[spark] val APP_JAR = ConfigBuilder("spark.yarn.user.jar") .internal() @@ -329,16 +314,6 @@ package object config { .stringConf .createOptional - private[spark] val CREDENTIALS_RENEWAL_TIME = ConfigBuilder("spark.yarn.credentials.renewalTime") - .internal() - .timeConf(TimeUnit.MILLISECONDS) - .createWithDefault(Long.MaxValue) - - private[spark] val CREDENTIALS_UPDATE_TIME = ConfigBuilder("spark.yarn.credentials.updateTime") - .internal() - .timeConf(TimeUnit.MILLISECONDS) - .createWithDefault(Long.MaxValue) - private[spark] val KERBEROS_RELOGIN_PERIOD = ConfigBuilder("spark.yarn.kerberos.relogin.period") .timeConf(TimeUnit.SECONDS) .createWithDefaultString("1m") diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala index eaf2cff111a49..bc8d47dbd54c6 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala @@ -18,221 +18,160 @@ package org.apache.spark.deploy.yarn.security import java.security.PrivilegedExceptionAction import java.util.concurrent.{ScheduledExecutorService, TimeUnit} +import java.util.concurrent.atomic.AtomicReference import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.security.HadoopDelegationTokenManager -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.UpdateDelegationTokens +import org.apache.spark.ui.UIUtils import org.apache.spark.util.ThreadUtils /** - * The following methods are primarily meant to make sure long-running apps like Spark - * Streaming apps can run without interruption while accessing secured services. The - * scheduleLoginFromKeytab method is called on the AM to get the new credentials. - * This method wakes up a thread that logs into the KDC - * once 75% of the renewal interval of the original credentials used for the container - * has elapsed. It then obtains new credentials and writes them to HDFS in a - * pre-specified location - the prefix of which is specified in the sparkConf by - * spark.yarn.credentials.file (so the file(s) would be named c-timestamp1-1, c-timestamp2-2 etc. - * - each update goes to a new file, with a monotonically increasing suffix), also the - * timestamp1, timestamp2 here indicates the time of next update for CredentialUpdater. - * After this, the credentials are renewed once 75% of the new tokens renewal interval has elapsed. + * A manager tasked with periodically updating delegation tokens needed by the application. * - * On the executor and driver (yarn client mode) side, the updateCredentialsIfRequired method is - * called once 80% of the validity of the original credentials has elapsed. At that time the - * executor finds the credentials file with the latest timestamp and checks if it has read those - * credentials before (by keeping track of the suffix of the last file it read). If a new file has - * appeared, it will read the credentials and update the currently running UGI with it. This - * process happens again once 80% of the validity of this has expired. + * This manager is meant to make sure long-running apps (such as Spark Streaming apps) can run + * without interruption while accessing secured services. It periodically logs in to the KDC with + * user-provided credentials, and contacts all the configured secure services to obtain delegation + * tokens to be distributed to the rest of the application. + * + * This class will manage the kerberos login, by renewing the TGT when needed. Because the UGI API + * does not expose the TTL of the TGT, a configuration controls how often to check that a relogin is + * necessary. This is done reasonably often since the check is a no-op when the relogin is not yet + * needed. The check period can be overridden in the configuration. + * + * New delegation tokens are created once 75% of the renewal interval of the original tokens has + * elapsed. The new tokens are sent to the Spark driver endpoint once it's registered with the AM. + * The driver is tasked with distributing the tokens to other processes that might need them. */ private[yarn] class AMCredentialRenewer( sparkConf: SparkConf, - hadoopConf: Configuration, - credentialManager: YARNHadoopDelegationTokenManager) extends Logging { + hadoopConf: Configuration) extends Logging { - private var lastCredentialsFileSuffix = 0 + private val principal = sparkConf.get(PRINCIPAL).get + private val keytab = sparkConf.get(KEYTAB).get + private val credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf) - private val credentialRenewerThread: ScheduledExecutorService = + private val renewalExecutor: ScheduledExecutorService = ThreadUtils.newDaemonSingleThreadScheduledExecutor("Credential Refresh Thread") - private val hadoopUtil = SparkHadoopUtil.get + private val driverRef = new AtomicReference[RpcEndpointRef]() - private val credentialsFile = sparkConf.get(CREDENTIALS_FILE_PATH) - private val daysToKeepFiles = sparkConf.get(CREDENTIALS_FILE_MAX_RETENTION) - private val numFilesToKeep = sparkConf.get(CREDENTIAL_FILE_MAX_COUNT) - private val freshHadoopConf = - hadoopUtil.getConfBypassingFSCache(hadoopConf, new Path(credentialsFile).toUri.getScheme) + private val renewalTask = new Runnable() { + override def run(): Unit = { + updateTokensTask() + } + } - @volatile private var timeOfNextRenewal: Long = sparkConf.get(CREDENTIALS_RENEWAL_TIME) + def setDriverRef(ref: RpcEndpointRef): Unit = { + driverRef.set(ref) + } /** - * Schedule a login from the keytab and principal set using the --principal and --keytab - * arguments to spark-submit. This login happens only when the credentials of the current user - * are about to expire. This method reads spark.yarn.principal and spark.yarn.keytab from - * SparkConf to do the login. This method is a no-op in non-YARN mode. + * Start the token renewer. Upon start, the renewer will: * + * - log in the configured user, and set up a task to keep that user's ticket renewed + * - obtain delegation tokens from all available providers + * - schedule a periodic task to update the tokens when needed. + * + * @return The newly logged in user. */ - private[spark] def scheduleLoginFromKeytab(): Unit = { - val principal = sparkConf.get(PRINCIPAL).get - val keytab = sparkConf.get(KEYTAB).get - - /** - * Schedule re-login and creation of new credentials. If credentials have already expired, this - * method will synchronously create new ones. - */ - def scheduleRenewal(runnable: Runnable): Unit = { - // Run now! - val remainingTime = timeOfNextRenewal - System.currentTimeMillis() - if (remainingTime <= 0) { - logInfo("Credentials have expired, creating new ones now.") - runnable.run() - } else { - logInfo(s"Scheduling login from keytab in $remainingTime millis.") - credentialRenewerThread.schedule(runnable, remainingTime, TimeUnit.MILLISECONDS) + def start(): UserGroupInformation = { + val originalCreds = UserGroupInformation.getCurrentUser().getCredentials() + val ugi = doLogin() + + val tgtRenewalTask = new Runnable() { + override def run(): Unit = { + ugi.checkTGTAndReloginFromKeytab() } } + val tgtRenewalPeriod = sparkConf.get(KERBEROS_RELOGIN_PERIOD) + renewalExecutor.scheduleAtFixedRate(tgtRenewalTask, tgtRenewalPeriod, tgtRenewalPeriod, + TimeUnit.SECONDS) - // This thread periodically runs on the AM to update the credentials on HDFS. - val credentialRenewerRunnable = - new Runnable { - override def run(): Unit = { - try { - writeNewCredentialsToHDFS(principal, keytab) - cleanupOldFiles() - } catch { - case e: Exception => - // Log the error and try to write new tokens back in an hour - logWarning("Failed to write out new credentials to HDFS, will try again in an " + - "hour! If this happens too often tasks will fail.", e) - credentialRenewerThread.schedule(this, 1, TimeUnit.HOURS) - return - } - scheduleRenewal(this) - } - } - // Schedule update of credentials. This handles the case of updating the credentials right now - // as well, since the renewal interval will be 0, and the thread will get scheduled - // immediately. - scheduleRenewal(credentialRenewerRunnable) + val creds = obtainTokensAndScheduleRenewal(ugi) + ugi.addCredentials(creds) + + // Transfer the original user's tokens to the new user, since that's needed to connect to + // YARN. Explicitly avoid overwriting tokens that already exist in the current user's + // credentials, since those were freshly obtained above (see SPARK-23361). + val existing = ugi.getCredentials() + existing.mergeAll(originalCreds) + ugi.addCredentials(existing) + + ugi + } + + def stop(): Unit = { + renewalExecutor.shutdown() + } + + private def scheduleRenewal(delay: Long): Unit = { + val _delay = math.max(0, delay) + logInfo(s"Scheduling login from keytab in ${UIUtils.formatDuration(delay)}.") + renewalExecutor.schedule(renewalTask, _delay, TimeUnit.MILLISECONDS) } - // Keeps only files that are newer than daysToKeepFiles days, and deletes everything else. At - // least numFilesToKeep files are kept for safety - private def cleanupOldFiles(): Unit = { - import scala.concurrent.duration._ + /** + * Periodic task to login to the KDC and create new delegation tokens. Re-schedules itself + * to fetch the next set of tokens when needed. + */ + private def updateTokensTask(): Unit = { try { - val remoteFs = FileSystem.get(freshHadoopConf) - val credentialsPath = new Path(credentialsFile) - val thresholdTime = System.currentTimeMillis() - (daysToKeepFiles.days).toMillis - hadoopUtil.listFilesSorted( - remoteFs, credentialsPath.getParent, - credentialsPath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) - .dropRight(numFilesToKeep) - .takeWhile(_.getModificationTime < thresholdTime) - .foreach(x => remoteFs.delete(x.getPath, true)) + val freshUGI = doLogin() + val creds = obtainTokensAndScheduleRenewal(freshUGI) + val tokens = SparkHadoopUtil.get.serialize(creds) + + val driver = driverRef.get() + if (driver != null) { + logInfo("Updating delegation tokens.") + driver.send(UpdateDelegationTokens(tokens)) + } else { + // This shouldn't really happen, since the driver should register way before tokens expire + // (or the AM should time out the application). + logWarning("Delegation tokens close to expiration but no driver has registered yet.") + SparkHadoopUtil.get.addDelegationTokens(tokens, sparkConf) + } } catch { - // Such errors are not fatal, so don't throw. Make sure they are logged though case e: Exception => - logWarning("Error while attempting to cleanup old credentials. If you are seeing many " + - "such warnings there may be an issue with your HDFS cluster.", e) + val delay = TimeUnit.SECONDS.toMillis(sparkConf.get(CREDENTIALS_RENEWAL_RETRY_WAIT)) + logWarning(s"Failed to update tokens, will try again in ${UIUtils.formatDuration(delay)}!" + + " If this happens too often tasks will fail.", e) + scheduleRenewal(delay) } } - private def writeNewCredentialsToHDFS(principal: String, keytab: String): Unit = { - // Keytab is copied by YARN to the working directory of the AM, so full path is - // not needed. - - // HACK: - // HDFS will not issue new delegation tokens, if the Credentials object - // passed in already has tokens for that FS even if the tokens are expired (it really only - // checks if there are tokens for the service, and not if they are valid). So the only real - // way to get new tokens is to make sure a different Credentials object is used each time to - // get new tokens and then the new tokens are copied over the current user's Credentials. - // So: - // - we login as a different user and get the UGI - // - use that UGI to get the tokens (see doAs block below) - // - copy the tokens over to the current user's credentials (this will overwrite the tokens - // in the current user's Credentials object for this FS). - // The login to KDC happens each time new tokens are required, but this is rare enough to not - // have to worry about (like once every day or so). This makes this code clearer than having - // to login and then relogin every time (the HDFS API may not relogin since we don't use this - // UGI directly for HDFS communication. - logInfo(s"Attempting to login to KDC using principal: $principal") - val keytabLoggedInUGI = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) - logInfo("Successfully logged into KDC.") - val tempCreds = keytabLoggedInUGI.getCredentials - val credentialsPath = new Path(credentialsFile) - val dst = credentialsPath.getParent - var nearestNextRenewalTime = Long.MaxValue - keytabLoggedInUGI.doAs(new PrivilegedExceptionAction[Void] { - // Get a copy of the credentials - override def run(): Void = { - nearestNextRenewalTime = credentialManager.obtainDelegationTokens( - freshHadoopConf, - tempCreds) - null + /** + * Obtain new delegation tokens from the available providers. Schedules a new task to fetch + * new tokens before the new set expires. + * + * @return Credentials containing the new tokens. + */ + private def obtainTokensAndScheduleRenewal(ugi: UserGroupInformation): Credentials = { + ugi.doAs(new PrivilegedExceptionAction[Credentials]() { + override def run(): Credentials = { + val creds = new Credentials() + val nextRenewal = credentialManager.obtainDelegationTokens(hadoopConf, creds) + + val timeToWait = SparkHadoopUtil.nextCredentialRenewalTime(nextRenewal, sparkConf) - + System.currentTimeMillis() + scheduleRenewal(timeToWait) + creds } }) - - val currTime = System.currentTimeMillis() - val timeOfNextUpdate = if (nearestNextRenewalTime <= currTime) { - // If next renewal time is earlier than current time, we set next renewal time to current - // time, this will trigger next renewal immediately. Also set next update time to current - // time. There still has a gap between token renewal and update will potentially introduce - // issue. - logWarning(s"Next credential renewal time ($nearestNextRenewalTime) is earlier than " + - s"current time ($currTime), which is unexpected, please check your credential renewal " + - "related configurations in the target services.") - timeOfNextRenewal = currTime - currTime - } else { - // Next valid renewal time is about 75% of credential renewal time, and update time is - // slightly later than valid renewal time (80% of renewal time). - timeOfNextRenewal = - SparkHadoopUtil.getDateOfNextUpdate(nearestNextRenewalTime, 0.75) - SparkHadoopUtil.getDateOfNextUpdate(nearestNextRenewalTime, 0.8) - } - - // Add the temp credentials back to the original ones. - UserGroupInformation.getCurrentUser.addCredentials(tempCreds) - val remoteFs = FileSystem.get(freshHadoopConf) - // If lastCredentialsFileSuffix is 0, then the AM is either started or restarted. If the AM - // was restarted, then the lastCredentialsFileSuffix might be > 0, so find the newest file - // and update the lastCredentialsFileSuffix. - if (lastCredentialsFileSuffix == 0) { - hadoopUtil.listFilesSorted( - remoteFs, credentialsPath.getParent, - credentialsPath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) - .lastOption.foreach { status => - lastCredentialsFileSuffix = hadoopUtil.getSuffixForCredentialsPath(status.getPath) - } - } - val nextSuffix = lastCredentialsFileSuffix + 1 - - val tokenPathStr = - credentialsFile + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + - timeOfNextUpdate.toLong.toString + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + - nextSuffix - val tokenPath = new Path(tokenPathStr) - val tempTokenPath = new Path(tokenPathStr + SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) - - logInfo("Writing out delegation tokens to " + tempTokenPath.toString) - val credentials = UserGroupInformation.getCurrentUser.getCredentials - credentials.writeTokenStorageFile(tempTokenPath, freshHadoopConf) - logInfo(s"Delegation Tokens written out successfully. Renaming file to $tokenPathStr") - remoteFs.rename(tempTokenPath, tokenPath) - logInfo("Delegation token file rename complete.") - lastCredentialsFileSuffix = nextSuffix } - def stop(): Unit = { - credentialRenewerThread.shutdown() + private def doLogin(): UserGroupInformation = { + logInfo(s"Attempting to login to KDC using principal: $principal") + val ugi = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) + logInfo("Successfully logged into KDC.") + ugi } + } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala deleted file mode 100644 index fe173dffc22a8..0000000000000 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import java.util.concurrent.{Executors, TimeUnit} - -import scala.util.control.NonFatal - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.security.{Credentials, UserGroupInformation} - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.internal.Logging -import org.apache.spark.util.{ThreadUtils, Utils} - -private[spark] class CredentialUpdater( - sparkConf: SparkConf, - hadoopConf: Configuration, - credentialManager: YARNHadoopDelegationTokenManager) extends Logging { - - @volatile private var lastCredentialsFileSuffix = 0 - - private val credentialsFile = sparkConf.get(CREDENTIALS_FILE_PATH) - private val freshHadoopConf = - SparkHadoopUtil.get.getConfBypassingFSCache( - hadoopConf, new Path(credentialsFile).toUri.getScheme) - - private val credentialUpdater = - Executors.newSingleThreadScheduledExecutor( - ThreadUtils.namedThreadFactory("Credential Refresh Thread")) - - // This thread wakes up and picks up new credentials from HDFS, if any. - private val credentialUpdaterRunnable = - new Runnable { - override def run(): Unit = Utils.logUncaughtExceptions(updateCredentialsIfRequired()) - } - - /** Start the credential updater task */ - def start(): Unit = { - val startTime = sparkConf.get(CREDENTIALS_UPDATE_TIME) - val remainingTime = startTime - System.currentTimeMillis() - if (remainingTime <= 0) { - credentialUpdater.schedule(credentialUpdaterRunnable, 1, TimeUnit.MINUTES) - } else { - logInfo(s"Scheduling credentials refresh from HDFS in $remainingTime ms.") - credentialUpdater.schedule(credentialUpdaterRunnable, remainingTime, TimeUnit.MILLISECONDS) - } - } - - private def updateCredentialsIfRequired(): Unit = { - val timeToNextUpdate = try { - val credentialsFilePath = new Path(credentialsFile) - val remoteFs = FileSystem.get(freshHadoopConf) - SparkHadoopUtil.get.listFilesSorted( - remoteFs, credentialsFilePath.getParent, - credentialsFilePath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) - .lastOption.map { credentialsStatus => - val suffix = SparkHadoopUtil.get.getSuffixForCredentialsPath(credentialsStatus.getPath) - if (suffix > lastCredentialsFileSuffix) { - logInfo("Reading new credentials from " + credentialsStatus.getPath) - val newCredentials = getCredentialsFromHDFSFile(remoteFs, credentialsStatus.getPath) - lastCredentialsFileSuffix = suffix - UserGroupInformation.getCurrentUser.addCredentials(newCredentials) - logInfo("Credentials updated from credentials file.") - - val remainingTime = (getTimeOfNextUpdateFromFileName(credentialsStatus.getPath) - - System.currentTimeMillis()) - if (remainingTime <= 0) TimeUnit.MINUTES.toMillis(1) else remainingTime - } else { - // If current credential file is older than expected, sleep 1 hour and check again. - TimeUnit.HOURS.toMillis(1) - } - }.getOrElse { - // Wait for 1 minute to check again if there's no credential file currently - TimeUnit.MINUTES.toMillis(1) - } - } catch { - // Since the file may get deleted while we are reading it, catch the Exception and come - // back in an hour to try again - case NonFatal(e) => - logWarning("Error while trying to update credentials, will try again in 1 hour", e) - TimeUnit.HOURS.toMillis(1) - } - - logInfo(s"Scheduling credentials refresh from HDFS in $timeToNextUpdate ms.") - credentialUpdater.schedule( - credentialUpdaterRunnable, timeToNextUpdate, TimeUnit.MILLISECONDS) - } - - private def getCredentialsFromHDFSFile(remoteFs: FileSystem, tokenPath: Path): Credentials = { - val stream = remoteFs.open(tokenPath) - try { - val newCredentials = new Credentials() - newCredentials.readTokenStorageStream(stream) - newCredentials - } finally { - stream.close() - } - } - - private def getTimeOfNextUpdateFromFileName(credentialsPath: Path): Long = { - val name = credentialsPath.getName - val index = name.lastIndexOf(SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM) - val slice = name.substring(0, index) - val last2index = slice.lastIndexOf(SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM) - name.substring(last2index + 1, index).toLong - } - - def stop(): Unit = { - credentialUpdater.shutdown() - } - -} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala index 163cfb4eb8624..d4eeb6bbcf886 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala @@ -22,11 +22,11 @@ import java.util.ServiceLoader import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.security.Credentials import org.apache.spark.SparkConf import org.apache.spark.deploy.security.HadoopDelegationTokenManager +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.util.Utils @@ -37,11 +37,10 @@ import org.apache.spark.util.Utils */ private[yarn] class YARNHadoopDelegationTokenManager( sparkConf: SparkConf, - hadoopConf: Configuration, - fileSystems: Configuration => Set[FileSystem]) extends Logging { + hadoopConf: Configuration) extends Logging { - private val delegationTokenManager = - new HadoopDelegationTokenManager(sparkConf, hadoopConf, fileSystems) + private val delegationTokenManager = new HadoopDelegationTokenManager(sparkConf, hadoopConf, + conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) // public for testing val credentialProviders = getCredentialProviders diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 0c6206eebe41d..06e54a2eaf95a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.yarn.api.records.YarnApplicationState import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil} +import org.apache.spark.deploy.yarn.{Client, ClientArguments} import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.launcher.SparkAppHandle @@ -62,12 +62,6 @@ private[spark] class YarnClientSchedulerBackend( super.start() waitForApplication() - // SPARK-8851: In yarn-client mode, the AM still does the credentials refresh. The driver - // reads the credentials from HDFS, just like the executors and updates its own credentials - // cache. - if (conf.contains("spark.yarn.credentials.file")) { - YarnSparkHadoopUtil.startCredentialUpdater(conf) - } monitorThread = asyncMonitorApplication() monitorThread.start() } @@ -153,7 +147,6 @@ private[spark] class YarnClientSchedulerBackend( client.reportLauncherState(SparkAppHandle.State.FINISHED) super.stop() - YarnSparkHadoopUtil.stopCredentialUpdater() client.stop() logInfo("Stopped") } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index bb615c36cd97f..63bea3e7a5003 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -24,9 +24,11 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.util.{Failure, Success} import scala.util.control.NonFatal +import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} import org.apache.spark.SparkContext +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.rpc._ import org.apache.spark.scheduler._ @@ -70,6 +72,7 @@ private[spark] abstract class YarnSchedulerBackend( /** Scheduler extension services. */ private val services: SchedulerExtensionServices = new SchedulerExtensionServices() + /** * Bind to YARN. This *must* be done before calling [[start()]]. * @@ -263,8 +266,13 @@ private[spark] abstract class YarnSchedulerBackend( logWarning(s"Requesting driver to remove executor $executorId for reason $reason") driverEndpoint.send(r) } - } + case u @ UpdateDelegationTokens(tokens) => + // Add the tokens to the current user and send a message to the scheduler so that it + // notifies all registered executors of the new tokens. + SparkHadoopUtil.get.addDelegationTokens(tokens, sc.conf) + driverEndpoint.send(u) + } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case r: RequestExecutors => diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala index 3c7cdc0f1dab8..9fa749b14c98c 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala @@ -22,7 +22,6 @@ import org.apache.hadoop.security.Credentials import org.scalatest.Matchers import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers { private var credentialManager: YARNHadoopDelegationTokenManager = null @@ -36,11 +35,7 @@ class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers } test("Correctly loads credential providers") { - credentialManager = new YARNHadoopDelegationTokenManager( - sparkConf, - hadoopConf, - conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) - + credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf) credentialManager.credentialProviders.get("yarn-test") should not be (None) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index aed67a5027433..3703a87cdb9ab 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -57,9 +57,6 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.yarn.jars", "spark.yarn.keytab", "spark.yarn.principal", - "spark.yarn.credentials.file", - "spark.yarn.credentials.renewalTime", - "spark.yarn.credentials.updateTime", "spark.ui.filters", "spark.mesos.driver.frameworkId") From 92e952557dbd8a170d66d615e25c6c6a8399dd43 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 23 Mar 2018 21:01:07 +0900 Subject: [PATCH 513/774] [MINOR][R] Fix R lint failure ## What changes were proposed in this pull request? The lint failure bugged me: ```R R/SQLContext.R:715:97: style: Trailing whitespace is superfluous. #' file-based streaming data source. \code{timeZone} to indicate a timezone to be used to ^ tests/fulltests/test_streaming.R:239:45: style: Commas should always have a space after. expect_equal(times[order(times$eventTime),][1, 2], 2) ^ lintr checks failed. ``` and I actually saw https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.6-ubuntu-test/500/console too. If I understood correctly, there is a try about moving to Unbuntu one. ## How was this patch tested? Manually tested by `./dev/lint-r`: ``` ... lintr checks passed. ``` Author: hyukjinkwon Closes #20879 from HyukjinKwon/minor-r-lint. --- R/pkg/R/SQLContext.R | 2 +- R/pkg/tests/fulltests/test_streaming.R | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index ebec0ce3d1920..429dd5d565492 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -712,7 +712,7 @@ read.jdbc <- function(url, tableName, #' @param schema The data schema defined in structType or a DDL-formatted string, this is #' required for file-based streaming data source #' @param ... additional external data source specific named options, for instance \code{path} for -#' file-based streaming data source. \code{timeZone} to indicate a timezone to be used to +#' file-based streaming data source. \code{timeZone} to indicate a timezone to be used to #' parse timestamps in the JSON/CSV data sources or partition values; If it isn't set, it #' uses the default value, session local timezone. #' @return SparkDataFrame diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R index a354d50c6b54e..bfb1a046490ec 100644 --- a/R/pkg/tests/fulltests/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -236,7 +236,7 @@ test_that("Watermark", { times <- collect(sql("SELECT * FROM times")) # looks like write timing can affect the first bucket; but it should be t - expect_equal(times[order(times$eventTime),][1, 2], 2) + expect_equal(times[order(times$eventTime), ][1, 2], 2) stopQuery(q) unlink(parquetPath) From 6ac4fba69290e1c7de2c0a5863f224981dedb919 Mon Sep 17 00:00:00 2001 From: arucard21 Date: Fri, 23 Mar 2018 21:02:34 +0900 Subject: [PATCH 514/774] [SPARK-23769][CORE] Remove comments that unnecessarily disable Scalastyle check ## What changes were proposed in this pull request? We re-enabled the Scalastyle checker on a line of code. It was previously disabled, but it does not violate any of the rules. So there's no reason to disable the Scalastyle checker here. ## How was this patch tested? We tested this by running `build/mvn scalastyle:check` after removing the comments that disable the checker. This check passed with no errors or warnings for Spark Core ``` [INFO] [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project Core 2.4.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ [INFO] [INFO] --- scalastyle-maven-plugin:1.0.0:check (default-cli) spark-core_2.11 --- Saving to outputFile=/spark/core/target/scalastyle-output.xml Processed 485 file(s) Found 0 errors Found 0 warnings Found 0 infos ``` We did not run all tests (with `dev/run-tests`) since this Scalastyle check seemed sufficient. ## Co-contributors: chialun-yeh Hrayo712 vpourquie Author: arucard21 Closes #20880 from arucard21/scalastyle_util. --- .../org/apache/spark/storage/BlockReplicationPolicy.scala | 4 +--- .../main/scala/org/apache/spark/util/CompletionIterator.scala | 2 -- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala index 353eac60df171..0bacc34cdfd90 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala @@ -54,10 +54,9 @@ trait BlockReplicationPolicy { } object BlockReplicationUtils { - // scalastyle:off line.size.limit /** * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while - * minimizing space usage. Please see + * minimizing space usage. Please see * here. * * @param n total number of indices @@ -65,7 +64,6 @@ object BlockReplicationUtils { * @param r random number generator * @return list of m random unique indices */ - // scalastyle:on line.size.limit private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = { val indices = (n - m + 1 to n).foldLeft(mutable.LinkedHashSet.empty[Int]) {case (set, i) => val t = r.nextInt(i) + 1 diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala index 31d230d0fec8e..21acaa95c5645 100644 --- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala @@ -22,9 +22,7 @@ package org.apache.spark.util * through all the elements. */ private[spark] -// scalastyle:off abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterator[A] { -// scalastyle:on private[this] var completed = false def next(): A = sub.next() From 8b56f16640fc4156aa7bd529c54469d27635b951 Mon Sep 17 00:00:00 2001 From: bag_of_tricks Date: Fri, 23 Mar 2018 10:36:23 -0700 Subject: [PATCH 515/774] [SPARK-23759][UI] Unable to bind Spark UI to specific host name / IP ## What changes were proposed in this pull request? Fixes SPARK-23759 by moving connector.start() after connector.setHost() Problem was created due connector.setHost(hostName) call was after connector.start() ## How was this patch tested? Patch was tested after build and deployment. This patch requires SPARK_LOCAL_IP environment variable to be set on spark-env.sh Author: bag_of_tricks Closes #20883 from felixalbani/SPARK-23759. --- core/src/main/scala/org/apache/spark/ui/JettyUtils.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 0adeb4058b6e4..0e8a6307de6a8 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -343,12 +343,13 @@ private[spark] object JettyUtils extends Logging { -1, connectionFactories: _*) connector.setPort(port) - connector.start() + connector.setHost(hostName) // Currently we only use "SelectChannelConnector" // Limit the max acceptor number to 8 so that we don't waste a lot of threads connector.setAcceptQueueSize(math.min(connector.getAcceptors, 8)) - connector.setHost(hostName) + + connector.start() // The number of selectors always equals to the number of acceptors minThreads += connector.getAcceptors * 2 From cb43bbe13606673349511829fd71d1f34fc39c45 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 23 Mar 2018 11:42:40 -0700 Subject: [PATCH 516/774] [SPARK-21685][PYTHON][ML] PySpark Params isSet state should not change after transform ## What changes were proposed in this pull request? Currently when a PySpark Model is transformed, default params that have not been explicitly set are then set on the Java side on the call to `wrapper._transfer_values_to_java`. This incorrectly changes the state of the Param as it should still be marked as a default value only. ## How was this patch tested? Added a new test to verify that when transferring Params to Java, default params have their state preserved. Author: Bryan Cutler Closes #18982 from BryanCutler/pyspark-ml-param-to-java-defaults-SPARK-21685. --- python/pyspark/ml/tests.py | 20 +++++++++++++++++++- python/pyspark/ml/wrapper.py | 13 ++++++++++--- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index fd45fd00b270b..080119959a4e8 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -369,7 +369,7 @@ def test_property(self): raise RuntimeError("Test property to raise error when invoked") -class ParamTests(PySparkTestCase): +class ParamTests(SparkSessionTestCase): def test_copy_new_parent(self): testParams = TestParams() @@ -514,6 +514,24 @@ def test_logistic_regression_check_thresholds(self): LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] ) + def test_preserve_set_state(self): + dataset = self.spark.createDataFrame([(0.5,)], ["data"]) + binarizer = Binarizer(inputCol="data") + self.assertFalse(binarizer.isSet("threshold")) + binarizer.transform(dataset) + binarizer._transfer_params_from_java() + self.assertFalse(binarizer.isSet("threshold"), + "Params not explicitly set should remain unset after transform") + + def test_default_params_transferred(self): + dataset = self.spark.createDataFrame([(0.5,)], ["data"]) + binarizer = Binarizer(inputCol="data") + # intentionally change the pyspark default, but don't set it + binarizer._defaultParamMap[binarizer.outputCol] = "my_default" + result = binarizer.transform(dataset).select("my_default").collect() + self.assertFalse(binarizer.isSet(binarizer.outputCol)) + self.assertEqual(result[0][0], 1.0) + @staticmethod def check_params(test_self, py_stage, check_params_exist=True): """ diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 5061f6434794a..d325633195ddb 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -118,11 +118,18 @@ def _transfer_params_to_java(self): """ Transforms the embedded params to the companion Java object. """ - paramMap = self.extractParamMap() + pair_defaults = [] for param in self.params: - if param in paramMap: - pair = self._make_java_param_pair(param, paramMap[param]) + if self.isSet(param): + pair = self._make_java_param_pair(param, self._paramMap[param]) self._java_obj.set(pair) + if self.hasDefault(param): + pair = self._make_java_param_pair(param, self._defaultParamMap[param]) + pair_defaults.append(pair) + if len(pair_defaults) > 0: + sc = SparkContext._active_spark_context + pair_defaults_seq = sc._jvm.PythonUtils.toSeq(pair_defaults) + self._java_obj.setDefault(pair_defaults_seq) def _transfer_param_map_to_java(self, pyParamMap): """ From 95c03cbd27cea2255d9d748f9a84a0a38e54594d Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 23 Mar 2018 11:56:17 -0700 Subject: [PATCH 517/774] [SPARK-23783][SPARK-11239][ML] Add PMML export to Spark ML pipelines ## What changes were proposed in this pull request? Adds PMML export support to Spark ML pipelines in the style of Spark's DataSource API to allow library authors to add their own model export formats. Includes a specific implementation for Spark ML linear regression PMML export. In addition to adding PMML to reach parity with our current MLlib implementation, this approach will allow other libraries & formats (like PFA) to implement and export models with a unified API. ## How was this patch tested? Basic unit test. Author: Holden Karau Author: Holden Karau Closes #19876 from holdenk/SPARK-11171-SPARK-11237-Add-PMML-export-for-ML-KMeans-r2. --- .../org.apache.spark.ml.util.MLFormatRegister | 2 + .../ml/regression/LinearRegression.scala | 70 ++++--- .../org/apache/spark/ml/util/ReadWrite.scala | 173 +++++++++++++++++- .../org.apache.spark.ml.util.MLFormatRegister | 3 + .../ml/regression/LinearRegressionSuite.scala | 27 ++- .../spark/ml/util/PMMLReadWriteTest.scala | 55 ++++++ .../org/apache/spark/ml/util/PMMLUtils.scala | 43 +++++ .../apache/spark/ml/util/ReadWriteSuite.scala | 132 +++++++++++++ 8 files changed, 474 insertions(+), 31 deletions(-) create mode 100644 mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister create mode 100644 mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister new file mode 100644 index 0000000000000..5e5484fd8784d --- /dev/null +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister @@ -0,0 +1,2 @@ +org.apache.spark.ml.regression.InternalLinearRegressionModelWriter +org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 92510154d500e..f67d9d831f327 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging -import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.{PipelineStage, PredictorParams} import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.linalg.BLAS._ @@ -39,10 +39,11 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.VectorImplicits._ +import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel} import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} import org.apache.spark.storage.StorageLevel @@ -643,7 +644,7 @@ class LinearRegressionModel private[ml] ( @Since("1.3.0") val intercept: Double, @Since("2.3.0") val scale: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams with MLWritable { + with LinearRegressionParams with GeneralMLWritable { private[ml] def this(uid: String, coefficients: Vector, intercept: Double) = this(uid, coefficients, intercept, 1.0) @@ -710,7 +711,7 @@ class LinearRegressionModel private[ml] ( } /** - * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance. + * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance. * * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. @@ -718,7 +719,50 @@ class LinearRegressionModel private[ml] ( * This also does not save the [[parent]] currently. */ @Since("1.6.0") - override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this) + override def write: GeneralMLWriter = new GeneralMLWriter(this) +} + +/** A writer for LinearRegression that handles the "internal" (or default) format */ +private class InternalLinearRegressionModelWriter + extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "internal" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + + private case class Data(intercept: Double, coefficients: Vector, scale: Double) + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val instance = stage.asInstanceOf[LinearRegressionModel] + val sc = sparkSession.sparkContext + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: intercept, coefficients, scale + val data = Data(instance.intercept, instance.coefficients, instance.scale) + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } +} + +/** A writer for LinearRegression that handles the "pmml" format */ +private class PMMLLinearRegressionModelWriter + extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "pmml" + + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + + private case class Data(intercept: Double, coefficients: Vector) + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val sc = sparkSession.sparkContext + // Construct the MLLib model which knows how to write to PMML. + val instance = stage.asInstanceOf[LinearRegressionModel] + val oldModel = new OldLinearRegressionModel(instance.coefficients, instance.intercept) + // Save PMML + oldModel.toPMML(sc, path) + } } @Since("1.6.0") @@ -730,22 +774,6 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { @Since("1.6.0") override def load(path: String): LinearRegressionModel = super.load(path) - /** [[MLWriter]] instance for [[LinearRegressionModel]] */ - private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) - extends MLWriter with Logging { - - private case class Data(intercept: Double, coefficients: Vector, scale: Double) - - override protected def saveImpl(path: String): Unit = { - // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) - // Save model data: intercept, coefficients, scale - val data = Data(instance.intercept, instance.coefficients, instance.scale) - val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) - } - } - private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] { /** Checked against metadata when loading model */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index a616907800969..7edcd498678cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -18,9 +18,11 @@ package org.apache.spark.ml.util import java.io.IOException -import java.util.Locale +import java.util.{Locale, ServiceLoader} +import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.util.{Failure, Success, Try} import org.apache.hadoop.fs.Path import org.json4s._ @@ -28,8 +30,8 @@ import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.SparkContext -import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml._ import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} @@ -86,7 +88,82 @@ private[util] sealed trait BaseReadWrite { } /** - * Abstract class for utility classes that can save ML instances. + * Abstract class to be implemented by objects that provide ML exportability. + * + * A new instance of this class will be instantiated each time a save call is made. + * + * Must have a valid zero argument constructor which will be called to instantiate. + * + * @since 2.4.0 + */ +@InterfaceStability.Unstable +@Since("2.4.0") +trait MLWriterFormat { + /** + * Function to write the provided pipeline stage out. + * + * @param path The path to write the result out to. + * @param session SparkSession associated with the write request. + * @param optionMap User provided options stored as strings. + * @param stage The pipeline stage to be saved. + */ + @Since("2.4.0") + def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String], + stage: PipelineStage): Unit +} + +/** + * ML export formats for should implement this trait so that users can specify a shortname rather + * than the fully qualified class name of the exporter. + * + * A new instance of this class will be instantiated each time a save call is made. + * + * @since 2.4.0 + */ +@InterfaceStability.Unstable +@Since("2.4.0") +trait MLFormatRegister extends MLWriterFormat { + /** + * The string that represents the format that this format provider uses. This is, along with + * stageName, is overridden by children to provide a nice alias for the writer. For example: + * + * {{{ + * override def format(): String = + * "pmml" + * }}} + * Indicates that this format is capable of saving a pmml model. + * + * Must have a valid zero argument constructor which will be called to instantiate. + * + * Format discovery is done using a ServiceLoader so make sure to list your format in + * META-INF/services. + * @since 2.4.0 + */ + @Since("2.4.0") + def format(): String + + /** + * The string that represents the stage type that this writer supports. This is, along with + * format, is overridden by children to provide a nice alias for the writer. For example: + * + * {{{ + * override def stageName(): String = + * "org.apache.spark.ml.regression.LinearRegressionModel" + * }}} + * Indicates that this format is capable of saving Spark's own PMML model. + * + * Format discovery is done using a ServiceLoader so make sure to list your format in + * META-INF/services. + * @since 2.4.0 + */ + @Since("2.4.0") + def stageName(): String + + private[ml] def shortName(): String = s"${format()}+${stageName()}" +} + +/** + * Abstract class for utility classes that can save ML instances in Spark's internal format. */ @Since("1.6.0") abstract class MLWriter extends BaseReadWrite with Logging { @@ -110,6 +187,15 @@ abstract class MLWriter extends BaseReadWrite with Logging { @Since("1.6.0") protected def saveImpl(path: String): Unit + /** + * Overwrites if the output path already exists. + */ + @Since("1.6.0") + def overwrite(): this.type = { + shouldOverwrite = true + this + } + /** * Map to store extra options for this writer. */ @@ -126,15 +212,73 @@ abstract class MLWriter extends BaseReadWrite with Logging { this } + // override for Java compatibility + @Since("1.6.0") + override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) + + // override for Java compatibility + @Since("1.6.0") + override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) +} + +/** + * A ML Writer which delegates based on the requested format. + */ +@InterfaceStability.Unstable +@Since("2.4.0") +class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { + private var source: String = "internal" + /** - * Overwrites if the output path already exists. + * Specifies the format of ML export (e.g. "pmml", "internal", or + * the fully qualified class name for export). */ - @Since("1.6.0") - def overwrite(): this.type = { - shouldOverwrite = true + @Since("2.4.0") + def format(source: String): this.type = { + this.source = source this } + /** + * Dispatches the save to the correct MLFormat. + */ + @Since("2.4.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + @throws[SparkException]("If multiple sources for a given short name format are found.") + override protected def saveImpl(path: String): Unit = { + val loader = Utils.getContextOrSparkClassLoader + val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], loader) + val stageName = stage.getClass.getName + val targetName = s"$source+$stageName" + val formats = serviceLoader.asScala.toList + val shortNames = formats.map(_.shortName()) + val writerCls = formats.filter(_.shortName().equalsIgnoreCase(targetName)) match { + // requested name did not match any given registered alias + case Nil => + Try(loader.loadClass(source)) match { + case Success(writer) => + // Found the ML writer using the fully qualified path + writer + case Failure(error) => + throw new SparkException( + s"Could not load requested format $source for $stageName ($targetName) had $formats" + + s"supporting $shortNames", error) + } + case head :: Nil => + head.getClass + case _ => + // Multiple sources + throw new SparkException( + s"Multiple writers found for $source+$stageName, try using the class name of the writer") + } + if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) { + val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat] + writer.write(path, sparkSession, optionMap, stage) + } else { + throw new SparkException(s"ML source $source is not a valid MLWriterFormat") + } + } + // override for Java compatibility override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) @@ -162,6 +306,19 @@ trait MLWritable { def save(path: String): Unit = write.save(path) } +/** + * Trait for classes that provide `GeneralMLWriter`. + */ +@Since("2.4.0") +@InterfaceStability.Unstable +trait GeneralMLWritable extends MLWritable { + /** + * Returns an `MLWriter` instance for this ML instance. + */ + @Since("2.4.0") + override def write: GeneralMLWriter +} + /** * :: DeveloperApi :: * diff --git a/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister new file mode 100644 index 0000000000000..100ef2545418f --- /dev/null +++ b/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister @@ -0,0 +1,3 @@ +org.apache.spark.ml.util.DuplicateLinearRegressionWriter1 +org.apache.spark.ml.util.DuplicateLinearRegressionWriter2 +org.apache.spark.ml.util.FakeLinearRegressionWriterWithName diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 9b19f63eba1bd..90ceb7dee38f7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -17,18 +17,23 @@ package org.apache.spark.ml.regression +import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.Random +import org.dmg.pmml.{OpType, PMML, RegressionModel => PMMLRegressionModel} + import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.ml.util._ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.sql.{DataFrame, Row} -class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { + +class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTest { import testImplicits._ @@ -1052,6 +1057,24 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { LinearRegressionSuite.allParamSettings, checkModelData) } + test("pmml export") { + val lr = new LinearRegression() + val model = lr.fit(datasetWithWeight) + def checkModel(pmml: PMML): Unit = { + val dd = pmml.getDataDictionary + assert(dd.getNumberOfFields === 3) + val fields = dd.getDataFields.asScala + assert(fields(0).getName().toString === "field_0") + assert(fields(0).getOpType() == OpType.CONTINUOUS) + val pmmlRegressionModel = pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel] + val pmmlPredictors = pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors + val pmmlWeights = pmmlPredictors.asScala.map(_.getCoefficient()).toList + assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3) + assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3) + } + testPMMLWrite(sc, model, checkModel) + } + test("should support all NumericType labels and weights, and not support other types") { for (solver <- Seq("auto", "l-bfgs", "normal")) { val lr = new LinearRegression().setMaxIter(1).setSolver(solver) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala new file mode 100644 index 0000000000000..d2c4832b12bac --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.{File, IOException} + +import org.dmg.pmml.PMML +import org.scalatest.Suite + +import org.apache.spark.SparkContext +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset + +trait PMMLReadWriteTest extends TempDirectory { self: Suite => + /** + * Test PMML export. Requires exported model is small enough to be loaded locally. + * Checks that the model can be exported and the result is valid PMML, but does not check + * the specific contents of the model. + */ + def testPMMLWrite[T <: Params with GeneralMLWritable](sc: SparkContext, instance: T, + checkModelData: PMML => Unit): Unit = { + val uid = instance.uid + val subdirName = Identifiable.randomUID("pmml-") + + val subdir = new File(tempDir, subdirName) + val path = new File(subdir, uid).getPath + + instance.write.format("pmml").save(path) + intercept[IOException] { + instance.write.format("pmml").save(path) + } + instance.write.format("pmml").overwrite().save(path) + val pmmlStr = sc.textFile(path).collect.mkString("\n") + val pmmlModel = PMMLUtils.loadFromString(pmmlStr) + assert(pmmlModel.getHeader().getApplication().getName().startsWith("Apache Spark")) + checkModelData(pmmlModel) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala new file mode 100644 index 0000000000000..dbdc69f95d841 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.util + +import java.io.StringReader +import javax.xml.bind.Unmarshaller +import javax.xml.transform.Source + +import org.dmg.pmml._ +import org.jpmml.model.{ImportFilter, JAXBUtil} +import org.xml.sax.InputSource + +/** + * Testing utils for working with PMML. + * Predictive Model Markup Language (PMML) is an XML-based file format + * developed by the Data Mining Group (www.dmg.org). + */ +private[spark] object PMMLUtils { + /** + * :: Experimental :: + * Load a PMML model from a string. Note: for testing only, PMML model evaluation is supported + * through external spark-packages. + */ + def loadFromString(input: String): PMML = { + val is = new StringReader(input) + val transformed = ImportFilter.apply(new InputSource(is)) + JAXBUtil.unmarshalPMML(transformed) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala new file mode 100644 index 0000000000000..f4c1f0bdb32cd --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import scala.collection.mutable + +import org.apache.spark.SparkException +import org.apache.spark.ml.PipelineStage +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.sql.{DataFrame, SparkSession} + +class FakeLinearRegressionWriter extends MLWriterFormat { + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Fake writer doesn't writestart") + } +} + +class FakeLinearRegressionWriterWithName extends MLFormatRegister { + override def format(): String = "fakeWithName" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Fake writer doesn't writestart") + } +} + + +class DuplicateLinearRegressionWriter1 extends MLFormatRegister { + override def format(): String = "dupe" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Duplicate writer shouldn't have been called") + } +} + +class DuplicateLinearRegressionWriter2 extends MLFormatRegister { + override def format(): String = "dupe" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Duplicate writer shouldn't have been called") + } +} + +class ReadWriteSuite extends MLTest { + + import testImplicits._ + + private val seed: Int = 42 + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + dataset = sc.parallelize(LinearDataGenerator.generateLinearInput( + intercept = 0.0, weights = Array(1.0, 2.0), xMean = Array(0.0, 1.0), + xVariance = Array(2.0, 1.0), nPoints = 10, seed, eps = 0.2)).map(_.asML).toDF() + } + + test("unsupported/non existent export formats") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + // Does not exist with a long class name + val thrownDNE = intercept[SparkException] { + model.write.format("com.holdenkarau.boop").save("boop") + } + assert(thrownDNE.getMessage(). + contains("Could not load requested format")) + + // Does not exist with a short name + val thrownDNEShort = intercept[SparkException] { + model.write.format("boop").save("boop") + } + assert(thrownDNEShort.getMessage(). + contains("Could not load requested format")) + + // Check with a valid class that is not a writer format. + val thrownInvalid = intercept[SparkException] { + model.write.format("org.apache.spark.SparkContext").save("boop2") + } + assert(thrownInvalid.getMessage() + .contains("ML source org.apache.spark.SparkContext is not a valid MLWriterFormat")) + } + + test("invalid paths fail") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + val thrown = intercept[Exception] { + model.write.format("pmml").save("") + } + assert(thrown.getMessage().contains("Can not create a Path from an empty string")) + } + + test("dummy export format is called") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + val thrown = intercept[Exception] { + model.write.format("org.apache.spark.ml.util.FakeLinearRegressionWriter").save("name") + } + assert(thrown.getMessage().contains("Fake writer doesn't write")) + val thrownWithName = intercept[Exception] { + model.write.format("fakeWithName").save("name") + } + assert(thrownWithName.getMessage().contains("Fake writer doesn't write")) + } + + test("duplicate format raises error") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + val thrown = intercept[Exception] { + model.write.format("dupe").save("dupepanda") + } + assert(thrown.getMessage().contains("Multiple writers found for")) + } +} From a33655348c4066d9c1d8ad2055aadfbc892ba7fd Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 23 Mar 2018 15:58:48 -0700 Subject: [PATCH 518/774] [SPARK-23615][ML][PYSPARK] Add maxDF Parameter to Python CountVectorizer ## What changes were proposed in this pull request? The maxDF parameter is for filtering out frequently occurring terms. This param was recently added to the Scala CountVectorizer and needs to be added to Python also. ## How was this patch tested? add test Author: Huaxin Gao Closes #20777 from huaxingao/spark-23615. --- .../spark/ml/feature/CountVectorizer.scala | 20 +++++----- python/pyspark/ml/feature.py | 40 ++++++++++++++----- python/pyspark/ml/tests.py | 25 ++++++++++++ 3 files changed, 67 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 60a4f918790a3..9e0ed437e7bfc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -70,19 +70,21 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit def getMinDF: Double = $(minDF) /** - * Specifies the maximum number of different documents a term must appear in to be included - * in the vocabulary. - * If this is an integer greater than or equal to 1, this specifies the number of documents - * the term must appear in; if this is a double in [0,1), then this specifies the fraction of - * documents. + * Specifies the maximum number of different documents a term could appear in to be included + * in the vocabulary. A term that appears more than the threshold will be ignored. If this is an + * integer greater than or equal to 1, this specifies the maximum number of documents the term + * could appear in; if this is a double in [0,1), then this specifies the maximum fraction of + * documents the term could appear in. * - * Default: (2^64^) - 1 + * Default: (2^63^) - 1 * @group param */ val maxDF: DoubleParam = new DoubleParam(this, "maxDF", "Specifies the maximum number of" + - " different documents a term must appear in to be included in the vocabulary." + - " If this is an integer >= 1, this specifies the number of documents the term must" + - " appear in; if this is a double in [0,1), then this specifies the fraction of documents.", + " different documents a term could appear in to be included in the vocabulary." + + " A term that appears more than the threshold will be ignored. If this is an integer >= 1," + + " this specifies the maximum number of documents the term could appear in;" + + " if this is a double in [0,1), then this specifies the maximum fraction of" + + " documents the term could appear in.", ParamValidators.gtEq(0.0)) /** @group getParam */ diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index a1ceb7f02da8b..fcb0dfc563720 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -422,6 +422,14 @@ class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol): " If this is an integer >= 1, this specifies the number of documents the term must" + " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + " Default 1.0", typeConverter=TypeConverters.toFloat) + maxDF = Param( + Params._dummy(), "maxDF", "Specifies the maximum number of" + + " different documents a term could appear in to be included in the vocabulary." + + " A term that appears more than the threshold will be ignored. If this is an" + + " integer >= 1, this specifies the maximum number of documents the term could appear in;" + + " if this is a double in [0,1), then this specifies the maximum" + + " fraction of documents the term could appear in." + + " Default (2^63) - 1", typeConverter=TypeConverters.toFloat) vocabSize = Param( Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", typeConverter=TypeConverters.toInt) @@ -433,7 +441,7 @@ class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol): def __init__(self, *args): super(_CountVectorizerParams, self).__init__(*args) - self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) + self._setDefault(minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False) @since("1.6.0") def getMinTF(self): @@ -449,6 +457,13 @@ def getMinDF(self): """ return self.getOrDefault(self.minDF) + @since("2.4.0") + def getMaxDF(self): + """ + Gets the value of maxDF or its default value. + """ + return self.getOrDefault(self.maxDF) + @since("1.6.0") def getVocabSize(self): """ @@ -513,11 +528,11 @@ class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, Jav """ @keyword_only - def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, - outputCol=None): + def __init__(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False, + inputCol=None, outputCol=None): """ - __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\ - outputCol=None) + __init__(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False,\ + inputCol=None,outputCol=None) """ super(CountVectorizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", @@ -527,11 +542,11 @@ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputC @keyword_only @since("1.6.0") - def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, - outputCol=None): + def setParams(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False, + inputCol=None, outputCol=None): """ - setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\ - outputCol=None) + setParams(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False,\ + inputCol=None, outputCol=None) Set the params for the CountVectorizer """ kwargs = self._input_kwargs @@ -551,6 +566,13 @@ def setMinDF(self, value): """ return self._set(minDF=value) + @since("2.4.0") + def setMaxDF(self, value): + """ + Sets the value of :py:attr:`maxDF`. + """ + return self._set(maxDF=value) + @since("1.6.0") def setVocabSize(self, value): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 080119959a4e8..cf1ffa181ecec 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -697,6 +697,31 @@ def test_count_vectorizer_with_binary(self): feature, expected = r self.assertEqual(feature, expected) + def test_count_vectorizer_with_maxDF(self): + dataset = self.spark.createDataFrame([ + (0, "a b c d".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), + (1, "a b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), + (2, "a b".split(' '), SparseVector(3, {0: 1.0}),), + (3, "a".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) + cv = CountVectorizer(inputCol="words", outputCol="features") + model1 = cv.setMaxDF(3).fit(dataset) + self.assertEqual(model1.vocabulary, ['b', 'c', 'd']) + + transformedList1 = model1.transform(dataset).select("features", "expected").collect() + + for r in transformedList1: + feature, expected = r + self.assertEqual(feature, expected) + + model2 = cv.setMaxDF(0.75).fit(dataset) + self.assertEqual(model2.vocabulary, ['b', 'c', 'd']) + + transformedList2 = model2.transform(dataset).select("features", "expected").collect() + + for r in transformedList2: + feature, expected = r + self.assertEqual(feature, expected) + def test_count_vectorizer_from_vocab(self): model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words", outputCol="features", minTF=2) From 816a5496ba4caac438f70400f72bb10bfcc02418 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Sat, 24 Mar 2018 18:21:01 -0700 Subject: [PATCH 519/774] [SPARK-23788][SS] Fix race in StreamingQuerySuite ## What changes were proposed in this pull request? The serializability test uses the same MemoryStream instance for 3 different queries. If any of those queries ask it to commit before the others have run, the rest will see empty dataframes. This can fail the test if q3 is affected. We should use one instance per query instead. ## How was this patch tested? Existing unit test. If I move q2.processAllAvailable() before starting q3, the test always fails without the fix. Author: Jose Torres Closes #20896 from jose-torres/fixrace. --- .../spark/sql/streaming/StreamingQuerySuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index ebc9a87b23f84..08749b49997e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -550,22 +550,22 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi .start() } - val input = MemoryStream[Int] - val q1 = startQuery(input.toDS, "stream_serializable_test_1") - val q2 = startQuery(input.toDS.map { i => + val input = MemoryStream[Int] :: MemoryStream[Int] :: MemoryStream[Int] :: Nil + val q1 = startQuery(input(0).toDS, "stream_serializable_test_1") + val q2 = startQuery(input(1).toDS.map { i => // Emulate that `StreamingQuery` get captured with normal usage unintentionally. // It should not fail the query. q1 i }, "stream_serializable_test_2") - val q3 = startQuery(input.toDS.map { i => + val q3 = startQuery(input(2).toDS.map { i => // Emulate that `StreamingQuery` is used in executors. We should fail the query with a clear // error message. q1.explain() i }, "stream_serializable_test_3") try { - input.addData(1) + input.foreach(_.addData(1)) // q2 should not fail since it doesn't use `q1` in the closure q2.processAllAvailable() From 5f653d4f7c84e6147cd323cd650da65e0381ebe8 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 25 Mar 2018 09:18:26 -0700 Subject: [PATCH 520/774] [SPARK-23167][SQL] Add TPCDS queries v2.7 in TPCDSQuerySuite ## What changes were proposed in this pull request? This pr added TPCDS v2.7 (latest) queries in `TPCDSQuerySuite` because the current `TPCDSQuerySuite` tests older one (v1.4) and some queries are different from v1.4 and v2.7. Since the original v2.7 queries have the syntaxes that Spark cannot parse, I changed these queries in a following way: - [date] + 14 days -> date + `INTERVAL` 14 days - [column name] as "30 days" -> [column name] as \`30 days\` - Fix some syntax errors, e.g., missing brackets ## How was this patch tested? Added tests in `TPCDSQuerySuite`. Author: Takeshi Yamamuro Closes #20343 from maropu/TPCDSV2_7. --- .../src/test/resources/tpcds-v2.7.0/q10a.sql | 69 ++++++ .../src/test/resources/tpcds-v2.7.0/q11.sql | 84 +++++++ .../src/test/resources/tpcds-v2.7.0/q12.sql | 23 ++ .../src/test/resources/tpcds-v2.7.0/q14.sql | 135 +++++++++++ .../src/test/resources/tpcds-v2.7.0/q14a.sql | 215 ++++++++++++++++++ .../src/test/resources/tpcds-v2.7.0/q18a.sql | 133 +++++++++++ .../src/test/resources/tpcds-v2.7.0/q20.sql | 19 ++ .../src/test/resources/tpcds-v2.7.0/q22.sql | 15 ++ .../src/test/resources/tpcds-v2.7.0/q22a.sql | 94 ++++++++ .../src/test/resources/tpcds-v2.7.0/q24.sql | 40 ++++ .../src/test/resources/tpcds-v2.7.0/q27a.sql | 70 ++++++ .../src/test/resources/tpcds-v2.7.0/q34.sql | 37 +++ .../src/test/resources/tpcds-v2.7.0/q35.sql | 65 ++++++ .../src/test/resources/tpcds-v2.7.0/q35a.sql | 62 +++++ .../src/test/resources/tpcds-v2.7.0/q36a.sql | 70 ++++++ .../src/test/resources/tpcds-v2.7.0/q47.sql | 64 ++++++ .../src/test/resources/tpcds-v2.7.0/q49.sql | 133 +++++++++++ .../src/test/resources/tpcds-v2.7.0/q51a.sql | 103 +++++++++ .../src/test/resources/tpcds-v2.7.0/q57.sql | 57 +++++ .../src/test/resources/tpcds-v2.7.0/q5a.sql | 158 +++++++++++++ .../src/test/resources/tpcds-v2.7.0/q6.sql | 23 ++ .../src/test/resources/tpcds-v2.7.0/q64.sql | 111 +++++++++ .../src/test/resources/tpcds-v2.7.0/q67a.sql | 208 +++++++++++++++++ .../src/test/resources/tpcds-v2.7.0/q70a.sql | 70 ++++++ .../src/test/resources/tpcds-v2.7.0/q72.sql | 40 ++++ .../src/test/resources/tpcds-v2.7.0/q74.sql | 60 +++++ .../src/test/resources/tpcds-v2.7.0/q75.sql | 78 +++++++ .../src/test/resources/tpcds-v2.7.0/q77a.sql | 121 ++++++++++ .../src/test/resources/tpcds-v2.7.0/q78.sql | 75 ++++++ .../src/test/resources/tpcds-v2.7.0/q80a.sql | 147 ++++++++++++ .../src/test/resources/tpcds-v2.7.0/q86a.sql | 61 +++++ .../src/test/resources/tpcds-v2.7.0/q98.sql | 22 ++ .../apache/spark/sql/TPCDSQuerySuite.scala | 38 +++- 33 files changed, 2691 insertions(+), 9 deletions(-) create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q10a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q11.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q12.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q14.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q14a.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q18a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q20.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q22.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q22a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q24.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q27a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q34.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q35.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q35a.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q36a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q47.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q49.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q51a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q57.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q5a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q6.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q64.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q67a.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q70a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q72.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q74.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q75.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q77a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q78.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q80a.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q86a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q98.sql diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q10a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q10a.sql new file mode 100644 index 0000000000000..50e521567eb3a --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q10a.sql @@ -0,0 +1,69 @@ +-- This is a new query in TPCDS v2.7 +select + cd_gender, + cd_marital_status, + cd_education_status, + count(*) cnt1, + cd_purchase_estimate, + count(*) cnt2, + cd_credit_rating, + count(*) cnt3, + cd_dep_count, + count(*) cnt4, + cd_dep_employed_count, + count(*) cnt5, + cd_dep_college_count, + count(*) cnt6 +from + customer c,customer_address ca,customer_demographics +where + c.c_current_addr_sk = ca.ca_address_sk + and ca_county in ('Walker County', 'Richland County', 'Gaines County', 'Douglas County', 'Dona Ana County') + and cd_demo_sk = c.c_current_cdemo_sk + and exists ( + select * + from store_sales,date_dim + where c.c_customer_sk = ss_customer_sk + and ss_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4 + 3) + and exists ( + select * + from ( + select + ws_bill_customer_sk as customer_sk, + d_year, + d_moy + from web_sales, date_dim + where ws_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4 + 3 + union all + select + cs_ship_customer_sk as customer_sk, + d_year, + d_moy + from catalog_sales, date_dim + where cs_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4 + 3) x + where c.c_customer_sk = customer_sk) +group by + cd_gender, + cd_marital_status, + cd_education_status, + cd_purchase_estimate, + cd_credit_rating, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +order by + cd_gender, + cd_marital_status, + cd_education_status, + cd_purchase_estimate, + cd_credit_rating, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q11.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q11.sql new file mode 100755 index 0000000000000..97bed33721742 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q11.sql @@ -0,0 +1,84 @@ +WITH year_total AS ( + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + c_preferred_cust_flag customer_preferred_cust_flag, + c_birth_country customer_birth_country, + c_login customer_login, + c_email_address customer_email_address, + d_year dyear, + sum(ss_ext_list_price - ss_ext_discount_amt) year_total, + 's' sale_type + FROM customer, store_sales, date_dim + WHERE c_customer_sk = ss_customer_sk + AND ss_sold_date_sk = d_date_sk + GROUP BY c_customer_id + , c_first_name + , c_last_name + , d_year + , c_preferred_cust_flag + , c_birth_country + , c_login + , c_email_address + , d_year + UNION ALL + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + c_preferred_cust_flag customer_preferred_cust_flag, + c_birth_country customer_birth_country, + c_login customer_login, + c_email_address customer_email_address, + d_year dyear, + sum(ws_ext_list_price - ws_ext_discount_amt) year_total, + 'w' sale_type + FROM customer, web_sales, date_dim + WHERE c_customer_sk = ws_bill_customer_sk + AND ws_sold_date_sk = d_date_sk + GROUP BY + c_customer_id, c_first_name, c_last_name, c_preferred_cust_flag, c_birth_country, + c_login, c_email_address, d_year) +SELECT + -- select list of q11 in TPCDS v1.4 is below: + -- t_s_secyear.customer_preferred_cust_flag + t_s_secyear.customer_id, + t_s_secyear.customer_first_name, + t_s_secyear.customer_last_name, + t_s_secyear.customer_email_address +FROM year_total t_s_firstyear + , year_total t_s_secyear + , year_total t_w_firstyear + , year_total t_w_secyear +WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id + AND t_s_firstyear.customer_id = t_w_secyear.customer_id + AND t_s_firstyear.customer_id = t_w_firstyear.customer_id + AND t_s_firstyear.sale_type = 's' + AND t_w_firstyear.sale_type = 'w' + AND t_s_secyear.sale_type = 's' + AND t_w_secyear.sale_type = 'w' + AND t_s_firstyear.dyear = 2001 + AND t_s_secyear.dyear = 2001 + 1 + AND t_w_firstyear.dyear = 2001 + AND t_w_secyear.dyear = 2001 + 1 + AND t_s_firstyear.year_total > 0 + AND t_w_firstyear.year_total > 0 + AND CASE WHEN t_w_firstyear.year_total > 0 + THEN t_w_secyear.year_total / t_w_firstyear.year_total + -- q11 in TPCDS v1.4 used NULL + -- ELSE NULL END + ELSE 0.0 END + > CASE WHEN t_s_firstyear.year_total > 0 + THEN t_s_secyear.year_total / t_s_firstyear.year_total + -- q11 in TPCDS v1.4 used NULL + -- ELSE NULL END + ELSE 0.0 END +ORDER BY + -- order-by list of q11 in TPCDS v1.4 is below: + -- t_s_secyear.customer_preferred_cust_flag + t_s_secyear.customer_id, + t_s_secyear.customer_first_name, + t_s_secyear.customer_last_name, + t_s_secyear.customer_email_address +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q12.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q12.sql new file mode 100755 index 0000000000000..7a6fafd22428a --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q12.sql @@ -0,0 +1,23 @@ +SELECT + i_item_id, -- This column did not exist in TPCDS v1.4 + i_item_desc, + i_category, + i_class, + i_current_price, + sum(ws_ext_sales_price) AS itemrevenue, + sum(ws_ext_sales_price) * 100 / sum(sum(ws_ext_sales_price)) + OVER + (PARTITION BY i_class) AS revenueratio +FROM + web_sales, item, date_dim +WHERE + ws_item_sk = i_item_sk + AND i_category IN ('Sports', 'Books', 'Home') + AND ws_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('1999-02-22' AS DATE) + AND (cast('1999-02-22' AS DATE) + INTERVAL 30 days) +GROUP BY + i_item_id, i_item_desc, i_category, i_class, i_current_price +ORDER BY + i_category, i_class, i_item_id, i_item_desc, revenueratio +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q14.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q14.sql new file mode 100644 index 0000000000000..b2ca3ddaf2baf --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q14.sql @@ -0,0 +1,135 @@ +-- This query is the alternative form of sql/core/src/test/resources/tpcds/q14a.sql +with cross_items as ( + select + i_item_sk ss_item_sk + from item, ( + select + iss.i_brand_id brand_id, + iss.i_class_id class_id, + iss.i_category_id category_id + from + store_sales, item iss, date_dim d1 + where + ss_item_sk = iss.i_item_sk + and ss_sold_date_sk = d1.d_date_sk + and d1.d_year between 1998 AND 1998 + 2 + intersect + select + ics.i_brand_id, + ics.i_class_id, + ics.i_category_id + from + catalog_sales, item ics, date_dim d2 + where + cs_item_sk = ics.i_item_sk + and cs_sold_date_sk = d2.d_date_sk + and d2.d_year between 1998 AND 1998 + 2 + intersect + select + iws.i_brand_id, + iws.i_class_id, + iws.i_category_id + from + web_sales, item iws, date_dim d3 + where + ws_item_sk = iws.i_item_sk + and ws_sold_date_sk = d3.d_date_sk + and d3.d_year between 1998 AND 1998 + 2) x + where + i_brand_id = brand_id + and i_class_id = class_id + and i_category_id = category_id), +avg_sales as ( + select + avg(quantity*list_price) average_sales + from ( + select + ss_quantity quantity, + ss_list_price list_price + from + store_sales, date_dim + where + ss_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2 + union all + select + cs_quantity quantity, + cs_list_price list_price + from + catalog_sales, date_dim + where + cs_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2 + union all + select + ws_quantity quantity, + ws_list_price list_price + from + web_sales, date_dim + where + ws_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2) x) +select + * +from ( + select + 'store' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ss_quantity * ss_list_price) sales, + count(*) number_sales + from + store_sales, item, date_dim + where + ss_item_sk in (select ss_item_sk from cross_items) + and ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and d_week_seq = ( + select d_week_seq + from date_dim + where d_year = 1998 + 1 + and d_moy = 12 + and d_dom = 16) + group by + i_brand_id, + i_class_id, + i_category_id + having + sum(ss_quantity*ss_list_price) > (select average_sales from avg_sales)) this_year, + ( + select + 'store' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ss_quantity * ss_list_price) sales, + count(*) number_sales + from + store_sales, item, date_dim + where + ss_item_sk in (select ss_item_sk from cross_items) + and ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and d_week_seq = ( + select d_week_seq + from date_dim + where d_year = 1998 + and d_moy = 12 + and d_dom = 16) + group by + i_brand_id, + i_class_id, + i_category_id + having + sum(ss_quantity * ss_list_price) > (select average_sales from avg_sales)) last_year +where + this_year.i_brand_id = last_year.i_brand_id + and this_year.i_class_id = last_year.i_class_id + and this_year.i_category_id = last_year.i_category_id +order by + this_year.channel, + this_year.i_brand_id, + this_year.i_class_id, + this_year.i_category_id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q14a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q14a.sql new file mode 100644 index 0000000000000..bfa70fe62d8d5 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q14a.sql @@ -0,0 +1,215 @@ +-- This query is the alternative form of sql/core/src/test/resources/tpcds/q14b.sql +with cross_items as ( + select + i_item_sk ss_item_sk + from item, ( + select + iss.i_brand_id brand_id, + iss.i_class_id class_id, + iss.i_category_id category_id + from + store_sales, item iss, date_dim d1 + where + ss_item_sk = iss.i_item_sk + and ss_sold_date_sk = d1.d_date_sk + and d1.d_year between 1999 AND 1999 + 2 + intersect + select + ics.i_brand_id, + ics.i_class_id, + ics.i_category_id + from + catalog_sales, item ics, date_dim d2 + where + cs_item_sk = ics.i_item_sk + and cs_sold_date_sk = d2.d_date_sk + and d2.d_year between 1999 AND 1999 + 2 + intersect + select + iws.i_brand_id, + iws.i_class_id, + iws.i_category_id + from + web_sales, item iws, date_dim d3 + where + ws_item_sk = iws.i_item_sk + and ws_sold_date_sk = d3.d_date_sk + and d3.d_year between 1999 AND 1999 + 2) x + where + i_brand_id = brand_id + and i_class_id = class_id + and i_category_id = category_id), +avg_sales as ( + select + avg(quantity*list_price) average_sales + from ( + select + ss_quantity quantity, + ss_list_price list_price + from + store_sales, date_dim + where + ss_sold_date_sk = d_date_sk + and d_year between 1999 and 2001 + union all + select + cs_quantity quantity, + cs_list_price list_price + from + catalog_sales, date_dim + where + cs_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2 + union all + select + ws_quantity quantity, + ws_list_price list_price + from + web_sales, date_dim + where + ws_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2) x), +results AS ( + select + channel, + i_brand_id, + i_class_id, + i_category_id, + sum(sales) sum_sales, + sum(number_sales) number_sales + from ( + select + 'store' channel, + i_brand_id,i_class_id, + i_category_id, + sum(ss_quantity*ss_list_price) sales, + count(*) number_sales + from + store_sales, item, date_dim + where + ss_item_sk in (select ss_item_sk from cross_items) + and ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and d_year = 1998 + 2 + and d_moy = 11 + group by + i_brand_id, + i_class_id, + i_category_id + having + sum(ss_quantity * ss_list_price) > (select average_sales from avg_sales) + union all + select + 'catalog' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(cs_quantity*cs_list_price) sales, + count(*) number_sales + from + catalog_sales, item, date_dim + where + cs_item_sk in (select ss_item_sk from cross_items) + and cs_item_sk = i_item_sk + and cs_sold_date_sk = d_date_sk + and d_year = 1998+2 + and d_moy = 11 + group by + i_brand_id,i_class_id,i_category_id + having + sum(cs_quantity*cs_list_price) > (select average_sales from avg_sales) + union all + select + 'web' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ws_quantity*ws_list_price) sales, + count(*) number_sales + from + web_sales, item, date_dim + where + ws_item_sk in (select ss_item_sk from cross_items) + and ws_item_sk = i_item_sk + and ws_sold_date_sk = d_date_sk + and d_year = 1998 + 2 + and d_moy = 11 + group by + i_brand_id, + i_class_id, + i_category_id + having + sum(ws_quantity*ws_list_price) > (select average_sales from avg_sales)) y + group by + channel, + i_brand_id, + i_class_id, + i_category_id) +select + channel, + i_brand_id, + i_class_id, + i_category_id, + sum_sales, + number_sales +from ( + select + channel, + i_brand_id, + i_class_id, + i_category_id, + sum_sales, + number_sales + from + results + union + select + channel, + i_brand_id, + i_class_id, + null as i_category_id, + sum(sum_sales), + sum(number_sales) + from results + group by + channel, + i_brand_id, + i_class_id + union + select + channel, + i_brand_id, + null as i_class_id, + null as i_category_id, + sum(sum_sales), + sum(number_sales) + from results + group by + channel, + i_brand_id + union + select + channel, + null as i_brand_id, + null as i_class_id, + null as i_category_id, + sum(sum_sales), + sum(number_sales) + from results + group by + channel + union + select + null as channel, + null as i_brand_id, + null as i_class_id, + null as i_category_id, + sum(sum_sales), + sum(number_sales) + from results) z +order by + channel, + i_brand_id, + i_class_id, + i_category_id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q18a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q18a.sql new file mode 100644 index 0000000000000..2201a302ab352 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q18a.sql @@ -0,0 +1,133 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + i_item_id, + ca_country, + ca_state, + ca_county, + cast(cs_quantity as decimal(12,2)) agg1, + cast(cs_list_price as decimal(12,2)) agg2, + cast(cs_coupon_amt as decimal(12,2)) agg3, + cast(cs_sales_price as decimal(12,2)) agg4, + cast(cs_net_profit as decimal(12,2)) agg5, + cast(c_birth_year as decimal(12,2)) agg6, + cast(cd1.cd_dep_count as decimal(12,2)) agg7 + from + catalog_sales, customer_demographics cd1, customer_demographics cd2, customer, + customer_address, date_dim, item + where + cs_sold_date_sk = d_date_sk + and cs_item_sk = i_item_sk + and cs_bill_cdemo_sk = cd1.cd_demo_sk + and cs_bill_customer_sk = c_customer_sk + and cd1.cd_gender = 'M' + and cd1.cd_education_status = 'College' + and c_current_cdemo_sk = cd2.cd_demo_sk + and c_current_addr_sk = ca_address_sk + and c_birth_month in (9,5,12,4,1,10) + and d_year = 2001 + and ca_state in ('ND','WI','AL','NC','OK','MS','TN')) +select + i_item_id, + ca_country, + ca_state, + ca_county, + agg1, + agg2, + agg3, + agg4, + agg5, + agg6, + agg7 +from ( + select + i_item_id, + ca_country, + ca_state, + ca_county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from + results + group by + i_item_id, + ca_country, + ca_state, + ca_county + union all + select + i_item_id, + ca_country, + ca_state, + NULL as county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from + results + group by + i_item_id, + ca_country, + ca_state + union all + select + i_item_id, + ca_country, + NULL as ca_state, + NULL as county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from results + group by + i_item_id, + ca_country + union all + select + i_item_id, + NULL as ca_country, + NULL as ca_state, + NULL as county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from results + group by + i_item_id + union all + select + NULL AS i_item_id, + NULL as ca_country, + NULL as ca_state, + NULL as county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from results) foo +order by + ca_country, + ca_state, + ca_county, + i_item_id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q20.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q20.sql new file mode 100755 index 0000000000000..34d46b1394d8f --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q20.sql @@ -0,0 +1,19 @@ +SELECT + i_item_id, -- This column did not exist in TPCDS v1.4 + i_item_desc, + i_category, + i_class, + i_current_price, + sum(cs_ext_sales_price) AS itemrevenue, + sum(cs_ext_sales_price) * 100 / sum(sum(cs_ext_sales_price)) + OVER + (PARTITION BY i_class) AS revenueratio +FROM catalog_sales, item, date_dim +WHERE cs_item_sk = i_item_sk + AND i_category IN ('Sports', 'Books', 'Home') + AND cs_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('1999-02-22' AS DATE) +AND (cast('1999-02-22' AS DATE) + INTERVAL 30 days) +GROUP BY i_item_id, i_item_desc, i_category, i_class, i_current_price +ORDER BY i_category, i_class, i_item_id, i_item_desc, revenueratio +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q22.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q22.sql new file mode 100755 index 0000000000000..e7bea0804f162 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q22.sql @@ -0,0 +1,15 @@ +SELECT + i_product_name, + i_brand, + i_class, + i_category, + avg(inv_quantity_on_hand) qoh +FROM inventory, date_dim, item, warehouse +WHERE inv_date_sk = d_date_sk + AND inv_item_sk = i_item_sk + -- q22 in TPCDS v1.4 had a condition below: + -- AND inv_warehouse_sk = w_warehouse_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 +GROUP BY ROLLUP (i_product_name, i_brand, i_class, i_category) +ORDER BY qoh, i_product_name, i_brand, i_class, i_category +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q22a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q22a.sql new file mode 100644 index 0000000000000..c886e6271511b --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q22a.sql @@ -0,0 +1,94 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + i_product_name, + i_brand, + i_class, + i_category, + avg(inv_quantity_on_hand) qoh + from + inventory, date_dim, item, warehouse + where + inv_date_sk = d_date_sk + and inv_item_sk = i_item_sk + and inv_warehouse_sk = w_warehouse_sk + and d_month_seq between 1212 and 1212 + 11 + group by + i_product_name, + i_brand, + i_class, + i_category), +results_rollup as ( + select + i_product_name, + i_brand, + i_class, + i_category, + avg(qoh) qoh + from + results + group by + i_product_name, + i_brand, + i_class, + i_category + union all + select + i_product_name, + i_brand, + i_class, + null i_category, + avg(qoh) qoh + from + results + group by + i_product_name, + i_brand, + i_class + union all + select + i_product_name, + i_brand, + null i_class, + null i_category, + avg(qoh) qoh + from + results + group by + i_product_name, + i_brand + union all + select + i_product_name, + null i_brand, + null i_class, + null i_category, + avg(qoh) qoh + from + results + group by + i_product_name + union all + select + null i_product_name, + null i_brand, + null i_class, + null i_category, + avg(qoh) qoh + from + results) +select + i_product_name, + i_brand, + i_class, + i_category, + qoh +from + results_rollup +order by + qoh, + i_product_name, + i_brand, + i_class, + i_category +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q24.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q24.sql new file mode 100755 index 0000000000000..92d64bc7eba78 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q24.sql @@ -0,0 +1,40 @@ +WITH ssales AS +(SELECT + c_last_name, + c_first_name, + s_store_name, + ca_state, + s_state, + i_color, + i_current_price, + i_manager_id, + i_units, + i_size, + sum(ss_net_paid) netpaid + FROM store_sales, store_returns, store, item, customer, customer_address + WHERE ss_ticket_number = sr_ticket_number + AND ss_item_sk = sr_item_sk + AND ss_customer_sk = c_customer_sk + AND ss_item_sk = i_item_sk + AND ss_store_sk = s_store_sk + AND c_current_addr_sk = ca_address_sk -- This condition did not exist in TPCDS v1.4 + AND c_birth_country = upper(ca_country) + AND s_zip = ca_zip + AND s_market_id = 8 + GROUP BY c_last_name, c_first_name, s_store_name, ca_state, s_state, i_color, + i_current_price, i_manager_id, i_units, i_size) +SELECT + c_last_name, + c_first_name, + s_store_name, + sum(netpaid) paid +FROM ssales +WHERE i_color = 'pale' +GROUP BY c_last_name, c_first_name, s_store_name +HAVING sum(netpaid) > (SELECT 0.05 * avg(netpaid) +FROM ssales) +-- no order-by exists in q24a of TPCDS v1.4 +ORDER BY + c_last_name, + c_first_name, + s_store_name diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q27a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q27a.sql new file mode 100644 index 0000000000000..c70a2420e8387 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q27a.sql @@ -0,0 +1,70 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + i_item_id, + s_state, 0 as g_state, + ss_quantity agg1, + ss_list_price agg2, + ss_coupon_amt agg3, + ss_sales_price agg4 + from + store_sales, customer_demographics, date_dim, store, item + where + ss_sold_date_sk = d_date_sk + and ss_item_sk = i_item_sk + and ss_store_sk = s_store_sk + and ss_cdemo_sk = cd_demo_sk + and cd_gender = 'F' + and cd_marital_status = 'W' + and cd_education_status = 'Primary' + and d_year = 1998 + and s_state in ('TN','TN', 'TN', 'TN', 'TN', 'TN')) +select + i_item_id, + s_state, + g_state, + agg1, + agg2, + agg3, + agg4 +from ( + select + i_item_id, + s_state, + 0 as g_state, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4 + from + results + group by + i_item_id, + s_state + union all + select + i_item_id, + NULL AS s_state, + 1 AS g_state, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4 + from results + group by + i_item_id + union all + select + NULL AS i_item_id, + NULL as s_state, + 1 as g_state, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4 + from + results) foo +order by + i_item_id, + s_state +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q34.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q34.sql new file mode 100755 index 0000000000000..bbede62acc9a7 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q34.sql @@ -0,0 +1,37 @@ +SELECT + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag, + ss_ticket_number, + cnt +FROM + (SELECT + ss_ticket_number, + ss_customer_sk, + count(*) cnt + FROM store_sales, date_dim, store, household_demographics + WHERE store_sales.ss_sold_date_sk = date_dim.d_date_sk + AND store_sales.ss_store_sk = store.s_store_sk + AND store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + AND (date_dim.d_dom BETWEEN 1 AND 3 OR date_dim.d_dom BETWEEN 25 AND 28) + AND (household_demographics.hd_buy_potential = '>10000' OR + household_demographics.hd_buy_potential = 'unknown') + AND household_demographics.hd_vehicle_count > 0 + AND (CASE WHEN household_demographics.hd_vehicle_count > 0 + THEN household_demographics.hd_dep_count / household_demographics.hd_vehicle_count + ELSE NULL + END) > 1.2 + AND date_dim.d_year IN (1999, 1999 + 1, 1999 + 2) + AND store.s_county IN + ('Williamson County', 'Williamson County', 'Williamson County', 'Williamson County', + 'Williamson County', 'Williamson County', 'Williamson County', 'Williamson County') + GROUP BY ss_ticket_number, ss_customer_sk) dn, customer +WHERE ss_customer_sk = c_customer_sk + AND cnt BETWEEN 15 AND 20 +ORDER BY + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag DESC, + ss_ticket_number -- This order-by condition did not exist in TPCDS v1.4 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q35.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q35.sql new file mode 100755 index 0000000000000..27116a563d5c6 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q35.sql @@ -0,0 +1,65 @@ +SELECT + -- select list of q35 in TPCDS v1.4 is below: + -- ca_state, + -- cd_gender, + -- cd_marital_status, + -- count(*) cnt1, + -- min(cd_dep_count), + -- max(cd_dep_count), + -- avg(cd_dep_count), + -- cd_dep_employed_count, + -- count(*) cnt2, + -- min(cd_dep_employed_count), + -- max(cd_dep_employed_count), + -- avg(cd_dep_employed_count), + -- cd_dep_college_count, + -- count(*) cnt3, + -- min(cd_dep_college_count), + -- max(cd_dep_college_count), + -- avg(cd_dep_college_count) + ca_state, + cd_gender, + cd_marital_status, + cd_dep_count, + count(*) cnt1, + avg(cd_dep_count), + max(cd_dep_count), + sum(cd_dep_count), + cd_dep_employed_count, + count(*) cnt2, + avg(cd_dep_employed_count), + max(cd_dep_employed_count), + sum(cd_dep_employed_count), + cd_dep_college_count, + count(*) cnt3, + avg(cd_dep_college_count), + max(cd_dep_college_count), + sum(cd_dep_college_count) +FROM + customer c, customer_address ca, customer_demographics +WHERE + c.c_current_addr_sk = ca.ca_address_sk AND + cd_demo_sk = c.c_current_cdemo_sk AND + exists(SELECT * + FROM store_sales, date_dim + WHERE c.c_customer_sk = ss_customer_sk AND + ss_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_qoy < 4) AND + (exists(SELECT * + FROM web_sales, date_dim + WHERE c.c_customer_sk = ws_bill_customer_sk AND + ws_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_qoy < 4) OR + exists(SELECT * + FROM catalog_sales, date_dim + WHERE c.c_customer_sk = cs_ship_customer_sk AND + cs_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_qoy < 4)) +GROUP BY ca_state, cd_gender, cd_marital_status, cd_dep_count, + cd_dep_employed_count, cd_dep_college_count +ORDER BY ca_state, cd_gender, cd_marital_status, cd_dep_count, + cd_dep_employed_count, cd_dep_college_count +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q35a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q35a.sql new file mode 100644 index 0000000000000..1c1463e44777f --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q35a.sql @@ -0,0 +1,62 @@ +-- This is a new query in TPCDS v2.7 +select + ca_state, + cd_gender, + cd_marital_status, + cd_dep_count, + count(*) cnt1, + avg(cd_dep_count), + max(cd_dep_count), + sum(cd_dep_count), + cd_dep_employed_count, + count(*) cnt2, + avg(cd_dep_employed_count), + max(cd_dep_employed_count), + sum(cd_dep_employed_count), + cd_dep_college_count, + count(*) cnt3, + avg(cd_dep_college_count), + max(cd_dep_college_count), + sum(cd_dep_college_count) +from + customer c, customer_address ca, customer_demographics +where + c.c_current_addr_sk = ca.ca_address_sk + and cd_demo_sk = c.c_current_cdemo_sk + and exists ( + select * + from store_sales, date_dim + where c.c_customer_sk = ss_customer_sk + and ss_sold_date_sk = d_date_sk + and d_year = 1999 + and d_qoy < 4) + and exists ( + select * + from ( + select ws_bill_customer_sk customsk + from web_sales, date_dim + where ws_sold_date_sk = d_date_sk + and d_year = 1999 + and d_qoy < 4 + union all + select cs_ship_customer_sk customsk + from catalog_sales, date_dim + where cs_sold_date_sk = d_date_sk + and d_year = 1999 + and d_qoy < 4) x + where x.customsk = c.c_customer_sk) +group by + ca_state, + cd_gender, + cd_marital_status, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +order by + ca_state, + cd_gender, + cd_marital_status, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q36a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q36a.sql new file mode 100644 index 0000000000000..9d98f32add508 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q36a.sql @@ -0,0 +1,70 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + sum(ss_net_profit) as ss_net_profit, + sum(ss_ext_sales_price) as ss_ext_sales_price, + sum(ss_net_profit)/sum(ss_ext_sales_price) as gross_margin, + i_category, + i_class, + 0 as g_category, + 0 as g_class + from + store_sales, + date_dim d1, + item, + store + where + d1.d_year = 2001 + and d1.d_date_sk = ss_sold_date_sk + and i_item_sk = ss_item_sk + and s_store_sk = ss_store_sk + and s_state in ('TN', 'TN', 'TN', 'TN', 'TN', 'TN', 'TN', 'TN') + group by + i_category, + i_class), + results_rollup as ( + select + gross_margin, + i_category, + i_class, + 0 as t_category, + 0 as t_class, + 0 as lochierarchy + from + results + union + select + sum(ss_net_profit) / sum(ss_ext_sales_price) as gross_margin, + i_category, NULL AS i_class, + 0 as t_category, + 1 as t_class, + 1 as lochierarchy + from + results + group by + i_category + union + select + sum(ss_net_profit) / sum(ss_ext_sales_price) as gross_margin, + NULL AS i_category, + NULL AS i_class, + 1 as t_category, + 1 as t_class, + 2 as lochierarchy + from + results) +select + gross_margin, + i_category, + i_class, + lochierarchy, + rank() over ( + partition by lochierarchy, case when t_class = 0 then i_category end + order by gross_margin asc) as rank_within_parent +from + results_rollup +order by + lochierarchy desc, + case when lochierarchy = 0 then i_category end, + rank_within_parent +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q47.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q47.sql new file mode 100755 index 0000000000000..9f7ee457ea45f --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q47.sql @@ -0,0 +1,64 @@ +WITH v1 AS ( + SELECT + i_category, + i_brand, + s_store_name, + s_company_name, + d_year, + d_moy, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) + OVER + (PARTITION BY i_category, i_brand, + s_store_name, s_company_name, d_year) + avg_monthly_sales, + rank() + OVER + (PARTITION BY i_category, i_brand, + s_store_name, s_company_name + ORDER BY d_year, d_moy) rn + FROM item, store_sales, date_dim, store + WHERE ss_item_sk = i_item_sk AND + ss_sold_date_sk = d_date_sk AND + ss_store_sk = s_store_sk AND + ( + d_year = 1999 OR + (d_year = 1999 - 1 AND d_moy = 12) OR + (d_year = 1999 + 1 AND d_moy = 1) + ) + GROUP BY i_category, i_brand, + s_store_name, s_company_name, + d_year, d_moy), + v2 AS ( + SELECT + v1.i_category, + -- q47 in TPCDS v1.4 had more columns below: + -- v1.i_brand, + -- v1.s_store_name, + -- v1.s_company_name, + v1.d_year, + v1.d_moy, + v1.avg_monthly_sales, + v1.sum_sales, + v1_lag.sum_sales psum, + v1_lead.sum_sales nsum + FROM v1, v1 v1_lag, v1 v1_lead + WHERE v1.i_category = v1_lag.i_category AND + v1.i_category = v1_lead.i_category AND + v1.i_brand = v1_lag.i_brand AND + v1.i_brand = v1_lead.i_brand AND + v1.s_store_name = v1_lag.s_store_name AND + v1.s_store_name = v1_lead.s_store_name AND + v1.s_company_name = v1_lag.s_company_name AND + v1.s_company_name = v1_lead.s_company_name AND + v1.rn = v1_lag.rn + 1 AND + v1.rn = v1_lead.rn - 1) +SELECT * +FROM v2 +WHERE d_year = 1999 AND + avg_monthly_sales > 0 AND + CASE WHEN avg_monthly_sales > 0 + THEN abs(sum_sales - avg_monthly_sales) / avg_monthly_sales + ELSE NULL END > 0.1 +ORDER BY sum_sales - avg_monthly_sales, 3 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q49.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q49.sql new file mode 100755 index 0000000000000..e8061bde4159e --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q49.sql @@ -0,0 +1,133 @@ +-- The first SELECT query below is different from q49 of TPCDS v1.4 +SELECT + channel, + item, + return_ratio, + return_rank, + currency_rank +FROM ( + SELECT + 'web' as channel, + in_web.item, + in_web.return_ratio, + in_web.return_rank, + in_web.currency_rank + FROM + (SELECT + item, + return_ratio, + currency_ratio, + rank() over (ORDER BY return_ratio) AS return_rank, + rank() over (ORDER BY currency_ratio) AS currency_rank + FROM ( + SELECT + ws.ws_item_sk AS item, + CAST(SUM(COALESCE(wr.wr_return_quantity, 0)) AS DECIMAL(15, 4)) / + CAST(SUM(COALESCE(ws.ws_quantity, 0)) AS DECIMAL(15, 4)) AS return_ratio, + CAST(SUM(COALESCE(wr.wr_return_amt, 0)) AS DECIMAL(15, 4)) / + CAST(SUM(COALESCE(ws.ws_net_paid, 0)) AS DECIMAL(15, 4)) AS currency_ratio + FROM + web_sales ws LEFT OUTER JOIN web_returns wr + ON (ws.ws_order_number = wr.wr_order_number AND ws.ws_item_sk = wr.wr_item_sk), + date_dim + WHERE + wr.wr_return_amt > 10000 + AND ws.ws_net_profit > 1 + AND ws.ws_net_paid > 0 + AND ws.ws_quantity > 0 + AND ws_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 12 + GROUP BY + ws.ws_item_sk) + ) in_web + ) web +WHERE (web.return_rank <= 10 OR web.currency_rank <= 10) +UNION +SELECT + 'catalog' AS channel, + catalog.item, + catalog.return_ratio, + catalog.return_rank, + catalog.currency_rank +FROM ( + SELECT + item, + return_ratio, + currency_ratio, + rank() + OVER ( + ORDER BY return_ratio) AS return_rank, + rank() + OVER ( + ORDER BY currency_ratio) AS currency_rank + FROM + (SELECT + cs.cs_item_sk AS item, + (cast(sum(coalesce(cr.cr_return_quantity, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(cs.cs_quantity, 0)) AS DECIMAL(15, 4))) AS return_ratio, + (cast(sum(coalesce(cr.cr_return_amount, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(cs.cs_net_paid, 0)) AS DECIMAL(15, 4))) AS currency_ratio + FROM + catalog_sales cs LEFT OUTER JOIN catalog_returns cr + ON (cs.cs_order_number = cr.cr_order_number AND + cs.cs_item_sk = cr.cr_item_sk) + , date_dim + WHERE + cr.cr_return_amount > 10000 + AND cs.cs_net_profit > 1 + AND cs.cs_net_paid > 0 + AND cs.cs_quantity > 0 + AND cs_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 12 + GROUP BY cs.cs_item_sk + ) in_cat + ) catalog +WHERE (catalog.return_rank <= 10 OR catalog.currency_rank <= 10) +UNION +SELECT + 'store' AS channel, + store.item, + store.return_ratio, + store.return_rank, + store.currency_rank +FROM ( + SELECT + item, + return_ratio, + currency_ratio, + rank() + OVER ( + ORDER BY return_ratio) AS return_rank, + rank() + OVER ( + ORDER BY currency_ratio) AS currency_rank + FROM + (SELECT + sts.ss_item_sk AS item, + (cast(sum(coalesce(sr.sr_return_quantity, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(sts.ss_quantity, 0)) AS DECIMAL(15, 4))) AS return_ratio, + (cast(sum(coalesce(sr.sr_return_amt, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(sts.ss_net_paid, 0)) AS DECIMAL(15, 4))) AS currency_ratio + FROM + store_sales sts LEFT OUTER JOIN store_returns sr + ON (sts.ss_ticket_number = sr.sr_ticket_number AND sts.ss_item_sk = sr.sr_item_sk) + , date_dim + WHERE + sr.sr_return_amt > 10000 + AND sts.ss_net_profit > 1 + AND sts.ss_net_paid > 0 + AND sts.ss_quantity > 0 + AND ss_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 12 + GROUP BY sts.ss_item_sk + ) in_store + ) store +WHERE (store.return_rank <= 10 OR store.currency_rank <= 10) +ORDER BY + -- order-by list of q49 in TPCDS v1.4 is below: + -- 1, 4, 5 + 1, 4, 5, 2 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q51a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q51a.sql new file mode 100644 index 0000000000000..b8cbbbc8ef7d5 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q51a.sql @@ -0,0 +1,103 @@ +-- This is a new query in TPCDS v2.7 +WITH web_tv as ( + select + ws_item_sk item_sk, + d_date, + sum(ws_sales_price) sumws, + row_number() over (partition by ws_item_sk order by d_date) rk + from + web_sales, date_dim + where + ws_sold_date_sk=d_date_sk + and d_month_seq between 1212 and 1212 + 11 + and ws_item_sk is not NULL + group by + ws_item_sk, d_date), +web_v1 as ( + select + v1.item_sk, + v1.d_date, + v1.sumws, + sum(v2.sumws) cume_sales + from + web_tv v1, web_tv v2 + where + v1.item_sk = v2.item_sk + and v1.rk >= v2.rk + group by + v1.item_sk, + v1.d_date, + v1.sumws), +store_tv as ( + select + ss_item_sk item_sk, + d_date, + sum(ss_sales_price) sumss, + row_number() over (partition by ss_item_sk order by d_date) rk + from + store_sales, date_dim + where + ss_sold_date_sk = d_date_sk + and d_month_seq between 1212 and 1212 + 11 + and ss_item_sk is not NULL + group by ss_item_sk, d_date), +store_v1 as ( + select + v1.item_sk, + v1.d_date, + v1.sumss, + sum(v2.sumss) cume_sales + from + store_tv v1, store_tv v2 + where + v1.item_sk = v2.item_sk + and v1.rk >= v2.rk + group by + v1.item_sk, + v1.d_date, + v1.sumss), +v as ( + select + item_sk, + d_date, + web_sales, + store_sales, + row_number() over (partition by item_sk order by d_date) rk + from ( + select + case when web.item_sk is not null + then web.item_sk + else store.item_sk end item_sk, + case when web.d_date is not null + then web.d_date + else store.d_date end d_date, + web.cume_sales web_sales, + store.cume_sales store_sales + from + web_v1 web full outer join store_v1 store + on (web.item_sk = store.item_sk and web.d_date = store.d_date))) +select * +from ( + select + v1.item_sk, + v1.d_date, + v1.web_sales, + v1.store_sales, + max(v2.web_sales) web_cumulative, + max(v2.store_sales) store_cumulative + from + v v1, v v2 + where + v1.item_sk = v2.item_sk + and v1.rk >= v2.rk + group by + v1.item_sk, + v1.d_date, + v1.web_sales, + v1.store_sales) x +where + web_cumulative > store_cumulative +order by + item_sk, + d_date +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q57.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q57.sql new file mode 100755 index 0000000000000..ccefaac3c12ca --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q57.sql @@ -0,0 +1,57 @@ +WITH v1 AS ( + SELECT + i_category, + i_brand, + cc_name, + d_year, + d_moy, + sum(cs_sales_price) sum_sales, + avg(sum(cs_sales_price)) + OVER + (PARTITION BY i_category, i_brand, cc_name, d_year) + avg_monthly_sales, + rank() + OVER + (PARTITION BY i_category, i_brand, cc_name + ORDER BY d_year, d_moy) rn + FROM item, catalog_sales, date_dim, call_center + WHERE cs_item_sk = i_item_sk AND + cs_sold_date_sk = d_date_sk AND + cc_call_center_sk = cs_call_center_sk AND + ( + d_year = 1999 OR + (d_year = 1999 - 1 AND d_moy = 12) OR + (d_year = 1999 + 1 AND d_moy = 1) + ) + GROUP BY i_category, i_brand, + cc_name, d_year, d_moy), + v2 AS ( + SELECT + v1.i_category, + v1.i_brand, + -- q57 in TPCDS v1.4 had a column below: + -- v1.cc_name, + v1.d_year, + v1.d_moy, + v1.avg_monthly_sales, + v1.sum_sales, + v1_lag.sum_sales psum, + v1_lead.sum_sales nsum + FROM v1, v1 v1_lag, v1 v1_lead + WHERE v1.i_category = v1_lag.i_category AND + v1.i_category = v1_lead.i_category AND + v1.i_brand = v1_lag.i_brand AND + v1.i_brand = v1_lead.i_brand AND + v1.cc_name = v1_lag.cc_name AND + v1.cc_name = v1_lead.cc_name AND + v1.rn = v1_lag.rn + 1 AND + v1.rn = v1_lead.rn - 1) +SELECT * +FROM v2 +WHERE d_year = 1999 AND + avg_monthly_sales > 0 AND + CASE WHEN avg_monthly_sales > 0 + THEN abs(sum_sales - avg_monthly_sales) / avg_monthly_sales + ELSE NULL END > 0.1 +ORDER BY sum_sales - avg_monthly_sales, 3 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q5a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q5a.sql new file mode 100644 index 0000000000000..42bcf59c2aeb1 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q5a.sql @@ -0,0 +1,158 @@ +-- This is a new query in TPCDS v2.7 +with ssr as( + select + s_store_id, + sum(sales_price) as sales, + sum(profit) as profit, + sum(return_amt) as returns, + sum(net_loss) as profit_loss + from ( + select + ss_store_sk as store_sk, + ss_sold_date_sk as date_sk, + ss_ext_sales_price as sales_price, + ss_net_profit as profit, + cast(0 as decimal(7,2)) as return_amt, + cast(0 as decimal(7,2)) as net_loss + from + store_sales + union all + select + sr_store_sk as store_sk, + sr_returned_date_sk as date_sk, + cast(0 as decimal(7,2)) as sales_price, + cast(0 as decimal(7,2)) as profit, + sr_return_amt as return_amt, + sr_net_loss as net_loss + from + store_returns) salesreturns, + date_dim, + store + where + date_sk = d_date_sk and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + INTERVAL 14 days) + and store_sk = s_store_sk + group by + s_store_id), +csr as ( + select + cp_catalog_page_id, + sum(sales_price) as sales, + sum(profit) as profit, + sum(return_amt) as returns, + sum(net_loss) as profit_loss + from ( + select + cs_catalog_page_sk as page_sk, + cs_sold_date_sk as date_sk, + cs_ext_sales_price as sales_price, + cs_net_profit as profit, + cast(0 as decimal(7,2)) as return_amt, + cast(0 as decimal(7,2)) as net_loss + from catalog_sales + union all + select + cr_catalog_page_sk as page_sk, + cr_returned_date_sk as date_sk, + cast(0 as decimal(7,2)) as sales_price, + cast(0 as decimal(7,2)) as profit, + cr_return_amount as return_amt, + cr_net_loss as net_loss + from catalog_returns) salesreturns, + date_dim, + catalog_page + where + date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + INTERVAL 14 days) + and page_sk = cp_catalog_page_sk + group by + cp_catalog_page_id), +wsr as ( + select + web_site_id, + sum(sales_price) as sales, + sum(profit) as profit, + sum(return_amt) as returns, + sum(net_loss) as profit_loss + from ( + select + ws_web_site_sk as wsr_web_site_sk, + ws_sold_date_sk as date_sk, + ws_ext_sales_price as sales_price, + ws_net_profit as profit, + cast(0 as decimal(7,2)) as return_amt, + cast(0 as decimal(7,2)) as net_loss + from + web_sales + union all + select + ws_web_site_sk as wsr_web_site_sk, + wr_returned_date_sk as date_sk, + cast(0 as decimal(7,2)) as sales_price, + cast(0 as decimal(7,2)) as profit, + wr_return_amt as return_amt, + wr_net_loss as net_loss + from + web_returns + left outer join web_sales on ( + wr_item_sk = ws_item_sk and wr_order_number = ws_order_number) + ) salesreturns, + date_dim, + web_site + where + date_sk = d_date_sk and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + INTERVAL 14 days) + and wsr_web_site_sk = web_site_sk + group by + web_site_id), +results as ( + select + channel, + id, + sum(sales) as sales, + sum(returns) as returns, + sum(profit) as profit + from ( + select + 'store channel' as channel, + 'store' || s_store_id as id, + sales, + returns, + (profit - profit_loss) as profit + from + ssr + union all + select + 'catalog channel' as channel, + 'catalog_page' || cp_catalog_page_id as id, + sales, + returns, + (profit - profit_loss) as profit + from + csr + union all + select + 'web channel' as channel, + 'web_site' || web_site_id as id, + sales, + returns, + (profit - profit_loss) as profit + from + wsr) x + group by + channel, id) +select + channel, id, sales, returns, profit +from ( + select channel, id, sales, returns, profit + from results + union + select channel, null as id, sum(sales), sum(returns), sum(profit) + from results + group by channel + union + select null as channel, null as id, sum(sales), sum(returns), sum(profit) + from results) foo + order by channel, id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q6.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q6.sql new file mode 100755 index 0000000000000..c0bfa40ad44a8 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q6.sql @@ -0,0 +1,23 @@ +SELECT + a.ca_state state, + count(*) cnt +FROM + customer_address a, customer c, store_sales s, date_dim d, item i +WHERE a.ca_address_sk = c.c_current_addr_sk + AND c.c_customer_sk = s.ss_customer_sk + AND s.ss_sold_date_sk = d.d_date_sk + AND s.ss_item_sk = i.i_item_sk + AND d.d_month_seq = + (SELECT DISTINCT (d_month_seq) + FROM date_dim + WHERE d_year = 2000 AND d_moy = 1) + AND i.i_current_price > 1.2 * + (SELECT avg(j.i_current_price) + FROM item j + WHERE j.i_category = i.i_category) +GROUP BY a.ca_state +HAVING count(*) >= 10 +-- order-by list of q6 in TPCDS v1.4 is below: +-- order by cnt +order by cnt, a.ca_state +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q64.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q64.sql new file mode 100755 index 0000000000000..cdcd8486b363d --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q64.sql @@ -0,0 +1,111 @@ +WITH cs_ui AS +(SELECT + cs_item_sk, + sum(cs_ext_list_price) AS sale, + sum(cr_refunded_cash + cr_reversed_charge + cr_store_credit) AS refund + FROM catalog_sales + , catalog_returns + WHERE cs_item_sk = cr_item_sk + AND cs_order_number = cr_order_number + GROUP BY cs_item_sk + HAVING sum(cs_ext_list_price) > 2 * sum(cr_refunded_cash + cr_reversed_charge + cr_store_credit)), + cross_sales AS + (SELECT + i_product_name product_name, + i_item_sk item_sk, + s_store_name store_name, + s_zip store_zip, + ad1.ca_street_number b_street_number, + ad1.ca_street_name b_streen_name, + ad1.ca_city b_city, + ad1.ca_zip b_zip, + ad2.ca_street_number c_street_number, + ad2.ca_street_name c_street_name, + ad2.ca_city c_city, + ad2.ca_zip c_zip, + d1.d_year AS syear, + d2.d_year AS fsyear, + d3.d_year s2year, + count(*) cnt, + sum(ss_wholesale_cost) s1, + sum(ss_list_price) s2, + sum(ss_coupon_amt) s3 + FROM store_sales, store_returns, cs_ui, date_dim d1, date_dim d2, date_dim d3, + store, customer, customer_demographics cd1, customer_demographics cd2, + promotion, household_demographics hd1, household_demographics hd2, + customer_address ad1, customer_address ad2, income_band ib1, income_band ib2, item + WHERE ss_store_sk = s_store_sk AND + ss_sold_date_sk = d1.d_date_sk AND + ss_customer_sk = c_customer_sk AND + ss_cdemo_sk = cd1.cd_demo_sk AND + ss_hdemo_sk = hd1.hd_demo_sk AND + ss_addr_sk = ad1.ca_address_sk AND + ss_item_sk = i_item_sk AND + ss_item_sk = sr_item_sk AND + ss_ticket_number = sr_ticket_number AND + ss_item_sk = cs_ui.cs_item_sk AND + c_current_cdemo_sk = cd2.cd_demo_sk AND + c_current_hdemo_sk = hd2.hd_demo_sk AND + c_current_addr_sk = ad2.ca_address_sk AND + c_first_sales_date_sk = d2.d_date_sk AND + c_first_shipto_date_sk = d3.d_date_sk AND + ss_promo_sk = p_promo_sk AND + hd1.hd_income_band_sk = ib1.ib_income_band_sk AND + hd2.hd_income_band_sk = ib2.ib_income_band_sk AND + cd1.cd_marital_status <> cd2.cd_marital_status AND + i_color IN ('purple', 'burlywood', 'indian', 'spring', 'floral', 'medium') AND + i_current_price BETWEEN 64 AND 64 + 10 AND + i_current_price BETWEEN 64 + 1 AND 64 + 15 + GROUP BY + i_product_name, + i_item_sk, + s_store_name, + s_zip, + ad1.ca_street_number, + ad1.ca_street_name, + ad1.ca_city, + ad1.ca_zip, + ad2.ca_street_number, + ad2.ca_street_name, + ad2.ca_city, + ad2.ca_zip, + d1.d_year, + d2.d_year, + d3.d_year + ) +SELECT + cs1.product_name, + cs1.store_name, + cs1.store_zip, + cs1.b_street_number, + cs1.b_streen_name, + cs1.b_city, + cs1.b_zip, + cs1.c_street_number, + cs1.c_street_name, + cs1.c_city, + cs1.c_zip, + cs1.syear, + cs1.cnt, + cs1.s1, + cs1.s2, + cs1.s3, + cs2.s1, + cs2.s2, + cs2.s3, + cs2.syear, + cs2.cnt +FROM cross_sales cs1, cross_sales cs2 +WHERE cs1.item_sk = cs2.item_sk AND + cs1.syear = 1999 AND + cs2.syear = 1999 + 1 AND + cs2.cnt <= cs1.cnt AND + cs1.store_name = cs2.store_name AND + cs1.store_zip = cs2.store_zip +ORDER BY + cs1.product_name, + cs1.store_name, + cs2.cnt, + -- The two columns below are newly added in TPCDS v2.7 + cs1.s1, + cs2.s1 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q67a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q67a.sql new file mode 100644 index 0000000000000..70a14043bbb3d --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q67a.sql @@ -0,0 +1,208 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sum(coalesce(ss_sales_price * ss_quantity, 0)) sumsales + from + store_sales, date_dim, store, item + where + ss_sold_date_sk=d_date_sk + and ss_item_sk=i_item_sk + and ss_store_sk = s_store_sk + and d_month_seq between 1212 and 1212 + 11 + group by + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id), +results_rollup as ( + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sumsales + from + results + union all + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy + union all + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy + union all + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand, + i_product_name, + d_year + union all + select + i_category, + i_class, + i_brand, + i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand, + i_product_name + union all + select + i_category, + i_class, + i_brand, + null i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand + union all + select + i_category, + i_class, + null i_brand, + null i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class + union all + select + i_category, + null i_class, + null i_brand, + null i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from results + group by + i_category + union all + select + null i_category, + null i_class, + null i_brand, + null i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results) +select + * +from ( + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sumsales, + rank() over (partition by i_category order by sumsales desc) rk + from results_rollup) dw2 +where + rk <= 100 +order by + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sumsales, + rk +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q70a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q70a.sql new file mode 100644 index 0000000000000..4aec9c7fd1fd6 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q70a.sql @@ -0,0 +1,70 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + sum(ss_net_profit) as total_sum, + s_state ,s_county, + 0 as gstate, + 0 as g_county + from + store_sales, date_dim d1, store + where + d1.d_month_seq between 1212 and 1212 + 11 + and d1.d_date_sk = ss_sold_date_sk + and s_store_sk = ss_store_sk + and s_state in ( + select s_state + from ( + select + s_state as s_state, + rank() over (partition by s_state order by sum(ss_net_profit) desc) as ranking + from store_sales, store, date_dim + where d_month_seq between 1212 and 1212 + 11 + and d_date_sk = ss_sold_date_sk + and s_store_sk = ss_store_sk + group by s_state) tmp1 + where ranking <= 5) + group by + s_state, s_county), +results_rollup as ( + select + total_sum, + s_state, + s_county, + 0 as g_state, + 0 as g_county, + 0 as lochierarchy + from results + union + select + sum(total_sum) as total_sum,s_state, + NULL as s_county, + 0 as g_state, + 1 as g_county, + 1 as lochierarchy + from results + group by s_state + union + select + sum(total_sum) as total_sum, + NULL as s_state, + NULL as s_county, + 1 as g_state, + 1 as g_county, + 2 as lochierarchy + from results) +select + total_sum, + s_state, + s_county, + lochierarchy, + rank() over ( + partition by lochierarchy, + case when g_county = 0 then s_state end + order by total_sum desc) as rank_within_parent +from + results_rollup +order by + lochierarchy desc, + case when lochierarchy = 0 then s_state end, + rank_within_parent +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q72.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q72.sql new file mode 100755 index 0000000000000..066d6a587e917 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q72.sql @@ -0,0 +1,40 @@ +SELECT + i_item_desc, + w_warehouse_name, + d1.d_week_seq, + count(CASE WHEN p_promo_sk IS NULL + THEN 1 + ELSE 0 END) no_promo, + count(CASE WHEN p_promo_sk IS NOT NULL + THEN 1 + ELSE 0 END) promo, + count(*) total_cnt +FROM catalog_sales + JOIN inventory ON (cs_item_sk = inv_item_sk) + JOIN warehouse ON (w_warehouse_sk = inv_warehouse_sk) + JOIN item ON (i_item_sk = cs_item_sk) + JOIN customer_demographics ON (cs_bill_cdemo_sk = cd_demo_sk) + JOIN household_demographics ON (cs_bill_hdemo_sk = hd_demo_sk) + JOIN date_dim d1 ON (cs_sold_date_sk = d1.d_date_sk) + JOIN date_dim d2 ON (inv_date_sk = d2.d_date_sk) + JOIN date_dim d3 ON (cs_ship_date_sk = d3.d_date_sk) + LEFT OUTER JOIN promotion ON (cs_promo_sk = p_promo_sk) + LEFT OUTER JOIN catalog_returns ON (cr_item_sk = cs_item_sk AND cr_order_number = cs_order_number) +-- q72 in TPCDS v1.4 had conditions below: +-- WHERE d1.d_week_seq = d2.d_week_seq +-- AND inv_quantity_on_hand < cs_quantity +-- AND d3.d_date > (cast(d1.d_date AS DATE) + interval 5 days) +-- AND hd_buy_potential = '>10000' +-- AND d1.d_year = 1999 +-- AND hd_buy_potential = '>10000' +-- AND cd_marital_status = 'D' +-- AND d1.d_year = 1999 +WHERE d1.d_week_seq = d2.d_week_seq + AND inv_quantity_on_hand < cs_quantity + AND d3.d_date > d1.d_date + INTERVAL 5 days + AND hd_buy_potential = '1001-5000' + AND d1.d_year = 2001 + AND cd_marital_status = 'M' +GROUP BY i_item_desc, w_warehouse_name, d1.d_week_seq +ORDER BY total_cnt DESC, i_item_desc, w_warehouse_name, d_week_seq +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q74.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q74.sql new file mode 100755 index 0000000000000..94a0063b36c0c --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q74.sql @@ -0,0 +1,60 @@ +WITH year_total AS ( + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + d_year AS year, + sum(ss_net_paid) year_total, + 's' sale_type + FROM + customer, store_sales, date_dim + WHERE c_customer_sk = ss_customer_sk + AND ss_sold_date_sk = d_date_sk + AND d_year IN (2001, 2001 + 1) + GROUP BY + c_customer_id, c_first_name, c_last_name, d_year + UNION ALL + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + d_year AS year, + sum(ws_net_paid) year_total, + 'w' sale_type + FROM + customer, web_sales, date_dim + WHERE c_customer_sk = ws_bill_customer_sk + AND ws_sold_date_sk = d_date_sk + AND d_year IN (2001, 2001 + 1) + GROUP BY + c_customer_id, c_first_name, c_last_name, d_year) +SELECT + t_s_secyear.customer_id, + t_s_secyear.customer_first_name, + t_s_secyear.customer_last_name +FROM + year_total t_s_firstyear, year_total t_s_secyear, + year_total t_w_firstyear, year_total t_w_secyear +WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id + AND t_s_firstyear.customer_id = t_w_secyear.customer_id + AND t_s_firstyear.customer_id = t_w_firstyear.customer_id + AND t_s_firstyear.sale_type = 's' + AND t_w_firstyear.sale_type = 'w' + AND t_s_secyear.sale_type = 's' + AND t_w_secyear.sale_type = 'w' + AND t_s_firstyear.year = 2001 + AND t_s_secyear.year = 2001 + 1 + AND t_w_firstyear.year = 2001 + AND t_w_secyear.year = 2001 + 1 + AND t_s_firstyear.year_total > 0 + AND t_w_firstyear.year_total > 0 + AND CASE WHEN t_w_firstyear.year_total > 0 + THEN t_w_secyear.year_total / t_w_firstyear.year_total + ELSE NULL END + > CASE WHEN t_s_firstyear.year_total > 0 + THEN t_s_secyear.year_total / t_s_firstyear.year_total + ELSE NULL END +-- order-by list of q74 in TPCDS v1.4 is below: +-- ORDER BY 1, 1, 1 +ORDER BY 2, 1, 3 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q75.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q75.sql new file mode 100755 index 0000000000000..ae5dc97ef2317 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q75.sql @@ -0,0 +1,78 @@ +WITH all_sales AS ( + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + SUM(sales_cnt) AS sales_cnt, + SUM(sales_amt) AS sales_amt + FROM ( + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + cs_quantity - COALESCE(cr_return_quantity, 0) AS sales_cnt, + cs_ext_sales_price - COALESCE(cr_return_amount, 0.0) AS sales_amt + FROM catalog_sales + JOIN item ON i_item_sk = cs_item_sk + JOIN date_dim ON d_date_sk = cs_sold_date_sk + LEFT JOIN catalog_returns ON (cs_order_number = cr_order_number + AND cs_item_sk = cr_item_sk) + WHERE i_category = 'Books' + UNION + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + ss_quantity - COALESCE(sr_return_quantity, 0) AS sales_cnt, + ss_ext_sales_price - COALESCE(sr_return_amt, 0.0) AS sales_amt + FROM store_sales + JOIN item ON i_item_sk = ss_item_sk + JOIN date_dim ON d_date_sk = ss_sold_date_sk + LEFT JOIN store_returns ON (ss_ticket_number = sr_ticket_number + AND ss_item_sk = sr_item_sk) + WHERE i_category = 'Books' + UNION + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + ws_quantity - COALESCE(wr_return_quantity, 0) AS sales_cnt, + ws_ext_sales_price - COALESCE(wr_return_amt, 0.0) AS sales_amt + FROM web_sales + JOIN item ON i_item_sk = ws_item_sk + JOIN date_dim ON d_date_sk = ws_sold_date_sk + LEFT JOIN web_returns ON (ws_order_number = wr_order_number + AND ws_item_sk = wr_item_sk) + WHERE i_category = 'Books') sales_detail + GROUP BY d_year, i_brand_id, i_class_id, i_category_id, i_manufact_id) +SELECT + prev_yr.d_year AS prev_year, + curr_yr.d_year AS year, + curr_yr.i_brand_id, + curr_yr.i_class_id, + curr_yr.i_category_id, + curr_yr.i_manufact_id, + prev_yr.sales_cnt AS prev_yr_cnt, + curr_yr.sales_cnt AS curr_yr_cnt, + curr_yr.sales_cnt - prev_yr.sales_cnt AS sales_cnt_diff, + curr_yr.sales_amt - prev_yr.sales_amt AS sales_amt_diff +FROM all_sales curr_yr, all_sales prev_yr +WHERE curr_yr.i_brand_id = prev_yr.i_brand_id + AND curr_yr.i_class_id = prev_yr.i_class_id + AND curr_yr.i_category_id = prev_yr.i_category_id + AND curr_yr.i_manufact_id = prev_yr.i_manufact_id + AND curr_yr.d_year = 2002 + AND prev_yr.d_year = 2002 - 1 + AND CAST(curr_yr.sales_cnt AS DECIMAL(17, 2)) / CAST(prev_yr.sales_cnt AS DECIMAL(17, 2)) < 0.9 +ORDER BY + sales_cnt_diff, + sales_amt_diff -- This order-by condition did not exist in TPCDS v1.4 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q77a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q77a.sql new file mode 100644 index 0000000000000..fc69c43470f1e --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q77a.sql @@ -0,0 +1,121 @@ +-- This is a new query in TPCDS v2.7 +with ss as ( + select + s_store_sk, + sum(ss_ext_sales_price) as sales, + sum(ss_net_profit) as profit + from + store_sales, date_dim, store + where + ss_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and ss_store_sk = s_store_sk + group by + s_store_sk), +sr as ( + select + s_store_sk, + sum(sr_return_amt) as returns, + sum(sr_net_loss) as profit_loss + from + store_returns, date_dim, store + where + sr_returned_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and sr_store_sk = s_store_sk + group by + s_store_sk), +cs as ( + select + cs_call_center_sk, + sum(cs_ext_sales_price) as sales, + sum(cs_net_profit) as profit + from + catalog_sales, + date_dim + where + cs_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + group by + cs_call_center_sk), + cr as ( + select + sum(cr_return_amount) as returns, + sum(cr_net_loss) as profit_loss + from catalog_returns, + date_dim + where + cr_returned_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days)), +ws as ( select wp_web_page_sk, + sum(ws_ext_sales_price) as sales, + sum(ws_net_profit) as profit + from web_sales, + date_dim, + web_page + where ws_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and ws_web_page_sk = wp_web_page_sk + group by wp_web_page_sk), + wr as + (select wp_web_page_sk, + sum(wr_return_amt) as returns, + sum(wr_net_loss) as profit_loss + from web_returns, + date_dim, + web_page + where wr_returned_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and wr_web_page_sk = wp_web_page_sk + group by wp_web_page_sk) + , + results as + (select channel + , id + , sum(sales) as sales + , sum(returns) as returns + , sum(profit) as profit + from + (select 'store channel' as channel + , ss.s_store_sk as id + , sales + , coalesce(returns, 0) as returns + , (profit - coalesce(profit_loss,0)) as profit + from ss left join sr + on ss.s_store_sk = sr.s_store_sk + union all + select 'catalog channel' as channel + , cs_call_center_sk as id + , sales + , returns + , (profit - profit_loss) as profit + from cs + , cr + union all + select 'web channel' as channel + , ws.wp_web_page_sk as id + , sales + , coalesce(returns, 0) returns + , (profit - coalesce(profit_loss,0)) as profit + from ws left join wr + on ws.wp_web_page_sk = wr.wp_web_page_sk + ) x + group by channel, id ) + + select * + from ( + select channel, id, sales, returns, profit from results + union + select channel, NULL AS id, sum(sales) as sales, sum(returns) as returns, sum(profit) as profit from results group by channel + union + select NULL AS channel, NULL AS id, sum(sales) as sales, sum(returns) as returns, sum(profit) as profit from results +) foo +order by + channel, id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q78.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q78.sql new file mode 100755 index 0000000000000..d03d8af77174c --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q78.sql @@ -0,0 +1,75 @@ +WITH ws AS +(SELECT + d_year AS ws_sold_year, + ws_item_sk, + ws_bill_customer_sk ws_customer_sk, + sum(ws_quantity) ws_qty, + sum(ws_wholesale_cost) ws_wc, + sum(ws_sales_price) ws_sp + FROM web_sales + LEFT JOIN web_returns ON wr_order_number = ws_order_number AND ws_item_sk = wr_item_sk + JOIN date_dim ON ws_sold_date_sk = d_date_sk + WHERE wr_order_number IS NULL + GROUP BY d_year, ws_item_sk, ws_bill_customer_sk +), + cs AS + (SELECT + d_year AS cs_sold_year, + cs_item_sk, + cs_bill_customer_sk cs_customer_sk, + sum(cs_quantity) cs_qty, + sum(cs_wholesale_cost) cs_wc, + sum(cs_sales_price) cs_sp + FROM catalog_sales + LEFT JOIN catalog_returns ON cr_order_number = cs_order_number AND cs_item_sk = cr_item_sk + JOIN date_dim ON cs_sold_date_sk = d_date_sk + WHERE cr_order_number IS NULL + GROUP BY d_year, cs_item_sk, cs_bill_customer_sk + ), + ss AS + (SELECT + d_year AS ss_sold_year, + ss_item_sk, + ss_customer_sk, + sum(ss_quantity) ss_qty, + sum(ss_wholesale_cost) ss_wc, + sum(ss_sales_price) ss_sp + FROM store_sales + LEFT JOIN store_returns ON sr_ticket_number = ss_ticket_number AND ss_item_sk = sr_item_sk + JOIN date_dim ON ss_sold_date_sk = d_date_sk + WHERE sr_ticket_number IS NULL + GROUP BY d_year, ss_item_sk, ss_customer_sk + ) +SELECT + round(ss_qty / (coalesce(ws_qty + cs_qty, 1)), 2) ratio, + ss_qty store_qty, + ss_wc store_wholesale_cost, + ss_sp store_sales_price, + coalesce(ws_qty, 0) + coalesce(cs_qty, 0) other_chan_qty, + coalesce(ws_wc, 0) + coalesce(cs_wc, 0) other_chan_wholesale_cost, + coalesce(ws_sp, 0) + coalesce(cs_sp, 0) other_chan_sales_price +FROM ss + LEFT JOIN ws + ON (ws_sold_year = ss_sold_year AND ws_item_sk = ss_item_sk AND ws_customer_sk = ss_customer_sk) + LEFT JOIN cs + ON (cs_sold_year = ss_sold_year AND cs_item_sk = ss_item_sk AND cs_customer_sk = ss_customer_sk) +WHERE coalesce(ws_qty, 0) > 0 AND coalesce(cs_qty, 0) > 0 AND ss_sold_year = 2000 +ORDER BY + -- order-by list of q78 in TPCDS v1.4 is below: + -- ratio, + -- ss_qty DESC, ss_wc DESC, ss_sp DESC, + -- other_chan_qty, + -- other_chan_wholesale_cost, + -- other_chan_sales_price, + -- round(ss_qty / (coalesce(ws_qty + cs_qty, 1)), 2) + ss_sold_year, + ss_item_sk, + ss_customer_sk, + ss_qty desc, + ss_wc desc, + ss_sp desc, + other_chan_qty, + other_chan_wholesale_cost, + other_chan_sales_price, + ratio +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q80a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q80a.sql new file mode 100644 index 0000000000000..686e03ba2a6d0 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q80a.sql @@ -0,0 +1,147 @@ +-- This is a new query in TPCDS v2.7 +with ssr as ( + select + s_store_id as store_id, + sum(ss_ext_sales_price) as sales, + sum(coalesce(sr_return_amt, 0)) as returns, + sum(ss_net_profit - coalesce(sr_net_loss, 0)) as profit + from + store_sales left outer join store_returns on ( + ss_item_sk = sr_item_sk and ss_ticket_number = sr_ticket_number), + date_dim, + store, + item, + promotion + where + ss_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and ss_store_sk = s_store_sk + and ss_item_sk = i_item_sk + and i_current_price > 50 + and ss_promo_sk = p_promo_sk + and p_channel_tv = 'N' + group by + s_store_id), +csr as ( + select + cp_catalog_page_id as catalog_page_id, + sum(cs_ext_sales_price) as sales, + sum(coalesce(cr_return_amount, 0)) as returns, + sum(cs_net_profit - coalesce(cr_net_loss, 0)) as profit + from + catalog_sales left outer join catalog_returns on + (cs_item_sk = cr_item_sk and cs_order_number = cr_order_number), + date_dim, + catalog_page, + item, + promotion + where + cs_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and cs_catalog_page_sk = cp_catalog_page_sk + and cs_item_sk = i_item_sk + and i_current_price > 50 + and cs_promo_sk = p_promo_sk + and p_channel_tv = 'N' + group by + cp_catalog_page_id), +wsr as ( + select + web_site_id, + sum(ws_ext_sales_price) as sales, + sum(coalesce(wr_return_amt, 0)) as returns, + sum(ws_net_profit - coalesce(wr_net_loss, 0)) as profit + from + web_sales left outer join web_returns on ( + ws_item_sk = wr_item_sk and ws_order_number = wr_order_number), + date_dim, + web_site, + item, + promotion + where + ws_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and ws_web_site_sk = web_site_sk + and ws_item_sk = i_item_sk + and i_current_price > 50 + and ws_promo_sk = p_promo_sk + and p_channel_tv = 'N' + group by + web_site_id), +results as ( + select + channel, + id, + sum(sales) as sales, + sum(returns) as returns, + sum(profit) as profit + from ( + select + 'store channel' as channel, + 'store' || store_id as id, + sales, + returns, + profit + from + ssr + union all + select + 'catalog channel' as channel, + 'catalog_page' || catalog_page_id as id, + sales, + returns, + profit + from + csr + union all + select + 'web channel' as channel, + 'web_site' || web_site_id as id, + sales, + returns, + profit + from + wsr) x + group by + channel, id) +select + channel, + id, + sales, + returns, + profit +from ( + select + channel, + id, + sales, + returns, + profit + from + results + union + select + channel, + NULL AS id, + sum(sales) as sales, + sum(returns) as returns, + sum(profit) as profit + from + results + group by + channel + union + select + NULL AS channel, + NULL AS id, + sum(sales) as sales, + sum(returns) as returns, + sum(profit) as profit + from + results) foo +order by + channel, id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q86a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q86a.sql new file mode 100644 index 0000000000000..fff76b08d4ba0 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q86a.sql @@ -0,0 +1,61 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + sum(ws_net_paid) as total_sum, + i_category, i_class, + 0 as g_category, + 0 as g_class + from + web_sales, date_dim d1, item + where + d1.d_month_seq between 1212 and 1212 + 11 + and d1.d_date_sk = ws_sold_date_sk + and i_item_sk = ws_item_sk + group by + i_category, i_class), +results_rollup as( + select + total_sum, + i_category, + i_class, + g_category, + g_class, + 0 as lochierarchy + from + results + union + select + sum(total_sum) as total_sum, + i_category, + NULL as i_class, + 0 as g_category, + 1 as g_class, + 1 as lochierarchy + from + results + group by + i_category + union + select + sum(total_sum) as total_sum, + NULL as i_category, + NULL as i_class, + 1 as g_category, + 1 as g_class, + 2 as lochierarchy + from + results) +select + total_sum, + i_category ,i_class, lochierarchy, + rank() over ( + partition by lochierarchy, + case when g_class = 0 then i_category end + order by total_sum desc) as rank_within_parent +from + results_rollup +order by + lochierarchy desc, + case when lochierarchy = 0 then i_category end, + rank_within_parent +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q98.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q98.sql new file mode 100755 index 0000000000000..771117add2ed2 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q98.sql @@ -0,0 +1,22 @@ +SELECT + i_item_id, -- This column did not exist in TPCDS v1.4 + i_item_desc, + i_category, + i_class, + i_current_price, + sum(ss_ext_sales_price) AS itemrevenue, + sum(ss_ext_sales_price) * 100 / sum(sum(ss_ext_sales_price)) + OVER + (PARTITION BY i_class) AS revenueratio +FROM + store_sales, item, date_dim +WHERE + ss_item_sk = i_item_sk + AND i_category IN ('Sports', 'Books', 'Home') + AND ss_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('1999-02-22' AS DATE) + AND (cast('1999-02-22' AS DATE) + INTERVAL 30 days) +GROUP BY + i_item_id, i_item_desc, i_category, i_class, i_current_price +ORDER BY + i_category, i_class, i_item_id, i_item_desc, revenueratio diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala index 1a584187a06e5..bc95b4696190d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -62,7 +62,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`c_first_sales_date_sk` INT, `c_salutation` STRING, `c_first_name` STRING, |`c_last_name` STRING, `c_preferred_cust_flag` STRING, `c_birth_day` INT, |`c_birth_month` INT, `c_birth_year` INT, `c_birth_country` STRING, `c_login` STRING, - |`c_email_address` STRING, `c_last_review_date` STRING) + |`c_email_address` STRING, `c_last_review_date` INT) |USING parquet """.stripMargin) @@ -88,7 +88,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { sql( """ |CREATE TABLE `date_dim` ( - |`d_date_sk` INT, `d_date_id` STRING, `d_date` STRING, + |`d_date_sk` INT, `d_date_id` STRING, `d_date` DATE, |`d_month_seq` INT, `d_week_seq` INT, `d_quarter_seq` INT, `d_year` INT, `d_dow` INT, |`d_moy` INT, `d_dom` INT, `d_qoy` INT, `d_fy_year` INT, `d_fy_quarter_seq` INT, |`d_fy_week_seq` INT, `d_day_name` STRING, `d_quarter_name` STRING, `d_holiday` STRING, @@ -115,8 +115,8 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { sql( """ - |CREATE TABLE `item` (`i_item_sk` INT, `i_item_id` STRING, `i_rec_start_date` STRING, - |`i_rec_end_date` STRING, `i_item_desc` STRING, `i_current_price` DECIMAL(7,2), + |CREATE TABLE `item` (`i_item_sk` INT, `i_item_id` STRING, `i_rec_start_date` DATE, + |`i_rec_end_date` DATE, `i_item_desc` STRING, `i_current_price` DECIMAL(7,2), |`i_wholesale_cost` DECIMAL(7,2), `i_brand_id` INT, `i_brand` STRING, `i_class_id` INT, |`i_class` STRING, `i_category_id` INT, `i_category` STRING, `i_manufact_id` INT, |`i_manufact` STRING, `i_size` STRING, `i_formulation` STRING, `i_color` STRING, @@ -139,8 +139,8 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { sql( """ |CREATE TABLE `store` ( - |`s_store_sk` INT, `s_store_id` STRING, `s_rec_start_date` STRING, - |`s_rec_end_date` STRING, `s_closed_date_sk` INT, `s_store_name` STRING, + |`s_store_sk` INT, `s_store_id` STRING, `s_rec_start_date` DATE, + |`s_rec_end_date` DATE, `s_closed_date_sk` INT, `s_store_name` STRING, |`s_number_employees` INT, `s_floor_space` INT, `s_hours` STRING, `s_manager` STRING, |`s_market_id` INT, `s_geography_class` STRING, `s_market_desc` STRING, |`s_market_manager` STRING, `s_division_id` INT, `s_division_name` STRING, @@ -157,7 +157,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`sr_returned_date_sk` BIGINT, `sr_return_time_sk` BIGINT, `sr_item_sk` BIGINT, |`sr_customer_sk` BIGINT, `sr_cdemo_sk` BIGINT, `sr_hdemo_sk` BIGINT, `sr_addr_sk` BIGINT, |`sr_store_sk` BIGINT, `sr_reason_sk` BIGINT, `sr_ticket_number` BIGINT, - |`sr_return_quantity` BIGINT, `sr_return_amt` DECIMAL(7,2), `sr_return_tax` DECIMAL(7,2), + |`sr_return_quantity` INT, `sr_return_amt` DECIMAL(7,2), `sr_return_tax` DECIMAL(7,2), |`sr_return_amt_inc_tax` DECIMAL(7,2), `sr_fee` DECIMAL(7,2), |`sr_return_ship_cost` DECIMAL(7,2), `sr_refunded_cash` DECIMAL(7,2), |`sr_reversed_charge` DECIMAL(7,2), `sr_store_credit` DECIMAL(7,2), @@ -225,7 +225,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`wr_refunded_hdemo_sk` BIGINT, `wr_refunded_addr_sk` BIGINT, |`wr_returning_customer_sk` BIGINT, `wr_returning_cdemo_sk` BIGINT, |`wr_returning_hdemo_sk` BIGINT, `wr_returning_addr_sk` BIGINT, `wr_web_page_sk` BIGINT, - |`wr_reason_sk` BIGINT, `wr_order_number` BIGINT, `wr_return_quantity` BIGINT, + |`wr_reason_sk` BIGINT, `wr_order_number` BIGINT, `wr_return_quantity` INT, |`wr_return_amt` DECIMAL(7,2), `wr_return_tax` DECIMAL(7,2), |`wr_return_amt_inc_tax` DECIMAL(7,2), `wr_fee` DECIMAL(7,2), |`wr_return_ship_cost` DECIMAL(7,2), `wr_refunded_cash` DECIMAL(7,2), @@ -244,7 +244,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`web_company_id` INT, `web_company_name` STRING, `web_street_number` STRING, |`web_street_name` STRING, `web_street_type` STRING, `web_suite_number` STRING, |`web_city` STRING, `web_county` STRING, `web_state` STRING, `web_zip` STRING, - |`web_country` STRING, `web_gmt_offset` STRING, `web_tax_percentage` DECIMAL(5,2)) + |`web_country` STRING, `web_gmt_offset` DECIMAL(5,2), `web_tax_percentage` DECIMAL(5,2)) |USING parquet """.stripMargin) @@ -315,6 +315,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { """.stripMargin) } + // The TPCDS queries below are based on v1.4 val tpcdsQueries = Seq( "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14a", "q14b", "q15", "q16", "q17", "q18", "q19", "q20", @@ -339,6 +340,25 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { } } + // This list only includes TPCDS v2.7 queries that are different from v1.4 ones + val tpcdsQueriesV2_7_0 = Seq( + "q5a", "q6", "q10a", "q11", "q12", "q14", "q14a", "q18a", + "q20", "q22", "q22a", "q24", "q27a", "q34", "q35", "q35a", "q36a", "q47", "q49", + "q51a", "q57", "q64", "q67a", "q70a", "q72", "q74", "q75", "q77a", "q78", + "q80a", "q86a", "q98") + + tpcdsQueriesV2_7_0.foreach { name => + val queryString = resourceToString(s"tpcds-v2.7.0/$name.sql", + classLoader = Thread.currentThread().getContextClassLoader) + test(s"$name-v2.7") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + // check the plans can be properly generated + val plan = sql(queryString).queryExecution.executedPlan + checkGeneratedCode(plan) + } + } + } + // These queries are from https://github.com/cloudera/impala-tpcds-kit/tree/master/queries val modifiedTPCDSQueries = Seq( "q3", "q7", "q10", "q19", "q27", "q34", "q42", "q43", "q46", "q52", "q53", "q55", "q59", From e4bec7cb88b9ee63f8497e3f9e0ab0bfa5d5a77c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 25 Mar 2018 16:38:49 -0700 Subject: [PATCH 521/774] [SPARK-23549][SQL] Cast to timestamp when comparing timestamp with date ## What changes were proposed in this pull request? This PR fixes an incorrect comparison in SQL between timestamp and date. This is because both of them are casted to `string` and then are compared lexicographically. This implementation shows `false` regarding this query `spark.sql("select cast('2017-03-01 00:00:00' as timestamp) between cast('2017-02-28' as date) and cast('2017-03-01' as date)").show`. This PR shows `true` for this query by casting `date("2017-03-01")` to `timestamp("2017-03-01 00:00:00")`. (Please fill in changes proposed in this fix) ## How was this patch tested? Added new UTs to `TypeCoercionSuite`. Author: Kazuaki Ishizaki Closes #20774 from kiszk/SPARK-23549. --- docs/sql-programming-guide.md | 1 + .../sql/catalyst/analysis/TypeCoercion.scala | 29 ++++++++----- .../apache/spark/sql/internal/SQLConf.scala | 13 ++++++ .../catalyst/analysis/TypeCoercionSuite.scala | 34 ++++++++++++--- .../sql-tests/inputs/predicate-functions.sql | 7 ++++ .../results/predicate-functions.sql.out | 42 ++++++++++++++++++- 6 files changed, 108 insertions(+), 18 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 421e2eaf62bfb..2b393f30d1435 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1808,6 +1808,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. + - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e8669c4637d06..ec7e7761dc4c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -47,9 +47,9 @@ import org.apache.spark.sql.types._ object TypeCoercion { def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] = - InConversion :: + InConversion(conf) :: WidenSetOperationTypes :: - PromoteStrings :: + PromoteStrings(conf) :: DecimalPrecision :: BooleanEquality :: FunctionArgumentConversion :: @@ -127,7 +127,8 @@ object TypeCoercion { * is a String and the other is not. It also handles when one op is a Date and the * other is a Timestamp by making the target type to be String. */ - val findCommonTypeForBinaryComparison: (DataType, DataType) => Option[DataType] = { + private def findCommonTypeForBinaryComparison( + dt1: DataType, dt2: DataType, conf: SQLConf): Option[DataType] = (dt1, dt2) match { // We should cast all relative timestamp/date/string comparison into string comparisons // This behaves as a user would expect because timestamp strings sort lexicographically. // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true @@ -135,11 +136,17 @@ object TypeCoercion { case (DateType, StringType) => Some(StringType) case (StringType, TimestampType) => Some(StringType) case (TimestampType, StringType) => Some(StringType) - case (TimestampType, DateType) => Some(StringType) - case (DateType, TimestampType) => Some(StringType) case (StringType, NullType) => Some(StringType) case (NullType, StringType) => Some(StringType) + // Cast to TimestampType when we compare DateType with TimestampType + // if conf.compareDateTimestampInTimestamp is true + // i.e. TimeStamp('2017-03-01 00:00:00') eq Date('2017-03-01') = true + case (TimestampType, DateType) + => if (conf.compareDateTimestampInTimestamp) Some(TimestampType) else Some(StringType) + case (DateType, TimestampType) + => if (conf.compareDateTimestampInTimestamp) Some(TimestampType) else Some(StringType) + // There is no proper decimal type we can pick, // using double type is the best we can do. // See SPARK-22469 for details. @@ -147,7 +154,7 @@ object TypeCoercion { case (s: StringType, n: DecimalType) => Some(DoubleType) case (l: StringType, r: AtomicType) if r != StringType => Some(r) - case (l: AtomicType, r: StringType) if (l != StringType) => Some(l) + case (l: AtomicType, r: StringType) if l != StringType => Some(l) case (l, r) => None } @@ -313,7 +320,7 @@ object TypeCoercion { /** * Promotes strings that appear in arithmetic expressions. */ - object PromoteStrings extends TypeCoercionRule { + case class PromoteStrings(conf: SQLConf) extends TypeCoercionRule { private def castExpr(expr: Expression, targetType: DataType): Expression = { (expr.dataType, targetType) match { case (NullType, dt) => Literal.create(null, targetType) @@ -342,8 +349,8 @@ object TypeCoercion { p.makeCopy(Array(left, Cast(right, TimestampType))) case p @ BinaryComparison(left, right) - if findCommonTypeForBinaryComparison(left.dataType, right.dataType).isDefined => - val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get + if findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).isDefined => + val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).get p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType))) case Abs(e @ StringType()) => Abs(Cast(e, DoubleType)) @@ -374,7 +381,7 @@ object TypeCoercion { * operator type is found the original expression will be returned and an * Analysis Exception will be raised at the type checking phase. */ - object InConversion extends TypeCoercionRule { + case class InConversion(conf: SQLConf) extends TypeCoercionRule { private def flattenExpr(expr: Expression): Seq[Expression] = { expr match { // Multi columns in IN clause is represented as a CreateNamedStruct. @@ -400,7 +407,7 @@ object TypeCoercion { val rhs = sub.output val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => - findCommonTypeForBinaryComparison(l.dataType, r.dataType) + findCommonTypeForBinaryComparison(l.dataType, r.dataType, conf) .orElse(findTightestCommonType(l.dataType, r.dataType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 11864bd1b1847..9cb03b5bb6152 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -479,6 +479,16 @@ object SQLConf { .checkValues(HiveCaseSensitiveInferenceMode.values.map(_.toString)) .createWithDefault(HiveCaseSensitiveInferenceMode.INFER_AND_SAVE.toString) + val TYPECOERCION_COMPARE_DATE_TIMESTAMP_IN_TIMESTAMP = + buildConf("spark.sql.typeCoercion.compareDateTimestampInTimestamp") + .internal() + .doc("When true (default), compare Date with Timestamp after converting both sides to " + + "Timestamp. This behavior is compatible with Hive 2.2 or later. See HIVE-15236. " + + "When false, restore the behavior prior to Spark 2.4. Compare Date with Timestamp after " + + "converting both sides to string. This config will be removed in spark 3.0") + .booleanConf + .createWithDefault(true) + val OPTIMIZER_METADATA_ONLY = buildConf("spark.sql.optimizer.metadataOnly") .doc("When true, enable the metadata-only query optimization that use the table's metadata " + "to produce the partition columns instead of table scans. It applies when all the columns " + @@ -1332,6 +1342,9 @@ class SQLConf extends Serializable with Logging { def caseSensitiveInferenceMode: HiveCaseSensitiveInferenceMode.Value = HiveCaseSensitiveInferenceMode.withName(getConf(HIVE_CASE_SENSITIVE_INFERENCE)) + def compareDateTimestampInTimestamp : Boolean = + getConf(TYPECOERCION_COMPARE_DATE_TIMESTAMP_IN_TIMESTAMP) + def gatherFastStats: Boolean = getConf(GATHER_FASTSTAT) def optimizerMetadataOnly: Boolean = getConf(OPTIMIZER_METADATA_ONLY) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 52a7ebdafd7c7..8ac49dc05e3cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1207,7 +1207,7 @@ class TypeCoercionSuite extends AnalysisTest { */ test("make sure rules do not fire early") { // InConversion - val inConversion = TypeCoercion.InConversion + val inConversion = TypeCoercion.InConversion(conf) ruleTest(inConversion, In(UnresolvedAttribute("a"), Seq(Literal(1))), In(UnresolvedAttribute("a"), Seq(Literal(1))) @@ -1251,18 +1251,40 @@ class TypeCoercionSuite extends AnalysisTest { } test("binary comparison with string promotion") { - ruleTest(PromoteStrings, + val rule = TypeCoercion.PromoteStrings(conf) + ruleTest(rule, GreaterThan(Literal("123"), Literal(1)), GreaterThan(Cast(Literal("123"), IntegerType), Literal(1))) - ruleTest(PromoteStrings, + ruleTest(rule, LessThan(Literal(true), Literal("123")), LessThan(Literal(true), Cast(Literal("123"), BooleanType))) - ruleTest(PromoteStrings, + ruleTest(rule, EqualTo(Literal(Array(1, 2)), Literal("123")), EqualTo(Literal(Array(1, 2)), Literal("123"))) - ruleTest(PromoteStrings, + ruleTest(rule, GreaterThan(Literal("1.5"), Literal(BigDecimal("0.5"))), - GreaterThan(Cast(Literal("1.5"), DoubleType), Cast(Literal(BigDecimal("0.5")), DoubleType))) + GreaterThan(Cast(Literal("1.5"), DoubleType), Cast(Literal(BigDecimal("0.5")), + DoubleType))) + Seq(true, false).foreach { convertToTS => + withSQLConf( + "spark.sql.typeCoercion.compareDateTimestampInTimestamp" -> convertToTS.toString) { + val date0301 = Literal(java.sql.Date.valueOf("2017-03-01")) + val timestamp0301000000 = Literal(Timestamp.valueOf("2017-03-01 00:00:00")) + val timestamp0301000001 = Literal(Timestamp.valueOf("2017-03-01 00:00:01")) + if (convertToTS) { + // `Date` should be treated as timestamp at 00:00:00 See SPARK-23549 + ruleTest(rule, EqualTo(date0301, timestamp0301000000), + EqualTo(Cast(date0301, TimestampType), timestamp0301000000)) + ruleTest(rule, LessThan(date0301, timestamp0301000001), + LessThan(Cast(date0301, TimestampType), timestamp0301000001)) + } else { + ruleTest(rule, LessThan(date0301, timestamp0301000000), + LessThan(Cast(date0301, StringType), Cast(timestamp0301000000, StringType))) + ruleTest(rule, LessThan(date0301, timestamp0301000001), + LessThan(Cast(date0301, StringType), Cast(timestamp0301000001, StringType))) + } + } + } } test("cast WindowFrame boundaries to the type they operate upon") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql index e99d5cef81f64..fadb4bb27fa13 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql @@ -39,3 +39,10 @@ select 2.0 <= '2.2'; select 0.5 <= '1.5'; select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52'); select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52'; + +-- SPARK-23549: Cast to timestamp when comparing timestamp with date +select to_date('2017-03-01') = to_timestamp('2017-03-01 00:00:00'); +select to_timestamp('2017-03-01 00:00:01') > to_date('2017-03-01'); +select to_timestamp('2017-03-01 00:00:01') >= to_date('2017-03-01'); +select to_date('2017-03-01') < to_timestamp('2017-03-01 00:00:01'); +select to_date('2017-03-01') <= to_timestamp('2017-03-01 00:00:01'); diff --git a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out index d51f6d37e4b41..cf828c69af62a 100644 --- a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 32 +-- Number of queries: 37 -- !query 0 @@ -256,3 +256,43 @@ select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52' struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) <= 2009-07-30 04:17:52):boolean> -- !query 31 output true + + +-- !query 32 +select to_date('2017-03-01') = to_timestamp('2017-03-01 00:00:00') +-- !query 32 schema +struct<(CAST(to_date('2017-03-01') AS TIMESTAMP) = to_timestamp('2017-03-01 00:00:00')):boolean> +-- !query 32 output +true + + +-- !query 33 +select to_timestamp('2017-03-01 00:00:01') > to_date('2017-03-01') +-- !query 33 schema +struct<(to_timestamp('2017-03-01 00:00:01') > CAST(to_date('2017-03-01') AS TIMESTAMP)):boolean> +-- !query 33 output +true + + +-- !query 34 +select to_timestamp('2017-03-01 00:00:01') >= to_date('2017-03-01') +-- !query 34 schema +struct<(to_timestamp('2017-03-01 00:00:01') >= CAST(to_date('2017-03-01') AS TIMESTAMP)):boolean> +-- !query 34 output +true + + +-- !query 35 +select to_date('2017-03-01') < to_timestamp('2017-03-01 00:00:01') +-- !query 35 schema +struct<(CAST(to_date('2017-03-01') AS TIMESTAMP) < to_timestamp('2017-03-01 00:00:01')):boolean> +-- !query 35 output +true + + +-- !query 36 +select to_date('2017-03-01') <= to_timestamp('2017-03-01 00:00:01') +-- !query 36 schema +struct<(CAST(to_date('2017-03-01') AS TIMESTAMP) <= to_timestamp('2017-03-01 00:00:01')):boolean> +-- !query 36 output +true From a9350d7095b79c8374fb4a06fd3f1a1a67615f6f Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 26 Mar 2018 12:42:32 +0900 Subject: [PATCH 522/774] [SPARK-23700][PYTHON] Cleanup imports in pyspark.sql ## What changes were proposed in this pull request? This cleans up unused imports, mainly from pyspark.sql module. Added a note in function.py that imports `UserDefinedFunction` only to maintain backwards compatibility for using `from pyspark.sql.function import UserDefinedFunction`. ## How was this patch tested? Existing tests and built docs. Author: Bryan Cutler Closes #20892 from BryanCutler/pyspark-cleanup-imports-SPARK-23700. --- python/pyspark/sql/column.py | 1 - python/pyspark/sql/conf.py | 1 - python/pyspark/sql/functions.py | 3 +-- python/pyspark/sql/group.py | 3 +-- python/pyspark/sql/readwriter.py | 2 +- python/pyspark/sql/streaming.py | 2 -- python/pyspark/sql/types.py | 1 - python/pyspark/sql/udf.py | 6 ++---- python/pyspark/util.py | 2 -- 9 files changed, 5 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index e05a7b33c11a7..922c7cf288f8f 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -16,7 +16,6 @@ # import sys -import warnings import json if sys.version >= '3': diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index b82224b6194ed..db49040e17b63 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -67,7 +67,6 @@ def _checkType(self, obj, identifier): def _test(): import os import doctest - from pyspark.context import SparkContext from pyspark.sql.session import SparkSession import pyspark.sql.conf diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index dff590983b4d9..a4edb1e27b599 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -18,7 +18,6 @@ """ A collections of builtin functions """ -import math import sys import functools import warnings @@ -28,10 +27,10 @@ from pyspark import since, SparkContext from pyspark.rdd import ignore_unicode_prefix, PythonEvalType -from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import StringType, DataType +# Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409 from pyspark.sql.udf import UserDefinedFunction, _create_udf diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 35cac406e0965..3505065b648f2 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -19,9 +19,8 @@ from pyspark import since from pyspark.rdd import ignore_unicode_prefix, PythonEvalType -from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal +from pyspark.sql.column import Column, _to_seq from pyspark.sql.dataframe import DataFrame -from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.types import * __all__ = ["GroupedData"] diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index e5288636c596e..4f9b9383a5ef4 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -22,7 +22,7 @@ from py4j.java_gateway import JavaClass -from pyspark import RDD, since, keyword_only +from pyspark import RDD, since from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.types import * diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 07f9ac1b5aa9e..c7907aaaf1f7b 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -24,8 +24,6 @@ else: intlike = (int, long) -from abc import ABCMeta, abstractmethod - from pyspark import since, keyword_only from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 5d5919e451b46..1f6534836d64a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -35,7 +35,6 @@ from pyspark import SparkContext from pyspark.serializers import CloudPickleSerializer -from pyspark.util import _exception_message __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 24dd06c26089c..9dbe49b831cef 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -17,16 +17,14 @@ """ User-defined function related classes and functions """ -import sys -import inspect import functools import sys from pyspark import SparkContext, since from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column, _to_seq -from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \ - _parse_datatype_string, to_arrow_type, to_arrow_schema +from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\ + to_arrow_type, to_arrow_schema from pyspark.util import _get_argspec __all__ = ["UDFRegistration"] diff --git a/python/pyspark/util.py b/python/pyspark/util.py index ed1bdd0e4be83..49afc13640332 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -22,8 +22,6 @@ __all__ = [] -import sys - def _exception_message(excp): """Return the message from an exception as either a str or unicode object. Supports both From 087fb3142028d679524e22596b0ad4f74ff47e8d Mon Sep 17 00:00:00 2001 From: "Michael (Stu) Stewart" Date: Mon, 26 Mar 2018 12:45:45 +0900 Subject: [PATCH 523/774] [SPARK-23645][MINOR][DOCS][PYTHON] Add docs RE `pandas_udf` with keyword args ## What changes were proposed in this pull request? Add documentation about the limitations of `pandas_udf` with keyword arguments and related concepts, like `functools.partial` fn objects. NOTE: intermediate commits on this PR show some of the steps that can be taken to fix some (but not all) of these pain points. ### Survey of problems we face today: (Initialize) Note: python 3.6 and spark 2.4snapshot. ``` from pyspark.sql import SparkSession import inspect, functools from pyspark.sql.functions import pandas_udf, PandasUDFType, col, lit, udf spark = SparkSession.builder.getOrCreate() print(spark.version) df = spark.range(1,6).withColumn('b', col('id') * 2) def ok(a,b): return a+b ``` Using a keyword argument at the call site `b=...` (and yes, *full* stack trace below, haha): ``` ---> 14 df.withColumn('ok', pandas_udf(f=ok, returnType='bigint')('id', b='id')).show() # no kwargs TypeError: wrapper() got an unexpected keyword argument 'b' ``` Using partial with a keyword argument where the kw-arg is the first argument of the fn: *(Aside: kind of interesting that lines 15,16 work great and then 17 explodes)* ``` --------------------------------------------------------------------------- ValueError Traceback (most recent call last) in () 15 df.withColumn('ok', pandas_udf(f=functools.partial(ok, 7), returnType='bigint')('id')).show() 16 df.withColumn('ok', pandas_udf(f=functools.partial(ok, b=7), returnType='bigint')('id')).show() ---> 17 df.withColumn('ok', pandas_udf(f=functools.partial(ok, a=7), returnType='bigint')('id')).show() /Users/stu/ZZ/spark/python/pyspark/sql/functions.py in pandas_udf(f, returnType, functionType) 2378 return functools.partial(_create_udf, returnType=return_type, evalType=eval_type) 2379 else: -> 2380 return _create_udf(f=f, returnType=return_type, evalType=eval_type) 2381 2382 /Users/stu/ZZ/spark/python/pyspark/sql/udf.py in _create_udf(f, returnType, evalType) 54 argspec.varargs is None: 55 raise ValueError( ---> 56 "Invalid function: 0-arg pandas_udfs are not supported. " 57 "Instead, create a 1-arg pandas_udf and ignore the arg in your function." 58 ) ValueError: Invalid function: 0-arg pandas_udfs are not supported. Instead, create a 1-arg pandas_udf and ignore the arg in your function. ``` Author: Michael (Stu) Stewart Closes #20900 from mstewart141/udfkw2. --- python/pyspark/sql/functions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a4edb1e27b599..ad3e37c872628 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2154,6 +2154,8 @@ def udf(f=None, returnType=StringType()): in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions. + .. note:: The user-defined functions do not take keyword arguments on the calling side. + :param f: python function if used as a standalone function :param returnType: the return type of the user-defined function. The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. @@ -2337,6 +2339,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. note:: The user-defined functions do not support conditional expressions or short circuiting in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions. + + .. note:: The user-defined functions do not take keyword arguments on the calling side. """ # decorator @pandas_udf(returnType, functionType) is_decorator = f is None or isinstance(f, (str, DataType)) From eb48edf9ca4f4b42c63f145718696472cb6a31ba Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 26 Mar 2018 14:01:04 +0800 Subject: [PATCH 524/774] [SPARK-23787][TESTS] Fix file download test in SparkSubmitSuite for Hadoop 2.9. This particular test assumed that Hadoop libraries did not support http as a file system. Hadoop 2.9 does, so the test failed. The test now forces a non-existent implementation for the http fs, which forces the expected error. There were also a couple of other issues in the same test: SparkSubmit arguments in the wrong order, and the wrong check later when asserting, which was being masked by the previous issues. Author: Marcelo Vanzin Closes #20895 from vanzin/SPARK-23787. --- .../spark/deploy/SparkSubmitSuite.scala | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 2d0c192db4915..d86ef907b4492 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -959,25 +959,28 @@ class SparkSubmitSuite } test("download remote resource if it is not supported by yarn service") { - testRemoteResources(isHttpSchemeBlacklisted = false, supportMockHttpFs = false) + testRemoteResources(enableHttpFs = false, blacklistHttpFs = false) } test("avoid downloading remote resource if it is supported by yarn service") { - testRemoteResources(isHttpSchemeBlacklisted = false, supportMockHttpFs = true) + testRemoteResources(enableHttpFs = true, blacklistHttpFs = false) } test("force download from blacklisted schemes") { - testRemoteResources(isHttpSchemeBlacklisted = true, supportMockHttpFs = true) + testRemoteResources(enableHttpFs = true, blacklistHttpFs = true) } - private def testRemoteResources(isHttpSchemeBlacklisted: Boolean, - supportMockHttpFs: Boolean): Unit = { + private def testRemoteResources( + enableHttpFs: Boolean, + blacklistHttpFs: Boolean): Unit = { val hadoopConf = new Configuration() updateConfWithFakeS3Fs(hadoopConf) - if (supportMockHttpFs) { + if (enableHttpFs) { hadoopConf.set("fs.http.impl", classOf[TestFileSystem].getCanonicalName) - hadoopConf.set("fs.http.impl.disable.cache", "true") + } else { + hadoopConf.set("fs.http.impl", getClass().getName() + ".DoesNotExist") } + hadoopConf.set("fs.http.impl.disable.cache", "true") val tmpDir = Utils.createTempDir() val mainResource = File.createTempFile("tmpPy", ".py", tmpDir) @@ -986,20 +989,19 @@ class SparkSubmitSuite val tmpHttpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) val tmpHttpJarPath = s"http://${new File(tmpHttpJar.toURI).getAbsolutePath}" + val forceDownloadArgs = if (blacklistHttpFs) { + Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http") + } else { + Nil + } + val args = Seq( "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), "--name", "testApp", "--master", "yarn", "--deploy-mode", "client", - "--jars", s"$tmpS3JarPath,$tmpHttpJarPath", - s"s3a://$mainResource" - ) ++ ( - if (isHttpSchemeBlacklisted) { - Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http,https") - } else { - Nil - } - ) + "--jars", s"$tmpS3JarPath,$tmpHttpJarPath" + ) ++ forceDownloadArgs ++ Seq(s"s3a://$mainResource") val appArgs = new SparkSubmitArguments(args) val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf)) @@ -1009,7 +1011,7 @@ class SparkSubmitSuite // The URI of remote S3 resource should still be remote. assert(jars.contains(tmpS3JarPath)) - if (supportMockHttpFs) { + if (enableHttpFs && !blacklistHttpFs) { // If Http FS is supported by yarn service, the URI of remote http resource should // still be remote. assert(jars.contains(tmpHttpJarPath)) From b30a7d28b399950953d4b112c57d4c9b9ab223e9 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 26 Mar 2018 12:45:45 -0700 Subject: [PATCH 525/774] [SPARK-23572][DOCS] Bring "security.md" up to date. This change basically rewrites the security documentation so that it's up to date with new features, more correct, and more complete. Because security is such an important feature, I chose to move all the relevant configuration documentation to the security page, instead of having them peppered all over the place in the configuration page. This allows an almost one-stop shop for security configuration in Spark. The only exceptions are some YARN-specific minor features which I left in the YARN page. I also re-organized the page's topics, since they didn't make a lot of sense. You had kerberos features described inside paragraphs talking about UI access control, and other oddities. It should be easier now to find information about specific Spark security features. I also enabled TOCs for both the Security and YARN pages, since that makes it easier to see what is covered. I removed most of the comments from the SecurityManager javadoc since they just replicated information in the security doc, with different levels of out-of-dateness. Author: Marcelo Vanzin Closes #20742 from vanzin/SPARK-23572. --- .gitignore | 1 + .../org/apache/spark/SecurityManager.scala | 144 +--- docs/configuration.md | 359 +--------- docs/monitoring.md | 40 +- docs/running-on-yarn.md | 203 +++--- docs/security.md | 629 +++++++++++++++--- 6 files changed, 673 insertions(+), 703 deletions(-) diff --git a/.gitignore b/.gitignore index 39085904e324c..e4c44d0590d59 100644 --- a/.gitignore +++ b/.gitignore @@ -76,6 +76,7 @@ streaming-tests.log target/ unit-tests.log work/ +docs/.jekyll-metadata # For Hive TempStatsStore/ diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index da1c89cd78901..09ec8932353a0 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -42,148 +42,10 @@ import org.apache.spark.util.Utils * should access it from that. There are some cases where the SparkEnv hasn't been * initialized yet and this class must be instantiated directly. * - * Spark currently supports authentication via a shared secret. - * Authentication can be configured to be on via the 'spark.authenticate' configuration - * parameter. This parameter controls whether the Spark communication protocols do - * authentication using the shared secret. This authentication is a basic handshake to - * make sure both sides have the same shared secret and are allowed to communicate. - * If the shared secret is not identical they will not be allowed to communicate. - * - * The Spark UI can also be secured by using javax servlet filters. A user may want to - * secure the UI if it has data that other users should not be allowed to see. The javax - * servlet filter specified by the user can authenticate the user and then once the user - * is logged in, Spark can compare that user versus the view acls to make sure they are - * authorized to view the UI. The configs 'spark.acls.enable', 'spark.ui.view.acls' and - * 'spark.ui.view.acls.groups' control the behavior of the acls. Note that the person who - * started the application always has view access to the UI. - * - * Spark has a set of individual and group modify acls (`spark.modify.acls`) and - * (`spark.modify.acls.groups`) that controls which users and groups have permission to - * modify a single application. This would include things like killing the application. - * By default the person who started the application has modify access. For modify access - * through the UI, you must have a filter that does authentication in place for the modify - * acls to work properly. - * - * Spark also has a set of individual and group admin acls (`spark.admin.acls`) and - * (`spark.admin.acls.groups`) which is a set of users/administrators and admin groups - * who always have permission to view or modify the Spark application. - * - * Starting from version 1.3, Spark has partial support for encrypted connections with SSL. - * - * At this point spark has multiple communication protocols that need to be secured and - * different underlying mechanisms are used depending on the protocol: - * - * - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty - * for the HttpServer. Jetty supports multiple authentication mechanisms - - * Basic, Digest, Form, Spnego, etc. It also supports multiple different login - * services - Hash, JAAS, Spnego, JDBC, etc. Spark currently uses the HashLoginService - * to authenticate using DIGEST-MD5 via a single user and the shared secret. - * Since we are using DIGEST-MD5, the shared secret is not passed on the wire - * in plaintext. - * - * We currently support SSL (https) for this communication protocol (see the details - * below). - * - * The Spark HttpServer installs the HashLoginServer and configures it to DIGEST-MD5. - * Any clients must specify the user and password. There is a default - * Authenticator installed in the SecurityManager to how it does the authentication - * and in this case gets the user name and password from the request. - * - * - BlockTransferService -> The Spark BlockTransferServices uses java nio to asynchronously - * exchange messages. For this we use the Java SASL - * (Simple Authentication and Security Layer) API and again use DIGEST-MD5 - * as the authentication mechanism. This means the shared secret is not passed - * over the wire in plaintext. - * Note that SASL is pluggable as to what mechanism it uses. We currently use - * DIGEST-MD5 but this could be changed to use Kerberos or other in the future. - * Spark currently supports "auth" for the quality of protection, which means - * the connection does not support integrity or privacy protection (encryption) - * after authentication. SASL also supports "auth-int" and "auth-conf" which - * SPARK could support in the future to allow the user to specify the quality - * of protection they want. If we support those, the messages will also have to - * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's. - * - * Since the NioBlockTransferService does asynchronous messages passing, the SASL - * authentication is a bit more complex. A ConnectionManager can be both a client - * and a Server, so for a particular connection it has to determine what to do. - * A ConnectionId was added to be able to track connections and is used to - * match up incoming messages with connections waiting for authentication. - * The ConnectionManager tracks all the sendingConnections using the ConnectionId, - * waits for the response from the server, and does the handshake before sending - * the real message. - * - * The NettyBlockTransferService ensures that SASL authentication is performed - * synchronously prior to any other communication on a connection. This is done in - * SaslClientBootstrap on the client side and SaslRpcHandler on the server side. - * - * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters - * can be used. Yarn requires a specific AmIpFilter be installed for security to work - * properly. For non-Yarn deployments, users can write a filter to go through their - * organization's normal login service. If an authentication filter is in place then the - * SparkUI can be configured to check the logged in user against the list of users who - * have view acls to see if that user is authorized. - * The filters can also be used for many different purposes. For instance filters - * could be used for logging, encryption, or compression. - * - * The exact mechanisms used to generate/distribute the shared secret are deployment-specific. - * - * For YARN deployments, the secret is automatically generated. The secret is placed in the Hadoop - * UGI which gets passed around via the Hadoop RPC mechanism. Hadoop RPC can be configured to - * support different levels of protection. See the Hadoop documentation for more details. Each - * Spark application on YARN gets a different shared secret. - * - * On YARN, the Spark UI gets configured to use the Hadoop YARN AmIpFilter which requires the user - * to go through the ResourceManager Proxy. That proxy is there to reduce the possibility of web - * based attacks through YARN. Hadoop can be configured to use filters to do authentication. That - * authentication then happens via the ResourceManager Proxy and Spark will use that to do - * authorization against the view acls. - * - * For other Spark deployments, the shared secret must be specified via the - * spark.authenticate.secret config. - * All the nodes (Master and Workers) and the applications need to have the same shared secret. - * This again is not ideal as one user could potentially affect another users application. - * This should be enhanced in the future to provide better protection. - * If the UI needs to be secure, the user needs to install a javax servlet filter to do the - * authentication. Spark will then use that user to compare against the view acls to do - * authorization. If not filter is in place the user is generally null and no authorization - * can take place. - * - * When authentication is being used, encryption can also be enabled by setting the option - * spark.authenticate.enableSaslEncryption to true. This is only supported by communication - * channels that use the network-common library, and can be used as an alternative to SSL in those - * cases. - * - * SSL can be used for encryption for certain communication channels. The user can configure the - * default SSL settings which will be used for all the supported communication protocols unless - * they are overwritten by protocol specific settings. This way the user can easily provide the - * common settings for all the protocols without disabling the ability to configure each one - * individually. - * - * All the SSL settings like `spark.ssl.xxx` where `xxx` is a particular configuration property, - * denote the global configuration for all the supported protocols. In order to override the global - * configuration for the particular protocol, the properties must be overwritten in the - * protocol-specific namespace. Use `spark.ssl.yyy.xxx` settings to overwrite the global - * configuration for particular protocol denoted by `yyy`. Currently `yyy` can be only`fs` for - * broadcast and file server. - * - * Refer to [[org.apache.spark.SSLOptions]] documentation for the list of - * options that can be specified. - * - * SecurityManager initializes SSLOptions objects for different protocols separately. SSLOptions - * object parses Spark configuration at a given namespace and builds the common representation - * of SSL settings. SSLOptions is then used to provide protocol-specific SSLContextFactory for - * Jetty. - * - * SSL must be configured on each node and configured for each component involved in - * communication using the particular protocol. In YARN clusters, the key-store can be prepared on - * the client side then distributed and used by the executors as the part of the application - * (YARN allows the user to deploy files before the application is started). - * In standalone deployment, the user needs to provide key-stores and configuration - * options for master and workers. In this mode, the user may allow the executors to use the SSL - * settings inherited from the worker which spawned that executor. It can be accomplished by - * setting `spark.ssl.useNodeLocalConf` to `true`. + * This class implements all of the configuration related to security features described + * in the "Security" document. Please refer to that document for specific features implemented + * here. */ - private[spark] class SecurityManager( sparkConf: SparkConf, val ioEncryptionKey: Option[Array[Byte]] = None) diff --git a/docs/configuration.md b/docs/configuration.md index e7f2419cc2fa4..2eb6a77434ea6 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -712,30 +712,6 @@ Apart from these, the following properties are also available, and may be useful When we fail to register to the external shuffle service, we will retry for maxAttempts times. - - spark.io.encryption.enabled - false - - Enable IO encryption. Currently supported by all modes except Mesos. It's recommended that RPC encryption - be enabled when using this feature. - - - - spark.io.encryption.keySizeBits - 128 - - IO encryption key size in bits. Supported values are 128, 192 and 256. - - - - spark.io.encryption.keygen.algorithm - HmacSHA1 - - The algorithm to use when generating the IO encryption key. The supported algorithms are - described in the KeyGenerator section of the Java Cryptography Architecture Standard Algorithm - Name Documentation. - - ### Spark UI @@ -893,6 +869,23 @@ Apart from these, the following properties are also available, and may be useful How many dead executors the Spark UI and status APIs remember before garbage collecting. + + spark.ui.filters + None + + Comma separated list of filter class names to apply to the Spark Web UI. The filter should be a + standard + javax servlet Filter. + +
    Filter parameters can also be specified in the configuration, by setting config entries + of the form spark.<class name of filter>.param.<param name>=<value> + +
    For example: +
    spark.ui.filters=com.test.filter1 +
    spark.com.test.filter1.param.name1=foo +
    spark.com.test.filter1.param.name2=bar + + ### Compression and Serialization @@ -1446,6 +1439,15 @@ Apart from these, the following properties are also available, and may be useful Duration for an RPC remote endpoint lookup operation to wait before timing out. + + spark.core.connection.ack.wait.timeout + spark.network.timeout + + How long for the connection to wait for ack to occur before timing + out and giving up. To avoid unwilling timeout caused by long pause like GC, + you can set larger value. + + ### Scheduling @@ -1817,313 +1819,8 @@ Apart from these, the following properties are also available, and may be useful ### Security - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    Property NameDefaultMeaning
    spark.acls.enablefalse - Whether Spark acls should be enabled. If enabled, this checks to see if the user has - access permissions to view or modify the job. Note this requires the user to be known, - so if the user comes across as null no checks are done. Filters can be used with the UI - to authenticate and set the user. -
    spark.admin.aclsEmpty - Comma separated list of users/administrators that have view and modify access to all Spark jobs. - This can be used if you run on a shared cluster and have a set of administrators or devs who - help debug when things do not work. Putting a "*" in the list means any user can have the - privilege of admin. -
    spark.admin.acls.groupsEmpty - Comma separated list of groups that have view and modify access to all Spark jobs. - This can be used if you have a set of administrators or developers who help maintain and debug - the underlying infrastructure. Putting a "*" in the list means any user in any group can have - the privilege of admin. The user groups are obtained from the instance of the groups mapping - provider specified by spark.user.groups.mapping. Check the entry - spark.user.groups.mapping for more details. -
    spark.user.groups.mappingorg.apache.spark.security.ShellBasedGroupsMappingProvider - The list of groups for a user is determined by a group mapping service defined by the trait - org.apache.spark.security.GroupMappingServiceProvider which can be configured by this property. - A default unix shell based implementation is provided org.apache.spark.security.ShellBasedGroupsMappingProvider - which can be specified to resolve a list of groups for a user. - Note: This implementation supports only a Unix/Linux based environment. Windows environment is - currently not supported. However, a new platform/protocol can be supported by implementing - the trait org.apache.spark.security.GroupMappingServiceProvider. -
    spark.authenticatefalse - Whether Spark authenticates its internal connections. See - spark.authenticate.secret if not running on YARN. -
    spark.authenticate.secretNone - Set the secret key used for Spark to authenticate between components. This needs to be set if - not running on YARN and authentication is enabled. -
    spark.network.crypto.enabledfalse - Enable encryption using the commons-crypto library for RPC and block transfer service. - Requires spark.authenticate to be enabled. -
    spark.network.crypto.keyLength128 - The length in bits of the encryption key to generate. Valid values are 128, 192 and 256. -
    spark.network.crypto.keyFactoryAlgorithmPBKDF2WithHmacSHA1 - The key factory algorithm to use when generating encryption keys. Should be one of the - algorithms supported by the javax.crypto.SecretKeyFactory class in the JRE being used. -
    spark.network.crypto.saslFallbacktrue - Whether to fall back to SASL authentication if authentication fails using Spark's internal - mechanism. This is useful when the application is connecting to old shuffle services that - do not support the internal Spark authentication protocol. On the server side, this can be - used to block older clients from authenticating against a new shuffle service. -
    spark.network.crypto.config.*None - Configuration values for the commons-crypto library, such as which cipher implementations to - use. The config name should be the name of commons-crypto configuration without the - "commons.crypto" prefix. -
    spark.authenticate.enableSaslEncryptionfalse - Enable encrypted communication when authentication is - enabled. This is supported by the block transfer service and the - RPC endpoints. -
    spark.network.sasl.serverAlwaysEncryptfalse - Disable unencrypted connections for services that support SASL authentication. -
    spark.core.connection.ack.wait.timeoutspark.network.timeout - How long for the connection to wait for ack to occur before timing - out and giving up. To avoid unwilling timeout caused by long pause like GC, - you can set larger value. -
    spark.modify.aclsEmpty - Comma separated list of users that have modify access to the Spark job. By default only the - user that started the Spark job has access to modify it (kill it for example). Putting a "*" in - the list means any user can have access to modify it. -
    spark.modify.acls.groupsEmpty - Comma separated list of groups that have modify access to the Spark job. This can be used if you - have a set of administrators or developers from the same team to have access to control the job. - Putting a "*" in the list means any user in any group has the access to modify the Spark job. - The user groups are obtained from the instance of the groups mapping provider specified by - spark.user.groups.mapping. Check the entry spark.user.groups.mapping - for more details. -
    spark.ui.filtersNone - Comma separated list of filter class names to apply to the Spark web UI. The filter should be a - standard - javax servlet Filter. Parameters to each filter can also be specified by setting a - java system property of:
    - spark.<class name of filter>.params='param1=value1,param2=value2'
    - For example:
    - -Dspark.ui.filters=com.test.filter1
    - -Dspark.com.test.filter1.params='param1=foo,param2=testing' -
    spark.ui.view.aclsEmpty - Comma separated list of users that have view access to the Spark web ui. By default only the - user that started the Spark job has view access. Putting a "*" in the list means any user can - have view access to this Spark job. -
    spark.ui.view.acls.groupsEmpty - Comma separated list of groups that have view access to the Spark web ui to view the Spark Job - details. This can be used if you have a set of administrators or developers or users who can - monitor the Spark job submitted. Putting a "*" in the list means any user in any group can view - the Spark job details on the Spark web ui. The user groups are obtained from the instance of the - groups mapping provider specified by spark.user.groups.mapping. Check the entry - spark.user.groups.mapping for more details. -
    - -### TLS / SSL - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    Property NameDefaultMeaning
    spark.ssl.enabledfalse - Whether to enable SSL connections on all supported protocols. - -
    When spark.ssl.enabled is configured, spark.ssl.protocol - is required. - -
    All the SSL settings like spark.ssl.xxx where xxx is a - particular configuration property, denote the global configuration for all the supported - protocols. In order to override the global configuration for the particular protocol, - the properties must be overwritten in the protocol-specific namespace. - -
    Use spark.ssl.YYY.XXX settings to overwrite the global configuration for - particular protocol denoted by YYY. Example values for YYY - include fs, ui, standalone, and - historyServer. See SSL - Configuration for details on hierarchical SSL configuration for services. -
    spark.ssl.[namespace].portNone - The port where the SSL service will listen on. - -
    The port must be defined within a namespace configuration; see - SSL Configuration for the available - namespaces. - -
    When not set, the SSL port will be derived from the non-SSL port for the - same service. A value of "0" will make the service bind to an ephemeral port. -
    spark.ssl.enabledAlgorithmsEmpty - A comma separated list of ciphers. The specified ciphers must be supported by JVM. - The reference list of protocols one can find on - this - page. - Note: If not set, it will use the default cipher suites of JVM. -
    spark.ssl.keyPasswordNone - A password to the private key in key-store. -
    spark.ssl.keyStoreNone - A path to a key-store file. The path can be absolute or relative to the directory where - the component is started in. -
    spark.ssl.keyStorePasswordNone - A password to the key-store. -
    spark.ssl.keyStoreTypeJKS - The type of the key-store. -
    spark.ssl.protocolNone - A protocol name. The protocol must be supported by JVM. The reference list of protocols - one can find on this - page. -
    spark.ssl.needClientAuthfalse - Set true if SSL needs client authentication. -
    spark.ssl.trustStoreNone - A path to a trust-store file. The path can be absolute or relative to the directory - where the component is started in. -
    spark.ssl.trustStorePasswordNone - A password to the trust-store. -
    spark.ssl.trustStoreTypeJKS - The type of the trust-store. -
    - +Please refer to the [Security](security.html) page for available options on how to secure different +Spark subsystems. ### Spark SQL diff --git a/docs/monitoring.md b/docs/monitoring.md index d5f7ffcc260a1..01736c77b0979 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -80,7 +80,10 @@ The history server can be configured as follows: -### Spark configuration options +### Spark History Server Configuration Options + +Security options for the Spark History Server are covered more detail in the +[Security](security.html#web-ui) page. @@ -160,41 +163,6 @@ The history server can be configured as follows: Location of the kerberos keytab file for the History Server. - - - - - - - - - - - - - - - diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index c010af35f8d2e..e07759a4dba87 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -2,6 +2,8 @@ layout: global title: Running Spark on YARN --- +* This will become a table of contents (this text will be scraped). +{:toc} Support for running on [YARN (Hadoop NextGen)](http://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/YARN.html) @@ -217,8 +219,8 @@ To use a custom metrics.properties for the application master and executors, upd @@ -265,19 +267,6 @@ To use a custom metrics.properties for the application master and executors, upd distribution. - - - - - @@ -373,31 +362,6 @@ To use a custom metrics.properties for the application master and executors, upd in YARN ApplicationReports, which can be used for filtering when querying YARN apps. - - - - - - - - - - - - - - - @@ -424,17 +388,6 @@ To use a custom metrics.properties for the application master and executors, upd See spark.yarn.config.gatewayPath. - - - - - @@ -468,48 +421,104 @@ To use a custom metrics.properties for the application master and executors, upd - The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. - The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. -# Running in a Secure Cluster +# Kerberos + +Standard Kerberos support in Spark is covered in the [Security](security.html#kerberos) page. + +In YARN mode, when accessing Hadoop file systems, aside from the service hosting the user's home +directory, Spark will also automatically obtain delegation tokens for the service hosting the +staging directory of the Spark application. + +If an application needs to interact with other secure Hadoop filesystems, their URIs need to be +explicitly provided to Spark at launch time. This is done by listing them in the +`spark.yarn.access.hadoopFileSystems` property, described in the configuration section below. -As covered in [security](security.html), Kerberos is used in a secure Hadoop cluster to -authenticate principals associated with services and clients. This allows clients to -make requests of these authenticated services; the services to grant rights -to the authenticated principals. +The YARN integration also supports custom delegation token providers using the Java Services +mechanism (see `java.util.ServiceLoader`). Implementations of +`org.apache.spark.deploy.yarn.security.ServiceCredentialProvider` can be made available to Spark +by listing their names in the corresponding file in the jar's `META-INF/services` directory. These +providers can be disabled individually by setting `spark.security.credentials.{service}.enabled` to +`false`, where `{service}` is the name of the credential provider. + +## YARN-specific Kerberos Configuration + +
    Property NameDefaultMeaning
    spark.history.ui.acls.enablefalse - Specifies whether acls should be checked to authorize users viewing the applications. - If enabled, access control checks are made regardless of what the individual application had - set for spark.ui.acls.enable when the application was run. The application owner - will always have authorization to view their own application and any users specified via - spark.ui.view.acls and groups specified via spark.ui.view.acls.groups - when the application was run will also have authorization to view that application. - If disabled, no access control checks are made. -
    spark.history.ui.admin.aclsempty - Comma separated list of users/administrators that have view access to all the Spark applications in - history server. By default only the users permitted to view the application at run-time could - access the related application history, with this, configured users/administrators could also - have the permission to access it. - Putting a "*" in the list means any user can have the privilege of admin. -
    spark.history.ui.admin.acls.groupsempty - Comma separated list of groups that have view access to all the Spark applications in - history server. By default only the groups permitted to view the application at run-time could - access the related application history, with this, configured groups could also - have the permission to access it. - Putting a "*" in the list means any group can have the privilege of admin. -
    spark.history.fs.cleaner.enabled falsespark.yarn.dist.forceDownloadSchemes (none) - Comma-separated list of schemes for which files will be downloaded to the local disk prior to - being added to YARN's distributed cache. For use in cases where the YARN service does not + Comma-separated list of schemes for which files will be downloaded to the local disk prior to + being added to YARN's distributed cache. For use in cases where the YARN service does not support schemes that are supported by Spark, like http, https and ftp.
    spark.yarn.access.hadoopFileSystems(none) - A comma-separated list of secure Hadoop filesystems your Spark application is going to access. For - example, spark.yarn.access.hadoopFileSystems=hdfs://nn1.com:8032,hdfs://nn2.com:8032, - webhdfs://nn3.com:50070. The Spark application must have access to the filesystems listed - and Kerberos must be properly configured to be able to access them (either in the same realm - or in a trusted realm). Spark acquires security tokens for each of the filesystems so that - the Spark application can access those remote Hadoop filesystems. spark.yarn.access.namenodes - is deprecated, please use this instead. -
    spark.yarn.appMasterEnv.[EnvironmentVariableName] (none)
    spark.yarn.keytab(none) - The full path to the file that contains the keytab for the principal specified above. - This keytab will be copied to the node running the YARN Application Master via the Secure Distributed Cache, - for renewing the login tickets and the delegation tokens periodically. (Works also with the "local" master) -
    spark.yarn.principal(none) - Principal to be used to login to KDC, while running on secure HDFS. (Works also with the "local" master) -
    spark.yarn.kerberos.relogin.period1m - How often to check whether the kerberos TGT should be renewed. This should be set to a value - that is shorter than the TGT renewal period (or the TGT lifetime if TGT renewal is not enabled). - The default value should be enough for most deployments. -
    spark.yarn.config.gatewayPath (none)
    spark.security.credentials.${service}.enabledtrue - Controls whether to obtain credentials for services when security is enabled. - By default, credentials for all supported services are retrieved when those services are - configured, but it's possible to disable that behavior if it somehow conflicts with the - application being run. For further details please see - [Running in a Secure Cluster](running-on-yarn.html#running-in-a-secure-cluster) -
    spark.yarn.rolledLog.includePattern (none)
    + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.yarn.keytab(none) + The full path to the file that contains the keytab for the principal specified above. This keytab + will be copied to the node running the YARN Application Master via the YARN Distributed Cache, and + will be used for renewing the login tickets and the delegation tokens periodically. Equivalent to + the --keytab command line argument. + +
    (Works also with the "local" master.) +
    spark.yarn.principal(none) + Principal to be used to login to KDC, while running on secure clusters. Equivalent to the + --principal command line argument. + +
    (Works also with the "local" master.) +
    spark.yarn.access.hadoopFileSystems(none) + A comma-separated list of secure Hadoop filesystems your Spark application is going to access. For + example, spark.yarn.access.hadoopFileSystems=hdfs://nn1.com:8032,hdfs://nn2.com:8032, + webhdfs://nn3.com:50070. The Spark application must have access to the filesystems listed + and Kerberos must be properly configured to be able to access them (either in the same realm + or in a trusted realm). Spark acquires security tokens for each of the filesystems so that + the Spark application can access those remote Hadoop filesystems. +
    spark.yarn.kerberos.relogin.period1m + How often to check whether the kerberos TGT should be renewed. This should be set to a value + that is shorter than the TGT renewal period (or the TGT lifetime if TGT renewal is not enabled). + The default value should be enough for most deployments. +
    -Hadoop services issue *hadoop tokens* to grant access to the services and data. -Clients must first acquire tokens for the services they will access and pass them along with their -application as it is launched in the YARN cluster. +## Troubleshooting Kerberos -For a Spark application to interact with any of the Hadoop filesystem (for example hdfs, webhdfs, etc), HBase and Hive, it must acquire the relevant tokens -using the Kerberos credentials of the user launching the application -—that is, the principal whose identity will become that of the launched Spark application. +Debugging Hadoop/Kerberos problems can be "difficult". One useful technique is to +enable extra logging of Kerberos operations in Hadoop by setting the `HADOOP_JAAS_DEBUG` +environment variable. -This is normally done at launch time: in a secure cluster Spark will automatically obtain a -token for the cluster's default Hadoop filesystem, and potentially for HBase and Hive. +```bash +export HADOOP_JAAS_DEBUG=true +``` -An HBase token will be obtained if HBase is in on classpath, the HBase configuration declares -the application is secure (i.e. `hbase-site.xml` sets `hbase.security.authentication` to `kerberos`), -and `spark.security.credentials.hbase.enabled` is not set to `false`. +The JDK classes can be configured to enable extra logging of their Kerberos and +SPNEGO/REST authentication via the system properties `sun.security.krb5.debug` +and `sun.security.spnego.debug=true` -Similarly, a Hive token will be obtained if Hive is on the classpath, its configuration -includes a URI of the metadata store in `"hive.metastore.uris`, and -`spark.security.credentials.hive.enabled` is not set to `false`. +``` +-Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true +``` -If an application needs to interact with other secure Hadoop filesystems, then -the tokens needed to access these clusters must be explicitly requested at -launch time. This is done by listing them in the `spark.yarn.access.hadoopFileSystems` property. +All these options can be enabled in the Application Master: ``` -spark.yarn.access.hadoopFileSystems hdfs://ireland.example.org:8020/,webhdfs://frankfurt.example.org:50070/ +spark.yarn.appMasterEnv.HADOOP_JAAS_DEBUG true +spark.yarn.am.extraJavaOptions -Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true ``` -Spark supports integrating with other security-aware services through Java Services mechanism (see -`java.util.ServiceLoader`). To do that, implementations of `org.apache.spark.deploy.yarn.security.ServiceCredentialProvider` -should be available to Spark by listing their names in the corresponding file in the jar's -`META-INF/services` directory. These plug-ins can be disabled by setting -`spark.security.credentials.{service}.enabled` to `false`, where `{service}` is the name of -credential provider. +Finally, if the log level for `org.apache.spark.deploy.yarn.Client` is set to `DEBUG`, the log +will include a list of all tokens obtained, and their expiry details -## Configuring the External Shuffle Service + +# Configuring the External Shuffle Service To start the Spark Shuffle Service on each `NodeManager` in your YARN cluster, follow these instructions: @@ -542,7 +551,7 @@ The following extra configuration options are available when the shuffle service -## Launching your application with Apache Oozie +# Launching your application with Apache Oozie Apache Oozie can launch Spark applications as part of a workflow. In a secure cluster, the launched application will need the relevant tokens to access the cluster's @@ -576,35 +585,7 @@ spark.security.credentials.hbase.enabled false The configuration option `spark.yarn.access.hadoopFileSystems` must be unset. -## Troubleshooting Kerberos - -Debugging Hadoop/Kerberos problems can be "difficult". One useful technique is to -enable extra logging of Kerberos operations in Hadoop by setting the `HADOOP_JAAS_DEBUG` -environment variable. - -```bash -export HADOOP_JAAS_DEBUG=true -``` - -The JDK classes can be configured to enable extra logging of their Kerberos and -SPNEGO/REST authentication via the system properties `sun.security.krb5.debug` -and `sun.security.spnego.debug=true` - -``` --Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true -``` - -All these options can be enabled in the Application Master: - -``` -spark.yarn.appMasterEnv.HADOOP_JAAS_DEBUG true -spark.yarn.am.extraJavaOptions -Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true -``` - -Finally, if the log level for `org.apache.spark.deploy.yarn.Client` is set to `DEBUG`, the log -will include a list of all tokens obtained, and their expiry details - -## Using the Spark History Server to replace the Spark Web UI +# Using the Spark History Server to replace the Spark Web UI It is possible to use the Spark History Server application page as the tracking URL for running applications when the application UI is disabled. This may be desirable on secure clusters, or to diff --git a/docs/security.md b/docs/security.md index 913d9df50eb1c..3e5607a9a0d67 100644 --- a/docs/security.md +++ b/docs/security.md @@ -3,47 +3,336 @@ layout: global displayTitle: Spark Security title: Security --- +* This will become a table of contents (this text will be scraped). +{:toc} -Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the `spark.authenticate` configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate. The shared secret is created as follows: +# Spark RPC -* For Spark on [YARN](running-on-yarn.html) and local deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. -* For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. +## Authentication -## Web UI +Spark currently supports authentication for RPC channels using a shared secret. Authentication can +be turned on by setting the `spark.authenticate` configuration parameter. -The Spark UI can be secured by using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html) via the `spark.ui.filters` setting -and by using [https/SSL](http://en.wikipedia.org/wiki/HTTPS) via [SSL settings](security.html#ssl-configuration). +The exact mechanism used to generate and distribute the shared secret is deployment-specific. -### Authentication +For Spark on [YARN](running-on-yarn.html) and local deployments, Spark will automatically handle +generating and distributing the shared secret. Each application will use a unique shared secret. In +the case of YARN, this feature relies on YARN RPC encryption being enabled for the distribution of +secrets to be secure. -A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view ACLs to make sure they are authorized to view the UI. The configs `spark.acls.enable`, `spark.ui.view.acls` and `spark.ui.view.acls.groups` control the behavior of the ACLs. Note that the user who started the application always has view access to the UI. On YARN, the Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. +For other resource managers, `spark.authenticate.secret` must be configured on each of the nodes. +This secret will be shared by all the daemons and applications, so this deployment configuration is +not as secure as the above, especially when considering multi-tenant clusters. -Spark also supports modify ACLs to control who has access to modify a running Spark application. This includes things like killing the application or a task. This is controlled by the configs `spark.acls.enable`, `spark.modify.acls` and `spark.modify.acls.groups`. Note that if you are authenticating the web UI, in order to use the kill button on the web UI it might be necessary to add the users in the modify acls to the view acls also. On YARN, the modify acls are passed in and control who has modify access via YARN interfaces. -Spark allows for a set of administrators to be specified in the acls who always have view and modify permissions to all the applications. is controlled by the configs `spark.admin.acls` and `spark.admin.acls.groups`. This is useful on a shared cluster where you might have administrators or support staff who help users debug applications. + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.authenticatefalseWhether Spark authenticates its internal connections.
    spark.authenticate.secretNone + The secret key used authentication. See above for when this configuration should be set. +
    + +## Encryption -## Event Logging +Spark supports AES-based encryption for RPC connections. For encryption to be enabled, RPC +authentication must also be enabled and properly configured. AES encryption uses the +[Apache Commons Crypto](http://commons.apache.org/proper/commons-crypto/) library, and Spark's +configuration system allows access to that library's configuration for advanced users. -If your applications are using event logging, the directory where the event logs go (`spark.eventLog.dir`) should be manually created and have the proper permissions set on it. If you want those log files secured, the permissions should be set to `drwxrwxrwxt` for that directory. The owner of the directory should be the super user who is running the history server and the group permissions should be restricted to super user group. This will allow all users to write to the directory but will prevent unprivileged users from removing or renaming a file unless they own the file or directory. The event log files will be created by Spark with permissions such that only the user and group have read and write access. +There is also support for SASL-based encryption, although it should be considered deprecated. It +is still required when talking to shuffle services from Spark versions older than 2.2.0. -## Encryption +The following table describes the different options available for configuring this feature. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.network.crypto.enabledfalse + Enable AES-based RPC encryption, including the new authentication protocol added in 2.2.0. +
    spark.network.crypto.keyLength128 + The length in bits of the encryption key to generate. Valid values are 128, 192 and 256. +
    spark.network.crypto.keyFactoryAlgorithmPBKDF2WithHmacSHA1 + The key factory algorithm to use when generating encryption keys. Should be one of the + algorithms supported by the javax.crypto.SecretKeyFactory class in the JRE being used. +
    spark.network.crypto.config.*None + Configuration values for the commons-crypto library, such as which cipher implementations to + use. The config name should be the name of commons-crypto configuration without the + commons.crypto prefix. +
    spark.network.crypto.saslFallbacktrue + Whether to fall back to SASL authentication if authentication fails using Spark's internal + mechanism. This is useful when the application is connecting to old shuffle services that + do not support the internal Spark authentication protocol. On the shuffle service side, + disabling this feature will block older clients from authenticating. +
    spark.authenticate.enableSaslEncryptionfalse + Enable SASL-based encrypted communication. +
    spark.network.sasl.serverAlwaysEncryptfalse + Disable unencrypted connections for ports using SASL authentication. This will deny connections + from clients that have authentication enabled, but do not request SASL-based encryption. +
    + + +# Local Storage Encryption + +Spark supports encrypting temporary data written to local disks. This covers shuffle files, shuffle +spills and data blocks stored on disk (for both caching and broadcast variables). It does not cover +encrypting output data generated by applications with APIs such as `saveAsHadoopFile` or +`saveAsTable`. + +The following settings cover enabling encryption for data written to disk: + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.io.encryption.enabledfalse + Enable local disk I/O encryption. Currently supported by all modes except Mesos. It's strongly + recommended that RPC encryption be enabled when using this feature. +
    spark.io.encryption.keySizeBits128 + IO encryption key size in bits. Supported values are 128, 192 and 256. +
    spark.io.encryption.keygen.algorithmHmacSHA1 + The algorithm to use when generating the IO encryption key. The supported algorithms are + described in the KeyGenerator section of the Java Cryptography Architecture Standard Algorithm + Name Documentation. +
    spark.io.encryption.commons.config.*None + Configuration values for the commons-crypto library, such as which cipher implementations to + use. The config name should be the name of commons-crypto configuration without the + commons.crypto prefix. +
    + + +# Web UI + +## Authentication and Authorization + +Enabling authentication for the Web UIs is done using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html). +You will need a filter that implements the authentication method you want to deploy. Spark does not +provide any built-in authentication filters. + +Spark also supports access control to the UI when an authentication filter is present. Each +application can be configured with its own separate access control lists (ACLs). Spark +differentiates between "view" permissions (who is allowed to see the application's UI), and "modify" +permissions (who can do things like kill jobs in a running application). + +ACLs can be configured for either users or groups. Configuration entries accept comma-separated +lists as input, meaning multiple users or groups can be given the desired privileges. This can be +used if you run on a shared cluster and have a set of administrators or developers who need to +monitor applications they may not have started themselves. A wildcard (`*`) added to specific ACL +means that all users will have the respective pivilege. By default, only the user submitting the +application is added to the ACLs. + +Group membership is established by using a configurable group mapping provider. The mapper is +configured using the spark.user.groups.mapping config option, described in the table +below. + +The following options control the authentication of Web UIs: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.ui.filtersNone + See the Spark UI configuration for how to configure + filters. +
    spark.acls.enablefalse + Whether UI ACLs should be enabled. If enabled, this checks to see if the user has access + permissions to view or modify the application. Note this requires the user to be authenticated, + so if no authentication filter is installed, this option does not do anything. +
    spark.admin.aclsNone + Comma-separated list of users that have view and modify access to the Spark application. +
    spark.admin.acls.groupsNone + Comma-separated list of groups that have view and modify access to the Spark application. +
    spark.modify.aclsNone + Comma-separated list of users that have modify access to the Spark application. +
    spark.modify.acls.groupsNone + Comma-separated list of groups that have modify access to the Spark application. +
    spark.ui.view.aclsNone + Comma-separated list of users that have view access to the Spark application. +
    spark.ui.view.acls.groupsNone + Comma-separated list of groups that have view access to the Spark application. +
    spark.user.groups.mappingorg.apache.spark.security.ShellBasedGroupsMappingProvider + The list of groups for a user is determined by a group mapping service defined by the trait + org.apache.spark.security.GroupMappingServiceProvider, which can be configured by + this property. + +
    By default, a Unix shell-based implementation is used, which collects this information + from the host OS. + +
    Note: This implementation supports only Unix/Linux-based environments. + Windows environment is currently not supported. However, a new platform/protocol can + be supported by implementing the trait mentioned above. +
    + +On YARN, the view and modify ACLs are provided to the YARN service when submitting applications, and +control who has the respective privileges via YARN interfaces. + +## Spark History Server ACLs -Spark supports SSL for HTTP protocols. SASL encryption is supported for the block transfer service -and the RPC endpoints. Shuffle files can also be encrypted if desired. +Authentication for the SHS Web UI is enabled the same way as for regular applications, using +servlet filters. -### SSL Configuration +To enable authorization in the SHS, a few extra options are used: + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.history.ui.acls.enablefalse + Specifies whether ACLs should be checked to authorize users viewing the applications in + the history server. If enabled, access control checks are performed regardless of what the + individual applications had set for spark.ui.acls.enable. The application owner + will always have authorization to view their own application and any users specified via + spark.ui.view.acls and groups specified via spark.ui.view.acls.groups + when the application was run will also have authorization to view that application. + If disabled, no access control checks are made for any application UIs available through + the history server. +
    spark.history.ui.admin.aclsNone + Comma separated list of users that have view access to all the Spark applications in history + server. +
    spark.history.ui.admin.acls.groupsNone + Comma separated list of groups that have view access to all the Spark applications in history + server. +
    + +The SHS uses the same options to configure the group mapping provider as regular applications. +In this case, the group mapping provider will apply to all UIs server by the SHS, and individual +application configurations will be ignored. + +## SSL Configuration Configuration for SSL is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the -protocols without disabling the ability to configure each one individually. The common SSL settings -are at `spark.ssl` namespace in Spark configuration. The following table describes the -component-specific configuration namespaces used to override the default settings: +protocols without disabling the ability to configure each one individually. The following table +describes the the SSL configuration namespaces: + + + + @@ -58,49 +347,205 @@ component-specific configuration namespaces used to override the default setting
    Config Namespace Component
    spark.ssl + The default SSL configuration. These values will apply to all namespaces below, unless + explicitly overridden at the namespace level. +
    spark.ssl.ui Spark application Web UI
    -The full breakdown of available SSL options can be found on the [configuration page](configuration.html). -SSL must be configured on each node and configured for each component involved in communication using the particular protocol. +The full breakdown of available SSL options can be found below. The `${ns}` placeholder should be +replaced with one of the above namespaces. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    ${ns}.enabledfalseEnables SSL. When enabled, ${ns}.ssl.protocol is required.
    ${ns}.portNone + The port where the SSL service will listen on. + +
    The port must be defined within a specific namespace configuration. The default + namespace is ignored when reading this configuration. + +
    When not set, the SSL port will be derived from the non-SSL port for the + same service. A value of "0" will make the service bind to an ephemeral port. +
    ${ns}.enabledAlgorithmsNone + A comma separated list of ciphers. The specified ciphers must be supported by JVM. + +
    The reference list of protocols can be found in the "JSSE Cipher Suite Names" section + of the Java security guide. The list for Java 8 can be found at + this + page. + +
    Note: If not set, the default cipher suite for the JRE will be used. +
    ${ns}.keyPasswordNone + The password to the private key in the key store. +
    ${ns}.keyStoreNone + Path to the key store file. The path can be absolute or relative to the directory in which the + process is started. +
    ${ns}.keyStorePasswordNonePassword to the key store.
    ${ns}.keyStoreTypeJKSThe type of the key store.
    ${ns}.protocolNone + TLS protocol to use. The protocol must be supported by JVM. + +
    The reference list of protocols can be found in the "Additional JSSE Standard Names" + section of the Java security guide. For Java 8, the list can be found at + this + page. +
    ${ns}.needClientAuthfalseWhether to require client authentication.
    ${ns}.trustStoreNone + Path to the trust store file. The path can be absolute or relative to the directory in which + the process is started. +
    ${ns}.trustStorePasswordNonePassword for the trust store.
    ${ns}.trustStoreTypeJKSThe type of the trust store.
    + +## Preparing the key stores + +Key stores can be generated by `keytool` program. The reference documentation for this tool for +Java 8 is [here](https://docs.oracle.com/javase/8/docs/technotes/tools/unix/keytool.html). +The most basic steps to configure the key stores and the trust store for a Spark Standalone +deployment mode is as follows: + +* Generate a key pair for each node +* Export the public key of the key pair to a file on each node +* Import all exported public keys into a single trust store +* Distribute the trust store to the cluster nodes ### YARN mode -The key-store can be prepared on the client side and then distributed and used by the executors as the part of the application. It is possible because the user is able to deploy files before the application is started in YARN by using `spark.yarn.dist.files` or `spark.yarn.dist.archives` configuration settings. The responsibility for encryption of transferring these files is on YARN side and has nothing to do with Spark. -For long-running apps like Spark Streaming apps to be able to write to HDFS, it is possible to pass a principal and keytab to `spark-submit` via the `--principal` and `--keytab` parameters respectively. The keytab passed in will be copied over to the machine running the Application Master via the Hadoop Distributed Cache (securely - if YARN is configured with SSL and HDFS encryption is enabled). The Kerberos login will be periodically renewed using this principal and keytab and the delegation tokens required for HDFS will be generated periodically so the application can continue writing to HDFS. +To provide a local trust store or key store file to drivers running in cluster mode, they can be +distributed with the application using the `--files` command line argument (or the equivalent +`spark.files` configuration). The files will be placed on the driver's working directory, so the TLS +configuration should just reference the file name with no absolute path. + +Distributing local key stores this way may require the files to be staged in HDFS (or other similar +distributed file system used by the cluster), so it's recommended that the undelying file system be +configured with security in mind (e.g. by enabling authentication and wire encryption). ### Standalone mode -The user needs to provide key-stores and configuration options for master and workers. They have to be set by attaching appropriate Java system properties in `SPARK_MASTER_OPTS` and in `SPARK_WORKER_OPTS` environment variables, or just in `SPARK_DAEMON_JAVA_OPTS`. In this mode, the user may allow the executors to use the SSL settings inherited from the worker which spawned that executor. It can be accomplished by setting `spark.ssl.useNodeLocalConf` to `true`. If that parameter is set, the settings provided by user on the client side, are not used by the executors. + +The user needs to provide key stores and configuration options for master and workers. They have to +be set by attaching appropriate Java system properties in `SPARK_MASTER_OPTS` and in +`SPARK_WORKER_OPTS` environment variables, or just in `SPARK_DAEMON_JAVA_OPTS`. + +The user may allow the executors to use the SSL settings inherited from the worker process. That +can be accomplished by setting `spark.ssl.useNodeLocalConf` to `true`. In that case, the settings +provided by the user on the client side are not used. ### Mesos mode -Mesos 1.3.0 and newer supports `Secrets` primitives as both file-based and environment based secrets. Spark allows the specification of file-based and environment variable based secrets with the `spark.mesos.driver.secret.filenames` and `spark.mesos.driver.secret.envkeys`, respectively. Depending on the secret store backend secrets can be passed by reference or by value with the `spark.mesos.driver.secret.names` and `spark.mesos.driver.secret.values` configuration properties, respectively. Reference type secrets are served by the secret store and referred to by name, for example `/mysecret`. Value type secrets are passed on the command line and translated into their appropriate files or environment variables. +Mesos 1.3.0 and newer supports `Secrets` primitives as both file-based and environment based +secrets. Spark allows the specification of file-based and environment variable based secrets with +`spark.mesos.driver.secret.filenames` and `spark.mesos.driver.secret.envkeys`, respectively. -### Preparing the key-stores -Key-stores can be generated by `keytool` program. The reference documentation for this tool is -[here](https://docs.oracle.com/javase/7/docs/technotes/tools/solaris/keytool.html). The most basic -steps to configure the key-stores and the trust-store for the standalone deployment mode is as -follows: +Depending on the secret store backend secrets can be passed by reference or by value with the +`spark.mesos.driver.secret.names` and `spark.mesos.driver.secret.values` configuration properties, +respectively. -* Generate a keys pair for each node -* Export the public key of the key pair to a file on each node -* Import all exported public keys into a single trust-store -* Distribute the trust-store over the nodes +Reference type secrets are served by the secret store and referred to by name, for example +`/mysecret`. Value type secrets are passed on the command line and translated into their +appropriate files or environment variables. -### Configuring SASL Encryption +## HTTP Security Headers -SASL encryption is currently supported for the block transfer service when authentication -(`spark.authenticate`) is enabled. To enable SASL encryption for an application, set -`spark.authenticate.enableSaslEncryption` to `true` in the application's configuration. +Apache Spark can be configured to include HTTP headers to aid in preventing Cross Site Scripting +(XSS), Cross-Frame Scripting (XFS), MIME-Sniffing, and also to enforce HTTP Strict Transport +Security. -When using an external shuffle service, it's possible to disable unencrypted connections by setting -`spark.network.sasl.serverAlwaysEncrypt` to `true` in the shuffle service's configuration. If that -option is enabled, applications that are not set up to use SASL encryption will fail to connect to -the shuffle service. + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.ui.xXssProtection1; mode=block + Value for HTTP X-XSS-Protection response header. You can choose appropriate value + from below: +
      +
    • 0 (Disables XSS filtering)
    • +
    • 1 (Enables XSS filtering. If a cross-site scripting attack is detected, + the browser will sanitize the page.)
    • +
    • 1; mode=block (Enables XSS filtering. The browser will prevent rendering + of the page if an attack is detected.)
    • +
    +
    spark.ui.xContentTypeOptions.enabledtrue + When enabled, X-Content-Type-Options HTTP response header will be set to "nosniff". +
    spark.ui.strictTransportSecurityNone + Value for HTTP Strict Transport Security (HSTS) Response Header. You can choose appropriate + value from below and set expire-time accordingly. This option is only used when + SSL/TLS is enabled. +
      +
    • max-age=<expire-time>
    • +
    • max-age=<expire-time>; includeSubDomains
    • +
    • max-age=<expire-time>; preload
    • +
    +
    -## Configuring Ports for Network Security + +# Configuring Ports for Network Security Spark makes heavy use of the network, and some environments have strict requirements for using tight firewall settings. Below are the primary ports that Spark uses for its communication and how to configure those ports. -### Standalone mode only +## Standalone mode only @@ -141,7 +586,7 @@ configure those ports.
    -### All cluster managers +## All cluster managers @@ -182,54 +627,70 @@ configure those ports.
    -### HTTP Security Headers -Apache Spark can be configured to include HTTP Headers which aids in preventing Cross -Site Scripting (XSS), Cross-Frame Scripting (XFS), MIME-Sniffing and also enforces HTTP -Strict Transport Security. +# Kerberos + +Spark supports submitting applications in environments that use Kerberos for authentication. +In most cases, Spark relies on the credentials of the current logged in user when authenticating +to Kerberos-aware services. Such credentials can be obtained by logging in to the configured KDC +with tools like `kinit`. + +When talking to Hadoop-based services, Spark needs to obtain delegation tokens so that non-local +processes can authenticate. Spark ships with support for HDFS and other Hadoop file systems, Hive +and HBase. + +When using a Hadoop filesystem (such HDFS or WebHDFS), Spark will acquire the relevant tokens +for the service hosting the user's home directory. + +An HBase token will be obtained if HBase is in the application's classpath, and the HBase +configuration has Kerberos authentication turned (`hbase.security.authentication=kerberos`). + +Similarly, a Hive token will be obtained if Hive is in the classpath, and the configuration includes +URIs for remote metastore services (`hive.metastore.uris` is not empty). + +Delegation token support is currently only supported in YARN and Mesos modes. Consult the +deployment-specific page for more information. + +The following options provides finer-grained control for this feature: - - - - - - + - - - - -
    Property NameDefaultMeaning
    spark.ui.xXssProtection1; mode=block - Value for HTTP X-XSS-Protection response header. You can choose appropriate value - from below: -
      -
    • 0 (Disables XSS filtering)
    • -
    • 1 (Enables XSS filtering. If a cross-site scripting attack is detected, - the browser will sanitize the page.)
    • -
    • 1; mode=block (Enables XSS filtering. The browser will prevent rendering - of the page if an attack is detected.)
    • -
    -
    spark.ui.xContentTypeOptions.enabledspark.security.credentials.${service}.enabled true - When value is set to "true", X-Content-Type-Options HTTP response header will be set - to "nosniff". Set "false" to disable. -
    spark.ui.strictTransportSecurityNone - Value for HTTP Strict Transport Security (HSTS) Response Header. You can choose appropriate - value from below and set expire-time accordingly, when Spark is SSL/TLS enabled. -
      -
    • max-age=<expire-time>
    • -
    • max-age=<expire-time>; includeSubDomains
    • -
    • max-age=<expire-time>; preload
    • -
    + Controls whether to obtain credentials for services when security is enabled. + By default, credentials for all supported services are retrieved when those services are + configured, but it's possible to disable that behavior if it somehow conflicts with the + application being run.
    - -See the [configuration page](configuration.html) for more details on the security configuration -parameters, and -org.apache.spark.SecurityManager for implementation details about security. +## Long-Running Applications + +Long-running applications may run into issues if their run time exceeds the maximum delegation +token lifetime configured in services it needs to access. + +Spark supports automatically creating new tokens for these applications when running in YARN mode. +Kerberos credentials need to be provided to the Spark application via the `spark-submit` command, +using the `--principal` and `--keytab` parameters. + +The provided keytab will be copied over to the machine running the Application Master via the Hadoop +Distributed Cache. For this reason, it's strongly recommended that both YARN and HDFS be secured +with encryption, at least. + +The Kerberos login will be periodically renewed using the provided credentials, and new delegation +tokens for supported will be created. + + +# Event Logging + +If your applications are using event logging, the directory where the event logs go +(`spark.eventLog.dir`) should be manually created with proper permissions. To secure the log files, +the directory permissions should be set to `drwxrwxrwxt`. The owner and group of the directory +should correspond to the super user who is running the Spark History Server. +This will allow all users to write to the directory but will prevent unprivileged users from +reading, removing or renaming a file unless they own it. The event log files will be created by +Spark with permissions such that only the user and group have read and write access. From 3e778f5a91b0553b09fe0e0ee84d771a71504960 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Mon, 26 Mar 2018 15:45:27 -0700 Subject: [PATCH 526/774] [SPARK-23162][PYSPARK][ML] Add r2adj into Python API in LinearRegressionSummary ## What changes were proposed in this pull request? Adding r2adj in LinearRegressionSummary for Python API. ## How was this patch tested? Added unit tests to exercise the api calls for the summary classes in tests.py. Author: Kevin Yu Closes #20842 from kevinyu98/spark-23162. --- python/pyspark/ml/regression.py | 18 ++++++++++++++++-- python/pyspark/ml/tests.py | 1 + 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index de0a0fa9f3bf8..9a66d87d7f211 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -336,10 +336,10 @@ def rootMeanSquaredError(self): @since("2.0.0") def r2(self): """ - Returns R^2^, the coefficient of determination. + Returns R^2, the coefficient of determination. .. seealso:: `Wikipedia coefficient of determination \ - ` + `_ .. note:: This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`. This will change in later Spark @@ -347,6 +347,20 @@ def r2(self): """ return self._call_java("r2") + @property + @since("2.4.0") + def r2adj(self): + """ + Returns Adjusted R^2, the adjusted coefficient of determination. + + .. seealso:: `Wikipedia coefficient of determination, Adjusted R^2 \ + `_ + + .. note:: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark versions. + """ + return self._call_java("r2adj") + @property @since("2.0.0") def residuals(self): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index cf1ffa181ecec..6b4376cbf14e8 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1559,6 +1559,7 @@ def test_linear_regression_summary(self): self.assertAlmostEqual(s.meanSquaredError, 0.0) self.assertAlmostEqual(s.rootMeanSquaredError, 0.0) self.assertAlmostEqual(s.r2, 1.0, 2) + self.assertAlmostEqual(s.r2adj, 1.0, 2) self.assertTrue(isinstance(s.residuals, DataFrame)) self.assertEqual(s.numInstances, 2) self.assertEqual(s.degreesOfFreedom, 1) From 35997b59f3116830af06b3d40a7675ef0dbf7091 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 27 Mar 2018 14:49:50 +0200 Subject: [PATCH 527/774] [SPARK-23794][SQL] Make UUID as stateful expression ## What changes were proposed in this pull request? The UUID() expression is stateful and should implement the `Stateful` trait instead of the `Nondeterministic` trait. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #20912 from viirya/SPARK-23794. --- .../org/apache/spark/sql/catalyst/expressions/misc.scala | 4 +++- .../sql/catalyst/expressions/MiscExpressionsSuite.scala | 6 ++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index ec93620038cff..a390f8ef7fd9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -123,7 +123,7 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { 46707d92-02f4-4817-8116-a4c3b23e6266 """) // scalastyle:on line.size.limit -case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Nondeterministic { +case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Stateful { def this() = this(None) @@ -152,4 +152,6 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Non ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", isNull = "false") } + + override def freshCopy(): Uuid = Uuid(randomSeed) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index 3383d421f5616..b6c269348b002 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -59,6 +59,12 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { evaluateWithGeneratedMutableProjection(Uuid(seed2))) assert(evaluateWithUnsafeProjection(Uuid(seed1)) !== evaluateWithUnsafeProjection(Uuid(seed2))) + + val uuid = Uuid(seed1) + assert(uuid.fastEquals(uuid)) + assert(!uuid.fastEquals(Uuid(seed1))) + assert(!uuid.fastEquals(uuid.freshCopy())) + assert(!uuid.fastEquals(Uuid(seed2))) } test("PrintToStderr") { From c68ec4e6a1ed9ea13345c7705ea60ff4df7aec7b Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 27 Mar 2018 14:39:05 -0700 Subject: [PATCH 528/774] [SPARK-23096][SS] Migrate rate source to V2 ## What changes were proposed in this pull request? This PR migrate micro batch rate source to V2 API and rewrite UTs to suite V2 test. ## How was this patch tested? UTs. Author: jerryshao Closes #20688 from jerryshao/SPARK-23096. --- ...pache.spark.sql.sources.DataSourceRegister | 3 +- .../execution/datasources/DataSource.scala | 6 +- .../streaming/RateSourceProvider.scala | 262 ------------- .../ContinuousRateStreamSource.scala | 25 +- .../sources/RateStreamMicroBatchReader.scala | 222 +++++++++++ .../sources/RateStreamProvider.scala | 125 +++++++ .../sources/RateStreamSourceV2.scala | 187 ---------- .../execution/streaming/RateSourceSuite.scala | 194 ---------- .../streaming/RateSourceV2Suite.scala | 191 ---------- .../sources/RateStreamProviderSuite.scala | 344 ++++++++++++++++++ 10 files changed, 715 insertions(+), 844 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 1fe9c093af99f..1b37905543b4e 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,6 +5,5 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider -org.apache.spark.sql.execution.streaming.RateSourceProvider +org.apache.spark.sql.execution.streaming.sources.RateStreamProvider org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider -org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 31fa89b4570a6..b84ea769808f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider +import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode @@ -566,6 +566,7 @@ object DataSource extends Logging { val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" val nativeOrc = classOf[OrcFileFormat].getCanonicalName val socket = classOf[TextSocketSourceProvider].getCanonicalName + val rate = classOf[RateStreamProvider].getCanonicalName Map( "org.apache.spark.sql.jdbc" -> jdbc, @@ -587,7 +588,8 @@ object DataSource extends Logging { "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, "org.apache.spark.ml.source.libsvm" -> libsvm, "com.databricks.spark.csv" -> csv, - "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket + "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket, + "org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala deleted file mode 100644 index 649fbbfa184ec..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ /dev/null @@ -1,262 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.io._ -import java.nio.charset.StandardCharsets -import java.util.Optional -import java.util.concurrent.TimeUnit - -import org.apache.commons.io.IOUtils - -import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader -import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} -import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader -import org.apache.spark.sql.types._ -import org.apache.spark.util.{ManualClock, SystemClock} - -/** - * A source that generates increment long values with timestamps. Each generated row has two - * columns: a timestamp column for the generated time and an auto increment long column starting - * with 0L. - * - * This source supports the following options: - * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. - * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed - * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer - * seconds. - * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the - * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may - * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. - */ -class RateSourceProvider extends StreamSourceProvider with DataSourceRegister - with DataSourceV2 with ContinuousReadSupport { - - override def sourceSchema( - sqlContext: SQLContext, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): (String, StructType) = { - if (schema.nonEmpty) { - throw new AnalysisException("The rate source does not support a user-specified schema.") - } - - (shortName(), RateSourceProvider.SCHEMA) - } - - override def createSource( - sqlContext: SQLContext, - metadataPath: String, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): Source = { - val params = CaseInsensitiveMap(parameters) - - val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L) - if (rowsPerSecond <= 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " + - "must be positive") - } - - val rampUpTimeSeconds = - params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L) - if (rampUpTimeSeconds < 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " + - "must not be negative") - } - - val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse( - sqlContext.sparkContext.defaultParallelism) - if (numPartitions <= 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " + - "must be positive") - } - - new RateStreamSource( - sqlContext, - metadataPath, - rowsPerSecond, - rampUpTimeSeconds, - numPartitions, - params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing - ) - } - - override def createContinuousReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = { - new RateStreamContinuousReader(options) - } - - override def shortName(): String = "rate" -} - -object RateSourceProvider { - val SCHEMA = - StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) - - val VERSION = 1 -} - -class RateStreamSource( - sqlContext: SQLContext, - metadataPath: String, - rowsPerSecond: Long, - rampUpTimeSeconds: Long, - numPartitions: Int, - useManualClock: Boolean) extends Source with Logging { - - import RateSourceProvider._ - import RateStreamSource._ - - val clock = if (useManualClock) new ManualClock else new SystemClock - - private val maxSeconds = Long.MaxValue / rowsPerSecond - - if (rampUpTimeSeconds > maxSeconds) { - throw new ArithmeticException( - s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + - s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") - } - - private val startTimeMs = { - val metadataLog = - new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) { - override def serialize(metadata: LongOffset, out: OutputStream): Unit = { - val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) - writer.write("v" + VERSION + "\n") - writer.write(metadata.json) - writer.flush - } - - override def deserialize(in: InputStream): LongOffset = { - val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) - // HDFSMetadataLog guarantees that it never creates a partial file. - assert(content.length != 0) - if (content(0) == 'v') { - val indexOfNewLine = content.indexOf("\n") - if (indexOfNewLine > 0) { - val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) - LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } - } - - metadataLog.get(0).getOrElse { - val offset = LongOffset(clock.getTimeMillis()) - metadataLog.add(0, offset) - logInfo(s"Start time: $offset") - offset - }.offset - } - - /** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */ - @volatile private var lastTimeMs = startTimeMs - - override def schema: StructType = RateSourceProvider.SCHEMA - - override def getOffset: Option[Offset] = { - val now = clock.getTimeMillis() - if (lastTimeMs < now) { - lastTimeMs = now - } - Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs))) - } - - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L) - val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) - assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") - if (endSeconds > maxSeconds) { - throw new ArithmeticException("Integer overflow. Max offset with " + - s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") - } - // Fix "lastTimeMs" for recovery - if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) { - lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs - } - val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) - val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) - logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + - s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") - - if (rangeStart == rangeEnd) { - return sqlContext.internalCreateDataFrame( - sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) - } - - val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) - val relativeMsPerValue = - TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) - - val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v => - val relative = math.round((v - rangeStart) * relativeMsPerValue) - InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v) - } - sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) - } - - override def stop(): Unit = {} - - override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " + - s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]" -} - -object RateStreamSource { - - /** Calculate the end value we will emit at the time `seconds`. */ - def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { - // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 - // Then speedDeltaPerSecond = 2 - // - // seconds = 0 1 2 3 4 5 6 - // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) - // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 - val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) - if (seconds <= rampUpTimeSeconds) { - // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to - // avoid overflow - if (seconds % 2 == 1) { - (seconds + 1) / 2 * speedDeltaPerSecond * seconds - } else { - seconds / 2 * speedDeltaPerSecond * (seconds + 1) - } - } else { - // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds - val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) - rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 20d90069163a6..2f0de2612c150 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -24,8 +24,8 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} +import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} @@ -40,8 +40,8 @@ class RateStreamContinuousReader(options: DataSourceOptions) val creationTime = System.currentTimeMillis() - val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt - val rowsPerSecond = options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong + val numPartitions = options.get(RateStreamProvider.NUM_PARTITIONS).orElse("5").toInt + val rowsPerSecond = options.get(RateStreamProvider.ROWS_PER_SECOND).orElse("6").toLong val perPartitionRate = rowsPerSecond.toDouble / numPartitions.toDouble override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { @@ -57,12 +57,12 @@ class RateStreamContinuousReader(options: DataSourceOptions) RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def readSchema(): StructType = RateSourceProvider.SCHEMA + override def readSchema(): StructType = RateStreamProvider.SCHEMA private var offset: Offset = _ override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { - this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime)) + this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime)) } override def getStartOffset(): Offset = offset @@ -98,6 +98,19 @@ class RateStreamContinuousReader(options: DataSourceOptions) override def commit(end: Offset): Unit = {} override def stop(): Unit = {} + private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { + RateStreamOffset( + Range(0, numPartitions).map { i => + // Note that the starting offset is exclusive, so we have to decrement the starting value + // by the increment that will later be applied. The first row output in each + // partition will have a value equal to the partition index. + (i, + ValueRunTimeMsPair( + (i - numPartitions).toLong, + creationTimeMs)) + }.toMap) + } + } case class RateStreamContinuousDataReaderFactory( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala new file mode 100644 index 0000000000000..6cf8520fc544f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io._ +import java.nio.charset.StandardCharsets +import java.util.Optional +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.IOUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{ManualClock, SystemClock} + +class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) + extends MicroBatchReader with Logging { + import RateStreamProvider._ + + private[sources] val clock = { + // The option to use a manual clock is provided only for unit testing purposes. + if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock + } + + private val rowsPerSecond = + options.get(ROWS_PER_SECOND).orElse("1").toLong + + private val rampUpTimeSeconds = + Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String])) + .map(JavaUtils.timeStringAsSec(_)) + .getOrElse(0L) + + private val maxSeconds = Long.MaxValue / rowsPerSecond + + if (rampUpTimeSeconds > maxSeconds) { + throw new ArithmeticException( + s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + + s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") + } + + private[sources] val creationTimeMs = { + val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) + require(session.isDefined) + + val metadataLog = + new HDFSMetadataLog[LongOffset](session.get, checkpointLocation) { + override def serialize(metadata: LongOffset, out: OutputStream): Unit = { + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): LongOffset = { + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + parseVersion(content.substring(0, indexOfNewLine), VERSION) + LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } + } + + metadataLog.get(0).getOrElse { + val offset = LongOffset(clock.getTimeMillis()) + metadataLog.add(0, offset) + logInfo(s"Start time: $offset") + offset + }.offset + } + + @volatile private var lastTimeMs: Long = creationTimeMs + + private var start: LongOffset = _ + private var end: LongOffset = _ + + override def readSchema(): StructType = SCHEMA + + override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { + this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset] + this.end = end.orElse { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) + }.asInstanceOf[LongOffset] + } + + override def getStartOffset(): Offset = { + if (start == null) throw new IllegalStateException("start offset not set") + start + } + override def getEndOffset(): Offset = { + if (end == null) throw new IllegalStateException("end offset not set") + end + } + + override def deserializeOffset(json: String): Offset = { + LongOffset(json.toLong) + } + + override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) + val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) + assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") + if (endSeconds > maxSeconds) { + throw new ArithmeticException("Integer overflow. Max offset with " + + s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") + } + // Fix "lastTimeMs" for recovery + if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs) { + lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs + } + val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) + val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) + logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + + s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") + + if (rangeStart == rangeEnd) { + return List.empty.asJava + } + + val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) + val relativeMsPerValue = + TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) + val numPartitions = { + val activeSession = SparkSession.getActiveSession + require(activeSession.isDefined) + Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String])) + .map(_.toInt) + .getOrElse(activeSession.get.sparkContext.defaultParallelism) + } + + (0 until numPartitions).map { p => + new RateStreamMicroBatchDataReaderFactory( + p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) + : DataReaderFactory[Row] + }.toList.asJava + } + + override def commit(end: Offset): Unit = {} + + override def stop(): Unit = {} + + override def toString: String = s"MicroBatchRateSource[rowsPerSecond=$rowsPerSecond, " + + s"rampUpTimeSeconds=$rampUpTimeSeconds, " + + s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" +} + +class RateStreamMicroBatchDataReaderFactory( + partitionId: Int, + numPartitions: Int, + rangeStart: Long, + rangeEnd: Long, + localStartTimeMs: Long, + relativeMsPerValue: Double) extends DataReaderFactory[Row] { + + override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader( + partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) +} + +class RateStreamMicroBatchDataReader( + partitionId: Int, + numPartitions: Int, + rangeStart: Long, + rangeEnd: Long, + localStartTimeMs: Long, + relativeMsPerValue: Double) extends DataReader[Row] { + private var count = 0 + + override def next(): Boolean = { + rangeStart + partitionId + numPartitions * count < rangeEnd + } + + override def get(): Row = { + val currValue = rangeStart + partitionId + numPartitions * count + count += 1 + val relative = math.round((currValue - rangeStart) * relativeMsPerValue) + Row( + DateTimeUtils.toJavaTimestamp( + DateTimeUtils.fromMillis(relative + localStartTimeMs)), + currValue + ) + } + + override def close(): Unit = {} +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala new file mode 100644 index 0000000000000..6bdd492f0cb35 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.util.Optional + +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader} +import org.apache.spark.sql.types._ + +/** + * A source that generates increment long values with timestamps. Each generated row has two + * columns: a timestamp column for the generated time and an auto increment long column starting + * with 0L. + * + * This source supports the following options: + * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. + * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed + * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer + * seconds. + * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the + * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may + * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. + */ +class RateStreamProvider extends DataSourceV2 + with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister { + import RateStreamProvider._ + + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): MicroBatchReader = { + if (options.get(ROWS_PER_SECOND).isPresent) { + val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong + if (rowsPerSecond <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive") + } + } + + if (options.get(RAMP_UP_TIME).isPresent) { + val rampUpTimeSeconds = + JavaUtils.timeStringAsSec(options.get(RAMP_UP_TIME).get()) + if (rampUpTimeSeconds < 0) { + throw new IllegalArgumentException( + s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative") + } + } + + if (options.get(NUM_PARTITIONS).isPresent) { + val numPartitions = options.get(NUM_PARTITIONS).get().toInt + if (numPartitions <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") + } + } + + if (schema.isPresent) { + throw new AnalysisException("The rate source does not support a user-specified schema.") + } + + new RateStreamMicroBatchReader(options, checkpointLocation) + } + + override def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options) + + override def shortName(): String = "rate" +} + +object RateStreamProvider { + val SCHEMA = + StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) + + val VERSION = 1 + + val NUM_PARTITIONS = "numPartitions" + val ROWS_PER_SECOND = "rowsPerSecond" + val RAMP_UP_TIME = "rampUpTime" + + /** Calculate the end value we will emit at the time `seconds`. */ + def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { + // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 + // Then speedDeltaPerSecond = 2 + // + // seconds = 0 1 2 3 4 5 6 + // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) + // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 + val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) + if (seconds <= rampUpTimeSeconds) { + // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to + // avoid overflow + if (seconds % 2 == 1) { + (seconds + 1) / 2 * speedDeltaPerSecond * seconds + } else { + seconds / 2 * speedDeltaPerSecond * (seconds + 1) + } + } else { + // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds + val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) + rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala deleted file mode 100644 index 4e2459bb05bd6..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.util.Optional - -import scala.collection.JavaConverters._ -import scala.collection.mutable - -import org.json4s.DefaultFormats -import org.json4s.jackson.Serialization - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} -import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} -import org.apache.spark.util.{ManualClock, SystemClock} - -/** - * This is a temporary register as we build out v2 migration. Microbatch read support should - * be implemented in the same register as v1. - */ -class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with DataSourceRegister { - override def createMicroBatchReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = { - new RateStreamMicroBatchReader(options) - } - - override def shortName(): String = "ratev2" -} - -class RateStreamMicroBatchReader(options: DataSourceOptions) - extends MicroBatchReader { - implicit val defaultFormats: DefaultFormats = DefaultFormats - - val clock = { - // The option to use a manual clock is provided only for unit testing purposes. - if (options.get("useManualClock").orElse("false").toBoolean) new ManualClock - else new SystemClock - } - - private val numPartitions = - options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt - private val rowsPerSecond = - options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong - - // The interval (in milliseconds) between rows in each partition. - // e.g. if there are 4 global rows per second, and 2 partitions, each partition - // should output rows every (1000 * 2 / 4) = 500 ms. - private val msPerPartitionBetweenRows = (1000 * numPartitions) / rowsPerSecond - - override def readSchema(): StructType = { - StructType( - StructField("timestamp", TimestampType, false) :: - StructField("value", LongType, false) :: Nil) - } - - val creationTimeMs = clock.getTimeMillis() - - private var start: RateStreamOffset = _ - private var end: RateStreamOffset = _ - - override def setOffsetRange( - start: Optional[Offset], - end: Optional[Offset]): Unit = { - this.start = start.orElse( - RateStreamSourceV2.createInitialOffset(numPartitions, creationTimeMs)) - .asInstanceOf[RateStreamOffset] - - this.end = end.orElse { - val currentTime = clock.getTimeMillis() - RateStreamOffset( - this.start.partitionToValueAndRunTimeMs.map { - case startOffset @ (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => - // Calculate the number of rows we should advance in this partition (based on the - // current time), and output a corresponding offset. - val readInterval = currentTime - currentReadTime - val numNewRows = readInterval / msPerPartitionBetweenRows - if (numNewRows <= 0) { - startOffset - } else { - (part, ValueRunTimeMsPair( - currentVal + (numNewRows * numPartitions), - currentReadTime + (numNewRows * msPerPartitionBetweenRows))) - } - } - ) - }.asInstanceOf[RateStreamOffset] - } - - override def getStartOffset(): Offset = { - if (start == null) throw new IllegalStateException("start offset not set") - start - } - override def getEndOffset(): Offset = { - if (end == null) throw new IllegalStateException("end offset not set") - end - } - - override def deserializeOffset(json: String): Offset = { - RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) - } - - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { - val startMap = start.partitionToValueAndRunTimeMs - val endMap = end.partitionToValueAndRunTimeMs - endMap.keys.toSeq.map { part => - val ValueRunTimeMsPair(endVal, _) = endMap(part) - val ValueRunTimeMsPair(startVal, startTimeMs) = startMap(part) - - val packedRows = mutable.ListBuffer[(Long, Long)]() - var outVal = startVal + numPartitions - var outTimeMs = startTimeMs - while (outVal <= endVal) { - packedRows.append((outTimeMs, outVal)) - outVal += numPartitions - outTimeMs += msPerPartitionBetweenRows - } - - RateStreamBatchTask(packedRows).asInstanceOf[DataReaderFactory[Row]] - }.toList.asJava - } - - override def commit(end: Offset): Unit = {} - override def stop(): Unit = {} -} - -case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactory[Row] { - override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals) -} - -class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] { - private var currentIndex = -1 - - override def next(): Boolean = { - // Return true as long as the new index is in the seq. - currentIndex += 1 - currentIndex < vals.size - } - - override def get(): Row = { - Row( - DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(vals(currentIndex)._1)), - vals(currentIndex)._2) - } - - override def close(): Unit = {} -} - -object RateStreamSourceV2 { - val NUM_PARTITIONS = "numPartitions" - val ROWS_PER_SECOND = "rowsPerSecond" - - private[sql] def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { - RateStreamOffset( - Range(0, numPartitions).map { i => - // Note that the starting offset is exclusive, so we have to decrement the starting value - // by the increment that will later be applied. The first row output in each - // partition will have a value equal to the partition index. - (i, - ValueRunTimeMsPair( - (i - numPartitions).toLong, - creationTimeMs)) - }.toMap) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala deleted file mode 100644 index 03d0f63fa4d7f..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.util.concurrent.TimeUnit - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} -import org.apache.spark.util.ManualClock - -class RateSourceSuite extends StreamTest { - - import testImplicits._ - - case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (Source, Offset) = { - assert(query.nonEmpty) - val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] => - source.asInstanceOf[RateStreamSource] - }.head - rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) - (rateSource, rateSource.getOffset.get) - } - } - - test("basic") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("useManualClock", "true") - .load() - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), - StopStream, - StartStream(), - // Advance 2 seconds because creating a new RateSource will also create a new ManualClock - AdvanceRateManualClock(seconds = 2), - CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) - ) - } - - test("uniform distribution of event timestamps") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "1500") - .option("useManualClock", "true") - .load() - .as[(java.sql.Timestamp, Long)] - .map(v => (v._1.getTime, v._2)) - val expectedAnswer = (0 until 1500).map { v => - (math.round(v * (1000.0 / 1500)), v) - } - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch(expectedAnswer: _*) - ) - } - - test("valueAtSecond") { - import RateStreamSource._ - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1) - assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3) - assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8) - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2) - assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6) - assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) - assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) - assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) - } - - test("rampUpTime") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("rampUpTime", "4s") - .option("useManualClock", "true") - .load() - .as[(java.sql.Timestamp, Long)] - .map(v => (v._1.getTime, v._2)) - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch({ - Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11) - }: _*), // speed = 6 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8 - AdvanceRateManualClock(seconds = 1), - // Now we should reach full speed - CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10 - ) - } - - test("numPartitions") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("numPartitions", "6") - .option("useManualClock", "true") - .load() - .select(spark_partition_id()) - .distinct() - testStream(input)( - AdvanceRateManualClock(1), - CheckLastBatch((0 until 6): _*) - ) - } - - testQuietly("overflow") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", Long.MaxValue.toString) - .option("useManualClock", "true") - .load() - .select(spark_partition_id()) - .distinct() - testStream(input)( - AdvanceRateManualClock(2), - ExpectFailure[ArithmeticException](t => { - Seq("overflow", "rowsPerSecond").foreach { msg => - assert(t.getMessage.contains(msg)) - } - }) - ) - } - - testQuietly("illegal option values") { - def testIllegalOptionValue( - option: String, - value: String, - expectedMessages: Seq[String]): Unit = { - val e = intercept[StreamingQueryException] { - spark.readStream - .format("rate") - .option(option, value) - .load() - .writeStream - .format("console") - .start() - .awaitTermination() - } - assert(e.getCause.isInstanceOf[IllegalArgumentException]) - for (msg <- expectedMessages) { - assert(e.getCause.getMessage.contains(msg)) - } - } - - testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) - testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) - } - - test("user-specified schema given") { - val exception = intercept[AnalysisException] { - spark.readStream - .format("rate") - .schema(spark.range(1).schema) - .load() - } - assert(exception.getMessage.contains( - "rate source does not support a user-specified schema")) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala deleted file mode 100644 index 983ba1668f58f..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ /dev/null @@ -1,191 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.util.Optional -import java.util.concurrent.TimeUnit - -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.Row -import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.util.ManualClock - -class RateSourceV2Suite extends StreamTest { - import testImplicits._ - - case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { - assert(query.nonEmpty) - val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source - }.head - rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) - rateSource.setOffsetRange(Optional.empty(), Optional.empty()) - (rateSource, rateSource.getEndOffset()) - } - } - - test("microbatch in registry") { - DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamMicroBatchReader]) - case _ => - throw new IllegalStateException("Could not find v2 read support for rate") - } - } - - test("basic microbatch execution") { - val input = spark.readStream - .format("rateV2") - .option("numPartitions", "1") - .option("rowsPerSecond", "10") - .option("useManualClock", "true") - .load() - testStream(input, useV2Sink = true)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), - StopStream, - StartStream(), - // Advance 2 seconds because creating a new RateSource will also create a new ManualClock - AdvanceRateManualClock(seconds = 2), - CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) - ) - } - - test("microbatch - numPartitions propagated") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 11) - } - - test("microbatch - set offset") { - val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty()) - val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) - val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000)))) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - assert(reader.getStartOffset() == startOffset) - assert(reader.getEndOffset() == endOffset) - } - - test("microbatch - infer offsets") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) - reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - reader.getStartOffset() match { - case r: RateStreamOffset => - assert(r.partitionToValueAndRunTimeMs(0).runTimeMs == reader.creationTimeMs) - case _ => throw new IllegalStateException("unexpected offset type") - } - reader.getEndOffset() match { - case r: RateStreamOffset => - // End offset may be a bit beyond 100 ms/9 rows after creation if the wait lasted - // longer than 100ms. It should never be early. - assert(r.partitionToValueAndRunTimeMs(0).value >= 9) - assert(r.partitionToValueAndRunTimeMs(0).runTimeMs >= reader.creationTimeMs + 100) - - case _ => throw new IllegalStateException("unexpected offset type") - } - } - - test("microbatch - predetermined batch size") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) - val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) - val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 1) - assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20) - } - - test("microbatch - data read") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) - val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) - val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { - case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => - (part, ValueRunTimeMsPair(currentVal + 33, currentReadTime + 1000)) - }.toMap) - - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 11) - - val readData = tasks.asScala - .map(_.createDataReader()) - .flatMap { reader => - val buf = scala.collection.mutable.ListBuffer[Row]() - while (reader.next()) buf.append(reader.get()) - buf - } - - assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) - } - - test("continuous in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: ContinuousReadSupport => - val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamContinuousReader]) - case _ => - throw new IllegalStateException("Could not find v2 read support for rate") - } - } - - test("continuous data") { - val reader = new RateStreamContinuousReader( - new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 2) - - val data = scala.collection.mutable.ListBuffer[Row]() - tasks.asScala.foreach { - case t: RateStreamContinuousDataReaderFactory => - val startTimeMs = reader.getStartOffset() - .asInstanceOf[RateStreamOffset] - .partitionToValueAndRunTimeMs(t.partitionIndex) - .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] - for (rowIndex <- 0 to 9) { - r.next() - data.append(r.get()) - assert(r.getOffset() == - RateStreamPartitionOffset( - t.partitionIndex, - t.partitionIndex + rowIndex * 2, - startTimeMs + (rowIndex + 1) * 100)) - } - assert(System.currentTimeMillis() >= startTimeMs + 1000) - - case _ => throw new IllegalStateException("Unexpected task type") - } - - assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala new file mode 100644 index 0000000000000..9149e50962255 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.nio.file.Files +import java.util.Optional +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.Offset +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.util.ManualClock + +class RateSourceSuite extends StreamTest { + + import testImplicits._ + + protected override def beforeAll(): Unit = { + super.beforeAll() + SparkSession.setActiveSession(spark) + } + + override def afterAll(): Unit = { + SparkSession.clearActiveSession() + super.afterAll() + } + + case class AdvanceRateManualClock(seconds: Long) extends AddData { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + assert(query.nonEmpty) + val rateSource = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source + }.head + + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) + val offset = LongOffset(TimeUnit.MILLISECONDS.toSeconds( + rateSource.clock.getTimeMillis() - rateSource.creationTimeMs)) + (rateSource, offset) + } + } + + test("microbatch in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + val reader = ds.createMicroBatchReader(Optional.empty(), "dummy", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamMicroBatchReader]) + case _ => + throw new IllegalStateException("Could not find read support for rate") + } + } + + test("compatible with old path in registry") { + DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", + spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + assert(ds.isInstanceOf[RateStreamProvider]) + case _ => + throw new IllegalStateException("Could not find read support for rate") + } + } + + test("microbatch - basic") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("useManualClock", "true") + .load() + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + StopStream, + StartStream(), + // Advance 2 seconds because creating a new RateSource will also create a new ManualClock + AdvanceRateManualClock(seconds = 2), + CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + + test("microbatch - uniform distribution of event timestamps") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "1500") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + val expectedAnswer = (0 until 1500).map { v => + (math.round(v * (1000.0 / 1500)), v) + } + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(expectedAnswer: _*) + ) + } + + test("microbatch - set offset") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + assert(reader.getStartOffset() == startOffset) + assert(reader.getEndOffset() == endOffset) + } + + test("microbatch - infer offsets") { + val tempFolder = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions( + Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), + tempFolder) + reader.clock.asInstanceOf[ManualClock].advance(100000) + reader.setOffsetRange(Optional.empty(), Optional.empty()) + reader.getStartOffset() match { + case r: LongOffset => assert(r.offset === 0L) + case _ => throw new IllegalStateException("unexpected offset type") + } + reader.getEndOffset() match { + case r: LongOffset => assert(r.offset >= 100) + case _ => throw new IllegalStateException("unexpected offset type") + } + } + + test("microbatch - predetermined batch size") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 1) + val dataReader = tasks.get(0).createDataReader() + val data = ArrayBuffer[Row]() + while (dataReader.next()) { + data.append(dataReader.get()) + } + assert(data.size === 20) + } + + test("microbatch - data read") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 11) + + val readData = tasks.asScala + .map(_.createDataReader()) + .flatMap { reader => + val buf = scala.collection.mutable.ListBuffer[Row]() + while (reader.next()) buf.append(reader.get()) + buf + } + + assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) + } + + test("valueAtSecond") { + import RateStreamProvider._ + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) + assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) + assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) + } + + test("rampUpTime") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("rampUpTime", "4s") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch({ + Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11) + }: _*), // speed = 6 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8 + AdvanceRateManualClock(seconds = 1), + // Now we should reach full speed + CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10 + ) + } + + test("numPartitions") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("numPartitions", "6") + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(1), + CheckLastBatch((0 until 6): _*) + ) + } + + testQuietly("overflow") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", Long.MaxValue.toString) + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(2), + ExpectFailure[TreeNodeException[_]](t => { + Seq("overflow", "rowsPerSecond").foreach { msg => + assert(t.getCause.getMessage.contains(msg)) + } + }) + ) + } + + testQuietly("illegal option values") { + def testIllegalOptionValue( + option: String, + value: String, + expectedMessages: Seq[String]): Unit = { + val e = intercept[IllegalArgumentException] { + spark.readStream + .format("rate") + .option(option, value) + .load() + .writeStream + .format("console") + .start() + .awaitTermination() + } + for (msg <- expectedMessages) { + assert(e.getMessage.contains(msg)) + } + } + + testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) + testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) + } + + test("user-specified schema given") { + val exception = intercept[AnalysisException] { + spark.readStream + .format("rate") + .schema(spark.range(1).schema) + .load() + } + assert(exception.getMessage.contains( + "rate source does not support a user-specified schema")) + } + + test("continuous in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: ContinuousReadSupport => + val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamContinuousReader]) + case _ => + throw new IllegalStateException("Could not find read support for continuous rate") + } + } + + test("continuous data") { + val reader = new RateStreamContinuousReader( + new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) + reader.setStartOffset(Optional.empty()) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 2) + + val data = scala.collection.mutable.ListBuffer[Row]() + tasks.asScala.foreach { + case t: RateStreamContinuousDataReaderFactory => + val startTimeMs = reader.getStartOffset() + .asInstanceOf[RateStreamOffset] + .partitionToValueAndRunTimeMs(t.partitionIndex) + .runTimeMs + val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] + for (rowIndex <- 0 to 9) { + r.next() + data.append(r.get()) + assert(r.getOffset() == + RateStreamPartitionOffset( + t.partitionIndex, + t.partitionIndex + rowIndex * 2, + startTimeMs + (rowIndex + 1) * 100)) + } + assert(System.currentTimeMillis() >= startTimeMs + 1000) + + case _ => throw new IllegalStateException("Unexpected task type") + } + + assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) + } +} From ed72badb04a56d8046bbd185245abf5ae265ccfd Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 27 Mar 2018 20:06:12 -0700 Subject: [PATCH 529/774] [SPARK-23699][PYTHON][SQL] Raise same type of error caught with Arrow enabled ## What changes were proposed in this pull request? When using Arrow for createDataFrame or toPandas and an error is encountered with fallback disabled, this will raise the same type of error instead of a RuntimeError. This change also allows for the traceback of the error to be retained and prevents the accidental chaining of exceptions with Python 3. ## How was this patch tested? Updated existing tests to verify error type. Author: Bryan Cutler Closes #20839 from BryanCutler/arrow-raise-same-error-SPARK-23699. --- python/pyspark/sql/dataframe.py | 25 +++++++++++++------------ python/pyspark/sql/session.py | 13 +++++++------ python/pyspark/sql/tests.py | 10 +++++----- python/pyspark/sql/utils.py | 6 ++++++ 4 files changed, 31 insertions(+), 23 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 3fc194d8ec1d1..16f8e52dead7b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2007,7 +2007,7 @@ def toPandas(self): "toPandas attempted Arrow optimization because " "'spark.sql.execution.arrow.enabled' is set to true; however, " "failed by the reason below:\n %s\n" - "Attempts non-optimization as " + "Attempting non-optimization as " "'spark.sql.execution.arrow.fallback.enabled' is set to " "true." % _exception_message(e)) warnings.warn(msg) @@ -2015,11 +2015,12 @@ def toPandas(self): else: msg = ( "toPandas attempted Arrow optimization because " - "'spark.sql.execution.arrow.enabled' is set to true; however, " - "failed by the reason below:\n %s\n" - "For fallback to non-optimization automatically, please set true to " - "'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e)) - raise RuntimeError(msg) + "'spark.sql.execution.arrow.enabled' is set to true, but has reached " + "the error below and will not continue because automatic fallback " + "with 'spark.sql.execution.arrow.fallback.enabled' has been set to " + "false.\n %s" % _exception_message(e)) + warnings.warn(msg) + raise # Try to use Arrow optimization when the schema is supported and the required version # of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled. @@ -2042,12 +2043,12 @@ def toPandas(self): # be executed. So, simply fail in this case for now. msg = ( "toPandas attempted Arrow optimization because " - "'spark.sql.execution.arrow.enabled' is set to true; however, " - "failed unexpectedly:\n %s\n" - "Note that 'spark.sql.execution.arrow.fallback.enabled' does " - "not have an effect in such failure in the middle of " - "computation." % _exception_message(e)) - raise RuntimeError(msg) + "'spark.sql.execution.arrow.enabled' is set to true, but has reached " + "the error below and can not continue. Note that " + "'spark.sql.execution.arrow.fallback.enabled' does not have an effect " + "on failures in the middle of computation.\n %s" % _exception_message(e)) + warnings.warn(msg) + raise # Below is toPandas without Arrow optimization. pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e82a9750a0014..13d6e2e53dbd0 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -674,18 +674,19 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr "createDataFrame attempted Arrow optimization because " "'spark.sql.execution.arrow.enabled' is set to true; however, " "failed by the reason below:\n %s\n" - "Attempts non-optimization as " + "Attempting non-optimization as " "'spark.sql.execution.arrow.fallback.enabled' is set to " "true." % _exception_message(e)) warnings.warn(msg) else: msg = ( "createDataFrame attempted Arrow optimization because " - "'spark.sql.execution.arrow.enabled' is set to true; however, " - "failed by the reason below:\n %s\n" - "For fallback to non-optimization automatically, please set true to " - "'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e)) - raise RuntimeError(msg) + "'spark.sql.execution.arrow.enabled' is set to true, but has reached " + "the error below and will not continue because automatic fallback " + "with 'spark.sql.execution.arrow.fallback.enabled' has been set to " + "false.\n %s" % _exception_message(e)) + warnings.warn(msg) + raise data = self._convert_from_pandas(data, schema, timezone) if isinstance(schema, StructType): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 967cc83166f3f..01c5dd6ff8c3f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3559,7 +3559,7 @@ def test_toPandas_fallback_enabled(self): warn.message for warn in warns if isinstance(warn.message, UserWarning)] self.assertTrue(len(user_warns) > 0) self.assertTrue( - "Attempts non-optimization" in _exception_message(user_warns[-1])) + "Attempting non-optimization" in _exception_message(user_warns[-1])) self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) def test_toPandas_fallback_disabled(self): @@ -3682,7 +3682,7 @@ def test_createDataFrame_with_incorrect_schema(self): pdf = self.create_pandas_data_frame() wrong_schema = StructType(list(reversed(self.schema))) with QuietTest(self.sc): - with self.assertRaisesRegexp(RuntimeError, ".*No cast.*string.*timestamp.*"): + with self.assertRaisesRegexp(Exception, ".*No cast.*string.*timestamp.*"): self.spark.createDataFrame(pdf, schema=wrong_schema) def test_createDataFrame_with_names(self): @@ -3707,7 +3707,7 @@ def test_createDataFrame_column_name_encoding(self): def test_createDataFrame_with_single_data_type(self): import pandas as pd with QuietTest(self.sc): - with self.assertRaisesRegexp(RuntimeError, ".*IntegerType.*not supported.*"): + with self.assertRaisesRegexp(ValueError, ".*IntegerType.*not supported.*"): self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int") def test_createDataFrame_does_not_modify_input(self): @@ -3775,14 +3775,14 @@ def test_createDataFrame_fallback_enabled(self): warn.message for warn in warns if isinstance(warn.message, UserWarning)] self.assertTrue(len(user_warns) > 0) self.assertTrue( - "Attempts non-optimization" in _exception_message(user_warns[-1])) + "Attempting non-optimization" in _exception_message(user_warns[-1])) self.assertEqual(df.collect(), [Row(a={u'a': 1})]) def test_createDataFrame_fallback_disabled(self): import pandas as pd with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported type'): + with self.assertRaisesRegexp(TypeError, 'Unsupported type'): self.spark.createDataFrame( pd.DataFrame([[{u'a': 1}]]), "a: map") diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 578298632dd4c..45363f089a73d 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -121,7 +121,10 @@ def require_minimum_pandas_version(): from distutils.version import LooseVersion try: import pandas + have_pandas = True except ImportError: + have_pandas = False + if not have_pandas: raise ImportError("Pandas >= %s must be installed; however, " "it was not found." % minimum_pandas_version) if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version): @@ -138,7 +141,10 @@ def require_minimum_pyarrow_version(): from distutils.version import LooseVersion try: import pyarrow + have_arrow = True except ImportError: + have_arrow = False + if not have_arrow: raise ImportError("PyArrow >= %s must be installed; however, " "it was not found." % minimum_pyarrow_version) if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version): From 34c4b9c57e114cdb390e4dbc7383284d82fea317 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 28 Mar 2018 19:49:27 +0800 Subject: [PATCH 530/774] [SPARK-23765][SQL] Supports custom line separator for json datasource ## What changes were proposed in this pull request? This PR proposes to add lineSep option for a configurable line separator in text datasource. It supports this option by using `LineRecordReader`'s functionality with passing it to the constructor. The approach is similar with https://github.com/apache/spark/pull/20727; however, one main difference is, it uses text datasource's `lineSep` option to parse line by line in JSON's schema inference. ## How was this patch tested? Manually tested and unit tests were added. Author: hyukjinkwon Author: hyukjinkwon Closes #20877 from HyukjinKwon/linesep-json. --- python/pyspark/sql/readwriter.py | 14 ++-- python/pyspark/sql/streaming.py | 6 +- python/pyspark/sql/tests.py | 17 +++++ .../spark/sql/catalyst/json/JSONOptions.scala | 11 ++++ .../sql/catalyst/json/JacksonGenerator.scala | 8 ++- .../apache/spark/sql/DataFrameReader.scala | 2 + .../apache/spark/sql/DataFrameWriter.scala | 2 + .../datasources/json/JsonDataSource.scala | 17 +++-- .../datasources/text/TextOptions.scala | 2 +- .../sql/streaming/DataStreamReader.scala | 2 + .../datasources/json/JsonSuite.scala | 66 ++++++++++++++++++- .../datasources/text/TextSuite.scala | 4 +- 12 files changed, 136 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 4f9b9383a5ef4..6bd79bc2f43e5 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -176,7 +176,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -237,6 +237,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param allowUnquotedControlChars: allows JSON Strings to contain unquoted control characters (ASCII characters with value less than 32, including tab and line feed characters) or not. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes @@ -254,7 +256,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -746,7 +748,8 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) self._jwrite.saveAsTable(name) @since(1.4) - def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None): + def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None, + lineSep=None): """Saves the content of the :class:`DataFrame` in JSON format (`JSON Lines text format or newline-delimited JSON `_) at the specified path. @@ -770,12 +773,15 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. + :param lineSep: defines the line separator that should be used for writing. If None is + set, it uses the default value, ``\\n``. >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) self._set_opts( - compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat) + compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat, + lineSep=lineSep) self._jwrite.json(path) @since(1.4) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index c7907aaaf1f7b..15f9407389864 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -405,7 +405,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None): """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. @@ -468,6 +468,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param allowUnquotedControlChars: allows JSON Strings to contain unquoted control characters (ASCII characters with value less than 32, including tab and line feed characters) or not. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema) >>> json_sdf.isStreaming @@ -482,7 +484,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep) if isinstance(path, basestring): return self._df(self._jreader.json(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 01c5dd6ff8c3f..5181053a0d318 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -676,6 +676,23 @@ def test_multiline_json(self): multiLine=True) self.assertEqual(people1.collect(), people_array.collect()) + def test_linesep_json(self): + df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",") + expected = [Row(_corrupt_record=None, name=u'Michael'), + Row(_corrupt_record=u' "age":30}\n{"name":"Justin"', name=None), + Row(_corrupt_record=u' "age":19}\n', name=None)] + self.assertEqual(df.collect(), expected) + + tpath = tempfile.mkdtemp() + shutil.rmtree(tpath) + try: + df = self.spark.read.json("python/test_support/sql/people.json") + df.write.json(tpath, lineSep="!!") + readback = self.spark.read.json(tpath, lineSep="!!") + self.assertEqual(readback.collect(), df.collect()) + finally: + shutil.rmtree(tpath) + def test_multiline_csv(self): ages_newlines = self.spark.read.csv( "python/test_support/sql/ages_newlines.csv", multiLine=True) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 652412b34478a..5c9adc3332bc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.json +import java.nio.charset.StandardCharsets import java.util.{Locale, TimeZone} import com.fasterxml.jackson.core.{JsonFactory, JsonParser} @@ -85,6 +86,16 @@ private[sql] class JSONOptions( val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) + val lineSeparator: Option[String] = parameters.get("lineSep").map { sep => + require(sep.nonEmpty, "'lineSep' cannot be an empty string.") + sep + } + // Note that the option 'lineSep' uses a different default value in read and write. + val lineSeparatorInRead: Option[Array[Byte]] = + lineSeparator.map(_.getBytes(StandardCharsets.UTF_8)) + // Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8. + val lineSeparatorInWrite: String = lineSeparator.getOrElse("\n") + /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index eb06e4f304f0a..9c413de752a8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.json import java.io.Writer +import java.nio.charset.StandardCharsets import com.fasterxml.jackson.core._ @@ -74,6 +75,8 @@ private[sql] class JacksonGenerator( private val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + private val lineSeparator: String = options.lineSeparatorInWrite + private def makeWriter(dataType: DataType): ValueWriter = dataType match { case NullType => (row: SpecializedGetters, ordinal: Int) => @@ -251,5 +254,8 @@ private[sql] class JacksonGenerator( mapType = dataType.asInstanceOf[MapType])) } - def writeLineEnding(): Unit = gen.writeRaw('\n') + def writeLineEnding(): Unit = { + // Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8. + gen.writeRaw(lineSeparator) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 1a5e47508c070..ae3ba1690f696 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -366,6 +366,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * `java.text.SimpleDateFormat`. This applies to timestamp type. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • + *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index bb93889dc55e9..bbc063148a72c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -518,6 +518,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • + *
  • `lineSep` (default `\n`): defines the line separator that should + * be used for writing.
  • * * * @since 1.4.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 77e7edc8e7a20..5769c09c9a1d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.text.TextFileFormat +import org.apache.spark.sql.execution.datasources.text.{TextFileFormat, TextOptions} import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -92,7 +92,8 @@ object TextInputJsonDataSource extends JsonDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): StructType = { - val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths) + val json: Dataset[String] = createBaseDataset( + sparkSession, inputPaths, parsedOptions.lineSeparator) inferFromDataset(json, parsedOptions) } @@ -104,13 +105,19 @@ object TextInputJsonDataSource extends JsonDataSource { private def createBaseDataset( sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): Dataset[String] = { + inputPaths: Seq[FileStatus], + lineSeparator: Option[String]): Dataset[String] = { + val textOptions = lineSeparator.map { lineSep => + Map(TextOptions.LINE_SEPARATOR -> lineSep) + }.getOrElse(Map.empty[String, String]) + val paths = inputPaths.map(_.getPath.toString) sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, paths = paths, - className = classOf[TextFileFormat].getName + className = classOf[TextFileFormat].getName, + options = textOptions ).resolveRelation(checkFilesExist = false)) .select("value").as(Encoders.STRING) } @@ -120,7 +127,7 @@ object TextInputJsonDataSource extends JsonDataSource { file: PartitionedFile, parser: JacksonParser, schema: StructType): Iterator[InternalRow] = { - val linesReader = new HadoopFileLinesReader(file, conf) + val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) val safeParser = new FailureSafeParser[Text]( input => parser.parse(input, CreateJacksonParser.text, textToUTF8String), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala index 18698df9fd8e5..5c1a35434f7b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala @@ -52,7 +52,7 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti lineSeparatorInRead.getOrElse("\n".getBytes(StandardCharsets.UTF_8)) } -private[text] object TextOptions { +private[datasources] object TextOptions { val COMPRESSION = "compression" val WHOLETEXT = "wholetext" val LINE_SEPARATOR = "lineSep" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 9b17406a816b5..ae93965bc50ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -268,6 +268,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * `java.text.SimpleDateFormat`. This applies to timestamp type. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • + *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 8c8d41ebf115a..10bac0554484a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.nio.charset.StandardCharsets +import java.nio.file.Files import java.sql.{Date, Timestamp} import java.util.Locale @@ -27,7 +28,7 @@ import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{functions => F, _} import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} @@ -2063,4 +2064,67 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } + + def testLineSeparator(lineSep: String): Unit = { + test(s"SPARK-21289: Support line separator - lineSep: '$lineSep'") { + // Read + val data = + s""" + | {"f": + |"a", "f0": 1}$lineSep{"f": + | + |"c", "f0": 2}$lineSep{"f": "d", "f0": 3} + """.stripMargin + val dataWithTrailingLineSep = s"$data$lineSep" + + Seq(data, dataWithTrailingLineSep).foreach { lines => + withTempPath { path => + Files.write(path.toPath, lines.getBytes(StandardCharsets.UTF_8)) + val df = spark.read.option("lineSep", lineSep).json(path.getAbsolutePath) + val expectedSchema = + StructType(StructField("f", StringType) :: StructField("f0", LongType) :: Nil) + checkAnswer(df, Seq(("a", 1), ("c", 2), ("d", 3)).toDF()) + assert(df.schema === expectedSchema) + } + } + + // Write + withTempPath { path => + Seq("a", "b", "c").toDF("value").coalesce(1) + .write.option("lineSep", lineSep).json(path.getAbsolutePath) + val partFile = TestUtils.recursiveList(path).filter(f => f.getName.startsWith("part-")).head + val readBack = new String(Files.readAllBytes(partFile.toPath), StandardCharsets.UTF_8) + assert( + readBack === s"""{"value":"a"}$lineSep{"value":"b"}$lineSep{"value":"c"}$lineSep""") + } + + // Roundtrip + withTempPath { path => + val df = Seq("a", "b", "c").toDF() + df.write.option("lineSep", lineSep).json(path.getAbsolutePath) + val readBack = spark.read.option("lineSep", lineSep).json(path.getAbsolutePath) + checkAnswer(df, readBack) + } + } + } + + // scalastyle:off nonascii + Seq("|", "^", "::", "!!!@3", 0x1E.toChar.toString, "아").foreach { lineSep => + testLineSeparator(lineSep) + } + // scalastyle:on nonascii + + test("""SPARK-21289: Support line separator - default value \r, \r\n and \n""") { + val data = + "{\"f\": \"a\", \"f0\": 1}\r{\"f\": \"c\", \"f0\": 2}\r\n{\"f\": \"d\", \"f0\": 3}\n" + + withTempPath { path => + Files.write(path.toPath, data.getBytes(StandardCharsets.UTF_8)) + val df = spark.read.json(path.getAbsolutePath) + val expectedSchema = + StructType(StructField("f", StringType) :: StructField("f0", LongType) :: Nil) + checkAnswer(df, Seq(("a", 1), ("c", 2), ("d", 3)).toDF()) + assert(df.schema === expectedSchema) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index e8a5299d6ba9d..0e7f3afa9c3ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -208,9 +208,11 @@ class TextSuite extends QueryTest with SharedSQLContext { } } - Seq("|", "^", "::", "!!!@3", 0x1E.toChar.toString).foreach { lineSep => + // scalastyle:off nonascii + Seq("|", "^", "::", "!!!@3", 0x1E.toChar.toString, "아").foreach { lineSep => testLineSeparator(lineSep) } + // scalastyle:on nonascii private def testFile: String = { Thread.currentThread().getContextClassLoader.getResource("test-data/text-suite.txt").toString From 761565a3ccbf7f083e587fee14a27b61867a3886 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 28 Mar 2018 09:11:52 -0700 Subject: [PATCH 531/774] Revert "[SPARK-23096][SS] Migrate rate source to V2" This reverts commit c68ec4e6a1ed9ea13345c7705ea60ff4df7aec7b. --- ...pache.spark.sql.sources.DataSourceRegister | 3 +- .../execution/datasources/DataSource.scala | 6 +- .../streaming/RateSourceProvider.scala | 262 +++++++++++++ .../ContinuousRateStreamSource.scala | 25 +- .../sources/RateStreamMicroBatchReader.scala | 222 ----------- .../sources/RateStreamProvider.scala | 125 ------- .../sources/RateStreamSourceV2.scala | 187 ++++++++++ .../execution/streaming/RateSourceSuite.scala | 194 ++++++++++ .../streaming/RateSourceV2Suite.scala | 191 ++++++++++ .../sources/RateStreamProviderSuite.scala | 344 ------------------ 10 files changed, 844 insertions(+), 715 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 1b37905543b4e..1fe9c093af99f 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,5 +5,6 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider -org.apache.spark.sql.execution.streaming.sources.RateStreamProvider +org.apache.spark.sql.execution.streaming.RateSourceProvider org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider +org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b84ea769808f9..31fa89b4570a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider} +import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode @@ -566,7 +566,6 @@ object DataSource extends Logging { val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" val nativeOrc = classOf[OrcFileFormat].getCanonicalName val socket = classOf[TextSocketSourceProvider].getCanonicalName - val rate = classOf[RateStreamProvider].getCanonicalName Map( "org.apache.spark.sql.jdbc" -> jdbc, @@ -588,8 +587,7 @@ object DataSource extends Logging { "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, "org.apache.spark.ml.source.libsvm" -> libsvm, "com.databricks.spark.csv" -> csv, - "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket, - "org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate + "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala new file mode 100644 index 0000000000000..649fbbfa184ec --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io._ +import java.nio.charset.StandardCharsets +import java.util.Optional +import java.util.concurrent.TimeUnit + +import org.apache.commons.io.IOUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader +import org.apache.spark.sql.types._ +import org.apache.spark.util.{ManualClock, SystemClock} + +/** + * A source that generates increment long values with timestamps. Each generated row has two + * columns: a timestamp column for the generated time and an auto increment long column starting + * with 0L. + * + * This source supports the following options: + * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. + * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed + * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer + * seconds. + * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the + * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may + * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. + */ +class RateSourceProvider extends StreamSourceProvider with DataSourceRegister + with DataSourceV2 with ContinuousReadSupport { + + override def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = { + if (schema.nonEmpty) { + throw new AnalysisException("The rate source does not support a user-specified schema.") + } + + (shortName(), RateSourceProvider.SCHEMA) + } + + override def createSource( + sqlContext: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + val params = CaseInsensitiveMap(parameters) + + val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L) + if (rowsPerSecond <= 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " + + "must be positive") + } + + val rampUpTimeSeconds = + params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L) + if (rampUpTimeSeconds < 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " + + "must not be negative") + } + + val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse( + sqlContext.sparkContext.defaultParallelism) + if (numPartitions <= 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " + + "must be positive") + } + + new RateStreamSource( + sqlContext, + metadataPath, + rowsPerSecond, + rampUpTimeSeconds, + numPartitions, + params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing + ) + } + + override def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): ContinuousReader = { + new RateStreamContinuousReader(options) + } + + override def shortName(): String = "rate" +} + +object RateSourceProvider { + val SCHEMA = + StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) + + val VERSION = 1 +} + +class RateStreamSource( + sqlContext: SQLContext, + metadataPath: String, + rowsPerSecond: Long, + rampUpTimeSeconds: Long, + numPartitions: Int, + useManualClock: Boolean) extends Source with Logging { + + import RateSourceProvider._ + import RateStreamSource._ + + val clock = if (useManualClock) new ManualClock else new SystemClock + + private val maxSeconds = Long.MaxValue / rowsPerSecond + + if (rampUpTimeSeconds > maxSeconds) { + throw new ArithmeticException( + s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + + s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") + } + + private val startTimeMs = { + val metadataLog = + new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) { + override def serialize(metadata: LongOffset, out: OutputStream): Unit = { + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): LongOffset = { + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) + LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } + } + + metadataLog.get(0).getOrElse { + val offset = LongOffset(clock.getTimeMillis()) + metadataLog.add(0, offset) + logInfo(s"Start time: $offset") + offset + }.offset + } + + /** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */ + @volatile private var lastTimeMs = startTimeMs + + override def schema: StructType = RateSourceProvider.SCHEMA + + override def getOffset: Option[Offset] = { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs))) + } + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L) + val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) + assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") + if (endSeconds > maxSeconds) { + throw new ArithmeticException("Integer overflow. Max offset with " + + s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") + } + // Fix "lastTimeMs" for recovery + if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) { + lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs + } + val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) + val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) + logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + + s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") + + if (rangeStart == rangeEnd) { + return sqlContext.internalCreateDataFrame( + sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) + } + + val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) + val relativeMsPerValue = + TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) + + val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v => + val relative = math.round((v - rangeStart) * relativeMsPerValue) + InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v) + } + sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) + } + + override def stop(): Unit = {} + + override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " + + s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]" +} + +object RateStreamSource { + + /** Calculate the end value we will emit at the time `seconds`. */ + def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { + // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 + // Then speedDeltaPerSecond = 2 + // + // seconds = 0 1 2 3 4 5 6 + // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) + // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 + val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) + if (seconds <= rampUpTimeSeconds) { + // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to + // avoid overflow + if (seconds % 2 == 1) { + (seconds + 1) / 2 * speedDeltaPerSecond * seconds + } else { + seconds / 2 * speedDeltaPerSecond * (seconds + 1) + } + } else { + // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds + val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) + rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 2f0de2612c150..20d90069163a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -24,8 +24,8 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider +import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair} +import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} @@ -40,8 +40,8 @@ class RateStreamContinuousReader(options: DataSourceOptions) val creationTime = System.currentTimeMillis() - val numPartitions = options.get(RateStreamProvider.NUM_PARTITIONS).orElse("5").toInt - val rowsPerSecond = options.get(RateStreamProvider.ROWS_PER_SECOND).orElse("6").toLong + val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt + val rowsPerSecond = options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong val perPartitionRate = rowsPerSecond.toDouble / numPartitions.toDouble override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { @@ -57,12 +57,12 @@ class RateStreamContinuousReader(options: DataSourceOptions) RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def readSchema(): StructType = RateStreamProvider.SCHEMA + override def readSchema(): StructType = RateSourceProvider.SCHEMA private var offset: Offset = _ override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { - this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime)) + this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime)) } override def getStartOffset(): Offset = offset @@ -98,19 +98,6 @@ class RateStreamContinuousReader(options: DataSourceOptions) override def commit(end: Offset): Unit = {} override def stop(): Unit = {} - private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { - RateStreamOffset( - Range(0, numPartitions).map { i => - // Note that the starting offset is exclusive, so we have to decrement the starting value - // by the increment that will later be applied. The first row output in each - // partition will have a value equal to the partition index. - (i, - ValueRunTimeMsPair( - (i - numPartitions).toLong, - creationTimeMs)) - }.toMap) - } - } case class RateStreamContinuousDataReaderFactory( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala deleted file mode 100644 index 6cf8520fc544f..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ /dev/null @@ -1,222 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.io._ -import java.nio.charset.StandardCharsets -import java.util.Optional -import java.util.concurrent.TimeUnit - -import scala.collection.JavaConverters._ - -import org.apache.commons.io.IOUtils - -import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.{ManualClock, SystemClock} - -class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) - extends MicroBatchReader with Logging { - import RateStreamProvider._ - - private[sources] val clock = { - // The option to use a manual clock is provided only for unit testing purposes. - if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock - } - - private val rowsPerSecond = - options.get(ROWS_PER_SECOND).orElse("1").toLong - - private val rampUpTimeSeconds = - Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String])) - .map(JavaUtils.timeStringAsSec(_)) - .getOrElse(0L) - - private val maxSeconds = Long.MaxValue / rowsPerSecond - - if (rampUpTimeSeconds > maxSeconds) { - throw new ArithmeticException( - s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + - s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") - } - - private[sources] val creationTimeMs = { - val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) - require(session.isDefined) - - val metadataLog = - new HDFSMetadataLog[LongOffset](session.get, checkpointLocation) { - override def serialize(metadata: LongOffset, out: OutputStream): Unit = { - val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) - writer.write("v" + VERSION + "\n") - writer.write(metadata.json) - writer.flush - } - - override def deserialize(in: InputStream): LongOffset = { - val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) - // HDFSMetadataLog guarantees that it never creates a partial file. - assert(content.length != 0) - if (content(0) == 'v') { - val indexOfNewLine = content.indexOf("\n") - if (indexOfNewLine > 0) { - parseVersion(content.substring(0, indexOfNewLine), VERSION) - LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } - } - - metadataLog.get(0).getOrElse { - val offset = LongOffset(clock.getTimeMillis()) - metadataLog.add(0, offset) - logInfo(s"Start time: $offset") - offset - }.offset - } - - @volatile private var lastTimeMs: Long = creationTimeMs - - private var start: LongOffset = _ - private var end: LongOffset = _ - - override def readSchema(): StructType = SCHEMA - - override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { - this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset] - this.end = end.orElse { - val now = clock.getTimeMillis() - if (lastTimeMs < now) { - lastTimeMs = now - } - LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) - }.asInstanceOf[LongOffset] - } - - override def getStartOffset(): Offset = { - if (start == null) throw new IllegalStateException("start offset not set") - start - } - override def getEndOffset(): Offset = { - if (end == null) throw new IllegalStateException("end offset not set") - end - } - - override def deserializeOffset(json: String): Offset = { - LongOffset(json.toLong) - } - - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { - val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) - val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) - assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") - if (endSeconds > maxSeconds) { - throw new ArithmeticException("Integer overflow. Max offset with " + - s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") - } - // Fix "lastTimeMs" for recovery - if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs) { - lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs - } - val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) - val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) - logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + - s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") - - if (rangeStart == rangeEnd) { - return List.empty.asJava - } - - val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) - val relativeMsPerValue = - TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) - val numPartitions = { - val activeSession = SparkSession.getActiveSession - require(activeSession.isDefined) - Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String])) - .map(_.toInt) - .getOrElse(activeSession.get.sparkContext.defaultParallelism) - } - - (0 until numPartitions).map { p => - new RateStreamMicroBatchDataReaderFactory( - p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) - : DataReaderFactory[Row] - }.toList.asJava - } - - override def commit(end: Offset): Unit = {} - - override def stop(): Unit = {} - - override def toString: String = s"MicroBatchRateSource[rowsPerSecond=$rowsPerSecond, " + - s"rampUpTimeSeconds=$rampUpTimeSeconds, " + - s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" -} - -class RateStreamMicroBatchDataReaderFactory( - partitionId: Int, - numPartitions: Int, - rangeStart: Long, - rangeEnd: Long, - localStartTimeMs: Long, - relativeMsPerValue: Double) extends DataReaderFactory[Row] { - - override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader( - partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) -} - -class RateStreamMicroBatchDataReader( - partitionId: Int, - numPartitions: Int, - rangeStart: Long, - rangeEnd: Long, - localStartTimeMs: Long, - relativeMsPerValue: Double) extends DataReader[Row] { - private var count = 0 - - override def next(): Boolean = { - rangeStart + partitionId + numPartitions * count < rangeEnd - } - - override def get(): Row = { - val currValue = rangeStart + partitionId + numPartitions * count - count += 1 - val relative = math.round((currValue - rangeStart) * relativeMsPerValue) - Row( - DateTimeUtils.toJavaTimestamp( - DateTimeUtils.fromMillis(relative + localStartTimeMs)), - currValue - ) - } - - override def close(): Unit = {} -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala deleted file mode 100644 index 6bdd492f0cb35..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.util.Optional - -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader} -import org.apache.spark.sql.types._ - -/** - * A source that generates increment long values with timestamps. Each generated row has two - * columns: a timestamp column for the generated time and an auto increment long column starting - * with 0L. - * - * This source supports the following options: - * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. - * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed - * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer - * seconds. - * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the - * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may - * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. - */ -class RateStreamProvider extends DataSourceV2 - with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister { - import RateStreamProvider._ - - override def createMicroBatchReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = { - if (options.get(ROWS_PER_SECOND).isPresent) { - val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong - if (rowsPerSecond <= 0) { - throw new IllegalArgumentException( - s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive") - } - } - - if (options.get(RAMP_UP_TIME).isPresent) { - val rampUpTimeSeconds = - JavaUtils.timeStringAsSec(options.get(RAMP_UP_TIME).get()) - if (rampUpTimeSeconds < 0) { - throw new IllegalArgumentException( - s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative") - } - } - - if (options.get(NUM_PARTITIONS).isPresent) { - val numPartitions = options.get(NUM_PARTITIONS).get().toInt - if (numPartitions <= 0) { - throw new IllegalArgumentException( - s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") - } - } - - if (schema.isPresent) { - throw new AnalysisException("The rate source does not support a user-specified schema.") - } - - new RateStreamMicroBatchReader(options, checkpointLocation) - } - - override def createContinuousReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options) - - override def shortName(): String = "rate" -} - -object RateStreamProvider { - val SCHEMA = - StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) - - val VERSION = 1 - - val NUM_PARTITIONS = "numPartitions" - val ROWS_PER_SECOND = "rowsPerSecond" - val RAMP_UP_TIME = "rampUpTime" - - /** Calculate the end value we will emit at the time `seconds`. */ - def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { - // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 - // Then speedDeltaPerSecond = 2 - // - // seconds = 0 1 2 3 4 5 6 - // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) - // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 - val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) - if (seconds <= rampUpTimeSeconds) { - // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to - // avoid overflow - if (seconds % 2 == 1) { - (seconds + 1) / 2 * speedDeltaPerSecond * seconds - } else { - seconds / 2 * speedDeltaPerSecond * (seconds + 1) - } - } else { - // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds - val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) - rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala new file mode 100644 index 0000000000000..4e2459bb05bd6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.util.Optional + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.json4s.DefaultFormats +import org.json4s.jackson.Serialization + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} +import org.apache.spark.util.{ManualClock, SystemClock} + +/** + * This is a temporary register as we build out v2 migration. Microbatch read support should + * be implemented in the same register as v1. + */ +class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with DataSourceRegister { + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): MicroBatchReader = { + new RateStreamMicroBatchReader(options) + } + + override def shortName(): String = "ratev2" +} + +class RateStreamMicroBatchReader(options: DataSourceOptions) + extends MicroBatchReader { + implicit val defaultFormats: DefaultFormats = DefaultFormats + + val clock = { + // The option to use a manual clock is provided only for unit testing purposes. + if (options.get("useManualClock").orElse("false").toBoolean) new ManualClock + else new SystemClock + } + + private val numPartitions = + options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt + private val rowsPerSecond = + options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong + + // The interval (in milliseconds) between rows in each partition. + // e.g. if there are 4 global rows per second, and 2 partitions, each partition + // should output rows every (1000 * 2 / 4) = 500 ms. + private val msPerPartitionBetweenRows = (1000 * numPartitions) / rowsPerSecond + + override def readSchema(): StructType = { + StructType( + StructField("timestamp", TimestampType, false) :: + StructField("value", LongType, false) :: Nil) + } + + val creationTimeMs = clock.getTimeMillis() + + private var start: RateStreamOffset = _ + private var end: RateStreamOffset = _ + + override def setOffsetRange( + start: Optional[Offset], + end: Optional[Offset]): Unit = { + this.start = start.orElse( + RateStreamSourceV2.createInitialOffset(numPartitions, creationTimeMs)) + .asInstanceOf[RateStreamOffset] + + this.end = end.orElse { + val currentTime = clock.getTimeMillis() + RateStreamOffset( + this.start.partitionToValueAndRunTimeMs.map { + case startOffset @ (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => + // Calculate the number of rows we should advance in this partition (based on the + // current time), and output a corresponding offset. + val readInterval = currentTime - currentReadTime + val numNewRows = readInterval / msPerPartitionBetweenRows + if (numNewRows <= 0) { + startOffset + } else { + (part, ValueRunTimeMsPair( + currentVal + (numNewRows * numPartitions), + currentReadTime + (numNewRows * msPerPartitionBetweenRows))) + } + } + ) + }.asInstanceOf[RateStreamOffset] + } + + override def getStartOffset(): Offset = { + if (start == null) throw new IllegalStateException("start offset not set") + start + } + override def getEndOffset(): Offset = { + if (end == null) throw new IllegalStateException("end offset not set") + end + } + + override def deserializeOffset(json: String): Offset = { + RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) + } + + override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + val startMap = start.partitionToValueAndRunTimeMs + val endMap = end.partitionToValueAndRunTimeMs + endMap.keys.toSeq.map { part => + val ValueRunTimeMsPair(endVal, _) = endMap(part) + val ValueRunTimeMsPair(startVal, startTimeMs) = startMap(part) + + val packedRows = mutable.ListBuffer[(Long, Long)]() + var outVal = startVal + numPartitions + var outTimeMs = startTimeMs + while (outVal <= endVal) { + packedRows.append((outTimeMs, outVal)) + outVal += numPartitions + outTimeMs += msPerPartitionBetweenRows + } + + RateStreamBatchTask(packedRows).asInstanceOf[DataReaderFactory[Row]] + }.toList.asJava + } + + override def commit(end: Offset): Unit = {} + override def stop(): Unit = {} +} + +case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactory[Row] { + override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals) +} + +class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] { + private var currentIndex = -1 + + override def next(): Boolean = { + // Return true as long as the new index is in the seq. + currentIndex += 1 + currentIndex < vals.size + } + + override def get(): Row = { + Row( + DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(vals(currentIndex)._1)), + vals(currentIndex)._2) + } + + override def close(): Unit = {} +} + +object RateStreamSourceV2 { + val NUM_PARTITIONS = "numPartitions" + val ROWS_PER_SECOND = "rowsPerSecond" + + private[sql] def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { + RateStreamOffset( + Range(0, numPartitions).map { i => + // Note that the starting offset is exclusive, so we have to decrement the starting value + // by the increment that will later be applied. The first row output in each + // partition will have a value equal to the partition index. + (i, + ValueRunTimeMsPair( + (i - numPartitions).toLong, + creationTimeMs)) + }.toMap) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala new file mode 100644 index 0000000000000..03d0f63fa4d7f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.TimeUnit + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} +import org.apache.spark.util.ManualClock + +class RateSourceSuite extends StreamTest { + + import testImplicits._ + + case class AdvanceRateManualClock(seconds: Long) extends AddData { + override def addData(query: Option[StreamExecution]): (Source, Offset) = { + assert(query.nonEmpty) + val rateSource = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] => + source.asInstanceOf[RateStreamSource] + }.head + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) + (rateSource, rateSource.getOffset.get) + } + } + + test("basic") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("useManualClock", "true") + .load() + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + StopStream, + StartStream(), + // Advance 2 seconds because creating a new RateSource will also create a new ManualClock + AdvanceRateManualClock(seconds = 2), + CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + + test("uniform distribution of event timestamps") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "1500") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + val expectedAnswer = (0 until 1500).map { v => + (math.round(v * (1000.0 / 1500)), v) + } + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(expectedAnswer: _*) + ) + } + + test("valueAtSecond") { + import RateStreamSource._ + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) + assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) + assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) + } + + test("rampUpTime") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("rampUpTime", "4s") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch({ + Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11) + }: _*), // speed = 6 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8 + AdvanceRateManualClock(seconds = 1), + // Now we should reach full speed + CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10 + ) + } + + test("numPartitions") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("numPartitions", "6") + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(1), + CheckLastBatch((0 until 6): _*) + ) + } + + testQuietly("overflow") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", Long.MaxValue.toString) + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(2), + ExpectFailure[ArithmeticException](t => { + Seq("overflow", "rowsPerSecond").foreach { msg => + assert(t.getMessage.contains(msg)) + } + }) + ) + } + + testQuietly("illegal option values") { + def testIllegalOptionValue( + option: String, + value: String, + expectedMessages: Seq[String]): Unit = { + val e = intercept[StreamingQueryException] { + spark.readStream + .format("rate") + .option(option, value) + .load() + .writeStream + .format("console") + .start() + .awaitTermination() + } + assert(e.getCause.isInstanceOf[IllegalArgumentException]) + for (msg <- expectedMessages) { + assert(e.getCause.getMessage.contains(msg)) + } + } + + testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) + testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) + } + + test("user-specified schema given") { + val exception = intercept[AnalysisException] { + spark.readStream + .format("rate") + .schema(spark.range(1).schema) + .load() + } + assert(exception.getMessage.contains( + "rate source does not support a user-specified schema")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala new file mode 100644 index 0000000000000..983ba1668f58f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.Optional +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.util.ManualClock + +class RateSourceV2Suite extends StreamTest { + import testImplicits._ + + case class AdvanceRateManualClock(seconds: Long) extends AddData { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + assert(query.nonEmpty) + val rateSource = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source + }.head + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) + rateSource.setOffsetRange(Optional.empty(), Optional.empty()) + (rateSource, rateSource.getEndOffset()) + } + } + + test("microbatch in registry") { + DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamMicroBatchReader]) + case _ => + throw new IllegalStateException("Could not find v2 read support for rate") + } + } + + test("basic microbatch execution") { + val input = spark.readStream + .format("rateV2") + .option("numPartitions", "1") + .option("rowsPerSecond", "10") + .option("useManualClock", "true") + .load() + testStream(input, useV2Sink = true)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + StopStream, + StartStream(), + // Advance 2 seconds because creating a new RateSource will also create a new ManualClock + AdvanceRateManualClock(seconds = 2), + CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + + test("microbatch - numPartitions propagated") { + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) + reader.setOffsetRange(Optional.empty(), Optional.empty()) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 11) + } + + test("microbatch - set offset") { + val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty()) + val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) + val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000)))) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + assert(reader.getStartOffset() == startOffset) + assert(reader.getEndOffset() == endOffset) + } + + test("microbatch - infer offsets") { + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) + reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100) + reader.setOffsetRange(Optional.empty(), Optional.empty()) + reader.getStartOffset() match { + case r: RateStreamOffset => + assert(r.partitionToValueAndRunTimeMs(0).runTimeMs == reader.creationTimeMs) + case _ => throw new IllegalStateException("unexpected offset type") + } + reader.getEndOffset() match { + case r: RateStreamOffset => + // End offset may be a bit beyond 100 ms/9 rows after creation if the wait lasted + // longer than 100ms. It should never be early. + assert(r.partitionToValueAndRunTimeMs(0).value >= 9) + assert(r.partitionToValueAndRunTimeMs(0).runTimeMs >= reader.creationTimeMs + 100) + + case _ => throw new IllegalStateException("unexpected offset type") + } + } + + test("microbatch - predetermined batch size") { + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) + val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) + val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 1) + assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20) + } + + test("microbatch - data read") { + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) + val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) + val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { + case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => + (part, ValueRunTimeMsPair(currentVal + 33, currentReadTime + 1000)) + }.toMap) + + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 11) + + val readData = tasks.asScala + .map(_.createDataReader()) + .flatMap { reader => + val buf = scala.collection.mutable.ListBuffer[Row]() + while (reader.next()) buf.append(reader.get()) + buf + } + + assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) + } + + test("continuous in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: ContinuousReadSupport => + val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamContinuousReader]) + case _ => + throw new IllegalStateException("Could not find v2 read support for rate") + } + } + + test("continuous data") { + val reader = new RateStreamContinuousReader( + new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) + reader.setStartOffset(Optional.empty()) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 2) + + val data = scala.collection.mutable.ListBuffer[Row]() + tasks.asScala.foreach { + case t: RateStreamContinuousDataReaderFactory => + val startTimeMs = reader.getStartOffset() + .asInstanceOf[RateStreamOffset] + .partitionToValueAndRunTimeMs(t.partitionIndex) + .runTimeMs + val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] + for (rowIndex <- 0 to 9) { + r.next() + data.append(r.get()) + assert(r.getOffset() == + RateStreamPartitionOffset( + t.partitionIndex, + t.partitionIndex + rowIndex * 2, + startTimeMs + (rowIndex + 1) * 100)) + } + assert(System.currentTimeMillis() >= startTimeMs + 1000) + + case _ => throw new IllegalStateException("Unexpected task type") + } + + assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala deleted file mode 100644 index 9149e50962255..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ /dev/null @@ -1,344 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.nio.file.Files -import java.util.Optional -import java.util.concurrent.TimeUnit - -import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.Offset -import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.util.ManualClock - -class RateSourceSuite extends StreamTest { - - import testImplicits._ - - protected override def beforeAll(): Unit = { - super.beforeAll() - SparkSession.setActiveSession(spark) - } - - override def afterAll(): Unit = { - SparkSession.clearActiveSession() - super.afterAll() - } - - case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { - assert(query.nonEmpty) - val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source - }.head - - rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) - val offset = LongOffset(TimeUnit.MILLISECONDS.toSeconds( - rateSource.clock.getTimeMillis() - rateSource.creationTimeMs)) - (rateSource, offset) - } - } - - test("microbatch in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader(Optional.empty(), "dummy", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamMicroBatchReader]) - case _ => - throw new IllegalStateException("Could not find read support for rate") - } - } - - test("compatible with old path in registry") { - DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", - spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - assert(ds.isInstanceOf[RateStreamProvider]) - case _ => - throw new IllegalStateException("Could not find read support for rate") - } - } - - test("microbatch - basic") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("useManualClock", "true") - .load() - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), - StopStream, - StartStream(), - // Advance 2 seconds because creating a new RateSource will also create a new ManualClock - AdvanceRateManualClock(seconds = 2), - CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) - ) - } - - test("microbatch - uniform distribution of event timestamps") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "1500") - .option("useManualClock", "true") - .load() - .as[(java.sql.Timestamp, Long)] - .map(v => (v._1.getTime, v._2)) - val expectedAnswer = (0 until 1500).map { v => - (math.round(v * (1000.0 / 1500)), v) - } - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch(expectedAnswer: _*) - ) - } - - test("microbatch - set offset") { - val temp = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - assert(reader.getStartOffset() == startOffset) - assert(reader.getEndOffset() == endOffset) - } - - test("microbatch - infer offsets") { - val tempFolder = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions( - Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), - tempFolder) - reader.clock.asInstanceOf[ManualClock].advance(100000) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - reader.getStartOffset() match { - case r: LongOffset => assert(r.offset === 0L) - case _ => throw new IllegalStateException("unexpected offset type") - } - reader.getEndOffset() match { - case r: LongOffset => assert(r.offset >= 100) - case _ => throw new IllegalStateException("unexpected offset type") - } - } - - test("microbatch - predetermined batch size") { - val temp = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 1) - val dataReader = tasks.get(0).createDataReader() - val data = ArrayBuffer[Row]() - while (dataReader.next()) { - data.append(dataReader.get()) - } - assert(data.size === 20) - } - - test("microbatch - data read") { - val temp = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 11) - - val readData = tasks.asScala - .map(_.createDataReader()) - .flatMap { reader => - val buf = scala.collection.mutable.ListBuffer[Row]() - while (reader.next()) buf.append(reader.get()) - buf - } - - assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) - } - - test("valueAtSecond") { - import RateStreamProvider._ - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1) - assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3) - assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8) - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2) - assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6) - assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) - assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) - assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) - } - - test("rampUpTime") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("rampUpTime", "4s") - .option("useManualClock", "true") - .load() - .as[(java.sql.Timestamp, Long)] - .map(v => (v._1.getTime, v._2)) - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch({ - Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11) - }: _*), // speed = 6 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8 - AdvanceRateManualClock(seconds = 1), - // Now we should reach full speed - CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10 - ) - } - - test("numPartitions") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("numPartitions", "6") - .option("useManualClock", "true") - .load() - .select(spark_partition_id()) - .distinct() - testStream(input)( - AdvanceRateManualClock(1), - CheckLastBatch((0 until 6): _*) - ) - } - - testQuietly("overflow") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", Long.MaxValue.toString) - .option("useManualClock", "true") - .load() - .select(spark_partition_id()) - .distinct() - testStream(input)( - AdvanceRateManualClock(2), - ExpectFailure[TreeNodeException[_]](t => { - Seq("overflow", "rowsPerSecond").foreach { msg => - assert(t.getCause.getMessage.contains(msg)) - } - }) - ) - } - - testQuietly("illegal option values") { - def testIllegalOptionValue( - option: String, - value: String, - expectedMessages: Seq[String]): Unit = { - val e = intercept[IllegalArgumentException] { - spark.readStream - .format("rate") - .option(option, value) - .load() - .writeStream - .format("console") - .start() - .awaitTermination() - } - for (msg <- expectedMessages) { - assert(e.getMessage.contains(msg)) - } - } - - testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) - testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) - } - - test("user-specified schema given") { - val exception = intercept[AnalysisException] { - spark.readStream - .format("rate") - .schema(spark.range(1).schema) - .load() - } - assert(exception.getMessage.contains( - "rate source does not support a user-specified schema")) - } - - test("continuous in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: ContinuousReadSupport => - val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamContinuousReader]) - case _ => - throw new IllegalStateException("Could not find read support for continuous rate") - } - } - - test("continuous data") { - val reader = new RateStreamContinuousReader( - new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 2) - - val data = scala.collection.mutable.ListBuffer[Row]() - tasks.asScala.foreach { - case t: RateStreamContinuousDataReaderFactory => - val startTimeMs = reader.getStartOffset() - .asInstanceOf[RateStreamOffset] - .partitionToValueAndRunTimeMs(t.partitionIndex) - .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] - for (rowIndex <- 0 to 9) { - r.next() - data.append(r.get()) - assert(r.getOffset() == - RateStreamPartitionOffset( - t.partitionIndex, - t.partitionIndex + rowIndex * 2, - startTimeMs + (rowIndex + 1) * 100)) - } - assert(System.currentTimeMillis() >= startTimeMs + 1000) - - case _ => throw new IllegalStateException("Unexpected task type") - } - - assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) - } -} From ea2fdc0d286e449884de44f22a908a26ab1248a5 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Wed, 28 Mar 2018 19:49:32 -0500 Subject: [PATCH 532/774] [SPARK-23675][WEB-UI] Title add spark logo, use spark logo image ## What changes were proposed in this pull request? Title add spark logo, use spark logo image. reference other big data system ui, so i think spark should add it. spark fix before: ![spark_fix_before](https://user-images.githubusercontent.com/26266482/37387866-2d5add0e-2799-11e8-9165-250f2b59df3f.png) spark fix after: ![spark_fix_after](https://user-images.githubusercontent.com/26266482/37387874-329e1876-2799-11e8-8bc5-c619fc1e680e.png) reference kafka ui: ![kafka](https://user-images.githubusercontent.com/26266482/37387878-35ca89d0-2799-11e8-834e-1598ae7158e1.png) reference storm ui: ![storm](https://user-images.githubusercontent.com/26266482/37387880-3854f12c-2799-11e8-8968-b428ba361995.png) reference yarn ui: ![yarn](https://user-images.githubusercontent.com/26266482/37387881-3a72e130-2799-11e8-97bb-dea85f573e95.png) reference nifi ui: ![nifi](https://user-images.githubusercontent.com/26266482/37387887-3cecfea0-2799-11e8-9a71-6c454d25840b.png) reference flink ui: ![flink](https://user-images.githubusercontent.com/26266482/37387888-3f16b1ee-2799-11e8-9d37-8355f0100548.png) ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #20818 from guoxiaolongzte/SPARK-23675. --- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index ba798df13c95d..02cf19e00ecde 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -224,6 +224,7 @@ private[spark] object UIUtils extends Logging { {commonHeaderNodes} {if (showVisualization) vizHeaderNodes else Seq.empty} {if (useDataTables) dataTablesHeaderNodes else Seq.empty} + {appName} - {title} @@ -265,6 +266,7 @@ private[spark] object UIUtils extends Logging { {commonHeaderNodes} {if (useDataTables) dataTablesHeaderNodes else Seq.empty} + {title} From 641aec68e8167546dbb922874c086c9b90198f08 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Thu, 29 Mar 2018 16:37:46 +0800 Subject: [PATCH 533/774] =?UTF-8?q?[SPARK-23806]=20Broadcast.unpersist=20c?= =?UTF-8?q?an=20cause=20fatal=20exception=20when=20used=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … with dynamic allocation ## What changes were proposed in this pull request? ignore errors when you are waiting for a broadcast.unpersist. This is handling it the same way as doing rdd.unpersist in https://issues.apache.org/jira/browse/SPARK-22618 ## How was this patch tested? Patch was tested manually against a couple jobs that exhibit this behavior, with the change the application no longer dies due to this and just prints the warning. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Thomas Graves Closes #20924 from tgravescs/SPARK-23806. --- .../spark/storage/BlockManagerMasterEndpoint.scala | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 89a6a71a589a1..56b95c31eb4c3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -192,11 +192,15 @@ class BlockManagerMasterEndpoint( val requiredBlockManagers = blockManagerInfo.values.filter { info => removeFromDriver || !info.blockManagerId.isDriver } - Future.sequence( - requiredBlockManagers.map { bm => - bm.slaveEndpoint.ask[Int](removeMsg) - }.toSeq - ) + val futures = requiredBlockManagers.map { bm => + bm.slaveEndpoint.ask[Int](removeMsg).recover { + case e: IOException => + logWarning(s"Error trying to remove broadcast $broadcastId", e) + 0 // zero blocks were removed + } + }.toSeq + + Future.sequence(futures) } private def removeBlockManager(blockManagerId: BlockManagerId) { From 505480cb578af9f23acc77bc82348afc9d8468e8 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 29 Mar 2018 19:38:28 +0900 Subject: [PATCH 534/774] [SPARK-23770][R] Exposes repartitionByRange in SparkR ## What changes were proposed in this pull request? This PR proposes to expose `repartitionByRange`. ```R > df <- createDataFrame(iris) ... > getNumPartitions(repartitionByRange(df, 3, col = df$Species)) [1] 3 ``` ## How was this patch tested? Manually tested and the unit tests were added. The diff with `repartition` can be checked as below: ```R > df <- createDataFrame(mtcars) > take(repartition(df, 10, df$wt), 3) mpg cyl disp hp drat wt qsec vs am gear carb 1 14.3 8 360.0 245 3.21 3.570 15.84 0 0 3 4 2 10.4 8 460.0 215 3.00 5.424 17.82 0 0 3 4 3 32.4 4 78.7 66 4.08 2.200 19.47 1 1 4 1 > take(repartitionByRange(df, 10, df$wt), 3) mpg cyl disp hp drat wt qsec vs am gear carb 1 30.4 4 75.7 52 4.93 1.615 18.52 1 1 4 2 2 33.9 4 71.1 65 4.22 1.835 19.90 1 1 4 1 3 27.3 4 79.0 66 4.08 1.935 18.90 1 1 4 1 ``` Author: hyukjinkwon Closes #20902 from HyukjinKwon/r-repartitionByRange. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 65 ++++++++++++++++++++++++++- R/pkg/R/generics.R | 3 ++ R/pkg/tests/fulltests/test_sparkSQL.R | 45 +++++++++++++++++++ 4 files changed, 112 insertions(+), 2 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index c51eb0f39c4b1..190c50ea10482 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -151,6 +151,7 @@ exportMethods("arrange", "registerTempTable", "rename", "repartition", + "repartitionByRange", "rollup", "sample", "sample_frac", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index c4852024c0f49..a1c9495b0795e 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -687,7 +687,7 @@ setMethod("storageLevel", #' @rdname coalesce #' @name coalesce #' @aliases coalesce,SparkDataFrame-method -#' @seealso \link{repartition} +#' @seealso \link{repartition}, \link{repartitionByRange} #' @examples #'\dontrun{ #' sparkR.session() @@ -723,7 +723,7 @@ setMethod("coalesce", #' @rdname repartition #' @name repartition #' @aliases repartition,SparkDataFrame-method -#' @seealso \link{coalesce} +#' @seealso \link{coalesce}, \link{repartitionByRange} #' @examples #'\dontrun{ #' sparkR.session() @@ -759,6 +759,67 @@ setMethod("repartition", dataFrame(sdf) }) + +#' Repartition by range +#' +#' The following options for repartition by range are possible: +#' \itemize{ +#' \item{1.} {Return a new SparkDataFrame range partitioned by +#' the given columns into \code{numPartitions}.} +#' \item{2.} {Return a new SparkDataFrame range partitioned by the given column(s), +#' using \code{spark.sql.shuffle.partitions} as number of partitions.} +#'} +#' +#' @param x a SparkDataFrame. +#' @param numPartitions the number of partitions to use. +#' @param col the column by which the range partitioning will be performed. +#' @param ... additional column(s) to be used in the range partitioning. +#' +#' @family SparkDataFrame functions +#' @rdname repartitionByRange +#' @name repartitionByRange +#' @aliases repartitionByRange,SparkDataFrame-method +#' @seealso \link{repartition}, \link{coalesce} +#' @examples +#'\dontrun{ +#' sparkR.session() +#' path <- "path/to/file.json" +#' df <- read.json(path) +#' newDF <- repartitionByRange(df, col = df$col1, df$col2) +#' newDF <- repartitionByRange(df, 3L, col = df$col1, df$col2) +#'} +#' @note repartitionByRange since 2.4.0 +setMethod("repartitionByRange", + signature(x = "SparkDataFrame"), + function(x, numPartitions = NULL, col = NULL, ...) { + if (!is.null(numPartitions) && !is.null(col)) { + # number of partitions and columns both are specified + if (is.numeric(numPartitions) && class(col) == "Column") { + cols <- list(col, ...) + jcol <- lapply(cols, function(c) { c@jc }) + sdf <- callJMethod(x@sdf, "repartitionByRange", numToInt(numPartitions), jcol) + } else { + stop(paste("numPartitions and col must be numeric and Column; however, got", + class(numPartitions), "and", class(col))) + } + } else if (!is.null(col)) { + # only columns are specified + if (class(col) == "Column") { + cols <- list(col, ...) + jcol <- lapply(cols, function(c) { c@jc }) + sdf <- callJMethod(x@sdf, "repartitionByRange", jcol) + } else { + stop(paste("col must be Column; however, got", class(col))) + } + } else if (!is.null(numPartitions)) { + # only numPartitions is specified + stop("At least one partition-by column must be specified.") + } else { + stop("Please, specify a column(s) or the number of partitions with a column(s)") + } + dataFrame(sdf) + }) + #' toJSON #' #' Converts a SparkDataFrame into a SparkDataFrame of JSON string. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 6fba4b6c761dd..974beff1a3d76 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -531,6 +531,9 @@ setGeneric("rename", function(x, ...) { standardGeneric("rename") }) #' @rdname repartition setGeneric("repartition", function(x, ...) { standardGeneric("repartition") }) +#' @rdname repartitionByRange +setGeneric("repartitionByRange", function(x, ...) { standardGeneric("repartitionByRange") }) + #' @rdname sample setGeneric("sample", function(x, withReplacement = FALSE, fraction, seed) { diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 439191adb23ea..7105469ffc242 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -3104,6 +3104,51 @@ test_that("repartition by columns on DataFrame", { }) }) +test_that("repartitionByRange on a DataFrame", { + # The tasks here launch R workers with shuffles. So, we decrease the number of shuffle + # partitions to reduce the number of the tasks to speed up the test. This is particularly + # slow on Windows because the R workers are unable to be forked. See also SPARK-21693. + conf <- callJMethod(sparkSession, "conf") + shufflepartitionsvalue <- callJMethod(conf, "get", "spark.sql.shuffle.partitions") + callJMethod(conf, "set", "spark.sql.shuffle.partitions", "5") + tryCatch({ + df <- createDataFrame(mtcars) + expect_error(repartitionByRange(df, "haha", df$mpg), + "numPartitions and col must be numeric and Column.*") + expect_error(repartitionByRange(df), + ".*specify a column.*or the number of partitions with a column.*") + expect_error(repartitionByRange(df, col = "haha"), + "col must be Column; however, got.*") + expect_error(repartitionByRange(df, 3), + "At least one partition-by column must be specified.") + + # The order of rows should be different with a normal repartition. + actual <- repartitionByRange(df, 3, df$mpg) + expect_equal(getNumPartitions(actual), 3) + expect_false(identical(collect(actual), collect(repartition(df, 3, df$mpg)))) + + actual <- repartitionByRange(df, col = df$mpg) + expect_false(identical(collect(actual), collect(repartition(df, col = df$mpg)))) + + # They should have same data. + actual <- collect(repartitionByRange(df, 3, df$mpg)) + actual <- actual[order(actual$mpg), ] + expected <- collect(repartition(df, 3, df$mpg)) + expected <- expected[order(expected$mpg), ] + expect_true(all(actual == expected)) + + actual <- collect(repartitionByRange(df, col = df$mpg)) + actual <- actual[order(actual$mpg), ] + expected <- collect(repartition(df, col = df$mpg)) + expected <- expected[order(expected$mpg), ] + expect_true(all(actual == expected)) + }, + finally = { + # Resetting the conf back to default value + callJMethod(conf, "set", "spark.sql.shuffle.partitions", shufflepartitionsvalue) + }) +}) + test_that("coalesce, repartition, numPartitions", { df <- as.DataFrame(cars, numPartitions = 5) expect_equal(getNumPartitions(df), 5) From 491ec114fd3886ebd9fa29a482e3d112fb5a088c Mon Sep 17 00:00:00 2001 From: Sahil Takiar Date: Thu, 29 Mar 2018 10:23:23 -0700 Subject: [PATCH 535/774] [SPARK-23785][LAUNCHER] LauncherBackend doesn't check state of connection before setting state ## What changes were proposed in this pull request? Changed `LauncherBackend` `set` method so that it checks if the connection is open or not before writing to it (uses `isConnected`). ## How was this patch tested? None Author: Sahil Takiar Closes #20893 from sahilTakiar/master. --- .../spark/launcher/LauncherBackend.scala | 6 +++--- .../spark/launcher/LauncherServerSuite.java | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala index aaae33ca4e6f3..1b049b786023a 100644 --- a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala +++ b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala @@ -67,13 +67,13 @@ private[spark] abstract class LauncherBackend { } def setAppId(appId: String): Unit = { - if (connection != null) { + if (connection != null && isConnected) { connection.send(new SetAppId(appId)) } } def setState(state: SparkAppHandle.State): Unit = { - if (connection != null && lastState != state) { + if (connection != null && isConnected && lastState != state) { connection.send(new SetState(state)) lastState = state } @@ -114,10 +114,10 @@ private[spark] abstract class LauncherBackend { override def close(): Unit = { try { + _isConnected = false super.close() } finally { onDisconnected() - _isConnected = false } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index d16337a319be3..5413d3a416545 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -185,6 +185,26 @@ public void testStreamFiltering() throws Exception { } } + @Test + public void testAppHandleDisconnect() throws Exception { + LauncherServer server = LauncherServer.getOrCreateServer(); + ChildProcAppHandle handle = new ChildProcAppHandle(server); + String secret = server.registerHandle(handle); + + TestClient client = null; + try { + Socket s = new Socket(InetAddress.getLoopbackAddress(), server.getPort()); + client = new TestClient(s); + client.send(new Hello(secret, "1.4.0")); + handle.disconnect(); + waitForError(client, secret); + } finally { + handle.kill(); + close(client); + client.clientThread.join(); + } + } + private void close(Closeable c) { if (c != null) { try { From a7755fd8ce2f022118b9827aaac7d5d59f0f297a Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 29 Mar 2018 10:46:28 -0700 Subject: [PATCH 536/774] [SPARK-23639][SQL] Obtain token before init metastore client in SparkSQL CLI ## What changes were proposed in this pull request? In SparkSQLCLI, SessionState generates before SparkContext instantiating. When we use --proxy-user to impersonate, it's unable to initializing a metastore client to talk to the secured metastore for no kerberos ticket. This PR use real user ugi to obtain token for owner before talking to kerberized metastore. ## How was this patch tested? Manually verified with kerberized hive metasotre / hdfs. Author: Kent Yao Closes #20784 from yaooqinn/SPARK-23639. --- .../deploy/security/HiveDelegationTokenProvider.scala | 8 ++++---- .../spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala | 9 +++++++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala index ece5ce79c650d..7249eb85ac7c7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala @@ -36,7 +36,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.KEYTAB import org.apache.spark.util.Utils -private[security] class HiveDelegationTokenProvider +private[spark] class HiveDelegationTokenProvider extends HadoopDelegationTokenProvider with Logging { override def serviceName: String = "hive" @@ -124,9 +124,9 @@ private[security] class HiveDelegationTokenProvider val currentUser = UserGroupInformation.getCurrentUser() val realUser = Option(currentUser.getRealUser()).getOrElse(currentUser) - // For some reason the Scala-generated anonymous class ends up causing an - // UndeclaredThrowableException, even if you annotate the method with @throws. - try { + // For some reason the Scala-generated anonymous class ends up causing an + // UndeclaredThrowableException, even if you annotate the method with @throws. + try { realUser.doAs(new PrivilegedExceptionAction[T]() { override def run(): T = fn }) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 832a15d09599f..084f8200102ba 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -34,11 +34,13 @@ import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.log4j.{Level, Logger} import org.apache.thrift.transport.TSocket import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.security.HiveDelegationTokenProvider import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.hive.HiveUtils @@ -121,6 +123,13 @@ private[hive] object SparkSQLCLIDriver extends Logging { } } + val tokenProvider = new HiveDelegationTokenProvider() + if (tokenProvider.delegationTokensRequired(sparkConf, hadoopConf)) { + val credentials = new Credentials() + tokenProvider.obtainDelegationTokens(hadoopConf, sparkConf, credentials) + UserGroupInformation.getCurrentUser.addCredentials(credentials) + } + SessionState.start(sessionState) // Clean up after we exit From b348901192b231153b58fe5720253168c87963d4 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 29 Mar 2018 21:36:56 -0700 Subject: [PATCH 537/774] [SPARK-23808][SQL] Set default Spark session in test-only spark sessions. ## What changes were proposed in this pull request? Set default Spark session in the TestSparkSession and TestHiveSparkSession constructors. ## How was this patch tested? new unit tests Author: Jose Torres Closes #20926 from jose-torres/test3. --- .../spark/sql/test/TestSQLContext.scala | 2 ++ .../sql/test/TestSparkSessionSuite.scala | 29 +++++++++++++++++++ .../apache/spark/sql/hive/test/TestHive.scala | 4 +++ 3 files changed, 35 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/TestSparkSessionSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 4286e8a6ca2c8..3038b822beb4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -34,6 +34,8 @@ private[spark] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) this(new SparkConf) } + SparkSession.setDefaultSession(this) + @transient override lazy val sessionState: SessionState = { new TestSQLSessionStateBuilder(this, None).build() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSparkSessionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSparkSessionSuite.scala new file mode 100644 index 0000000000000..4019c6888da98 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSparkSessionSuite.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkSession + +class TestSparkSessionSuite extends SparkFunSuite { + test("default session is set in constructor") { + val session = new TestSparkSession() + assert(SparkSession.getDefaultSession.contains(session)) + session.stop() + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index fcf2025d34432..814038d4ef7af 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -159,6 +159,10 @@ private[hive] class TestHiveSparkSession( private val loadTestTables: Boolean) extends SparkSession(sc) with Logging { self => + // TODO(SPARK-23826): TestHiveSparkSession should set default session the same way as + // TestSparkSession, but doing this the same way breaks many tests in the package. We need + // to investigate and find a different strategy. + def this(sc: SparkContext, loadTestTables: Boolean) { this( sc, From df05fb63abe6018ccbe572c34cf65fc3ecbf1166 Mon Sep 17 00:00:00 2001 From: Jongyoul Lee Date: Fri, 30 Mar 2018 14:07:35 +0800 Subject: [PATCH 538/774] [SPARK-23743][SQL] Changed a comparison logic from containing 'slf4j' to starting with 'org.slf4j' ## What changes were proposed in this pull request? isSharedClass returns if some classes can/should be shared or not. It checks if the classes names have some keywords or start with some names. Following the logic, it can occur unintended behaviors when a custom package has `slf4j` inside the package or class name. As I guess, the first intention seems to figure out the class containing `org.slf4j`. It would be better to change the comparison logic to `name.startsWith("org.slf4j")` ## How was this patch tested? This patch should pass all of the current tests and keep all of the current behaviors. In my case, I'm using ProtobufDeserializer to get a table schema from hive tables. Thus some Protobuf packages and names have `slf4j` inside. Without this patch, it cannot be resolved because of ClassCastException from different classloaders. Author: Jongyoul Lee Closes #20860 from jongyoul/SPARK-23743. --- .../apache/spark/sql/hive/client/IsolatedClientLoader.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 12975bc85b971..c2690ec32b9e7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -179,8 +179,9 @@ private[hive] class IsolatedClientLoader( val isHadoopClass = name.startsWith("org.apache.hadoop.") && !name.startsWith("org.apache.hadoop.hive.") - name.contains("slf4j") || - name.contains("log4j") || + name.startsWith("org.slf4j") || + name.startsWith("org.apache.log4j") || // log4j1.x + name.startsWith("org.apache.logging.log4j") || // log4j2 name.startsWith("org.apache.spark.") || (sharesHadoopClasses && isHadoopClass) || name.startsWith("scala.") || From b02e76cbffe9e589b7a4e60f91250ca12a4420b2 Mon Sep 17 00:00:00 2001 From: yucai Date: Fri, 30 Mar 2018 15:07:38 +0800 Subject: [PATCH 539/774] [SPARK-23727][SQL] Support for pushing down filters for DateType in parquet ## What changes were proposed in this pull request? This PR supports for pushing down filters for DateType in parquet ## How was this patch tested? Added UT and tested in local. Author: yucai Closes #20851 from yucai/SPARK-23727. --- .../apache/spark/sql/internal/SQLConf.scala | 9 ++++ .../datasources/parquet/ParquetFilters.scala | 33 ++++++++++++ .../parquet/ParquetFilterSuite.scala | 50 +++++++++++++++++-- 3 files changed, 89 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9cb03b5bb6152..13f31a6b2eb93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -353,6 +353,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_DATE_ENABLED = buildConf("spark.sql.parquet.filterPushdown.date") + .doc("If true, enables Parquet filter push-down optimization for Date. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + .internal() + .booleanConf + .createWithDefault(true) + val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") .doc("Whether to be compatible with the legacy Parquet format adopted by Spark 1.4 and prior " + "versions, when converting Parquet schema to Spark SQL schema and vice versa.") @@ -1329,6 +1336,8 @@ class SQLConf extends Serializable with Logging { def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) + def parquetFilterPushDownDate: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DATE_ENABLED) + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 763841efbd9f3..ccc8306866d68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -17,10 +17,15 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.sql.Date + import org.apache.parquet.filter2.predicate._ import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.io.api.Binary +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.types._ @@ -29,6 +34,10 @@ import org.apache.spark.sql.types._ */ private[parquet] object ParquetFilters { + private def dateToDays(date: Date): SQLDate = { + DateTimeUtils.fromJavaDate(date) + } + private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { case BooleanType => (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) @@ -50,6 +59,10 @@ private[parquet] object ParquetFilters { (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.eq( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -72,6 +85,10 @@ private[parquet] object ParquetFilters { (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.notEq( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -91,6 +108,10 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.lt( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -110,6 +131,10 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.ltEq( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -129,6 +154,10 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.gt( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -148,6 +177,10 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.gtEq( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 33801954ebd51..1d3476e747046 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.nio.charset.StandardCharsets +import java.sql.Date import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.parquet.filter2.predicate.FilterApi._ @@ -76,8 +77,10 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex expected: Seq[Row]): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + withSQLConf( + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_DATE_ENABLED.key -> "true", + SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { val query = df .select(output.map(e => Column(e)): _*) .where(Column(predicate)) @@ -102,7 +105,6 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex maybeFilter.exists(_.getClass === filterClass) } checker(stripSparkFilter(query), expected) - } } } @@ -313,6 +315,48 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("filter pushdown - date") { + implicit class StringToDate(s: String) { + def date: Date = Date.valueOf(s) + } + + val data = Seq("2018-03-18", "2018-03-19", "2018-03-20", "2018-03-21") + + withParquetDataFrame(data.map(i => Tuple1(i.date))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], data.map(i => Row.apply(i.date))) + + checkFilterPredicate('_1 === "2018-03-18".date, classOf[Eq[_]], "2018-03-18".date) + checkFilterPredicate('_1 <=> "2018-03-18".date, classOf[Eq[_]], "2018-03-18".date) + checkFilterPredicate('_1 =!= "2018-03-18".date, classOf[NotEq[_]], + Seq("2018-03-19", "2018-03-20", "2018-03-21").map(i => Row.apply(i.date))) + + checkFilterPredicate('_1 < "2018-03-19".date, classOf[Lt[_]], "2018-03-18".date) + checkFilterPredicate('_1 > "2018-03-20".date, classOf[Gt[_]], "2018-03-21".date) + checkFilterPredicate('_1 <= "2018-03-18".date, classOf[LtEq[_]], "2018-03-18".date) + checkFilterPredicate('_1 >= "2018-03-21".date, classOf[GtEq[_]], "2018-03-21".date) + + checkFilterPredicate( + Literal("2018-03-18".date) === '_1, classOf[Eq[_]], "2018-03-18".date) + checkFilterPredicate( + Literal("2018-03-18".date) <=> '_1, classOf[Eq[_]], "2018-03-18".date) + checkFilterPredicate( + Literal("2018-03-19".date) > '_1, classOf[Lt[_]], "2018-03-18".date) + checkFilterPredicate( + Literal("2018-03-20".date) < '_1, classOf[Gt[_]], "2018-03-21".date) + checkFilterPredicate( + Literal("2018-03-18".date) >= '_1, classOf[LtEq[_]], "2018-03-18".date) + checkFilterPredicate( + Literal("2018-03-21".date) <= '_1, classOf[GtEq[_]], "2018-03-21".date) + + checkFilterPredicate(!('_1 < "2018-03-21".date), classOf[GtEq[_]], "2018-03-21".date) + checkFilterPredicate( + '_1 < "2018-03-19".date || '_1 > "2018-03-20".date, + classOf[Operators.Or], + Seq(Row("2018-03-18".date), Row("2018-03-21".date))) + } + } + test("SPARK-6554: don't push down predicates which reference partition columns") { import testImplicits._ From 5b5a36ed6d2bb0971edfeccddf0f280936d2275f Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 30 Mar 2018 21:54:26 +0800 Subject: [PATCH 540/774] Roll forward "[SPARK-23096][SS] Migrate rate source to V2" ## What changes were proposed in this pull request? Roll forward c68ec4e (#20688). There are two minor test changes required: * An error which used to be TreeNodeException[ArithmeticException] is no longer wrapped and is now just ArithmeticException. * The test framework simply does not set the active Spark session. (Or rather, it doesn't do so early enough - I think it only happens when a query is analyzed.) I've added the required logic to SQLTestUtils. ## How was this patch tested? existing tests Author: Jose Torres Author: jerryshao Closes #20922 from jose-torres/ratefix. --- ...pache.spark.sql.sources.DataSourceRegister | 3 +- .../execution/datasources/DataSource.scala | 6 +- .../streaming/RateSourceProvider.scala | 262 ------------------ .../ContinuousRateStreamSource.scala | 25 +- .../sources/RateStreamMicroBatchReader.scala | 222 +++++++++++++++ .../sources/RateStreamProvider.scala | 125 +++++++++ .../sources/RateStreamSourceV2.scala | 187 ------------- .../streaming/RateSourceV2Suite.scala | 191 ------------- .../RateStreamProviderSuite.scala} | 166 ++++++++++- 9 files changed, 524 insertions(+), 663 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala rename sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/{RateSourceSuite.scala => sources/RateStreamProviderSuite.scala} (50%) diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 1fe9c093af99f..1b37905543b4e 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,6 +5,5 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider -org.apache.spark.sql.execution.streaming.RateSourceProvider +org.apache.spark.sql.execution.streaming.sources.RateStreamProvider org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider -org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 31fa89b4570a6..b84ea769808f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider +import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode @@ -566,6 +566,7 @@ object DataSource extends Logging { val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" val nativeOrc = classOf[OrcFileFormat].getCanonicalName val socket = classOf[TextSocketSourceProvider].getCanonicalName + val rate = classOf[RateStreamProvider].getCanonicalName Map( "org.apache.spark.sql.jdbc" -> jdbc, @@ -587,7 +588,8 @@ object DataSource extends Logging { "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, "org.apache.spark.ml.source.libsvm" -> libsvm, "com.databricks.spark.csv" -> csv, - "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket + "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket, + "org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala deleted file mode 100644 index 649fbbfa184ec..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ /dev/null @@ -1,262 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.io._ -import java.nio.charset.StandardCharsets -import java.util.Optional -import java.util.concurrent.TimeUnit - -import org.apache.commons.io.IOUtils - -import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader -import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} -import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader -import org.apache.spark.sql.types._ -import org.apache.spark.util.{ManualClock, SystemClock} - -/** - * A source that generates increment long values with timestamps. Each generated row has two - * columns: a timestamp column for the generated time and an auto increment long column starting - * with 0L. - * - * This source supports the following options: - * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. - * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed - * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer - * seconds. - * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the - * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may - * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. - */ -class RateSourceProvider extends StreamSourceProvider with DataSourceRegister - with DataSourceV2 with ContinuousReadSupport { - - override def sourceSchema( - sqlContext: SQLContext, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): (String, StructType) = { - if (schema.nonEmpty) { - throw new AnalysisException("The rate source does not support a user-specified schema.") - } - - (shortName(), RateSourceProvider.SCHEMA) - } - - override def createSource( - sqlContext: SQLContext, - metadataPath: String, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): Source = { - val params = CaseInsensitiveMap(parameters) - - val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L) - if (rowsPerSecond <= 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " + - "must be positive") - } - - val rampUpTimeSeconds = - params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L) - if (rampUpTimeSeconds < 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " + - "must not be negative") - } - - val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse( - sqlContext.sparkContext.defaultParallelism) - if (numPartitions <= 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " + - "must be positive") - } - - new RateStreamSource( - sqlContext, - metadataPath, - rowsPerSecond, - rampUpTimeSeconds, - numPartitions, - params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing - ) - } - - override def createContinuousReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = { - new RateStreamContinuousReader(options) - } - - override def shortName(): String = "rate" -} - -object RateSourceProvider { - val SCHEMA = - StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) - - val VERSION = 1 -} - -class RateStreamSource( - sqlContext: SQLContext, - metadataPath: String, - rowsPerSecond: Long, - rampUpTimeSeconds: Long, - numPartitions: Int, - useManualClock: Boolean) extends Source with Logging { - - import RateSourceProvider._ - import RateStreamSource._ - - val clock = if (useManualClock) new ManualClock else new SystemClock - - private val maxSeconds = Long.MaxValue / rowsPerSecond - - if (rampUpTimeSeconds > maxSeconds) { - throw new ArithmeticException( - s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + - s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") - } - - private val startTimeMs = { - val metadataLog = - new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) { - override def serialize(metadata: LongOffset, out: OutputStream): Unit = { - val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) - writer.write("v" + VERSION + "\n") - writer.write(metadata.json) - writer.flush - } - - override def deserialize(in: InputStream): LongOffset = { - val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) - // HDFSMetadataLog guarantees that it never creates a partial file. - assert(content.length != 0) - if (content(0) == 'v') { - val indexOfNewLine = content.indexOf("\n") - if (indexOfNewLine > 0) { - val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) - LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } - } - - metadataLog.get(0).getOrElse { - val offset = LongOffset(clock.getTimeMillis()) - metadataLog.add(0, offset) - logInfo(s"Start time: $offset") - offset - }.offset - } - - /** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */ - @volatile private var lastTimeMs = startTimeMs - - override def schema: StructType = RateSourceProvider.SCHEMA - - override def getOffset: Option[Offset] = { - val now = clock.getTimeMillis() - if (lastTimeMs < now) { - lastTimeMs = now - } - Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs))) - } - - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L) - val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) - assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") - if (endSeconds > maxSeconds) { - throw new ArithmeticException("Integer overflow. Max offset with " + - s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") - } - // Fix "lastTimeMs" for recovery - if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) { - lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs - } - val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) - val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) - logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + - s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") - - if (rangeStart == rangeEnd) { - return sqlContext.internalCreateDataFrame( - sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) - } - - val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) - val relativeMsPerValue = - TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) - - val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v => - val relative = math.round((v - rangeStart) * relativeMsPerValue) - InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v) - } - sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) - } - - override def stop(): Unit = {} - - override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " + - s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]" -} - -object RateStreamSource { - - /** Calculate the end value we will emit at the time `seconds`. */ - def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { - // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 - // Then speedDeltaPerSecond = 2 - // - // seconds = 0 1 2 3 4 5 6 - // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) - // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 - val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) - if (seconds <= rampUpTimeSeconds) { - // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to - // avoid overflow - if (seconds % 2 == 1) { - (seconds + 1) / 2 * speedDeltaPerSecond * seconds - } else { - seconds / 2 * speedDeltaPerSecond * (seconds + 1) - } - } else { - // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds - val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) - rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 20d90069163a6..2f0de2612c150 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -24,8 +24,8 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} +import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} @@ -40,8 +40,8 @@ class RateStreamContinuousReader(options: DataSourceOptions) val creationTime = System.currentTimeMillis() - val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt - val rowsPerSecond = options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong + val numPartitions = options.get(RateStreamProvider.NUM_PARTITIONS).orElse("5").toInt + val rowsPerSecond = options.get(RateStreamProvider.ROWS_PER_SECOND).orElse("6").toLong val perPartitionRate = rowsPerSecond.toDouble / numPartitions.toDouble override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { @@ -57,12 +57,12 @@ class RateStreamContinuousReader(options: DataSourceOptions) RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def readSchema(): StructType = RateSourceProvider.SCHEMA + override def readSchema(): StructType = RateStreamProvider.SCHEMA private var offset: Offset = _ override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { - this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime)) + this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime)) } override def getStartOffset(): Offset = offset @@ -98,6 +98,19 @@ class RateStreamContinuousReader(options: DataSourceOptions) override def commit(end: Offset): Unit = {} override def stop(): Unit = {} + private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { + RateStreamOffset( + Range(0, numPartitions).map { i => + // Note that the starting offset is exclusive, so we have to decrement the starting value + // by the increment that will later be applied. The first row output in each + // partition will have a value equal to the partition index. + (i, + ValueRunTimeMsPair( + (i - numPartitions).toLong, + creationTimeMs)) + }.toMap) + } + } case class RateStreamContinuousDataReaderFactory( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala new file mode 100644 index 0000000000000..6cf8520fc544f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io._ +import java.nio.charset.StandardCharsets +import java.util.Optional +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.IOUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{ManualClock, SystemClock} + +class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) + extends MicroBatchReader with Logging { + import RateStreamProvider._ + + private[sources] val clock = { + // The option to use a manual clock is provided only for unit testing purposes. + if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock + } + + private val rowsPerSecond = + options.get(ROWS_PER_SECOND).orElse("1").toLong + + private val rampUpTimeSeconds = + Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String])) + .map(JavaUtils.timeStringAsSec(_)) + .getOrElse(0L) + + private val maxSeconds = Long.MaxValue / rowsPerSecond + + if (rampUpTimeSeconds > maxSeconds) { + throw new ArithmeticException( + s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + + s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") + } + + private[sources] val creationTimeMs = { + val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) + require(session.isDefined) + + val metadataLog = + new HDFSMetadataLog[LongOffset](session.get, checkpointLocation) { + override def serialize(metadata: LongOffset, out: OutputStream): Unit = { + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): LongOffset = { + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + parseVersion(content.substring(0, indexOfNewLine), VERSION) + LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } + } + + metadataLog.get(0).getOrElse { + val offset = LongOffset(clock.getTimeMillis()) + metadataLog.add(0, offset) + logInfo(s"Start time: $offset") + offset + }.offset + } + + @volatile private var lastTimeMs: Long = creationTimeMs + + private var start: LongOffset = _ + private var end: LongOffset = _ + + override def readSchema(): StructType = SCHEMA + + override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { + this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset] + this.end = end.orElse { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) + }.asInstanceOf[LongOffset] + } + + override def getStartOffset(): Offset = { + if (start == null) throw new IllegalStateException("start offset not set") + start + } + override def getEndOffset(): Offset = { + if (end == null) throw new IllegalStateException("end offset not set") + end + } + + override def deserializeOffset(json: String): Offset = { + LongOffset(json.toLong) + } + + override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) + val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) + assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") + if (endSeconds > maxSeconds) { + throw new ArithmeticException("Integer overflow. Max offset with " + + s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") + } + // Fix "lastTimeMs" for recovery + if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs) { + lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs + } + val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) + val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) + logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + + s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") + + if (rangeStart == rangeEnd) { + return List.empty.asJava + } + + val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) + val relativeMsPerValue = + TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) + val numPartitions = { + val activeSession = SparkSession.getActiveSession + require(activeSession.isDefined) + Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String])) + .map(_.toInt) + .getOrElse(activeSession.get.sparkContext.defaultParallelism) + } + + (0 until numPartitions).map { p => + new RateStreamMicroBatchDataReaderFactory( + p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) + : DataReaderFactory[Row] + }.toList.asJava + } + + override def commit(end: Offset): Unit = {} + + override def stop(): Unit = {} + + override def toString: String = s"MicroBatchRateSource[rowsPerSecond=$rowsPerSecond, " + + s"rampUpTimeSeconds=$rampUpTimeSeconds, " + + s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" +} + +class RateStreamMicroBatchDataReaderFactory( + partitionId: Int, + numPartitions: Int, + rangeStart: Long, + rangeEnd: Long, + localStartTimeMs: Long, + relativeMsPerValue: Double) extends DataReaderFactory[Row] { + + override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader( + partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) +} + +class RateStreamMicroBatchDataReader( + partitionId: Int, + numPartitions: Int, + rangeStart: Long, + rangeEnd: Long, + localStartTimeMs: Long, + relativeMsPerValue: Double) extends DataReader[Row] { + private var count = 0 + + override def next(): Boolean = { + rangeStart + partitionId + numPartitions * count < rangeEnd + } + + override def get(): Row = { + val currValue = rangeStart + partitionId + numPartitions * count + count += 1 + val relative = math.round((currValue - rangeStart) * relativeMsPerValue) + Row( + DateTimeUtils.toJavaTimestamp( + DateTimeUtils.fromMillis(relative + localStartTimeMs)), + currValue + ) + } + + override def close(): Unit = {} +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala new file mode 100644 index 0000000000000..6bdd492f0cb35 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.util.Optional + +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader} +import org.apache.spark.sql.types._ + +/** + * A source that generates increment long values with timestamps. Each generated row has two + * columns: a timestamp column for the generated time and an auto increment long column starting + * with 0L. + * + * This source supports the following options: + * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. + * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed + * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer + * seconds. + * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the + * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may + * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. + */ +class RateStreamProvider extends DataSourceV2 + with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister { + import RateStreamProvider._ + + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): MicroBatchReader = { + if (options.get(ROWS_PER_SECOND).isPresent) { + val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong + if (rowsPerSecond <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive") + } + } + + if (options.get(RAMP_UP_TIME).isPresent) { + val rampUpTimeSeconds = + JavaUtils.timeStringAsSec(options.get(RAMP_UP_TIME).get()) + if (rampUpTimeSeconds < 0) { + throw new IllegalArgumentException( + s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative") + } + } + + if (options.get(NUM_PARTITIONS).isPresent) { + val numPartitions = options.get(NUM_PARTITIONS).get().toInt + if (numPartitions <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") + } + } + + if (schema.isPresent) { + throw new AnalysisException("The rate source does not support a user-specified schema.") + } + + new RateStreamMicroBatchReader(options, checkpointLocation) + } + + override def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options) + + override def shortName(): String = "rate" +} + +object RateStreamProvider { + val SCHEMA = + StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) + + val VERSION = 1 + + val NUM_PARTITIONS = "numPartitions" + val ROWS_PER_SECOND = "rowsPerSecond" + val RAMP_UP_TIME = "rampUpTime" + + /** Calculate the end value we will emit at the time `seconds`. */ + def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { + // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 + // Then speedDeltaPerSecond = 2 + // + // seconds = 0 1 2 3 4 5 6 + // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) + // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 + val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) + if (seconds <= rampUpTimeSeconds) { + // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to + // avoid overflow + if (seconds % 2 == 1) { + (seconds + 1) / 2 * speedDeltaPerSecond * seconds + } else { + seconds / 2 * speedDeltaPerSecond * (seconds + 1) + } + } else { + // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds + val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) + rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala deleted file mode 100644 index 4e2459bb05bd6..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.util.Optional - -import scala.collection.JavaConverters._ -import scala.collection.mutable - -import org.json4s.DefaultFormats -import org.json4s.jackson.Serialization - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} -import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} -import org.apache.spark.util.{ManualClock, SystemClock} - -/** - * This is a temporary register as we build out v2 migration. Microbatch read support should - * be implemented in the same register as v1. - */ -class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with DataSourceRegister { - override def createMicroBatchReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = { - new RateStreamMicroBatchReader(options) - } - - override def shortName(): String = "ratev2" -} - -class RateStreamMicroBatchReader(options: DataSourceOptions) - extends MicroBatchReader { - implicit val defaultFormats: DefaultFormats = DefaultFormats - - val clock = { - // The option to use a manual clock is provided only for unit testing purposes. - if (options.get("useManualClock").orElse("false").toBoolean) new ManualClock - else new SystemClock - } - - private val numPartitions = - options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt - private val rowsPerSecond = - options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong - - // The interval (in milliseconds) between rows in each partition. - // e.g. if there are 4 global rows per second, and 2 partitions, each partition - // should output rows every (1000 * 2 / 4) = 500 ms. - private val msPerPartitionBetweenRows = (1000 * numPartitions) / rowsPerSecond - - override def readSchema(): StructType = { - StructType( - StructField("timestamp", TimestampType, false) :: - StructField("value", LongType, false) :: Nil) - } - - val creationTimeMs = clock.getTimeMillis() - - private var start: RateStreamOffset = _ - private var end: RateStreamOffset = _ - - override def setOffsetRange( - start: Optional[Offset], - end: Optional[Offset]): Unit = { - this.start = start.orElse( - RateStreamSourceV2.createInitialOffset(numPartitions, creationTimeMs)) - .asInstanceOf[RateStreamOffset] - - this.end = end.orElse { - val currentTime = clock.getTimeMillis() - RateStreamOffset( - this.start.partitionToValueAndRunTimeMs.map { - case startOffset @ (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => - // Calculate the number of rows we should advance in this partition (based on the - // current time), and output a corresponding offset. - val readInterval = currentTime - currentReadTime - val numNewRows = readInterval / msPerPartitionBetweenRows - if (numNewRows <= 0) { - startOffset - } else { - (part, ValueRunTimeMsPair( - currentVal + (numNewRows * numPartitions), - currentReadTime + (numNewRows * msPerPartitionBetweenRows))) - } - } - ) - }.asInstanceOf[RateStreamOffset] - } - - override def getStartOffset(): Offset = { - if (start == null) throw new IllegalStateException("start offset not set") - start - } - override def getEndOffset(): Offset = { - if (end == null) throw new IllegalStateException("end offset not set") - end - } - - override def deserializeOffset(json: String): Offset = { - RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) - } - - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { - val startMap = start.partitionToValueAndRunTimeMs - val endMap = end.partitionToValueAndRunTimeMs - endMap.keys.toSeq.map { part => - val ValueRunTimeMsPair(endVal, _) = endMap(part) - val ValueRunTimeMsPair(startVal, startTimeMs) = startMap(part) - - val packedRows = mutable.ListBuffer[(Long, Long)]() - var outVal = startVal + numPartitions - var outTimeMs = startTimeMs - while (outVal <= endVal) { - packedRows.append((outTimeMs, outVal)) - outVal += numPartitions - outTimeMs += msPerPartitionBetweenRows - } - - RateStreamBatchTask(packedRows).asInstanceOf[DataReaderFactory[Row]] - }.toList.asJava - } - - override def commit(end: Offset): Unit = {} - override def stop(): Unit = {} -} - -case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactory[Row] { - override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals) -} - -class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] { - private var currentIndex = -1 - - override def next(): Boolean = { - // Return true as long as the new index is in the seq. - currentIndex += 1 - currentIndex < vals.size - } - - override def get(): Row = { - Row( - DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(vals(currentIndex)._1)), - vals(currentIndex)._2) - } - - override def close(): Unit = {} -} - -object RateStreamSourceV2 { - val NUM_PARTITIONS = "numPartitions" - val ROWS_PER_SECOND = "rowsPerSecond" - - private[sql] def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { - RateStreamOffset( - Range(0, numPartitions).map { i => - // Note that the starting offset is exclusive, so we have to decrement the starting value - // by the increment that will later be applied. The first row output in each - // partition will have a value equal to the partition index. - (i, - ValueRunTimeMsPair( - (i - numPartitions).toLong, - creationTimeMs)) - }.toMap) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala deleted file mode 100644 index 983ba1668f58f..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ /dev/null @@ -1,191 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.util.Optional -import java.util.concurrent.TimeUnit - -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.Row -import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.util.ManualClock - -class RateSourceV2Suite extends StreamTest { - import testImplicits._ - - case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { - assert(query.nonEmpty) - val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source - }.head - rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) - rateSource.setOffsetRange(Optional.empty(), Optional.empty()) - (rateSource, rateSource.getEndOffset()) - } - } - - test("microbatch in registry") { - DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamMicroBatchReader]) - case _ => - throw new IllegalStateException("Could not find v2 read support for rate") - } - } - - test("basic microbatch execution") { - val input = spark.readStream - .format("rateV2") - .option("numPartitions", "1") - .option("rowsPerSecond", "10") - .option("useManualClock", "true") - .load() - testStream(input, useV2Sink = true)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), - StopStream, - StartStream(), - // Advance 2 seconds because creating a new RateSource will also create a new ManualClock - AdvanceRateManualClock(seconds = 2), - CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) - ) - } - - test("microbatch - numPartitions propagated") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 11) - } - - test("microbatch - set offset") { - val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty()) - val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) - val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000)))) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - assert(reader.getStartOffset() == startOffset) - assert(reader.getEndOffset() == endOffset) - } - - test("microbatch - infer offsets") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) - reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - reader.getStartOffset() match { - case r: RateStreamOffset => - assert(r.partitionToValueAndRunTimeMs(0).runTimeMs == reader.creationTimeMs) - case _ => throw new IllegalStateException("unexpected offset type") - } - reader.getEndOffset() match { - case r: RateStreamOffset => - // End offset may be a bit beyond 100 ms/9 rows after creation if the wait lasted - // longer than 100ms. It should never be early. - assert(r.partitionToValueAndRunTimeMs(0).value >= 9) - assert(r.partitionToValueAndRunTimeMs(0).runTimeMs >= reader.creationTimeMs + 100) - - case _ => throw new IllegalStateException("unexpected offset type") - } - } - - test("microbatch - predetermined batch size") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) - val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) - val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 1) - assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20) - } - - test("microbatch - data read") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) - val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) - val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { - case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => - (part, ValueRunTimeMsPair(currentVal + 33, currentReadTime + 1000)) - }.toMap) - - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 11) - - val readData = tasks.asScala - .map(_.createDataReader()) - .flatMap { reader => - val buf = scala.collection.mutable.ListBuffer[Row]() - while (reader.next()) buf.append(reader.get()) - buf - } - - assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) - } - - test("continuous in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: ContinuousReadSupport => - val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamContinuousReader]) - case _ => - throw new IllegalStateException("Could not find v2 read support for rate") - } - } - - test("continuous data") { - val reader = new RateStreamContinuousReader( - new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 2) - - val data = scala.collection.mutable.ListBuffer[Row]() - tasks.asScala.foreach { - case t: RateStreamContinuousDataReaderFactory => - val startTimeMs = reader.getStartOffset() - .asInstanceOf[RateStreamOffset] - .partitionToValueAndRunTimeMs(t.partitionIndex) - .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] - for (rowIndex <- 0 to 9) { - r.next() - data.append(r.get()) - assert(r.getOffset() == - RateStreamPartitionOffset( - t.partitionIndex, - t.partitionIndex + rowIndex * 2, - startTimeMs + (rowIndex + 1) * 100)) - } - assert(System.currentTimeMillis() >= startTimeMs + 1000) - - case _ => throw new IllegalStateException("Unexpected task type") - } - - assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala similarity index 50% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index 03d0f63fa4d7f..ff14ec38e66a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -15,13 +15,24 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming +package org.apache.spark.sql.execution.streaming.sources +import java.nio.file.Files +import java.util.Optional import java.util.concurrent.TimeUnit -import org.apache.spark.sql.AnalysisException +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.Offset +import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock class RateSourceSuite extends StreamTest { @@ -29,18 +40,40 @@ class RateSourceSuite extends StreamTest { import testImplicits._ case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (Source, Offset) = { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { assert(query.nonEmpty) val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] => - source.asInstanceOf[RateStreamSource] + case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source }.head + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) - (rateSource, rateSource.getOffset.get) + val offset = LongOffset(TimeUnit.MILLISECONDS.toSeconds( + rateSource.clock.getTimeMillis() - rateSource.creationTimeMs)) + (rateSource, offset) + } + } + + test("microbatch in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + val reader = ds.createMicroBatchReader(Optional.empty(), "dummy", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamMicroBatchReader]) + case _ => + throw new IllegalStateException("Could not find read support for rate") + } + } + + test("compatible with old path in registry") { + DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", + spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + assert(ds.isInstanceOf[RateStreamProvider]) + case _ => + throw new IllegalStateException("Could not find read support for rate") } } - test("basic") { + test("microbatch - basic") { val input = spark.readStream .format("rate") .option("rowsPerSecond", "10") @@ -57,7 +90,7 @@ class RateSourceSuite extends StreamTest { ) } - test("uniform distribution of event timestamps") { + test("microbatch - uniform distribution of event timestamps") { val input = spark.readStream .format("rate") .option("rowsPerSecond", "1500") @@ -74,8 +107,74 @@ class RateSourceSuite extends StreamTest { ) } + test("microbatch - set offset") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + assert(reader.getStartOffset() == startOffset) + assert(reader.getEndOffset() == endOffset) + } + + test("microbatch - infer offsets") { + val tempFolder = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions( + Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), + tempFolder) + reader.clock.asInstanceOf[ManualClock].advance(100000) + reader.setOffsetRange(Optional.empty(), Optional.empty()) + reader.getStartOffset() match { + case r: LongOffset => assert(r.offset === 0L) + case _ => throw new IllegalStateException("unexpected offset type") + } + reader.getEndOffset() match { + case r: LongOffset => assert(r.offset >= 100) + case _ => throw new IllegalStateException("unexpected offset type") + } + } + + test("microbatch - predetermined batch size") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 1) + val dataReader = tasks.get(0).createDataReader() + val data = ArrayBuffer[Row]() + while (dataReader.next()) { + data.append(dataReader.get()) + } + assert(data.size === 20) + } + + test("microbatch - data read") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 11) + + val readData = tasks.asScala + .map(_.createDataReader()) + .flatMap { reader => + val buf = scala.collection.mutable.ListBuffer[Row]() + while (reader.next()) buf.append(reader.get()) + buf + } + + assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) + } + test("valueAtSecond") { - import RateStreamSource._ + import RateStreamProvider._ assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) @@ -161,7 +260,7 @@ class RateSourceSuite extends StreamTest { option: String, value: String, expectedMessages: Seq[String]): Unit = { - val e = intercept[StreamingQueryException] { + val e = intercept[IllegalArgumentException] { spark.readStream .format("rate") .option(option, value) @@ -171,9 +270,8 @@ class RateSourceSuite extends StreamTest { .start() .awaitTermination() } - assert(e.getCause.isInstanceOf[IllegalArgumentException]) for (msg <- expectedMessages) { - assert(e.getCause.getMessage.contains(msg)) + assert(e.getMessage.contains(msg)) } } @@ -191,4 +289,46 @@ class RateSourceSuite extends StreamTest { assert(exception.getMessage.contains( "rate source does not support a user-specified schema")) } + + test("continuous in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: ContinuousReadSupport => + val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamContinuousReader]) + case _ => + throw new IllegalStateException("Could not find read support for continuous rate") + } + } + + test("continuous data") { + val reader = new RateStreamContinuousReader( + new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) + reader.setStartOffset(Optional.empty()) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 2) + + val data = scala.collection.mutable.ListBuffer[Row]() + tasks.asScala.foreach { + case t: RateStreamContinuousDataReaderFactory => + val startTimeMs = reader.getStartOffset() + .asInstanceOf[RateStreamOffset] + .partitionToValueAndRunTimeMs(t.partitionIndex) + .runTimeMs + val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] + for (rowIndex <- 0 to 9) { + r.next() + data.append(r.get()) + assert(r.getOffset() == + RateStreamPartitionOffset( + t.partitionIndex, + t.partitionIndex + rowIndex * 2, + startTimeMs + (rowIndex + 1) * 100)) + } + assert(System.currentTimeMillis() >= startTimeMs + 1000) + + case _ => throw new IllegalStateException("Unexpected task type") + } + + assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) + } } From bc8d0931170cfa20a4fb64b3b11a2027ddb0d6e9 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 30 Mar 2018 23:21:07 +0800 Subject: [PATCH 541/774] [SPARK-23500][SQL][FOLLOWUP] Fix complex type simplification rules to apply to entire plan ## What changes were proposed in this pull request? This PR is to improve the test coverage of the original PR https://github.com/apache/spark/pull/20687 ## How was this patch tested? N/A Author: gatorsmile Closes #20911 from gatorsmile/addTests. --- .../optimizer/complexTypesSuite.scala | 176 ++++++++++++------ .../apache/spark/sql/ComplexTypesSuite.scala | 109 +++++++++++ 2 files changed, 233 insertions(+), 52 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index e44a6692ad8e2..21ed987627b3b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -47,10 +47,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { SimplifyExtractValueOps) :: Nil } - val idAtt = ('id).long.notNull - val nullableIdAtt = ('nullable_id).long + private val idAtt = ('id).long.notNull + private val nullableIdAtt = ('nullable_id).long - lazy val relation = LocalRelation(idAtt, nullableIdAtt) + private val relation = LocalRelation(idAtt, nullableIdAtt) + private val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.double, 'e.int) + + private def checkRule(originalQuery: LogicalPlan, correctAnswer: LogicalPlan) = { + val optimized = Optimizer.execute(originalQuery.analyze) + assert(optimized.resolved, "optimized plans must be still resolvable") + comparePlans(optimized, correctAnswer.analyze) + } test("explicit get from namedStruct") { val query = relation @@ -58,31 +65,28 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { GetStructField( CreateNamedStruct(Seq("att", 'id )), 0, - None) as "outerAtt").analyze - val expected = relation.select('id as "outerAtt").analyze + None) as "outerAtt") + val expected = relation.select('id as "outerAtt") - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("explicit get from named_struct- expression maintains original deduced alias") { val query = relation .select(GetStructField(CreateNamedStruct(Seq("att", 'id)), 0, None)) - .analyze val expected = relation .select('id as "named_struct(att, id).att") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("collapsed getStructField ontop of namedStruct") { val query = relation .select(CreateNamedStruct(Seq("att", 'id)) as "struct1") .select(GetStructField('struct1, 0, None) as "struct1Att") - .analyze - val expected = relation.select('id as "struct1Att").analyze - comparePlans(Optimizer execute query, expected) + val expected = relation.select('id as "struct1Att") + checkRule(query, expected) } test("collapse multiple CreateNamedStruct/GetStructField pairs") { @@ -94,16 +98,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select( GetStructField('struct1, 0, None) as "struct1Att1", GetStructField('struct1, 1, None) as "struct1Att2") - .analyze val expected = relation. select( 'id as "struct1Att1", ('id * 'id) as "struct1Att2") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("collapsed2 - deduced names") { @@ -115,16 +117,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select( GetStructField('struct1, 0, None), GetStructField('struct1, 1, None)) - .analyze val expected = relation. select( 'id as "struct1.att1", ('id * 'id) as "struct1.att2") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplified array ops") { @@ -151,7 +151,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { 1, false), 1) as "a4") - .analyze val expected = relation .select( @@ -161,8 +160,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { "att2", (('id + 1L) * ('id + 1L)))) as "a2", ('id + 1L) as "a3", ('id + 1L) as "a4") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("SPARK-22570: CreateArray should not create a lot of global variables") { @@ -188,7 +186,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { GetStructField(GetMapValue('m, "r1"), 0, None) as "a2", GetMapValue('m, "r32") as "a3", GetStructField(GetMapValue('m, "r32"), 0, None) as "a4") - .analyze val expected = relation.select( @@ -201,8 +198,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ) ) as "a3", Literal.create(null, LongType) as "a4") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplify map ops, constant lookup, dynamic keys") { @@ -216,7 +212,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 13L) as "a") - .analyze val expected = relation .select( @@ -225,8 +220,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { (EqualTo(13L, ('id + 1L)), ('id + 2L)), (EqualTo(13L, ('id + 2L)), ('id + 3L)), (Literal(true), 'id))) as "a") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplify map ops, dynamic lookup, dynamic keys, lookup is equivalent to one of the keys") { @@ -240,7 +234,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), ('id + 3L)) as "a") - .analyze val expected = relation .select( CaseWhen(Seq( @@ -248,8 +241,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { (EqualTo('id + 3L, ('id + 1L)), ('id + 2L)), (EqualTo('id + 3L, ('id + 2L)), ('id + 3L)), (Literal(true), ('id + 4L)))) as "a") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplify map ops, no positive match") { @@ -263,7 +255,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 'id + 30L) as "a") - .analyze val expected = relation.select( CaseWhen(Seq( (EqualTo('id + 30L, 'id), ('id + 1L)), @@ -271,8 +262,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { (EqualTo('id + 30L, ('id + 2L)), ('id + 3L)), (EqualTo('id + 30L, ('id + 3L)), ('id + 4L)), (EqualTo('id + 30L, ('id + 4L)), ('id + 5L)))) as "a") - .analyze - comparePlans(Optimizer execute rel, expected) + checkRule(rel, expected) } test("simplify map ops, constant lookup, mixed keys, eliminated constants") { @@ -287,7 +277,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 13L) as "a") - .analyze val expected = relation .select( @@ -297,9 +286,8 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 2L), ('id + 3L), ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))) as "a") - .analyze - comparePlans(Optimizer execute rel, expected) + checkRule(rel, expected) } test("simplify map ops, potential dynamic match with null value + an absolute constant match") { @@ -314,7 +302,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 2L ) as "a") - .analyze val expected = relation .select( @@ -327,18 +314,69 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { // but it cannot override a potential match with ('id + 2L), // which is exactly what [[Coalesce]] would do in this case. (Literal.TrueLiteral, 'id))) as "a") - .analyze - comparePlans(Optimizer execute rel, expected) + checkRule(rel, expected) + } + + test("SPARK-23500: Simplify array ops that are not at the top node") { + val query = LocalRelation('id.long) + .select( + CreateArray(Seq( + CreateNamedStruct(Seq( + "att1", 'id, + "att2", 'id * 'id)), + CreateNamedStruct(Seq( + "att1", 'id + 1, + "att2", ('id + 1) * ('id + 1)) + )) + ) as "arr") + .select( + GetStructField(GetArrayItem('arr, 1), 0, None) as "a1", + GetArrayItem( + GetArrayStructFields('arr, + StructField("att1", LongType, nullable = false), + ordinal = 0, + numFields = 1, + containsNull = false), + ordinal = 1) as "a2") + .orderBy('id.asc) + + val expected = LocalRelation('id.long) + .select( + ('id + 1L) as "a1", + ('id + 1L) as "a2") + .orderBy('id.asc) + checkRule(query, expected) + } + + test("SPARK-23500: Simplify map ops that are not top nodes") { + val query = + LocalRelation('id.long) + .select( + CreateMap(Seq( + "r1", 'id, + "r2", 'id + 1L)) as "m") + .select( + GetMapValue('m, "r1") as "a1", + GetMapValue('m, "r32") as "a2") + .orderBy('id.asc) + .select('a1, 'a2) + + val expected = + LocalRelation('id.long).select( + 'id as "a1", + Literal.create(null, LongType) as "a2") + .orderBy('id.asc) + checkRule(query, expected) } test("SPARK-23500: Simplify complex ops that aren't at the plan root") { val structRel = relation .select(GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None) as "foo") - .groupBy($"foo")("1").analyze + .groupBy($"foo")("1") val structExpected = relation .select('nullable_id as "foo") - .groupBy($"foo")("1").analyze - comparePlans(Optimizer execute structRel, structExpected) + .groupBy($"foo")("1") + checkRule(structRel, structExpected) // These tests must use nullable attributes from the base relation for the following reason: // in the 'original' plans below, the Aggregate node produced by groupBy() has a @@ -351,17 +389,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { // SPARK-23634. val arrayRel = relation .select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1") - .groupBy($"a1")("1").analyze - val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1").analyze - comparePlans(Optimizer execute arrayRel, arrayExpected) + .groupBy($"a1")("1") + val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1") + checkRule(arrayRel, arrayExpected) val mapRel = relation .select(GetMapValue(CreateMap(Seq("id", 'nullable_id)), "id") as "m1") - .groupBy($"m1")("1").analyze + .groupBy($"m1")("1") val mapExpected = relation .select('nullable_id as "m1") - .groupBy($"m1")("1").analyze - comparePlans(Optimizer execute mapRel, mapExpected) + .groupBy($"m1")("1") + checkRule(mapRel, mapExpected) } test("SPARK-23500: Ensure that aggregation expressions are not simplified") { @@ -369,11 +407,45 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { // grouping exprs so aren't tested here. val structAggRel = relation.groupBy( CreateNamedStruct(Seq("att1", 'nullable_id)))( - GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)).analyze - comparePlans(Optimizer execute structAggRel, structAggRel) + GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)) + checkRule(structAggRel, structAggRel) val arrayAggRel = relation.groupBy( - CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)).analyze - comparePlans(Optimizer execute arrayAggRel, arrayAggRel) + CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)) + checkRule(arrayAggRel, arrayAggRel) + + // This could be done if we had a more complex rule that checks that + // the CreateMap does not come from key. + val originalQuery = relation + .groupBy('id)( + GetMapValue(CreateMap(Seq('id, 'id + 1L)), 0L) as "a" + ) + checkRule(originalQuery, originalQuery) + } + + test("SPARK-23500: namedStruct and getField in the same Project #1") { + val originalQuery = + testRelation + .select( + namedStruct("col1", 'b, "col2", 'c).as("s1"), 'a, 'b) + .select('s1 getField "col2" as 's1Col2, + namedStruct("col1", 'a, "col2", 'b).as("s2")) + .select('s1Col2, 's2 getField "col2" as 's2Col2) + val correctAnswer = + testRelation + .select('c as 's1Col2, 'b as 's2Col2) + checkRule(originalQuery, correctAnswer) + } + + test("SPARK-23500: namedStruct and getField in the same Project #2") { + val originalQuery = + testRelation + .select( + namedStruct("col1", 'b, "col2", 'c) getField "col2" as 'sCol2, + namedStruct("col1", 'a, "col2", 'c) getField "col1" as 'sCol1) + val correctAnswer = + testRelation + .select('c as 'sCol2, 'a as 'sCol1) + checkRule(originalQuery, correctAnswer) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala new file mode 100644 index 0000000000000..b74fe2f90df23 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.test.SharedSQLContext + +class ComplexTypesSuite extends QueryTest with SharedSQLContext { + + override def beforeAll() { + super.beforeAll() + spark.range(10).selectExpr( + "id + 1 as i1", "id + 2 as i2", "id + 3 as i3", "id + 4 as i4", "id + 5 as i5") + .write.saveAsTable("tab") + } + + override def afterAll() { + try { + spark.sql("DROP TABLE IF EXISTS tab") + } finally { + super.afterAll() + } + } + + def checkNamedStruct(plan: LogicalPlan, expectedCount: Int): Unit = { + var count = 0 + plan.foreach { operator => + operator.transformExpressions { + case c: CreateNamedStruct => + count += 1 + c + } + } + + if (expectedCount != count) { + fail(s"expect $expectedCount CreateNamedStruct but got $count.") + } + } + + test("simple case") { + val df = spark.table("tab").selectExpr( + "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4) as col2") + .filter("col2.c > 11").selectExpr("col1.a") + checkAnswer(df, Row(9) :: Row(10) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + } + + test("named_struct is used in the top Project") { + val df = spark.table("tab").selectExpr( + "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4)") + .selectExpr("col1.a", "col1") + .filter("col1.a > 8") + checkAnswer(df, Row(9, Row(9, 10)) :: Row(10, Row(10, 11)) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 1) + + val df1 = spark.table("tab").selectExpr( + "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4)") + .sort("col1") + .selectExpr("col1.a") + .filter("col1.a > 8") + checkAnswer(df1, Row(9) :: Row(10) :: Nil) + checkNamedStruct(df1.queryExecution.optimizedPlan, expectedCount = 1) + } + + test("expression in named_struct") { + val df = spark.table("tab") + .selectExpr("i5", "struct(i1 as exp, i2, i3) as cola") + .selectExpr("cola.exp", "cola.i3").filter("cola.i3 > 10") + checkAnswer(df, Row(9, 11) :: Row(10, 12) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + + val df1 = spark.table("tab") + .selectExpr("i5", "struct(i1 + 1 as exp, i2, i3) as cola") + .selectExpr("cola.i3").filter("cola.exp > 10") + checkAnswer(df1, Row(12) :: Nil) + checkNamedStruct(df1.queryExecution.optimizedPlan, expectedCount = 0) + } + + test("nested case") { + val df = spark.table("tab") + .selectExpr("struct(struct(i2, i3) as exp, i4) as cola") + .selectExpr("cola.exp.i2", "cola.i4").filter("cola.exp.i2 > 10") + checkAnswer(df, Row(11, 13) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + + val df1 = spark.table("tab") + .selectExpr("struct(i2, i3) as exp", "i4") + .selectExpr("struct(exp, i4) as cola") + .selectExpr("cola.exp.i2", "cola.i4").filter("cola.i4 > 11") + checkAnswer(df1, Row(10, 12) :: Row(11, 13) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + } +} From ae9172017c361e5c1039bc2ca94048117021974a Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 30 Mar 2018 14:09:14 -0700 Subject: [PATCH 542/774] [SPARK-23640][CORE] Fix hadoop config may override spark config ## What changes were proposed in this pull request? It may be get `spark.shuffle.service.port` from https://github.com/apache/spark/blob/9745ec3a61c99be59ef6a9d5eebd445e8af65b7a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala#L459 Therefore, the client configuration `spark.shuffle.service.port` does not working unless the configuration is `spark.hadoop.spark.shuffle.service.port`. - This configuration is not working: ``` bin/spark-sql --master yarn --conf spark.shuffle.service.port=7338 ``` - This configuration works: ``` bin/spark-sql --master yarn --conf spark.hadoop.spark.shuffle.service.port=7338 ``` This PR fix this issue. ## How was this patch tested? It's difficult to carry out unit testing. But I've tested it manually. Author: Yuming Wang Closes #20785 from wangyum/SPARK-23640. --- .../scala/org/apache/spark/util/Utils.scala | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 5caedeb526469..d2be93226e2a2 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2302,16 +2302,20 @@ private[spark] object Utils extends Logging { } /** - * Return the value of a config either through the SparkConf or the Hadoop configuration - * if this is Yarn mode. In the latter case, this defaults to the value set through SparkConf - * if the key is not set in the Hadoop configuration. + * Return the value of a config either through the SparkConf or the Hadoop configuration. + * We Check whether the key is set in the SparkConf before look at any Hadoop configuration. + * If the key is set in SparkConf, no matter whether it is running on YARN or not, + * gets the value from SparkConf. + * Only when the key is not set in SparkConf and running on YARN, + * gets the value from Hadoop configuration. */ def getSparkOrYarnConfig(conf: SparkConf, key: String, default: String): String = { - val sparkValue = conf.get(key, default) - if (conf.get(SparkLauncher.SPARK_MASTER, null) == "yarn") { - new YarnConfiguration(SparkHadoopUtil.get.newConfiguration(conf)).get(key, sparkValue) + if (conf.contains(key)) { + conf.get(key, default) + } else if (conf.get(SparkLauncher.SPARK_MASTER, null) == "yarn") { + new YarnConfiguration(SparkHadoopUtil.get.newConfiguration(conf)).get(key, default) } else { - sparkValue + default } } From 15298b99ac8944e781328423289586176cf824d7 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 30 Mar 2018 16:48:26 -0700 Subject: [PATCH 543/774] [SPARK-23827][SS] StreamingJoinExec should ensure that input data is partitioned into specific number of partitions ## What changes were proposed in this pull request? Currently, the requiredChildDistribution does not specify the partitions. This can cause the weird corner cases where the child's distribution is `SinglePartition` which satisfies the required distribution of `ClusterDistribution(no-num-partition-requirement)`, thus eliminating the shuffle needed to repartition input data into the required number of partitions (i.e. same as state stores). That can lead to "file not found" errors on the state store delta files as the micro-batch-with-no-shuffle will not run certain tasks and therefore not generate the expected state store delta files. This PR adds the required constraint on the number of partitions. ## How was this patch tested? Modified test harness to always check that ANY stateful operator should have a constraint on the number of partitions. As part of that, the existing opt-in checks on child output partitioning were removed, as they are redundant. Author: Tathagata Das Closes #20941 from tdas/SPARK-23827. --- .../streaming/IncrementalExecution.scala | 2 +- .../StreamingSymmetricHashJoinExec.scala | 3 +- .../sql/streaming/DeduplicateSuite.scala | 8 +-- .../FlatMapGroupsWithStateSuite.scala | 5 +- .../sql/streaming/StatefulOperatorTest.scala | 49 ------------------- .../spark/sql/streaming/StreamTest.scala | 19 +++++++ .../streaming/StreamingAggregationSuite.scala | 4 +- 7 files changed, 25 insertions(+), 65 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index a10ed5f2df1b5..1a83c884d55bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -62,7 +62,7 @@ class IncrementalExecution( StreamingDeduplicationStrategy :: Nil } - private val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) + private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) .map(SQLConf.SHUFFLE_PARTITIONS.valueConverter) .getOrElse(sparkSession.sessionState.conf.numShufflePartitions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index c351f658cb955..fa7c8ee906ecd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -167,7 +167,8 @@ case class StreamingSymmetricHashJoinExec( val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: + ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil override def output: Seq[Attribute] = joinType match { case _: InnerLike => left.output ++ right.output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala index caf2bab8a5859..0088b64d6195e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala @@ -25,9 +25,7 @@ import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplic import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.functions._ -class DeduplicateSuite extends StateStoreMetricsTest - with BeforeAndAfterAll - with StatefulOperatorTest { +class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { import testImplicits._ @@ -44,8 +42,6 @@ class DeduplicateSuite extends StateStoreMetricsTest AddData(inputData, "a"), CheckLastBatch("a"), assertNumStateRows(total = 1, updated = 1), - AssertOnQuery(sq => - checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("value"))), AddData(inputData, "a"), CheckLastBatch(), assertNumStateRows(total = 1, updated = 0), @@ -63,8 +59,6 @@ class DeduplicateSuite extends StateStoreMetricsTest AddData(inputData, "a" -> 1), CheckLastBatch("a" -> 1), assertNumStateRows(total = 1, updated = 1), - AssertOnQuery(sq => - checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("_1"))), AddData(inputData, "a" -> 2), // Dropped CheckLastBatch(), assertNumStateRows(total = 1, updated = 0), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index de2b51678cea6..b1416bff87ee7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -42,8 +42,7 @@ case class RunningCount(count: Long) case class Result(key: Long, count: Int) class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest - with BeforeAndAfterAll - with StatefulOperatorTest { + with BeforeAndAfterAll { import testImplicits._ import GroupStateImpl._ @@ -618,8 +617,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest AddData(inputData, "a"), CheckLastBatch(("a", "1")), assertNumStateRows(total = 1, updated = 1), - AssertOnQuery(sq => checkChildOutputHashPartitioning[FlatMapGroupsWithStateExec]( - sq, Seq("value"))), AddData(inputData, "a", "b"), CheckLastBatch(("a", "2"), ("b", "1")), assertNumStateRows(total = 2, updated = 2), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala deleted file mode 100644 index 45142278993bb..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.streaming._ - -trait StatefulOperatorTest { - /** - * Check that the output partitioning of a child operator of a Stateful operator satisfies the - * distribution that we expect for our Stateful operator. - */ - protected def checkChildOutputHashPartitioning[T <: StatefulOperator]( - sq: StreamingQuery, - colNames: Seq[String]): Boolean = { - val attr = sq.asInstanceOf[StreamExecution].lastExecution.analyzed.output - val partitions = sq.sparkSession.sessionState.conf.numShufflePartitions - val groupingAttr = attr.filter(a => colNames.contains(a.name)) - checkChildOutputPartitioning(sq, HashPartitioning(groupingAttr, partitions)) - } - - /** - * Check that the output partitioning of a child operator of a Stateful operator satisfies the - * distribution that we expect for our Stateful operator. - */ - protected def checkChildOutputPartitioning[T <: StatefulOperator]( - sq: StreamingQuery, - expectedPartitioning: Partitioning): Boolean = { - val operator = sq.asInstanceOf[StreamExecution].lastExecution - .executedPlan.collect { case p: T => p } - operator.head.children.forall( - _.outputPartitioning.numPartitions == expectedPartitioning.numPartitions) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index e44aef09f1f3c..00741d660dd2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -37,6 +37,7 @@ import org.apache.spark.SparkEnv import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.physical.AllTuples import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ @@ -444,6 +445,24 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } } + val lastExecution = currentStream.lastExecution + if (currentStream.isInstanceOf[MicroBatchExecution] && lastExecution != null) { + // Verify if stateful operators have correct metadata and distribution + // This can often catch hard to debug errors when developing stateful operators + lastExecution.executedPlan.collect { case s: StatefulOperator => s }.foreach { s => + assert(s.stateInfo.map(_.numPartitions).contains(lastExecution.numStateStores)) + s.requiredChildDistribution.foreach { d => + withClue(s"$s specifies incorrect # partitions in requiredChildDistribution $d") { + assert(d.requiredNumPartitions.isDefined) + assert(d.requiredNumPartitions.get >= 1) + if (d != AllTuples) { + assert(d.requiredNumPartitions.get == s.stateInfo.get.numPartitions) + } + } + } + } + } + val (latestBatchData, allData) = sink match { case s: MemorySink => (s.latestBatchData, s.allData) case s: MemorySinkV2 => (s.latestBatchData, s.allData) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 97e065193fd05..1cae8cb8d47f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -44,7 +44,7 @@ object FailureSingleton { } class StreamingAggregationSuite extends StateStoreMetricsTest - with BeforeAndAfterAll with Assertions with StatefulOperatorTest { + with BeforeAndAfterAll with Assertions { override def afterAll(): Unit = { super.afterAll() @@ -281,8 +281,6 @@ class StreamingAggregationSuite extends StateStoreMetricsTest AddData(inputData, 0L, 5L, 5L, 10L), AdvanceManualClock(10 * 1000), CheckLastBatch((0L, 1), (5L, 2), (10L, 1)), - AssertOnQuery(sq => - checkChildOutputHashPartitioning[StateStoreRestoreExec](sq, Seq("value"))), // advance clock to 20 seconds, should retain keys >= 10 AddData(inputData, 15L, 15L, 20L), From 529f847105fa8d98a5dc4d20955e4870df6bc1c5 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Sat, 31 Mar 2018 10:34:01 +0800 Subject: [PATCH 544/774] [SPARK-23040][CORE][FOLLOW-UP] Avoid double wrap result Iterator. ## What changes were proposed in this pull request? Address https://github.com/apache/spark/pull/20449#discussion_r172414393, If `resultIter` is already a `InterruptibleIterator`, don't double wrap it. ## How was this patch tested? Existing tests. Author: Xingbo Jiang Closes #20920 from jiangxb1987/SPARK-23040. --- .../spark/shuffle/BlockStoreShuffleReader.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 85e7e56a04a7d..4103dfb10175e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -111,8 +111,13 @@ private[spark] class BlockStoreShuffleReader[K, C]( case None => aggregatedIter } - // Use another interruptible iterator here to support task cancellation as aggregator or(and) - // sorter may have consumed previous interruptible iterator. - new InterruptibleIterator[Product2[K, C]](context, resultIter) + + resultIter match { + case _: InterruptibleIterator[Product2[K, C]] => resultIter + case _ => + // Use another interruptible iterator here to support task cancellation as aggregator + // or(and) sorter may have consumed previous interruptible iterator. + new InterruptibleIterator[Product2[K, C]](context, resultIter) + } } } From 44a9f8e6e82c300dc61ca18515aee16f17f27501 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 2 Apr 2018 09:53:37 -0700 Subject: [PATCH 545/774] [SPARK-15009][PYTHON][FOLLOWUP] Add default param checks for CountVectorizerModel ## What changes were proposed in this pull request? Adding test for default params for `CountVectorizerModel` constructed from vocabulary. This required that the param `maxDF` be added, which was done in SPARK-23615. ## How was this patch tested? Added an explicit test for CountVectorizerModel in DefaultValuesTests. Author: Bryan Cutler Closes #20942 from BryanCutler/pyspark-CountVectorizerModel-default-param-test-SPARK-15009. --- python/pyspark/ml/tests.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6b4376cbf14e8..c2c4861e2aff4 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -2096,6 +2096,11 @@ def test_java_params(self): # NOTE: disable check_params_exist until there is parity with Scala API ParamTests.check_params(self, cls(), check_params_exist=False) + # Additional classes that need explicit construction + from pyspark.ml.feature import CountVectorizerModel + ParamTests.check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'), + check_params_exist=False) + def _squared_distance(a, b): if isinstance(a, Vector): From 6151f29f9f589301159482044fc32717f430db6e Mon Sep 17 00:00:00 2001 From: David Vogelbacher Date: Mon, 2 Apr 2018 12:00:37 -0700 Subject: [PATCH 546/774] [SPARK-23825][K8S] Requesting memory + memory overhead for pod memory ## What changes were proposed in this pull request? Kubernetes driver and executor pods should request `memory + memoryOverhead` as their resources instead of just `memory`, see https://issues.apache.org/jira/browse/SPARK-23825 ## How was this patch tested? Existing unit tests were adapted. Author: David Vogelbacher Closes #20943 from dvogelbacher/spark-23825. --- .../k8s/submit/steps/BasicDriverConfigurationStep.scala | 5 +---- .../spark/scheduler/cluster/k8s/ExecutorPodFactory.scala | 5 +---- .../submit/steps/BasicDriverConfigurationStepSuite.scala | 2 +- .../scheduler/cluster/k8s/ExecutorPodFactorySuite.scala | 6 ++++-- 4 files changed, 7 insertions(+), 11 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index 347c4d2d66826..b811db324108c 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -93,9 +93,6 @@ private[spark] class BasicDriverConfigurationStep( .withAmount(driverCpuCores) .build() val driverMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${driverMemoryMiB}Mi") - .build() - val driverMemoryLimitQuantity = new QuantityBuilder(false) .withAmount(s"${driverMemoryWithOverheadMiB}Mi") .build() val maybeCpuLimitQuantity = driverLimitCores.map { limitCores => @@ -117,7 +114,7 @@ private[spark] class BasicDriverConfigurationStep( .withNewResources() .addToRequests("cpu", driverCpuQuantity) .addToRequests("memory", driverMemoryQuantity) - .addToLimits("memory", driverMemoryLimitQuantity) + .addToLimits("memory", driverMemoryQuantity) .addToLimits(maybeCpuLimitQuantity.toMap.asJava) .endResources() .addToArgs("driver") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index 98cbd5607da00..ac42385459dda 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -108,9 +108,6 @@ private[spark] class ExecutorPodFactory( SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ executorLabels val executorMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${executorMemoryMiB}Mi") - .build() - val executorMemoryLimitQuantity = new QuantityBuilder(false) .withAmount(s"${executorMemoryWithOverhead}Mi") .build() val executorCpuQuantity = new QuantityBuilder(false) @@ -167,7 +164,7 @@ private[spark] class ExecutorPodFactory( .withImagePullPolicy(imagePullPolicy) .withNewResources() .addToRequests("memory", executorMemoryQuantity) - .addToLimits("memory", executorMemoryLimitQuantity) + .addToLimits("memory", executorMemoryQuantity) .addToRequests("cpu", executorCpuQuantity) .endResources() .addAllToEnv(executorEnv.asJava) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index ce068531c7673..e59c6d28a8cc2 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -91,7 +91,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { val resourceRequirements = preparedDriverSpec.driverContainer.getResources val requests = resourceRequirements.getRequests.asScala assert(requests("cpu").getAmount === "2") - assert(requests("memory").getAmount === "256Mi") + assert(requests("memory").getAmount === "456Mi") val limits = resourceRequirements.getLimits.asScala assert(limits("memory").getAmount === "456Mi") assert(limits("cpu").getAmount === "4") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 7755b93835047..cee8fe27039c9 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -66,12 +66,14 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef assert(executor.getMetadata.getLabels.size() === 3) assert(executor.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) === "1") - // There is exactly 1 container with no volume mounts and default memory limits. - // Default memory limit is 1024M + 384M (minimum overhead constant). + // There is exactly 1 container with no volume mounts and default memory limits and requests. + // Default memory limit/request is 1024M + 384M (minimum overhead constant). assert(executor.getSpec.getContainers.size() === 1) assert(executor.getSpec.getContainers.get(0).getImage === executorImage) assert(executor.getSpec.getContainers.get(0).getVolumeMounts.isEmpty) assert(executor.getSpec.getContainers.get(0).getResources.getLimits.size() === 1) + assert(executor.getSpec.getContainers.get(0).getResources + .getRequests.get("memory").getAmount === "1408Mi") assert(executor.getSpec.getContainers.get(0).getResources .getLimits.get("memory").getAmount === "1408Mi") From fe2b7a4568d65a62da6e6eb00fff05f248b4332c Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Mon, 2 Apr 2018 12:20:55 -0700 Subject: [PATCH 547/774] [SPARK-23285][K8S] Add a config property for specifying physical executor cores ## What changes were proposed in this pull request? As mentioned in SPARK-23285, this PR introduces a new configuration property `spark.kubernetes.executor.cores` for specifying the physical CPU cores requested for each executor pod. This is to avoid changing the semantics of `spark.executor.cores` and `spark.task.cpus` and their role in task scheduling, task parallelism, dynamic resource allocation, etc. The new configuration property only determines the physical CPU cores available to an executor. An executor can still run multiple tasks simultaneously by using appropriate values for `spark.executor.cores` and `spark.task.cpus`. ## How was this patch tested? Unit tests. felixcheung srowen jiangxb1987 jerryshao mccheah foxish Author: Yinan Li Author: Yinan Li Closes #20553 from liyinan926/master. --- docs/running-on-kubernetes.md | 15 ++++++++--- .../org/apache/spark/deploy/k8s/Config.scala | 6 +++++ .../cluster/k8s/ExecutorPodFactory.scala | 12 ++++++--- .../cluster/k8s/ExecutorPodFactorySuite.scala | 27 +++++++++++++++++++ 4 files changed, 53 insertions(+), 7 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 975b28de47e20..9c4644947c911 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -549,14 +549,23 @@ specific to Spark on Kubernetes. spark.kubernetes.driver.limit.cores (none) - Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for the driver pod. + Specify a hard cpu [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for the driver pod. + + spark.kubernetes.executor.request.cores + (none) + + Specify the cpu request for each executor pod. Values conform to the Kubernetes [convention](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#meaning-of-cpu). + Example values include 0.1, 500m, 1.5, 5, etc., with the definition of cpu units documented in [CPU units](https://kubernetes.io/docs/tasks/configure-pod-container/assign-cpu-resource/#cpu-units). + This is distinct from spark.executor.cores: it is only used and takes precedence over spark.executor.cores for specifying the executor pod cpu request if set. Task + parallelism, e.g., number of tasks an executor can run concurrently is not affected by this. + spark.kubernetes.executor.limit.cores (none) - Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for each executor pod launched for the Spark Application. + Specify a hard cpu [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for each executor pod launched for the Spark Application. @@ -593,4 +602,4 @@ specific to Spark on Kubernetes. spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. - \ No newline at end of file + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index da34a7e06238a..405ea476351bb 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -91,6 +91,12 @@ private[spark] object Config extends Logging { .stringConf .createOptional + val KUBERNETES_EXECUTOR_REQUEST_CORES = + ConfigBuilder("spark.kubernetes.executor.request.cores") + .doc("Specify the cpu request for each executor pod") + .stringConf + .createOptional + val KUBERNETES_DRIVER_POD_NAME = ConfigBuilder("spark.kubernetes.driver.pod.name") .doc("Name of the driver pod.") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index ac42385459dda..7143f7a6f0b71 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -83,7 +83,12 @@ private[spark] class ExecutorPodFactory( MEMORY_OVERHEAD_MIN_MIB)) private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB - private val executorCores = sparkConf.getDouble("spark.executor.cores", 1) + private val executorCores = sparkConf.getInt("spark.executor.cores", 1) + private val executorCoresRequest = if (sparkConf.contains(KUBERNETES_EXECUTOR_REQUEST_CORES)) { + sparkConf.get(KUBERNETES_EXECUTOR_REQUEST_CORES).get + } else { + executorCores.toString + } private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) /** @@ -111,7 +116,7 @@ private[spark] class ExecutorPodFactory( .withAmount(s"${executorMemoryWithOverhead}Mi") .build() val executorCpuQuantity = new QuantityBuilder(false) - .withAmount(executorCores.toString) + .withAmount(executorCoresRequest) .build() val executorExtraClasspathEnv = executorExtraClasspath.map { cp => new EnvVarBuilder() @@ -130,8 +135,7 @@ private[spark] class ExecutorPodFactory( }.getOrElse(Seq.empty[EnvVar]) val executorEnv = (Seq( (ENV_DRIVER_URL, driverUrl), - // Executor backend expects integral value for executor cores, so round it up to an int. - (ENV_EXECUTOR_CORES, math.ceil(executorCores).toInt.toString), + (ENV_EXECUTOR_CORES, executorCores.toString), (ENV_EXECUTOR_MEMORY, executorMemoryString), (ENV_APPLICATION_ID, applicationId), // This is to set the SPARK_CONF_DIR to be /opt/spark/conf diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index cee8fe27039c9..a71a2a1b888bc 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -85,6 +85,33 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef checkOwnerReferences(executor, driverPodUid) } + test("executor core request specification") { + var factory = new ExecutorPodFactory(baseConf, None) + var executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + assert(executor.getSpec.getContainers.size() === 1) + assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount + === "1") + + val conf = baseConf.clone() + + conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "0.1") + factory = new ExecutorPodFactory(conf, None) + executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + assert(executor.getSpec.getContainers.size() === 1) + assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount + === "0.1") + + conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "100m") + factory = new ExecutorPodFactory(conf, None) + conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "100m") + executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount + === "100m") + } + test("executor pod hostnames get truncated to 63 characters") { val conf = baseConf.clone() conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, From a7c19d9c21d59fd0109a7078c80b33d3da03fafd Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 2 Apr 2018 21:48:44 +0200 Subject: [PATCH 548/774] [SPARK-23713][SQL] Cleanup UnsafeWriter and BufferHolder classes ## What changes were proposed in this pull request? This PR implemented the following cleanups related to `UnsafeWriter` class: - Remove code duplication between `UnsafeRowWriter` and `UnsafeArrayWriter` - Make `BufferHolder` class internal by delegating its accessor methods to `UnsafeWriter` - Replace `UnsafeRow.setTotalSize(...)` with `UnsafeRowWriter.setTotalSize()` ## How was this patch tested? Tested by existing UTs Author: Kazuaki Ishizaki Closes #20850 from kiszk/SPARK-23713. --- .../sql/kafka010/KafkaContinuousReader.scala | 3 - .../KafkaRecordToUnsafeRowConverter.scala | 11 +- .../expressions/codegen/BufferHolder.java | 32 +-- .../codegen/UnsafeArrayWriter.java | 133 +++--------- .../expressions/codegen/UnsafeRowWriter.java | 189 +++++++----------- .../expressions/codegen/UnsafeWriter.java | 157 ++++++++++++++- .../InterpretedUnsafeProjection.scala | 90 ++++----- .../codegen/GenerateUnsafeProjection.scala | 124 +++++------- .../RowBasedKeyValueBatchSuite.java | 28 +-- .../aggregate/RowBasedHashMapGenerator.scala | 12 +- .../columnar/GenerateColumnAccessor.scala | 9 +- .../datasources/text/TextFileFormat.scala | 11 +- 12 files changed, 391 insertions(+), 408 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index e7e27876088f3..f26c134c2f6e9 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -27,13 +27,10 @@ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.types.UTF8String /** * A [[ContinuousReader]] for data from kafka. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala index 1acdd56125741..f35a143e00374 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala @@ -20,18 +20,16 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.clients.consumer.ConsumerRecord import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.unsafe.types.UTF8String /** A simple class for converting Kafka ConsumerRecord to UnsafeRow */ private[kafka010] class KafkaRecordToUnsafeRowConverter { - private val sharedRow = new UnsafeRow(7) - private val bufferHolder = new BufferHolder(sharedRow) - private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) + private val rowWriter = new UnsafeRowWriter(7) def toUnsafeRow(record: ConsumerRecord[Array[Byte], Array[Byte]]): UnsafeRow = { - bufferHolder.reset() + rowWriter.reset() if (record.key == null) { rowWriter.setNullAt(0) @@ -46,7 +44,6 @@ private[kafka010] class KafkaRecordToUnsafeRowConverter { 5, DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(record.timestamp))) rowWriter.write(6, record.timestampType.id) - sharedRow.setTotalSize(bufferHolder.totalSize) - sharedRow + rowWriter.getRow() } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 259976118c12f..537ef244b7e81 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -30,25 +30,21 @@ * this class per writing program, so that the memory segment/data buffer can be reused. Note that * for each incoming record, we should call `reset` of BufferHolder instance before write the record * and reuse the data buffer. - * - * Generally we should call `UnsafeRow.setTotalSize` and pass in `BufferHolder.totalSize` to update - * the size of the result row, after writing a record to the buffer. However, we can skip this step - * if the fields of row are all fixed-length, as the size of result row is also fixed. */ -public class BufferHolder { +final class BufferHolder { private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; - public byte[] buffer; - public int cursor = Platform.BYTE_ARRAY_OFFSET; + private byte[] buffer; + private int cursor = Platform.BYTE_ARRAY_OFFSET; private final UnsafeRow row; private final int fixedSize; - public BufferHolder(UnsafeRow row) { + BufferHolder(UnsafeRow row) { this(row, 64); } - public BufferHolder(UnsafeRow row, int initialSize) { + BufferHolder(UnsafeRow row, int initialSize) { int bitsetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()); if (row.numFields() > (ARRAY_MAX - initialSize - bitsetWidthInBytes) / 8) { throw new UnsupportedOperationException( @@ -64,7 +60,7 @@ public BufferHolder(UnsafeRow row, int initialSize) { /** * Grows the buffer by at least neededSize and points the row to the buffer. */ - public void grow(int neededSize) { + void grow(int neededSize) { if (neededSize > ARRAY_MAX - totalSize()) { throw new UnsupportedOperationException( "Cannot grow BufferHolder by size " + neededSize + " because the size after growing " + @@ -86,11 +82,23 @@ public void grow(int neededSize) { } } - public void reset() { + byte[] getBuffer() { + return buffer; + } + + int getCursor() { + return cursor; + } + + void increaseCursor(int val) { + cursor += val; + } + + void reset() { cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize; } - public int totalSize() { + int totalSize() { return cursor - Platform.BYTE_ARRAY_OFFSET; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 82cd1b24607e1..a78dd970d23e4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -21,8 +21,6 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; import static org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculateHeaderPortionInBytes; @@ -32,14 +30,12 @@ */ public final class UnsafeArrayWriter extends UnsafeWriter { - private BufferHolder holder; - - // The offset of the global buffer where we start to write this array. - private int startingOffset; - // The number of elements in this array private int numElements; + // The element size in this array + private int elementSize; + private int headerInBytes; private void assertIndexIsValid(int index) { @@ -47,13 +43,17 @@ private void assertIndexIsValid(int index) { assert index < numElements : "index (" + index + ") should < " + numElements; } - public void initialize(BufferHolder holder, int numElements, int elementSize) { + public UnsafeArrayWriter(UnsafeWriter writer, int elementSize) { + super(writer.getBufferHolder()); + this.elementSize = elementSize; + } + + public void initialize(int numElements) { // We need 8 bytes to store numElements in header this.numElements = numElements; this.headerInBytes = calculateHeaderPortionInBytes(numElements); - this.holder = holder; - this.startingOffset = holder.cursor; + this.startingOffset = cursor(); // Grows the global buffer ahead for header and fixed size data. int fixedPartInBytes = @@ -61,112 +61,92 @@ public void initialize(BufferHolder holder, int numElements, int elementSize) { holder.grow(headerInBytes + fixedPartInBytes); // Write numElements and clear out null bits to header - Platform.putLong(holder.buffer, startingOffset, numElements); + Platform.putLong(getBuffer(), startingOffset, numElements); for (int i = 8; i < headerInBytes; i += 8) { - Platform.putLong(holder.buffer, startingOffset + i, 0L); + Platform.putLong(getBuffer(), startingOffset + i, 0L); } // fill 0 into reminder part of 8-bytes alignment in unsafe array for (int i = elementSize * numElements; i < fixedPartInBytes; i++) { - Platform.putByte(holder.buffer, startingOffset + headerInBytes + i, (byte) 0); + Platform.putByte(getBuffer(), startingOffset + headerInBytes + i, (byte) 0); } - holder.cursor += (headerInBytes + fixedPartInBytes); + increaseCursor(headerInBytes + fixedPartInBytes); } - private void zeroOutPaddingBytes(int numBytes) { - if ((numBytes & 0x07) > 0) { - Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); - } - } - - private long getElementOffset(int ordinal, int elementSize) { + private long getElementOffset(int ordinal) { return startingOffset + headerInBytes + ordinal * elementSize; } - public void setOffsetAndSize(int ordinal, int currentCursor, int size) { - assertIndexIsValid(ordinal); - final long relativeOffset = currentCursor - startingOffset; - final long offsetAndSize = (relativeOffset << 32) | (long)size; - - write(ordinal, offsetAndSize); - } - private void setNullBit(int ordinal) { assertIndexIsValid(ordinal); - BitSetMethods.set(holder.buffer, startingOffset + 8, ordinal); + BitSetMethods.set(getBuffer(), startingOffset + 8, ordinal); } public void setNull1Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), (byte)0); + writeByte(getElementOffset(ordinal), (byte)0); } public void setNull2Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), (short)0); + writeShort(getElementOffset(ordinal), (short)0); } public void setNull4Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), 0); + writeInt(getElementOffset(ordinal), 0); } public void setNull8Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), (long)0); + writeLong(getElementOffset(ordinal), 0); } public void setNull(int ordinal) { setNull8Bytes(ordinal); } public void write(int ordinal, boolean value) { assertIndexIsValid(ordinal); - Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), value); + writeBoolean(getElementOffset(ordinal), value); } public void write(int ordinal, byte value) { assertIndexIsValid(ordinal); - Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), value); + writeByte(getElementOffset(ordinal), value); } public void write(int ordinal, short value) { assertIndexIsValid(ordinal); - Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), value); + writeShort(getElementOffset(ordinal), value); } public void write(int ordinal, int value) { assertIndexIsValid(ordinal); - Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), value); + writeInt(getElementOffset(ordinal), value); } public void write(int ordinal, long value) { assertIndexIsValid(ordinal); - Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), value); + writeLong(getElementOffset(ordinal), value); } public void write(int ordinal, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } assertIndexIsValid(ordinal); - Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), value); + writeFloat(getElementOffset(ordinal), value); } public void write(int ordinal, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } assertIndexIsValid(ordinal); - Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), value); + writeDouble(getElementOffset(ordinal), value); } public void write(int ordinal, Decimal input, int precision, int scale) { // make sure Decimal object has the same scale as DecimalType assertIndexIsValid(ordinal); - if (input.changePrecision(precision, scale)) { + if (input != null && input.changePrecision(precision, scale)) { if (precision <= Decimal.MAX_LONG_DIGITS()) { write(ordinal, input.toUnscaledLong()); } else { @@ -180,65 +160,14 @@ public void write(int ordinal, Decimal input, int precision, int scale) { // Write the bytes to the variable length portion. Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes); - setOffsetAndSize(ordinal, holder.cursor, numBytes); + bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes); + setOffsetAndSize(ordinal, numBytes); // move the cursor forward with 8-bytes boundary - holder.cursor += roundedSize; + increaseCursor(roundedSize); } } else { setNull(ordinal); } } - - public void write(int ordinal, UTF8String input) { - final int numBytes = input.numBytes(); - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - input.writeToMemory(holder.buffer, holder.cursor); - - setOffsetAndSize(ordinal, holder.cursor, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, byte[] input) { - final int numBytes = input.length; - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - Platform.copyMemory( - input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes); - - setOffsetAndSize(ordinal, holder.cursor, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, CalendarInterval input) { - // grow the global buffer before writing data. - holder.grow(16); - - // Write the months and microseconds fields of Interval to the variable length portion. - Platform.putLong(holder.buffer, holder.cursor, input.months); - Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); - - setOffsetAndSize(ordinal, holder.cursor, 16); - - // move the cursor forward. - holder.cursor += 16; - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 2620bbcfb87a2..71c49d8ed0177 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -20,10 +20,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; /** * A helper class to write data into global row buffer using `UnsafeRow` format. @@ -31,7 +28,7 @@ * It will remember the offset of row buffer which it starts to write, and move the cursor of row * buffer while writing. If new data(can be the input record if this is the outermost writer, or * nested struct if this is an inner writer) comes, the starting cursor of row buffer may be - * changed, so we need to call `UnsafeRowWriter.reset` before writing, to update the + * changed, so we need to call `UnsafeRowWriter.resetRowWriter` before writing, to update the * `startingOffset` and clear out null bits. * * Note that if this is the outermost writer, which means we will always write from the very @@ -40,29 +37,58 @@ */ public final class UnsafeRowWriter extends UnsafeWriter { - private final BufferHolder holder; - // The offset of the global buffer where we start to write this row. - private int startingOffset; + private final UnsafeRow row; + private final int nullBitsSize; private final int fixedSize; - public UnsafeRowWriter(BufferHolder holder, int numFields) { - this.holder = holder; + public UnsafeRowWriter(int numFields) { + this(new UnsafeRow(numFields)); + } + + public UnsafeRowWriter(int numFields, int initialBufferSize) { + this(new UnsafeRow(numFields), initialBufferSize); + } + + public UnsafeRowWriter(UnsafeWriter writer, int numFields) { + this(null, writer.getBufferHolder(), numFields); + } + + private UnsafeRowWriter(UnsafeRow row) { + this(row, new BufferHolder(row), row.numFields()); + } + + private UnsafeRowWriter(UnsafeRow row, int initialBufferSize) { + this(row, new BufferHolder(row, initialBufferSize), row.numFields()); + } + + private UnsafeRowWriter(UnsafeRow row, BufferHolder holder, int numFields) { + super(holder); + this.row = row; this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields); this.fixedSize = nullBitsSize + 8 * numFields; - this.startingOffset = holder.cursor; + this.startingOffset = cursor(); + } + + /** + * Updates total size of the UnsafeRow using the size collected by BufferHolder, and returns + * the UnsafeRow created at a constructor + */ + public UnsafeRow getRow() { + row.setTotalSize(totalSize()); + return row; } /** * Resets the `startingOffset` according to the current cursor of row buffer, and clear out null * bits. This should be called before we write a new nested struct to the row buffer. */ - public void reset() { - this.startingOffset = holder.cursor; + public void resetRowWriter() { + this.startingOffset = cursor(); // grow the global buffer to make sure it has enough space to write fixed-length data. - holder.grow(fixedSize); - holder.cursor += fixedSize; + grow(fixedSize); + increaseCursor(fixedSize); zeroOutNullBytes(); } @@ -72,25 +98,17 @@ public void reset() { */ public void zeroOutNullBytes() { for (int i = 0; i < nullBitsSize; i += 8) { - Platform.putLong(holder.buffer, startingOffset + i, 0L); - } - } - - private void zeroOutPaddingBytes(int numBytes) { - if ((numBytes & 0x07) > 0) { - Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); + Platform.putLong(getBuffer(), startingOffset + i, 0L); } } - public BufferHolder holder() { return holder; } - public boolean isNullAt(int ordinal) { - return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal); + return BitSetMethods.isSet(getBuffer(), startingOffset, ordinal); } public void setNullAt(int ordinal) { - BitSetMethods.set(holder.buffer, startingOffset, ordinal); - Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L); + BitSetMethods.set(getBuffer(), startingOffset, ordinal); + write(ordinal, 0L); } @Override @@ -117,67 +135,49 @@ public long getFieldOffset(int ordinal) { return startingOffset + nullBitsSize + 8 * ordinal; } - public void setOffsetAndSize(int ordinal, int size) { - setOffsetAndSize(ordinal, holder.cursor, size); - } - - public void setOffsetAndSize(int ordinal, int currentCursor, int size) { - final long relativeOffset = currentCursor - startingOffset; - final long fieldOffset = getFieldOffset(ordinal); - final long offsetAndSize = (relativeOffset << 32) | (long) size; - - Platform.putLong(holder.buffer, fieldOffset, offsetAndSize); - } - public void write(int ordinal, boolean value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putBoolean(holder.buffer, offset, value); + writeLong(offset, 0L); + writeBoolean(offset, value); } public void write(int ordinal, byte value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putByte(holder.buffer, offset, value); + writeLong(offset, 0L); + writeByte(offset, value); } public void write(int ordinal, short value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putShort(holder.buffer, offset, value); + writeLong(offset, 0L); + writeShort(offset, value); } public void write(int ordinal, int value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putInt(holder.buffer, offset, value); + writeLong(offset, 0L); + writeInt(offset, value); } public void write(int ordinal, long value) { - Platform.putLong(holder.buffer, getFieldOffset(ordinal), value); + writeLong(getFieldOffset(ordinal), value); } public void write(int ordinal, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putFloat(holder.buffer, offset, value); + writeLong(offset, 0); + writeFloat(offset, value); } public void write(int ordinal, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } - Platform.putDouble(holder.buffer, getFieldOffset(ordinal), value); + writeDouble(getFieldOffset(ordinal), value); } public void write(int ordinal, Decimal input, int precision, int scale) { if (precision <= Decimal.MAX_LONG_DIGITS()) { // make sure Decimal object has the same scale as DecimalType - if (input.changePrecision(precision, scale)) { - Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong()); + if (input != null && input.changePrecision(precision, scale)) { + write(ordinal, input.toUnscaledLong()); } else { setNullAt(ordinal); } @@ -185,82 +185,31 @@ public void write(int ordinal, Decimal input, int precision, int scale) { // grow the global buffer before writing data. holder.grow(16); - // zero-out the bytes - Platform.putLong(holder.buffer, holder.cursor, 0L); - Platform.putLong(holder.buffer, holder.cursor + 8, 0L); - // Make sure Decimal object has the same scale as DecimalType. // Note that we may pass in null Decimal object to set null for it. if (input == null || !input.changePrecision(precision, scale)) { - BitSetMethods.set(holder.buffer, startingOffset, ordinal); + // zero-out the bytes + Platform.putLong(getBuffer(), cursor(), 0L); + Platform.putLong(getBuffer(), cursor() + 8, 0L); + + BitSetMethods.set(getBuffer(), startingOffset, ordinal); // keep the offset for future update setOffsetAndSize(ordinal, 0); } else { final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); - assert bytes.length <= 16; + final int numBytes = bytes.length; + assert numBytes <= 16; + + zeroOutPaddingBytes(numBytes); // Write the bytes to the variable length portion. Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); + bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes); setOffsetAndSize(ordinal, bytes.length); } // move the cursor forward. - holder.cursor += 16; + increaseCursor(16); } } - - public void write(int ordinal, UTF8String input) { - final int numBytes = input.numBytes(); - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - input.writeToMemory(holder.buffer, holder.cursor); - - setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, byte[] input) { - write(ordinal, input, 0, input.length); - } - - public void write(int ordinal, byte[] input, int offset, int numBytes) { - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET + offset, - holder.buffer, holder.cursor, numBytes); - - setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, CalendarInterval input) { - // grow the global buffer before writing data. - holder.grow(16); - - // Write the months and microseconds fields of Interval to the variable length portion. - Platform.putLong(holder.buffer, holder.cursor, input.months); - Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); - - setOffsetAndSize(ordinal, 16); - - // move the cursor forward. - holder.cursor += 16; - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index c94b5c7a367ef..de0eb6dbb76be 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen; import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -24,10 +26,73 @@ * Base class for writing Unsafe* structures. */ public abstract class UnsafeWriter { + // Keep internal buffer holder + protected final BufferHolder holder; + + // The offset of the global buffer where we start to write this structure. + protected int startingOffset; + + protected UnsafeWriter(BufferHolder holder) { + this.holder = holder; + } + + /** + * Accessor methods are delegated from BufferHolder class + */ + public final BufferHolder getBufferHolder() { + return holder; + } + + public final byte[] getBuffer() { + return holder.getBuffer(); + } + + public final void reset() { + holder.reset(); + } + + public final int totalSize() { + return holder.totalSize(); + } + + public final void grow(int neededSize) { + holder.grow(neededSize); + } + + public final int cursor() { + return holder.getCursor(); + } + + public final void increaseCursor(int val) { + holder.increaseCursor(val); + } + + public final void setOffsetAndSizeFromPreviousCursor(int ordinal, int previousCursor) { + setOffsetAndSize(ordinal, previousCursor, cursor() - previousCursor); + } + + protected void setOffsetAndSize(int ordinal, int size) { + setOffsetAndSize(ordinal, cursor(), size); + } + + protected void setOffsetAndSize(int ordinal, int currentCursor, int size) { + final long relativeOffset = currentCursor - startingOffset; + final long offsetAndSize = (relativeOffset << 32) | (long)size; + + write(ordinal, offsetAndSize); + } + + protected final void zeroOutPaddingBytes(int numBytes) { + if ((numBytes & 0x07) > 0) { + Platform.putLong(getBuffer(), cursor() + ((numBytes >> 3) << 3), 0L); + } + } + public abstract void setNull1Bytes(int ordinal); public abstract void setNull2Bytes(int ordinal); public abstract void setNull4Bytes(int ordinal); public abstract void setNull8Bytes(int ordinal); + public abstract void write(int ordinal, boolean value); public abstract void write(int ordinal, byte value); public abstract void write(int ordinal, short value); @@ -36,8 +101,92 @@ public abstract class UnsafeWriter { public abstract void write(int ordinal, float value); public abstract void write(int ordinal, double value); public abstract void write(int ordinal, Decimal input, int precision, int scale); - public abstract void write(int ordinal, UTF8String input); - public abstract void write(int ordinal, byte[] input); - public abstract void write(int ordinal, CalendarInterval input); - public abstract void setOffsetAndSize(int ordinal, int currentCursor, int size); + + public final void write(int ordinal, UTF8String input) { + final int numBytes = input.numBytes(); + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + + // grow the global buffer before writing data. + grow(roundedSize); + + zeroOutPaddingBytes(numBytes); + + // Write the bytes to the variable length portion. + input.writeToMemory(getBuffer(), cursor()); + + setOffsetAndSize(ordinal, numBytes); + + // move the cursor forward. + increaseCursor(roundedSize); + } + + public final void write(int ordinal, byte[] input) { + write(ordinal, input, 0, input.length); + } + + public final void write(int ordinal, byte[] input, int offset, int numBytes) { + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); + + // grow the global buffer before writing data. + grow(roundedSize); + + zeroOutPaddingBytes(numBytes); + + // Write the bytes to the variable length portion. + Platform.copyMemory( + input, Platform.BYTE_ARRAY_OFFSET + offset, getBuffer(), cursor(), numBytes); + + setOffsetAndSize(ordinal, numBytes); + + // move the cursor forward. + increaseCursor(roundedSize); + } + + public final void write(int ordinal, CalendarInterval input) { + // grow the global buffer before writing data. + grow(16); + + // Write the months and microseconds fields of Interval to the variable length portion. + Platform.putLong(getBuffer(), cursor(), input.months); + Platform.putLong(getBuffer(), cursor() + 8, input.microseconds); + + setOffsetAndSize(ordinal, 16); + + // move the cursor forward. + increaseCursor(16); + } + + protected final void writeBoolean(long offset, boolean value) { + Platform.putBoolean(getBuffer(), offset, value); + } + + protected final void writeByte(long offset, byte value) { + Platform.putByte(getBuffer(), offset, value); + } + + protected final void writeShort(long offset, short value) { + Platform.putShort(getBuffer(), offset, value); + } + + protected final void writeInt(long offset, int value) { + Platform.putInt(getBuffer(), offset, value); + } + + protected final void writeLong(long offset, long value) { + Platform.putLong(getBuffer(), offset, value); + } + + protected final void writeFloat(long offset, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } + Platform.putFloat(getBuffer(), offset, value); + } + + protected final void writeDouble(long offset, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } + Platform.putDouble(getBuffer(), offset, value); + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 0da5ece7e47fe..b31466f5c92d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter} +import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types.{UserDefinedType, _} import org.apache.spark.unsafe.Platform @@ -42,17 +42,12 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe /** The row representing the expression results. */ private[this] val intermediate = new GenericInternalRow(values) - /** The row returned by the projection. */ - private[this] val result = new UnsafeRow(numFields) - - /** The buffer which holds the resulting row's backing data. */ - private[this] val holder = new BufferHolder(result, numFields * 32) + /* The row writer for UnsafeRow result */ + private[this] val rowWriter = new UnsafeRowWriter(numFields, numFields * 32) /** The writer that writes the intermediate result to the result row. */ private[this] val writer: InternalRow => Unit = { - val rowWriter = new UnsafeRowWriter(holder, numFields) val baseWriter = generateStructWriter( - holder, rowWriter, expressions.map(e => StructField("", e.dataType, e.nullable))) if (!expressions.exists(_.nullable)) { @@ -83,10 +78,9 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe } // Write the intermediate row to an unsafe row. - holder.reset() + rowWriter.reset() writer(intermediate) - result.setTotalSize(holder.totalSize()) - result + rowWriter.getRow() } } @@ -111,14 +105,13 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * given buffer using the given [[UnsafeRowWriter]]. */ private def generateStructWriter( - bufferHolder: BufferHolder, rowWriter: UnsafeRowWriter, fields: Array[StructField]): InternalRow => Unit = { val numFields = fields.length // Create field writers. val fieldWriters = fields.map { field => - generateFieldWriter(bufferHolder, rowWriter, field.dataType, field.nullable) + generateFieldWriter(rowWriter, field.dataType, field.nullable) } // Create basic writer. row => { @@ -136,7 +129,6 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * or array) to the given buffer using the given [[UnsafeWriter]]. */ private def generateFieldWriter( - bufferHolder: BufferHolder, writer: UnsafeWriter, dt: DataType, nullable: Boolean): (SpecializedGetters, Int) => Unit = { @@ -178,81 +170,79 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { case StructType(fields) => val numFields = fields.length - val rowWriter = new UnsafeRowWriter(bufferHolder, numFields) - val structWriter = generateStructWriter(bufferHolder, rowWriter, fields) + val rowWriter = new UnsafeRowWriter(writer, numFields) + val structWriter = generateStructWriter(rowWriter, fields) (v, i) => { - val tmpCursor = bufferHolder.cursor + val previousCursor = writer.cursor() v.getStruct(i, fields.length) match { case row: UnsafeRow => writeUnsafeData( - bufferHolder, + rowWriter, row.getBaseObject, row.getBaseOffset, row.getSizeInBytes) case row => // Nested struct. We don't know where this will start because a row can be // variable length, so we need to update the offsets and zero out the bit mask. - rowWriter.reset() + rowWriter.resetRowWriter() structWriter.apply(row) } - writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case ArrayType(elementType, containsNull) => - val arrayWriter = new UnsafeArrayWriter - val elementSize = getElementSize(elementType) + val arrayWriter = new UnsafeArrayWriter(writer, getElementSize(elementType)) val elementWriter = generateFieldWriter( - bufferHolder, arrayWriter, elementType, containsNull) (v, i) => { - val tmpCursor = bufferHolder.cursor - writeArray(bufferHolder, arrayWriter, elementWriter, v.getArray(i), elementSize) - writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + val previousCursor = writer.cursor() + writeArray(arrayWriter, elementWriter, v.getArray(i)) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case MapType(keyType, valueType, valueContainsNull) => - val keyArrayWriter = new UnsafeArrayWriter - val keySize = getElementSize(keyType) + val keyArrayWriter = new UnsafeArrayWriter(writer, getElementSize(keyType)) val keyWriter = generateFieldWriter( - bufferHolder, keyArrayWriter, keyType, nullable = false) - val valueArrayWriter = new UnsafeArrayWriter - val valueSize = getElementSize(valueType) + val valueArrayWriter = new UnsafeArrayWriter(writer, getElementSize(valueType)) val valueWriter = generateFieldWriter( - bufferHolder, valueArrayWriter, valueType, valueContainsNull) (v, i) => { - val tmpCursor = bufferHolder.cursor + val previousCursor = writer.cursor() v.getMap(i) match { case map: UnsafeMapData => writeUnsafeData( - bufferHolder, + valueArrayWriter, map.getBaseObject, map.getBaseOffset, map.getSizeInBytes) case map => // preserve 8 bytes to write the key array numBytes later. - bufferHolder.grow(8) - bufferHolder.cursor += 8 + valueArrayWriter.grow(8) + valueArrayWriter.increaseCursor(8) // Write the keys and write the numBytes of key array into the first 8 bytes. - writeArray(bufferHolder, keyArrayWriter, keyWriter, map.keyArray(), keySize) - Platform.putLong(bufferHolder.buffer, tmpCursor, bufferHolder.cursor - tmpCursor - 8) + writeArray(keyArrayWriter, keyWriter, map.keyArray()) + Platform.putLong( + valueArrayWriter.getBuffer, + previousCursor, + valueArrayWriter.cursor - previousCursor - 8 + ) // Write the values. - writeArray(bufferHolder, valueArrayWriter, valueWriter, map.valueArray(), valueSize) + writeArray(valueArrayWriter, valueWriter, map.valueArray()) } - writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case udt: UserDefinedType[_] => - generateFieldWriter(bufferHolder, writer, udt.sqlType, nullable) + generateFieldWriter(writer, udt.sqlType, nullable) case NullType => (_, _) => {} @@ -324,20 +314,18 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * copy. */ private def writeArray( - bufferHolder: BufferHolder, arrayWriter: UnsafeArrayWriter, elementWriter: (SpecializedGetters, Int) => Unit, - array: ArrayData, - elementSize: Int): Unit = array match { + array: ArrayData): Unit = array match { case unsafe: UnsafeArrayData => writeUnsafeData( - bufferHolder, + arrayWriter, unsafe.getBaseObject, unsafe.getBaseOffset, unsafe.getSizeInBytes) case _ => val numElements = array.numElements() - arrayWriter.initialize(bufferHolder, numElements, elementSize) + arrayWriter.initialize(numElements) var i = 0 while (i < numElements) { elementWriter.apply(array, i) @@ -350,17 +338,17 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * [[UnsafeRow]], [[UnsafeArrayData]] and [[UnsafeMapData]] objects. */ private def writeUnsafeData( - bufferHolder: BufferHolder, + writer: UnsafeWriter, baseObject: AnyRef, baseOffset: Long, sizeInBytes: Int) : Unit = { - bufferHolder.grow(sizeInBytes) + writer.grow(sizeInBytes) Platform.copyMemory( baseObject, baseOffset, - bufferHolder.buffer, - bufferHolder.cursor, + writer.getBuffer, + writer.cursor, sizeInBytes) - bufferHolder.cursor += sizeInBytes + writer.increaseCursor(sizeInBytes) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 6682ba55b18b1..ab2254cd9f70a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -48,19 +48,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, input: String, fieldTypes: Seq[DataType], - bufferHolder: String): String = { + rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => ExprCode("", s"$tmpInput.isNullAt($i)", CodeGenerator.getValue(tmpInput, dt, i.toString)) } + val rowWriterClass = classOf[UnsafeRowWriter].getName + val structRowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", + v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});") + s""" final InternalRow $tmpInput = $input; if ($tmpInput instanceof UnsafeRow) { - ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", bufferHolder)} + ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", structRowWriter)} } else { - ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, bufferHolder)} + ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)} } """ } @@ -70,12 +74,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro row: String, inputs: Seq[ExprCode], inputTypes: Seq[DataType], - bufferHolder: String, + rowWriter: String, isTopLevel: Boolean = false): String = { - val rowWriterClass = classOf[UnsafeRowWriter].getName - val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", - v => s"$v = new $rowWriterClass($bufferHolder, ${inputs.length});") - val resetWriter = if (isTopLevel) { // For top level row writer, it always writes to the beginning of the global buffer holder, // which means its fixed-size region always in the same position, so we don't need to call @@ -88,7 +88,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.zeroOutNullBytes();" } } else { - s"$rowWriter.reset();" + s"$rowWriter.resetRowWriter();" } val writeFields = inputs.zip(inputTypes).zipWithIndex.map { @@ -97,7 +97,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case udt: UserDefinedType[_] => udt.sqlType case other => other } - val tmpCursor = ctx.freshName("tmpCursor") val setNull = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => @@ -105,33 +104,34 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});" case _ => s"$rowWriter.setNullAt($index);" } + val previousCursor = ctx.freshName("previousCursor") val writeField = dt match { case t: StructType => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $tmpCursor = $bufferHolder.cursor; - ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $rowWriter.cursor(); + ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), rowWriter)} + $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case a @ ArrayType(et, _) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, input.value, et, bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $rowWriter.cursor(); + ${writeArrayToBuffer(ctx, input.value, et, rowWriter)} + $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case m @ MapType(kt, vt, _) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $tmpCursor = $bufferHolder.cursor; - ${writeMapToBuffer(ctx, input.value, kt, vt, bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $rowWriter.cursor(); + ${writeMapToBuffer(ctx, input.value, kt, vt, rowWriter)} + $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case t: DecimalType => @@ -181,12 +181,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, input: String, elementType: DataType, - bufferHolder: String): String = { + rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") - val arrayWriterClass = classOf[UnsafeArrayWriter].getName - val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", - v => s"$v = new $arrayWriterClass();") val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") @@ -203,28 +200,32 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => 8 // we need 8 bytes to store offset and length } - val tmpCursor = ctx.freshName("tmpCursor") + val arrayWriterClass = classOf[UnsafeArrayWriter].getName + val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", + v => s"$v = new $arrayWriterClass($rowWriter, $elementOrOffsetSize);") + val previousCursor = ctx.freshName("previousCursor") + val element = CodeGenerator.getValue(tmpInput, et, index) val writeElement = et match { case t: StructType => s""" - final int $tmpCursor = $bufferHolder.cursor; - ${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $arrayWriter.cursor(); + ${writeStructToBuffer(ctx, element, t.map(_.dataType), arrayWriter)} + $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case a @ ArrayType(et, _) => s""" - final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, element, et, bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $arrayWriter.cursor(); + ${writeArrayToBuffer(ctx, element, et, arrayWriter)} + $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case m @ MapType(kt, vt, _) => s""" - final int $tmpCursor = $bufferHolder.cursor; - ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $arrayWriter.cursor(); + ${writeMapToBuffer(ctx, element, kt, vt, arrayWriter)} + $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case t: DecimalType => @@ -240,10 +241,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" final ArrayData $tmpInput = $input; if ($tmpInput instanceof UnsafeArrayData) { - ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", bufferHolder)} + ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", arrayWriter)} } else { final int $numElements = $tmpInput.numElements(); - $arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize); + $arrayWriter.initialize($numElements); for (int $index = 0; $index < $numElements; $index++) { if ($tmpInput.isNullAt($index)) { @@ -262,7 +263,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro input: String, keyType: DataType, valueType: DataType, - bufferHolder: String): String = { + rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val tmpCursor = ctx.freshName("tmpCursor") @@ -271,20 +272,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" final MapData $tmpInput = $input; if ($tmpInput instanceof UnsafeMapData) { - ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", bufferHolder)} + ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", rowWriter)} } else { // preserve 8 bytes to write the key array numBytes later. - $bufferHolder.grow(8); - $bufferHolder.cursor += 8; + $rowWriter.grow(8); + $rowWriter.increaseCursor(8); // Remember the current cursor so that we can write numBytes of key array later. - final int $tmpCursor = $bufferHolder.cursor; + final int $tmpCursor = $rowWriter.cursor(); - ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, bufferHolder)} + ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} // Write the numBytes of key array into the first 8 bytes. - Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor); + Platform.putLong($rowWriter.getBuffer(), $tmpCursor - 8, $rowWriter.cursor() - $tmpCursor); - ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, bufferHolder)} + ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} } """ } @@ -293,14 +294,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro * If the input is already in unsafe format, we don't need to go through all elements/fields, * we can directly write it. */ - private def writeUnsafeData(ctx: CodegenContext, input: String, bufferHolder: String) = { + private def writeUnsafeData(ctx: CodegenContext, input: String, rowWriter: String) = { val sizeInBytes = ctx.freshName("sizeInBytes") s""" final int $sizeInBytes = $input.getSizeInBytes(); // grow the global buffer before writing data. - $bufferHolder.grow($sizeInBytes); - $input.writeToMemory($bufferHolder.buffer, $bufferHolder.cursor); - $bufferHolder.cursor += $sizeInBytes; + $rowWriter.grow($sizeInBytes); + $input.writeToMemory($rowWriter.getBuffer(), $rowWriter.cursor()); + $rowWriter.increaseCursor($sizeInBytes); """ } @@ -317,38 +318,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => true } - val result = ctx.addMutableState("UnsafeRow", "result", - v => s"$v = new UnsafeRow(${expressions.length});") - - val holderClass = classOf[BufferHolder].getName - val holder = ctx.addMutableState(holderClass, "holder", - v => s"$v = new $holderClass($result, ${numVarLenFields * 32});") - - val resetBufferHolder = if (numVarLenFields == 0) { - "" - } else { - s"$holder.reset();" - } - val updateRowSize = if (numVarLenFields == 0) { - "" - } else { - s"$result.setTotalSize($holder.totalSize());" - } + val rowWriterClass = classOf[UnsafeRowWriter].getName + val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", + v => s"$v = new $rowWriterClass(${expressions.length}, ${numVarLenFields * 32});") // Evaluate all the subexpression. val evalSubexpr = ctx.subexprFunctions.mkString("\n") - val writeExpressions = - writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true) + val writeExpressions = writeExpressionsToBuffer( + ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true) val code = s""" - $resetBufferHolder + $rowWriter.reset(); $evalSubexpr $writeExpressions - $updateRowSize """ - ExprCode(code, "false", result) + ExprCode(code, "false", s"$rowWriter.getRow()") } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java index fb3dbe8ed1996..2da87113c6229 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java @@ -27,7 +27,6 @@ import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; import org.apache.spark.unsafe.types.UTF8String; @@ -55,36 +54,27 @@ private String getRandomString(int length) { } private UnsafeRow makeKeyRow(long k1, String k2) { - UnsafeRow row = new UnsafeRow(2); - BufferHolder holder = new BufferHolder(row, 32); - UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); - holder.reset(); + UnsafeRowWriter writer = new UnsafeRowWriter(2); + writer.reset(); writer.write(0, k1); writer.write(1, UTF8String.fromString(k2)); - row.setTotalSize(holder.totalSize()); - return row; + return writer.getRow(); } private UnsafeRow makeKeyRow(long k1, long k2) { - UnsafeRow row = new UnsafeRow(2); - BufferHolder holder = new BufferHolder(row, 0); - UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); - holder.reset(); + UnsafeRowWriter writer = new UnsafeRowWriter(2); + writer.reset(); writer.write(0, k1); writer.write(1, k2); - row.setTotalSize(holder.totalSize()); - return row; + return writer.getRow(); } private UnsafeRow makeValueRow(long v1, long v2) { - UnsafeRow row = new UnsafeRow(2); - BufferHolder holder = new BufferHolder(row, 0); - UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); - holder.reset(); + UnsafeRowWriter writer = new UnsafeRowWriter(2); + writer.reset(); writer.write(0, v1); writer.write(1, v2); - row.setTotalSize(holder.totalSize()); - return row; + return writer.getRow(); } private UnsafeRow appendRow(RowBasedKeyValueBatch batch, UnsafeRow key, UnsafeRow value) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index 8617be88f3570..d5508275c48c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -165,18 +165,14 @@ class RowBasedHashMapGenerator( | if (buckets[idx] == -1) { | if (numRows < capacity && !isBatchFull) { | // creating the unsafe for new entry - | UnsafeRow agg_result = new UnsafeRow(${groupingKeySchema.length}); - | org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder - | = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, - | ${numVarLenFields * 32}); | org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter | = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter( - | agg_holder, - | ${groupingKeySchema.length}); - | agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed + | ${groupingKeySchema.length}, ${numVarLenFields * 32}); + | agg_rowWriter.reset(); //TODO: investigate if reset or zeroout are actually needed | agg_rowWriter.zeroOutNullBytes(); | ${createUnsafeRowForKey}; - | agg_result.setTotalSize(agg_holder.totalSize()); + | org.apache.spark.sql.catalyst.expressions.UnsafeRow agg_result + | = agg_rowWriter.getRow(); | Object kbase = agg_result.getBaseObject(); | long koff = agg_result.getBaseOffset(); | int klen = agg_result.getSizeInBytes(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 3b5655ba0582e..2d699e8a9d088 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -165,9 +165,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private ByteOrder nativeOrder = null; private byte[][] buffers = null; - private UnsafeRow unsafeRow = new UnsafeRow($numFields); - private BufferHolder bufferHolder = new BufferHolder(unsafeRow); - private UnsafeRowWriter rowWriter = new UnsafeRowWriter(bufferHolder, $numFields); + private UnsafeRowWriter rowWriter = new UnsafeRowWriter($numFields); private MutableUnsafeRow mutableRow = null; private int currentRow = 0; @@ -212,11 +210,10 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera public InternalRow next() { currentRow += 1; - bufferHolder.reset(); + rowWriter.reset(); rowWriter.zeroOutNullBytes(); ${extractorCalls} - unsafeRow.setTotalSize(bufferHolder.totalSize()); - return unsafeRow; + return rowWriter.getRow(); } ${ctx.declareAddedFunctions()} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 9647f09867643..e93908da43535 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -26,7 +26,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ @@ -130,16 +130,13 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { val emptyUnsafeRow = new UnsafeRow(0) reader.map(_ => emptyUnsafeRow) } else { - val unsafeRow = new UnsafeRow(1) - val bufferHolder = new BufferHolder(unsafeRow) - val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) + val unsafeRowWriter = new UnsafeRowWriter(1) reader.map { line => // Writes to an UnsafeRow directly - bufferHolder.reset() + unsafeRowWriter.reset() unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) - unsafeRow.setTotalSize(bufferHolder.totalSize()) - unsafeRow + unsafeRowWriter.getRow() } } } From 28ea4e3142b88eb396aa8dd5daf7b02b556204ba Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 2 Apr 2018 14:35:07 -0700 Subject: [PATCH 549/774] [SPARK-23834][TEST] Wait for connection before disconnect in LauncherServer test. It was possible that the disconnect() was called on the handle before the server had received the handshake messages, so no connection was yet attached to the handle. The fix waits until we're sure the handle has been mapped to a client connection. Author: Marcelo Vanzin Closes #20950 from vanzin/SPARK-23834. --- .../org/apache/spark/launcher/LauncherServerSuite.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 5413d3a416545..f8dc0ec7a0bf6 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -196,6 +196,14 @@ public void testAppHandleDisconnect() throws Exception { Socket s = new Socket(InetAddress.getLoopbackAddress(), server.getPort()); client = new TestClient(s); client.send(new Hello(secret, "1.4.0")); + client.send(new SetAppId("someId")); + + // Wait until we know the server has received the messages and matched the handle to the + // connection before disconnecting. + eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> { + assertEquals("someId", handle.getAppId()); + }); + handle.disconnect(); waitForError(client, secret); } finally { From a1351828d376a01e5ee0959cf608f767d756dd86 Mon Sep 17 00:00:00 2001 From: Yogesh Garg Date: Mon, 2 Apr 2018 16:41:26 -0700 Subject: [PATCH 550/774] [SPARK-23690][ML] Add handleinvalid to VectorAssembler ## What changes were proposed in this pull request? Introduce `handleInvalid` parameter in `VectorAssembler` that can take in `"keep", "skip", "error"` options. "error" throws an error on seeing a row containing a `null`, "skip" filters out all such rows, and "keep" adds relevant number of NaN. "keep" figures out an example to find out what this number of NaN s should be added and throws an error when no such number could be found. ## How was this patch tested? Unit tests are added to check the behavior of `assemble` on specific rows and the transformer is called on `DataFrame`s of different configurations to test different corner cases. Author: Yogesh Garg Author: Bago Amirbekian Author: Yogesh Garg <1059168+yogeshg@users.noreply.github.com> Closes #20829 from yogeshg/rformula_handleinvalid. --- .../spark/ml/feature/StringIndexer.scala | 2 +- .../spark/ml/feature/VectorAssembler.scala | 198 ++++++++++++++---- .../ml/feature/VectorAssemblerSuite.scala | 131 +++++++++++- 3 files changed, 284 insertions(+), 47 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 1cdcdfcaeab78..67cdb097217a2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -234,7 +234,7 @@ class StringIndexerModel ( val metadata = NominalAttribute.defaultAttr .withName($(outputCol)).withValues(filteredLabels).toMetadata() // If we are skipping invalid records, filter them out. - val (filteredDataset, keepInvalid) = getHandleInvalid match { + val (filteredDataset, keepInvalid) = $(handleInvalid) match { case StringIndexer.SKIP_INVALID => val filterer = udf { label: String => labelToIndex.contains(label) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index b373ae921ed38..6bf4aa38b1fcb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -17,14 +17,17 @@ package org.apache.spark.ml.feature -import scala.collection.mutable.ArrayBuilder +import java.util.NoSuchElementException + +import scala.collection.mutable +import scala.language.existentials import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset, Row} @@ -33,10 +36,14 @@ import org.apache.spark.sql.types._ /** * A feature transformer that merges multiple columns into a vector column. + * + * This requires one pass over the entire dataset. In case we need to infer column lengths from the + * data we require an additional call to the 'first' Dataset method, see 'handleInvalid' parameter. */ @Since("1.4.0") class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable { + extends Transformer with HasInputCols with HasOutputCol with HasHandleInvalid + with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("vecAssembler")) @@ -49,32 +56,63 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + @Since("2.4.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + /** + * Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with + * invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the + * output). Column lengths are taken from the size of ML Attribute Group, which can be set using + * `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred + * from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'. + * Default: "error" + * @group param + */ + @Since("2.4.0") + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + """Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with + |invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the + |output). Column lengths are taken from the size of ML Attribute Group, which can be set using + |`VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred + |from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'. + |""".stripMargin.replaceAll("\n", " "), + ParamValidators.inArray(VectorAssembler.supportedHandleInvalids)) + + setDefault(handleInvalid, VectorAssembler.ERROR_INVALID) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) // Schema transformation. val schema = dataset.schema - lazy val first = dataset.toDF.first() - val attrs = $(inputCols).flatMap { c => + + val vectorCols = $(inputCols).filter { c => + schema(c).dataType match { + case _: VectorUDT => true + case _ => false + } + } + val vectorColsLengths = VectorAssembler.getLengths(dataset, vectorCols, $(handleInvalid)) + + val featureAttributesMap = $(inputCols).map { c => val field = schema(c) - val index = schema.fieldIndex(c) field.dataType match { case DoubleType => - val attr = Attribute.fromStructField(field) - // If the input column doesn't have ML attribute, assume numeric. - if (attr == UnresolvedAttribute) { - Some(NumericAttribute.defaultAttr.withName(c)) - } else { - Some(attr.withName(c)) + val attribute = Attribute.fromStructField(field) + attribute match { + case UnresolvedAttribute => + Seq(NumericAttribute.defaultAttr.withName(c)) + case _ => + Seq(attribute.withName(c)) } case _: NumericType | BooleanType => // If the input column type is a compatible scalar type, assume numeric. - Some(NumericAttribute.defaultAttr.withName(c)) + Seq(NumericAttribute.defaultAttr.withName(c)) case _: VectorUDT => - val group = AttributeGroup.fromStructField(field) - if (group.attributes.isDefined) { - // If attributes are defined, copy them with updated names. - group.attributes.get.zipWithIndex.map { case (attr, i) => + val attributeGroup = AttributeGroup.fromStructField(field) + if (attributeGroup.attributes.isDefined) { + attributeGroup.attributes.get.zipWithIndex.toSeq.map { case (attr, i) => if (attr.name.isDefined) { // TODO: Define a rigorous naming scheme. attr.withName(c + "_" + attr.name.get) @@ -85,18 +123,25 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) } else { // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes // from metadata, check the first row. - val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) - Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i)) + (0 until vectorColsLengths(c)).map { i => + NumericAttribute.defaultAttr.withName(c + "_" + i) + } } case otherType => throw new SparkException(s"VectorAssembler does not support the $otherType type") } } - val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() - + val featureAttributes = featureAttributesMap.flatten[Attribute].toArray + val lengths = featureAttributesMap.map(a => a.length).toArray + val metadata = new AttributeGroup($(outputCol), featureAttributes).toMetadata() + val (filteredDataset, keepInvalid) = $(handleInvalid) match { + case VectorAssembler.SKIP_INVALID => (dataset.na.drop($(inputCols)), false) + case VectorAssembler.KEEP_INVALID => (dataset, true) + case VectorAssembler.ERROR_INVALID => (dataset, false) + } // Data transformation. val assembleFunc = udf { r: Row => - VectorAssembler.assemble(r.toSeq: _*) + VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*) }.asNondeterministic() val args = $(inputCols).map { c => schema(c).dataType match { @@ -106,7 +151,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) } } - dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) + filteredDataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) } @Since("1.4.0") @@ -136,34 +181,117 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("1.6.0") object VectorAssembler extends DefaultParamsReadable[VectorAssembler] { + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val supportedHandleInvalids: Array[String] = + Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) + + /** + * Infers lengths of vector columns from the first row of the dataset + * @param dataset the dataset + * @param columns name of vector columns whose lengths need to be inferred + * @return map of column names to lengths + */ + private[feature] def getVectorLengthsFromFirstRow( + dataset: Dataset[_], + columns: Seq[String]): Map[String, Int] = { + try { + val first_row = dataset.toDF().select(columns.map(col): _*).first() + columns.zip(first_row.toSeq).map { + case (c, x) => c -> x.asInstanceOf[Vector].size + }.toMap + } catch { + case e: NullPointerException => throw new NullPointerException( + s"""Encountered null value while inferring lengths from the first row. Consider using + |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """ + .stripMargin.replaceAll("\n", " ") + e.toString) + case e: NoSuchElementException => throw new NoSuchElementException( + s"""Encountered empty dataframe while inferring lengths from the first row. Consider using + |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """ + .stripMargin.replaceAll("\n", " ") + e.toString) + } + } + + private[feature] def getLengths( + dataset: Dataset[_], + columns: Seq[String], + handleInvalid: String): Map[String, Int] = { + val groupSizes = columns.map { c => + c -> AttributeGroup.fromStructField(dataset.schema(c)).size + }.toMap + val missingColumns = groupSizes.filter(_._2 == -1).keys.toSeq + val firstSizes = (missingColumns.nonEmpty, handleInvalid) match { + case (true, VectorAssembler.ERROR_INVALID) => + getVectorLengthsFromFirstRow(dataset, missingColumns) + case (true, VectorAssembler.SKIP_INVALID) => + getVectorLengthsFromFirstRow(dataset.na.drop(missingColumns), missingColumns) + case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException( + s"""Can not infer column lengths with handleInvalid = "keep". Consider using VectorSizeHint + |to add metadata for columns: ${columns.mkString("[", ", ", "]")}.""" + .stripMargin.replaceAll("\n", " ")) + case (_, _) => Map.empty + } + groupSizes ++ firstSizes + } + + @Since("1.6.0") override def load(path: String): VectorAssembler = super.load(path) - private[feature] def assemble(vv: Any*): Vector = { - val indices = ArrayBuilder.make[Int] - val values = ArrayBuilder.make[Double] - var cur = 0 + /** + * Returns a function that has the required information to assemble each row. + * @param lengths an array of lengths of input columns, whose size should be equal to the number + * of cells in the row (vv) + * @param keepInvalid indicate whether to throw an error or not on seeing a null in the rows + * @return a udf that can be applied on each row + */ + private[feature] def assemble(lengths: Array[Int], keepInvalid: Boolean)(vv: Any*): Vector = { + val indices = mutable.ArrayBuilder.make[Int] + val values = mutable.ArrayBuilder.make[Double] + var featureIndex = 0 + + var inputColumnIndex = 0 vv.foreach { case v: Double => - if (v != 0.0) { - indices += cur + if (v.isNaN && !keepInvalid) { + throw new SparkException( + s"""Encountered NaN while assembling a row with handleInvalid = "error". Consider + |removing NaNs from dataset or using handleInvalid = "keep" or "skip".""" + .stripMargin) + } else if (v != 0.0) { + indices += featureIndex values += v } - cur += 1 + inputColumnIndex += 1 + featureIndex += 1 case vec: Vector => vec.foreachActive { case (i, v) => if (v != 0.0) { - indices += cur + i + indices += featureIndex + i values += v } } - cur += vec.size + inputColumnIndex += 1 + featureIndex += vec.size case null => - // TODO: output Double.NaN? - throw new SparkException("Values to assemble cannot be null.") + if (keepInvalid) { + val length: Int = lengths(inputColumnIndex) + Array.range(0, length).foreach { i => + indices += featureIndex + i + values += Double.NaN + } + inputColumnIndex += 1 + featureIndex += length + } else { + throw new SparkException( + s"""Encountered null while assembling a row with handleInvalid = "keep". Consider + |removing nulls from dataset or using handleInvalid = "keep" or "skip".""" + .stripMargin) + } case o => throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") } - Vectors.sparse(cur, indices.result(), values.result()).compressed + Vectors.sparse(featureIndex, indices.result(), values.result()).compressed } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index eca065f7e775d..91fb24a268b8c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.ml.feature import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.{col, udf} class VectorAssemblerSuite @@ -31,30 +31,49 @@ class VectorAssemblerSuite import testImplicits._ + @transient var dfWithNullsAndNaNs: Dataset[_] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val sv = Vectors.sparse(2, Array(1), Array(3.0)) + dfWithNullsAndNaNs = Seq[(Long, Long, java.lang.Double, Vector, String, Vector, Long, String)]( + (1, 2, 0.0, Vectors.dense(1.0, 2.0), "a", sv, 7L, null), + (2, 1, 0.0, null, "a", sv, 6L, null), + (3, 3, null, Vectors.dense(1.0, 2.0), "a", sv, 8L, null), + (4, 4, null, null, "a", sv, 9L, null), + (5, 5, java.lang.Double.NaN, Vectors.dense(1.0, 2.0), "a", sv, 7L, null), + (6, 6, java.lang.Double.NaN, null, "a", sv, 8L, null)) + .toDF("id1", "id2", "x", "y", "name", "z", "n", "nulls") + } + test("params") { ParamsSuite.checkParams(new VectorAssembler) } test("assemble") { import org.apache.spark.ml.feature.VectorAssembler.assemble - assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty)) - assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0))) + assert(assemble(Array(1), keepInvalid = true)(0.0) + === Vectors.sparse(1, Array.empty, Array.empty)) + assert(assemble(Array(1, 1), keepInvalid = true)(0.0, 1.0) + === Vectors.sparse(2, Array(1), Array(1.0))) val dv = Vectors.dense(2.0, 0.0) - assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0))) + assert(assemble(Array(1, 2, 1), keepInvalid = true)(0.0, dv, 1.0) === + Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0))) val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0)) - assert(assemble(0.0, dv, 1.0, sv) === + assert(assemble(Array(1, 2, 1, 2), keepInvalid = true)(0.0, dv, 1.0, sv) === Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0))) - for (v <- Seq(1, "a", null)) { - intercept[SparkException](assemble(v)) - intercept[SparkException](assemble(1.0, v)) + for (v <- Seq(1, "a")) { + intercept[SparkException](assemble(Array(1), keepInvalid = true)(v)) + intercept[SparkException](assemble(Array(1, 1), keepInvalid = true)(1.0, v)) } } test("assemble should compress vectors") { import org.apache.spark.ml.feature.VectorAssembler.assemble - val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0)) + val v1 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(0.0, 0.0, 0.0, Vectors.dense(4.0)) assert(v1.isInstanceOf[SparseVector]) - val v2 = assemble(1.0, 2.0, 3.0, Vectors.sparse(1, Array(0), Array(4.0))) + val sv = Vectors.sparse(1, Array(0), Array(4.0)) + val v2 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(1.0, 2.0, 3.0, sv) assert(v2.isInstanceOf[DenseVector]) } @@ -147,4 +166,94 @@ class VectorAssemblerSuite .filter(vectorUDF($"features") > 1) .count() == 1) } + + test("assemble should keep nulls when keepInvalid is true") { + import org.apache.spark.ml.feature.VectorAssembler.assemble + assert(assemble(Array(1, 1), keepInvalid = true)(1.0, null) === Vectors.dense(1.0, Double.NaN)) + assert(assemble(Array(1, 2), keepInvalid = true)(1.0, null) + === Vectors.dense(1.0, Double.NaN, Double.NaN)) + assert(assemble(Array(1), keepInvalid = true)(null) === Vectors.dense(Double.NaN)) + assert(assemble(Array(2), keepInvalid = true)(null) === Vectors.dense(Double.NaN, Double.NaN)) + } + + test("assemble should throw errors when keepInvalid is false") { + import org.apache.spark.ml.feature.VectorAssembler.assemble + intercept[SparkException](assemble(Array(1, 1), keepInvalid = false)(1.0, null)) + intercept[SparkException](assemble(Array(1, 2), keepInvalid = false)(1.0, null)) + intercept[SparkException](assemble(Array(1), keepInvalid = false)(null)) + intercept[SparkException](assemble(Array(2), keepInvalid = false)(null)) + } + + test("get lengths functions") { + import org.apache.spark.ml.feature.VectorAssembler._ + val df = dfWithNullsAndNaNs + assert(getVectorLengthsFromFirstRow(df, Seq("y")) === Map("y" -> 2)) + assert(intercept[NullPointerException](getVectorLengthsFromFirstRow(df.sort("id2"), Seq("y"))) + .getMessage.contains("VectorSizeHint")) + assert(intercept[NoSuchElementException](getVectorLengthsFromFirstRow(df.filter("id1 > 6"), + Seq("y"))).getMessage.contains("VectorSizeHint")) + + assert(getLengths(df.sort("id2"), Seq("y"), SKIP_INVALID).exists(_ == "y" -> 2)) + assert(intercept[NullPointerException](getLengths(df.sort("id2"), Seq("y"), ERROR_INVALID)) + .getMessage.contains("VectorSizeHint")) + assert(intercept[RuntimeException](getLengths(df.sort("id2"), Seq("y"), KEEP_INVALID)) + .getMessage.contains("VectorSizeHint")) + } + + test("Handle Invalid should behave properly") { + val assembler = new VectorAssembler() + .setInputCols(Array("x", "y", "z", "n")) + .setOutputCol("features") + + def runWithMetadata(mode: String, additional_filter: String = "true"): Dataset[_] = { + val attributeY = new AttributeGroup("y", 2) + val attributeZ = new AttributeGroup( + "z", + Array[Attribute]( + NumericAttribute.defaultAttr.withName("foo"), + NumericAttribute.defaultAttr.withName("bar"))) + val dfWithMetadata = dfWithNullsAndNaNs.withColumn("y", col("y"), attributeY.toMetadata()) + .withColumn("z", col("z"), attributeZ.toMetadata()).filter(additional_filter) + val output = assembler.setHandleInvalid(mode).transform(dfWithMetadata) + output.collect() + output + } + + def runWithFirstRow(mode: String): Dataset[_] = { + val output = assembler.setHandleInvalid(mode).transform(dfWithNullsAndNaNs) + output.collect() + output + } + + def runWithAllNullVectors(mode: String): Dataset[_] = { + val output = assembler.setHandleInvalid(mode) + .transform(dfWithNullsAndNaNs.filter("0 == id1 % 2")) + output.collect() + output + } + + // behavior when vector size hint is given + assert(runWithMetadata("keep").count() == 6, "should keep all rows") + assert(runWithMetadata("skip").count() == 1, "should skip rows with nulls") + // should throw error with nulls + intercept[SparkException](runWithMetadata("error")) + // should throw error with NaNs + intercept[SparkException](runWithMetadata("error", additional_filter = "id1 > 4")) + + // behavior when first row has information + assert(intercept[RuntimeException](runWithFirstRow("keep").count()) + .getMessage.contains("VectorSizeHint"), "should suggest to use metadata") + assert(runWithFirstRow("skip").count() == 1, "should infer size and skip rows with nulls") + intercept[SparkException](runWithFirstRow("error")) + + // behavior when vector column is all null + assert(intercept[RuntimeException](runWithAllNullVectors("skip")) + .getMessage.contains("VectorSizeHint"), "should suggest to use metadata") + assert(intercept[NullPointerException](runWithAllNullVectors("error")) + .getMessage.contains("VectorSizeHint"), "should suggest to use metadata") + + // behavior when scalar column is all null + assert(runWithMetadata("keep", additional_filter = "id1 > 2").count() == 4) + } + } From 441d0d0766e9a6ac4c6ff79680394999ff7191fd Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 3 Apr 2018 09:31:47 +0800 Subject: [PATCH 551/774] [SPARK-19964][CORE] Avoid reading from remote repos in SparkSubmitSuite. These tests can fail with a timeout if the remote repos are not responding, or slow. The tests don't need anything from those repos, so use an empty ivy config file to avoid setting up the defaults. The tests are passing reliably for me locally now, and failing more often than not today without this change since http://dl.bintray.com/spark-packages/maven doesn't seem to be loading from my machine. Author: Marcelo Vanzin Closes #20916 from vanzin/SPARK-19964. --- .../org/apache/spark/deploy/DependencyUtils.scala | 13 ++++++++----- .../scala/org/apache/spark/deploy/SparkSubmit.scala | 3 ++- .../apache/spark/deploy/SparkSubmitArguments.scala | 2 ++ .../apache/spark/deploy/worker/DriverWrapper.scala | 13 +++++++++---- .../org/apache/spark/deploy/SparkSubmitSuite.scala | 9 ++++++--- 5 files changed, 27 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala index ab319c860ee69..fac834a70b893 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -33,7 +33,8 @@ private[deploy] object DependencyUtils { packagesExclusions: String, packages: String, repositories: String, - ivyRepoPath: String): String = { + ivyRepoPath: String, + ivySettingsPath: Option[String]): String = { val exclusions: Seq[String] = if (!StringUtils.isBlank(packagesExclusions)) { packagesExclusions.split(",") @@ -41,10 +42,12 @@ private[deploy] object DependencyUtils { Nil } // Create the IvySettings, either load from file or build defaults - val ivySettings = sys.props.get("spark.jars.ivySettings").map { ivySettingsFile => - SparkSubmitUtils.loadIvySettings(ivySettingsFile, Option(repositories), Option(ivyRepoPath)) - }.getOrElse { - SparkSubmitUtils.buildIvySettings(Option(repositories), Option(ivyRepoPath)) + val ivySettings = ivySettingsPath match { + case Some(path) => + SparkSubmitUtils.loadIvySettings(path, Option(repositories), Option(ivyRepoPath)) + + case None => + SparkSubmitUtils.buildIvySettings(Option(repositories), Option(ivyRepoPath)) } SparkSubmitUtils.resolveMavenCoordinates(packages, ivySettings, exclusions = exclusions) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 3965f17f4b56e..eddbedeb1024d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -359,7 +359,8 @@ object SparkSubmit extends CommandLineUtils with Logging { // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files // too for packages that include Python code val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies( - args.packagesExclusions, args.packages, args.repositories, args.ivyRepoPath) + args.packagesExclusions, args.packages, args.repositories, args.ivyRepoPath, + args.ivySettingsPath) if (!StringUtils.isBlank(resolvedMavenCoordinates)) { args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index e7796d4ddbe34..8e7070593687b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -63,6 +63,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var packages: String = null var repositories: String = null var ivyRepoPath: String = null + var ivySettingsPath: Option[String] = None var packagesExclusions: String = null var verbose: Boolean = false var isPython: Boolean = false @@ -184,6 +185,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull files = Option(files).orElse(sparkProperties.get("spark.files")).orNull ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull + ivySettingsPath = sparkProperties.get("spark.jars.ivySettings") packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull packagesExclusions = Option(packagesExclusions) .orElse(sparkProperties.get("spark.jars.excludes")).orNull diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index b19c9904d5982..3f71237164a15 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -79,12 +79,17 @@ object DriverWrapper extends Logging { val secMgr = new SecurityManager(sparkConf) val hadoopConf = SparkHadoopUtil.newConfiguration(sparkConf) - val Seq(packagesExclusions, packages, repositories, ivyRepoPath) = - Seq("spark.jars.excludes", "spark.jars.packages", "spark.jars.repositories", "spark.jars.ivy") - .map(sys.props.get(_).orNull) + val Seq(packagesExclusions, packages, repositories, ivyRepoPath, ivySettingsPath) = + Seq( + "spark.jars.excludes", + "spark.jars.packages", + "spark.jars.repositories", + "spark.jars.ivy", + "spark.jars.ivySettings" + ).map(sys.props.get(_).orNull) val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies(packagesExclusions, - packages, repositories, ivyRepoPath) + packages, repositories, ivyRepoPath, Option(ivySettingsPath)) val jars = { val jarsProp = sys.props.get("spark.jars").orNull if (!StringUtils.isBlank(resolvedMavenCoordinates)) { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index d86ef907b4492..0d7c342a5eacd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -106,6 +106,9 @@ class SparkSubmitSuite // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x implicit val defaultSignaler: Signaler = ThreadSignaler + private val emptyIvySettings = File.createTempFile("ivy", ".xml") + FileUtils.write(emptyIvySettings, "", StandardCharsets.UTF_8) + override def beforeEach() { super.beforeEach() } @@ -520,6 +523,7 @@ class SparkSubmitSuite "--repositories", repo, "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", + "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}", unusedJar.toString, "my.great.lib.MyLib", "my.great.dep.MyLib") runSparkSubmit(args) @@ -530,7 +534,6 @@ class SparkSubmitSuite val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val dep = MavenCoordinate("my.great.dep", "mylib", "0.1") - // Test using "spark.jars.packages" and "spark.jars.repositories" configurations. IvyTestUtils.withRepository(main, Some(dep.toString), None) { repo => val args = Seq( "--class", JarCreationTest.getClass.getName.stripSuffix("$"), @@ -540,6 +543,7 @@ class SparkSubmitSuite "--conf", s"spark.jars.repositories=$repo", "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", + "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}", unusedJar.toString, "my.great.lib.MyLib", "my.great.dep.MyLib") runSparkSubmit(args) @@ -550,7 +554,6 @@ class SparkSubmitSuite // See https://gist.github.com/shivaram/3a2fecce60768a603dac for a error log ignore("correctly builds R packages included in a jar with --packages") { assume(RUtils.isRInstalled, "R isn't installed on this machine.") - // Check if the SparkR package is installed assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.") val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) @@ -563,6 +566,7 @@ class SparkSubmitSuite "--master", "local-cluster[2,1,1024]", "--packages", main.toString, "--repositories", repo, + "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}", "--verbose", "--conf", "spark.ui.enabled=false", rScriptDir) @@ -573,7 +577,6 @@ class SparkSubmitSuite test("include an external JAR in SparkR") { assume(RUtils.isRInstalled, "R isn't installed on this machine.") val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - // Check if the SparkR package is installed assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.") val rScriptDir = Seq(sparkHome, "R", "pkg", "tests", "fulltests", "jarTest.R").mkString(File.separator) From 8020f66fc47140a1b5f843fb18c34ec80541d5ca Mon Sep 17 00:00:00 2001 From: lemonjing <932191671@qq.com> Date: Tue, 3 Apr 2018 09:36:44 +0800 Subject: [PATCH 552/774] [MINOR][DOC] Fix a few markdown typos ## What changes were proposed in this pull request? Easy fix in the markdown. ## How was this patch tested? jekyII build test manually. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: lemonjing <932191671@qq.com> Closes #20897 from Lemonjing/master. --- docs/ml-guide.md | 2 +- docs/mllib-feature-extraction.md | 4 ++-- docs/mllib-pmml-model-export.md | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 702bcf748fc74..aea07be34cb86 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -111,7 +111,7 @@ and the migration guide below will explain all changes between releases. * The class and trait hierarchy for logistic regression model summaries was changed to be cleaner and better accommodate the addition of the multi-class summary. This is a breaking change for user code that casts a `LogisticRegressionTrainingSummary` to a -` BinaryLogisticRegressionTrainingSummary`. Users should instead use the `model.binarySummary` +`BinaryLogisticRegressionTrainingSummary`. Users should instead use the `model.binarySummary` method. See [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139) for more detail (_note_ this is an `Experimental` API). This _does not_ affect the Python `summary` method, which will still work correctly for both multinomial and binary cases. diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 75aea70601875..8b89296b14cdd 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -278,8 +278,8 @@ for details on the API. multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) between the input vector, `v` and transforming vector, `scalingVec`, to yield a result vector. -Qu8T948*1# -Denoting the `scalingVec` as "`w`," this transformation may be written as: + +Denoting the `scalingVec` as "`w`", this transformation may be written as: `\[ \begin{pmatrix} v_1 \\ diff --git a/docs/mllib-pmml-model-export.md b/docs/mllib-pmml-model-export.md index d3530908706d0..f567565437927 100644 --- a/docs/mllib-pmml-model-export.md +++ b/docs/mllib-pmml-model-export.md @@ -7,7 +7,7 @@ displayTitle: PMML model export - RDD-based API * Table of contents {:toc} -## `spark.mllib` supported models +## spark.mllib supported models `spark.mllib` supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)). @@ -15,7 +15,7 @@ The table below outlines the `spark.mllib` models that can be exported to PMML a - + From 7cf9fab33457ccc9b2d548f15dd5700d5e8d08ef Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 3 Apr 2018 21:26:49 +0800 Subject: [PATCH 553/774] [MINOR][CORE] Show block manager id when remove RDD/Broadcast fails. ## What changes were proposed in this pull request? Address https://github.com/apache/spark/pull/20924#discussion_r177987175, show block manager id when remove RDD/Broadcast fails. ## How was this patch tested? N/A Author: Xingbo Jiang Closes #20960 from jiangxb1987/bmid. --- .../apache/spark/storage/BlockManagerMasterEndpoint.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 56b95c31eb4c3..8e8f7d197c9ef 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -164,7 +164,8 @@ class BlockManagerMasterEndpoint( val futures = blockManagerInfo.values.map { bm => bm.slaveEndpoint.ask[Int](removeMsg).recover { case e: IOException => - logWarning(s"Error trying to remove RDD $rddId", e) + logWarning(s"Error trying to remove RDD $rddId from block manager ${bm.blockManagerId}", + e) 0 // zero blocks were removed } }.toSeq @@ -195,7 +196,8 @@ class BlockManagerMasterEndpoint( val futures = requiredBlockManagers.map { bm => bm.slaveEndpoint.ask[Int](removeMsg).recover { case e: IOException => - logWarning(s"Error trying to remove broadcast $broadcastId", e) + logWarning(s"Error trying to remove broadcast $broadcastId from block manager " + + s"${bm.blockManagerId}", e) 0 // zero blocks were removed } }.toSeq From 66a3a5a2dc83e03dedcee9839415c1ddc1fb8125 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 3 Apr 2018 11:05:29 -0700 Subject: [PATCH 554/774] [SPARK-23099][SS] Migrate foreach sink to DataSourceV2 ## What changes were proposed in this pull request? Migrate foreach sink to DataSourceV2. Since the previous attempt at this PR #20552, we've changed and strictly defined the lifecycle of writer components. This means we no longer need the complicated lifecycle shim from that PR; it just naturally works. ## How was this patch tested? existing tests Author: Jose Torres Closes #20951 from jose-torres/foreach. --- .../sql/execution/streaming/ForeachSink.scala | 68 ----------- .../sources/ForeachWriterProvider.scala | 111 ++++++++++++++++++ .../sql/streaming/DataStreamWriter.scala | 4 +- .../ForeachWriterSuite.scala} | 83 ++++++------- .../sql/streaming/StreamingQuerySuite.scala | 1 + 5 files changed, 156 insertions(+), 111 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala rename sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/{ForeachSinkSuite.scala => sources/ForeachWriterSuite.scala} (77%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala deleted file mode 100644 index 2cc54107f8b83..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import org.apache.spark.TaskContext -import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter} -import org.apache.spark.sql.catalyst.encoders.encoderFor - -/** - * A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by - * [[ForeachWriter]]. - * - * @param writer The [[ForeachWriter]] to process all data. - * @tparam T The expected type of the sink. - */ -class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable { - - override def addBatch(batchId: Long, data: DataFrame): Unit = { - // This logic should've been as simple as: - // ``` - // data.as[T].foreachPartition { iter => ... } - // ``` - // - // Unfortunately, doing that would just break the incremental planing. The reason is, - // `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` will - // create a new plan. Because StreamExecution uses the existing plan to collect metrics and - // update watermark, we should never create a new plan. Otherwise, metrics and watermark are - // updated in the new plan, and StreamExecution cannot retrieval them. - // - // Hence, we need to manually convert internal rows to objects using encoder. - val encoder = encoderFor[T].resolveAndBind( - data.logicalPlan.output, - data.sparkSession.sessionState.analyzer) - data.queryExecution.toRdd.foreachPartition { iter => - if (writer.open(TaskContext.getPartitionId(), batchId)) { - try { - while (iter.hasNext) { - writer.process(encoder.fromRow(iter.next())) - } - } catch { - case e: Throwable => - writer.close(e) - throw e - } - writer.close(null) - } else { - writer.close(null) - } - } - } - - override def toString(): String = "ForeachSink" -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala new file mode 100644 index 0000000000000..df5d69d57e36f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType + +/** + * A [[org.apache.spark.sql.sources.v2.DataSourceV2]] for forwarding data into the specified + * [[ForeachWriter]]. + * + * @param writer The [[ForeachWriter]] to process all data. + * @tparam T The expected type of the sink. + */ +case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends StreamWriteSupport { + override def createStreamWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceOptions): StreamWriter = { + new StreamWriter with SupportsWriteInternalRow { + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + + override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { + val encoder = encoderFor[T].resolveAndBind( + schema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + ForeachWriterFactory(writer, encoder) + } + + override def toString: String = "ForeachSink" + } + } +} + +case class ForeachWriterFactory[T: Encoder]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T]) + extends DataWriterFactory[InternalRow] { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): ForeachDataWriter[T] = { + new ForeachDataWriter(writer, encoder, partitionId, epochId) + } +} + +/** + * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]]. + * @param writer The [[ForeachWriter]] to process all data. + * @param encoder An encoder which can convert [[InternalRow]] to the required type [[T]] + * @param partitionId + * @param epochId + * @tparam T The type expected by the writer. + */ +class ForeachDataWriter[T : Encoder]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T], + partitionId: Int, + epochId: Long) + extends DataWriter[InternalRow] { + + // If open returns false, we should skip writing rows. + private val opened = writer.open(partitionId, epochId) + + override def write(record: InternalRow): Unit = { + if (!opened) return + + try { + writer.process(encoder.fromRow(record)) + } catch { + case t: Throwable => + writer.close(t) + throw t + } + } + + override def commit(): WriterCommitMessage = { + writer.close(null) + ForeachWriterCommitMessage + } + + override def abort(): Unit = {} +} + +/** + * An empty [[WriterCommitMessage]]. [[ForeachWriter]] implementations have no global coordination. + */ +case object ForeachWriterCommitMessage extends WriterCommitMessage diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 2fc903168cfa0..effc1471e8e12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger -import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} +import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2} import org.apache.spark.sql.sources.v2.StreamWriteSupport /** @@ -269,7 +269,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = new ForeachSink[T](foreachWriter)(ds.exprEnc) + val sink = new ForeachWriterProvider[T](foreachWriter)(ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala similarity index 77% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala index b249dd41a84a6..03bf71b3f4b78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming +package org.apache.spark.sql.execution.streaming.sources import java.util.concurrent.ConcurrentLinkedQueue @@ -25,11 +25,12 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext -class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { +class ForeachWriterSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { import testImplicits._ @@ -47,9 +48,9 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf .start() def verifyOutput(expectedVersion: Int, expectedData: Seq[Int]): Unit = { - import ForeachSinkSuite._ + import ForeachWriterSuite._ - val events = ForeachSinkSuite.allEvents() + val events = ForeachWriterSuite.allEvents() assert(events.size === 2) // one seq of events for each of the 2 partitions // Verify both seq of events have an Open event as the first event @@ -64,13 +65,13 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf } // -- batch 0 --------------------------------------- - ForeachSinkSuite.clear() + ForeachWriterSuite.clear() input.addData(1, 2, 3, 4) query.processAllAvailable() verifyOutput(expectedVersion = 0, expectedData = 1 to 4) // -- batch 1 --------------------------------------- - ForeachSinkSuite.clear() + ForeachWriterSuite.clear() input.addData(5, 6, 7, 8) query.processAllAvailable() verifyOutput(expectedVersion = 1, expectedData = 5 to 8) @@ -95,27 +96,27 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf input.addData(1, 2, 3, 4) query.processAllAvailable() - var allEvents = ForeachSinkSuite.allEvents() + var allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) var expectedEvents = Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Process(value = 4), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 0), + ForeachWriterSuite.Process(value = 4), + ForeachWriterSuite.Close(None) ) assert(allEvents === Seq(expectedEvents)) - ForeachSinkSuite.clear() + ForeachWriterSuite.clear() // -- batch 1 --------------------------------------- input.addData(5, 6, 7, 8) query.processAllAvailable() - allEvents = ForeachSinkSuite.allEvents() + allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) expectedEvents = Seq( - ForeachSinkSuite.Open(partition = 0, version = 1), - ForeachSinkSuite.Process(value = 8), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 1), + ForeachWriterSuite.Process(value = 8), + ForeachWriterSuite.Close(None) ) assert(allEvents === Seq(expectedEvents)) @@ -131,7 +132,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf .foreach(new TestForeachWriter() { override def process(value: Int): Unit = { super.process(value) - throw new RuntimeException("error") + throw new RuntimeException("ForeachSinkSuite error") } }).start() input.addData(1, 2, 3, 4) @@ -141,18 +142,18 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf query.processAllAvailable() } assert(e.getCause.isInstanceOf[SparkException]) - assert(e.getCause.getCause.getMessage === "error") + assert(e.getCause.getCause.getCause.getMessage === "ForeachSinkSuite error") assert(query.isActive === false) - val allEvents = ForeachSinkSuite.allEvents() + val allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) - assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0)) - assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1)) + assert(allEvents(0)(0) === ForeachWriterSuite.Open(partition = 0, version = 0)) + assert(allEvents(0)(1) === ForeachWriterSuite.Process(value = 1)) // `close` should be called with the error - val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close] + val errorEvent = allEvents(0)(2).asInstanceOf[ForeachWriterSuite.Close] assert(errorEvent.error.get.isInstanceOf[RuntimeException]) - assert(errorEvent.error.get.getMessage === "error") + assert(errorEvent.error.get.getMessage === "ForeachSinkSuite error") } } @@ -177,12 +178,12 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf inputData.addData(10, 11, 12) query.processAllAvailable() - val allEvents = ForeachSinkSuite.allEvents() + val allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) val expectedEvents = Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Process(value = 3), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 0), + ForeachWriterSuite.Process(value = 3), + ForeachWriterSuite.Close(None) ) assert(allEvents === Seq(expectedEvents)) } finally { @@ -216,21 +217,21 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf query.processAllAvailable() // There should be 3 batches and only does the last batch contain a value. - val allEvents = ForeachSinkSuite.allEvents() + val allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 3) val expectedEvents = Seq( Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 0), + ForeachWriterSuite.Close(None) ), Seq( - ForeachSinkSuite.Open(partition = 0, version = 1), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 1), + ForeachWriterSuite.Close(None) ), Seq( - ForeachSinkSuite.Open(partition = 0, version = 2), - ForeachSinkSuite.Process(value = 3), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 2), + ForeachWriterSuite.Process(value = 3), + ForeachWriterSuite.Close(None) ) ) assert(allEvents === expectedEvents) @@ -258,7 +259,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf } /** A global object to collect events in the executor */ -object ForeachSinkSuite { +object ForeachWriterSuite { trait Event @@ -285,21 +286,21 @@ object ForeachSinkSuite { /** A [[ForeachWriter]] that writes collected events to ForeachSinkSuite */ class TestForeachWriter extends ForeachWriter[Int] { - ForeachSinkSuite.clear() + ForeachWriterSuite.clear() - private val events = mutable.ArrayBuffer[ForeachSinkSuite.Event]() + private val events = mutable.ArrayBuffer[ForeachWriterSuite.Event]() override def open(partitionId: Long, version: Long): Boolean = { - events += ForeachSinkSuite.Open(partition = partitionId, version = version) + events += ForeachWriterSuite.Open(partition = partitionId, version = version) true } override def process(value: Int): Unit = { - events += ForeachSinkSuite.Process(value) + events += ForeachWriterSuite.Process(value) } override def close(errorOrNull: Throwable): Unit = { - events += ForeachSinkSuite.Close(error = Option(errorOrNull)) - ForeachSinkSuite.addEvents(events) + events += ForeachWriterSuite.Close(error = Option(errorOrNull)) + ForeachWriterSuite.addEvents(events) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 08749b49997e0..20942ed93897c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.reader.DataReaderFactory From 1035aaa61704b2790192d3186fe37e678553d36d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 4 Apr 2018 01:36:58 +0200 Subject: [PATCH 555/774] [SPARK-23587][SQL] Add interpreted execution for MapObjects expression ## What changes were proposed in this pull request? Add interpreted execution for `MapObjects` expression. ## How was this patch tested? Added unit test. Author: Liang-Chi Hsieh Closes #20771 from viirya/SPARK-23587. --- .../expressions/objects/objects.scala | 110 ++++++++++++++++-- .../expressions/ObjectExpressionsSuite.scala | 67 ++++++++++- 2 files changed, 165 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index adf9ddf327c96..0e9d357c19c63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects import java.lang.reflect.Modifier +import scala.collection.JavaConverters._ import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag @@ -501,12 +502,22 @@ case class LambdaVariable( value: String, isNull: String, dataType: DataType, - nullable: Boolean = true) extends LeafExpression - with Unevaluable with NonSQLExpression { + nullable: Boolean = true) extends LeafExpression with NonSQLExpression { + + // Interpreted execution of `LambdaVariable` always get the 0-index element from input row. + override def eval(input: InternalRow): Any = { + assert(input.numFields == 1, + "The input row of interpreted LambdaVariable should have only 1 field.") + input.get(0, dataType) + } override def genCode(ctx: CodegenContext): ExprCode = { ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false") } + + // This won't be called as `genCode` is overrided, just overriding it to make + // `LambdaVariable` non-abstract. + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev } /** @@ -599,8 +610,92 @@ case class MapObjects private( override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + // The data with UserDefinedType are actually stored with the data type of its sqlType. + // When we want to apply MapObjects on it, we have to use it. + lazy private val inputDataType = inputData.dataType match { + case u: UserDefinedType[_] => u.sqlType + case _ => inputData.dataType + } + + private def executeFuncOnCollection(inputCollection: Seq[_]): Iterator[_] = { + val row = new GenericInternalRow(1) + inputCollection.toIterator.map { element => + row.update(0, element) + lambdaFunction.eval(row) + } + } + + private lazy val convertToSeq: Any => Seq[_] = inputDataType match { + case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + _.asInstanceOf[Seq[_]] + case ObjectType(cls) if cls.isArray => + _.asInstanceOf[Array[_]].toSeq + case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + _.asInstanceOf[java.util.List[_]].asScala + case ObjectType(cls) if cls == classOf[Object] => + (inputCollection) => { + if (inputCollection.getClass.isArray) { + inputCollection.asInstanceOf[Array[_]].toSeq + } else { + inputCollection.asInstanceOf[Seq[_]] + } + } + case ArrayType(et, _) => + _.asInstanceOf[ArrayData].array + } + + private lazy val mapElements: Seq[_] => Any = customCollectionCls match { + case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + // Scala sequence + executeFuncOnCollection(_).toSeq + case Some(cls) if classOf[scala.collection.Set[_]].isAssignableFrom(cls) => + // Scala set + executeFuncOnCollection(_).toSet + case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + // Java list + if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] || + cls == classOf[java.util.AbstractSequentialList[_]]) { + // Specifying non concrete implementations of `java.util.List` + executeFuncOnCollection(_).toSeq.asJava + } else { + val constructors = cls.getConstructors() + val intParamConstructor = constructors.find { constructor => + constructor.getParameterCount == 1 && constructor.getParameterTypes()(0) == classOf[Int] + } + val noParamConstructor = constructors.find { constructor => + constructor.getParameterCount == 0 + } + + val constructor = intParamConstructor.map { intConstructor => + (len: Int) => intConstructor.newInstance(len.asInstanceOf[Object]) + }.getOrElse { + (_: Int) => noParamConstructor.get.newInstance() + } + + // Specifying concrete implementations of `java.util.List` + (inputs) => { + val results = executeFuncOnCollection(inputs) + val builder = constructor(inputs.length).asInstanceOf[java.util.List[Any]] + results.foreach(builder.add(_)) + builder + } + } + case None => + // array + x => new GenericArrayData(executeFuncOnCollection(x).toArray) + case Some(cls) => + throw new RuntimeException(s"class `${cls.getName}` is not supported by `MapObjects` as " + + "resulting collection.") + } + + override def eval(input: InternalRow): Any = { + val inputCollection = inputData.eval(input) + + if (inputCollection == null) { + return null + } + mapElements(convertToSeq(inputCollection)) + } override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse( @@ -647,13 +742,6 @@ case class MapObjects private( case _ => "" } - // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. - // When we want to apply MapObjects on it, we have to use it. - val inputDataType = inputData.dataType match { - case p: PythonUserDefinedType => p.sqlType - case _ => inputData.dataType - } - // `MapObjects` generates a while loop to traverse the elements of the input collection. We // need to take care of Seq and List because they may have O(n) complexity for indexed accessing // like `list.get(1)`. Here we use Iterator to traverse Seq and List. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 1f6964dfef598..0edd27c8241e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.{SparkConf, SparkFunSuite} @@ -25,7 +26,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -135,6 +136,70 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-23587: MapObjects should support interpreted execution") { + def testMapObjects(collection: Any, collectionCls: Class[_], inputType: DataType): Unit = { + val function = (lambda: Expression) => Add(lambda, Literal(1)) + val elementType = IntegerType + val expected = Seq(2, 3, 4) + + val inputObject = BoundReference(0, inputType, nullable = true) + val optClass = Option(collectionCls) + val mapObj = MapObjects(function, inputObject, elementType, true, optClass) + val row = InternalRow.fromSeq(Seq(collection)) + val result = mapObj.eval(row) + + collectionCls match { + case null => + assert(result.asInstanceOf[ArrayData].array.toSeq == expected) + case l if classOf[java.util.List[_]].isAssignableFrom(l) => + assert(result.asInstanceOf[java.util.List[_]].asScala.toSeq == expected) + case s if classOf[Seq[_]].isAssignableFrom(s) => + assert(result.asInstanceOf[Seq[_]].toSeq == expected) + case s if classOf[scala.collection.Set[_]].isAssignableFrom(s) => + assert(result.asInstanceOf[scala.collection.Set[_]] == expected.toSet) + } + } + + val customCollectionClasses = Seq(classOf[Seq[Int]], classOf[scala.collection.Set[Int]], + classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]], + classOf[java.util.AbstractSequentialList[Int]], classOf[java.util.Vector[Int]], + classOf[java.util.Stack[Int]], null) + + val list = new java.util.ArrayList[Int]() + list.add(1) + list.add(2) + list.add(3) + val arrayData = new GenericArrayData(Array(1, 2, 3)) + val vector = new java.util.Vector[Int]() + vector.add(1) + vector.add(2) + vector.add(3) + val stack = new java.util.Stack[Int]() + stack.add(1) + stack.add(2) + stack.add(3) + + Seq( + (Seq(1, 2, 3), ObjectType(classOf[Seq[Int]])), + (Array(1, 2, 3), ObjectType(classOf[Array[Int]])), + (Seq(1, 2, 3), ObjectType(classOf[Object])), + (Array(1, 2, 3), ObjectType(classOf[Object])), + (list, ObjectType(classOf[java.util.List[Int]])), + (vector, ObjectType(classOf[java.util.Vector[Int]])), + (stack, ObjectType(classOf[java.util.Stack[Int]])), + (arrayData, ArrayType(IntegerType)) + ).foreach { case (collection, inputType) => + customCollectionClasses.foreach(testMapObjects(collection, _, inputType)) + + // Unsupported custom collection class + val errMsg = intercept[RuntimeException] { + testMapObjects(collection, classOf[scala.collection.Map[Int, Int]], inputType) + }.getMessage() + assert(errMsg.contains("`scala.collection.Map` is not supported by `MapObjects` " + + "as resulting collection.")) + } + } + test("SPARK-23592: DecodeUsingSerializer should support interpreted execution") { val cls = classOf[java.lang.Integer] val inputObject = BoundReference(0, ObjectType(classOf[Array[Byte]]), nullable = true) From 359375eff74630c9f0ea5a90ab7d45bf1b281ed0 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 3 Apr 2018 17:09:12 -0700 Subject: [PATCH 556/774] [SPARK-23809][SQL] Active SparkSession should be set by getOrCreate ## What changes were proposed in this pull request? Currently, the active spark session is set inconsistently (e.g., in createDataFrame, prior to query execution). Many places in spark also incorrectly query active session when they should be calling activeSession.getOrElse(defaultSession) and so might get None even if a Spark session exists. The semantics here can be cleaned up if we also set the active session when the default session is set. Related: https://github.com/apache/spark/pull/20926/files ## How was this patch tested? Unit test, existing test. Note that if https://github.com/apache/spark/pull/20926 merges first we should also update the tests there. Author: Eric Liang Closes #20927 from ericl/active-session-cleanup. --- .../org/apache/spark/sql/SparkSession.scala | 14 +++++++++++++- .../spark/sql/SparkSessionBuilderSuite.scala | 18 ++++++++++++++++++ .../apache/spark/sql/test/TestSQLContext.scala | 1 + .../apache/spark/sql/hive/test/TestHive.scala | 3 +++ 4 files changed, 35 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 734573ba31f71..b107492fbb330 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -951,7 +951,8 @@ object SparkSession { session = new SparkSession(sparkContext, None, None, extensions) options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) } - defaultSession.set(session) + setDefaultSession(session) + setActiveSession(session) // Register a successfully instantiated context to the singleton. This should be at the // end of the class definition so that the singleton is updated only if there is no @@ -1027,6 +1028,17 @@ object SparkSession { */ def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + /** + * Returns the currently active SparkSession, otherwise the default one. If there is no default + * SparkSession, throws an exception. + * + * @since 2.4.0 + */ + def active: SparkSession = { + getActiveSession.getOrElse(getDefaultSession.getOrElse( + throw new IllegalStateException("No active or default Spark session found"))) + } + //////////////////////////////////////////////////////////////////////////////////////// // Private methods from now on //////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index c0301f2ce2d66..44bf8624a6bcd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -50,6 +50,24 @@ class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach { assert(SparkSession.builder().getOrCreate() == session) } + test("sets default and active session") { + assert(SparkSession.getDefaultSession == None) + assert(SparkSession.getActiveSession == None) + val session = SparkSession.builder().master("local").getOrCreate() + assert(SparkSession.getDefaultSession == Some(session)) + assert(SparkSession.getActiveSession == Some(session)) + } + + test("get active or default session") { + val session = SparkSession.builder().master("local").getOrCreate() + assert(SparkSession.active == session) + SparkSession.clearActiveSession() + assert(SparkSession.active == session) + SparkSession.clearDefaultSession() + intercept[IllegalStateException](SparkSession.active) + session.stop() + } + test("config options are propagated to existing SparkSession") { val session1 = SparkSession.builder().master("local").config("spark-config1", "a").getOrCreate() assert(session1.conf.get("spark-config1") == "a") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 3038b822beb4a..17603deacdcdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -35,6 +35,7 @@ private[spark] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) } SparkSession.setDefaultSession(this) + SparkSession.setActiveSession(this) @transient override lazy val sessionState: SessionState = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 814038d4ef7af..a7006a16d7b73 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -179,6 +179,9 @@ private[hive] class TestHiveSparkSession( loadTestTables) } + SparkSession.setDefaultSession(this) + SparkSession.setActiveSession(this) + { // set the metastore temporary configuration val metastoreTempConf = HiveUtils.newTemporaryConfiguration(useInMemoryDerby = false) ++ Map( ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true", From 5cfd5fabcdbd77a806b98a6dd59b02772d2f6dee Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Tue, 3 Apr 2018 17:25:54 -0700 Subject: [PATCH 557/774] [SPARK-23802][SQL] PropagateEmptyRelation can leave query plan in unresolved state ## What changes were proposed in this pull request? Add cast to nulls introduced by PropagateEmptyRelation so in cases they're part of coalesce they will not break its type checking rules ## How was this patch tested? Added unit test Author: Robert Kruszewski Closes #20914 from robert3005/rk/propagate-empty-fix. --- .../optimizer/PropagateEmptyRelation.scala | 8 ++++-- .../PropagateEmptyRelationSuite.scala | 26 ++++++++++++++----- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index a6e5aa6daca65..c3fdb924243df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf /** * Collapse plans consisting empty local relations generated by [[PruneFilters]]. @@ -32,7 +34,7 @@ import org.apache.spark.sql.catalyst.rules._ * - Aggregate with all empty children and at least one grouping expression. * - Generate(Explode) with all empty children. Others like Hive UDTF may return results. */ -object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { +object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper with CastSupport { private def isEmptyLocalRelation(plan: LogicalPlan): Boolean = plan match { case p: LocalRelation => p.data.isEmpty case _ => false @@ -43,7 +45,9 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { // Construct a project list from plan's output, while the value is always NULL. private def nullValueProjectList(plan: LogicalPlan): Seq[NamedExpression] = - plan.output.map{ a => Alias(Literal(null), a.name)(a.exprId) } + plan.output.map{ a => Alias(cast(Literal(null), a.dataType), a.name)(a.exprId) } + + override def conf: SQLConf = SQLConf.get def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p: Union if p.children.forall(isEmptyLocalRelation) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index 3964508e3a55e..f1ce7543ffdc1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{IntegerType, StructType} class PropagateEmptyRelationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -37,7 +37,8 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceIntersectWithSemiJoin, PushDownPredicate, PruneFilters, - PropagateEmptyRelation) :: Nil + PropagateEmptyRelation, + CollapseProject) :: Nil } object OptimizeWithoutPropagateEmptyRelation extends RuleExecutor[LogicalPlan] { @@ -48,7 +49,8 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters) :: Nil + PruneFilters, + CollapseProject) :: Nil } val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1))) @@ -79,9 +81,11 @@ class PropagateEmptyRelationSuite extends PlanTest { (true, false, Inner, Some(LocalRelation('a.int, 'b.int))), (true, false, Cross, Some(LocalRelation('a.int, 'b.int))), - (true, false, LeftOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)), + (true, false, LeftOuter, + Some(Project(Seq('a, Literal(null).cast(IntegerType).as('b)), testRelation1).analyze)), (true, false, RightOuter, Some(LocalRelation('a.int, 'b.int))), - (true, false, FullOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)), + (true, false, FullOuter, + Some(Project(Seq('a, Literal(null).cast(IntegerType).as('b)), testRelation1).analyze)), (true, false, LeftAnti, Some(testRelation1)), (true, false, LeftSemi, Some(LocalRelation('a.int))), @@ -89,8 +93,9 @@ class PropagateEmptyRelationSuite extends PlanTest { (false, true, Cross, Some(LocalRelation('a.int, 'b.int))), (false, true, LeftOuter, Some(LocalRelation('a.int, 'b.int))), (false, true, RightOuter, - Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)), - (false, true, FullOuter, Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)), + Some(Project(Seq(Literal(null).cast(IntegerType).as('a), 'b), testRelation2).analyze)), + (false, true, FullOuter, + Some(Project(Seq(Literal(null).cast(IntegerType).as('a), 'b), testRelation2).analyze)), (false, true, LeftAnti, Some(LocalRelation('a.int))), (false, true, LeftSemi, Some(LocalRelation('a.int))), @@ -209,4 +214,11 @@ class PropagateEmptyRelationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("propagate empty relation keeps the plan resolved") { + val query = testRelation1.join( + LocalRelation('a.int, 'b.int), UsingJoin(FullOuter, "a" :: Nil), None) + val optimized = Optimize.execute(query.analyze) + assert(optimized.resolved) + } } From 16ef6baa36ac11c72cfeafaa2363e6b69f0ba573 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 4 Apr 2018 14:31:03 +0800 Subject: [PATCH 558/774] [SPARK-23826][TEST] TestHiveSparkSession should set default session ## What changes were proposed in this pull request? In TestHive, the base spark session does this in getOrCreate(), we emulate that behavior for tests. ## How was this patch tested? N/A Author: gatorsmile Closes #20969 from gatorsmile/setDefault. --- .../main/scala/org/apache/spark/sql/hive/test/TestHive.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index a7006a16d7b73..965aea2b61456 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -159,10 +159,6 @@ private[hive] class TestHiveSparkSession( private val loadTestTables: Boolean) extends SparkSession(sc) with Logging { self => - // TODO(SPARK-23826): TestHiveSparkSession should set default session the same way as - // TestSparkSession, but doing this the same way breaks many tests in the package. We need - // to investigate and find a different strategy. - def this(sc: SparkContext, loadTestTables: Boolean) { this( sc, From 5197562afe8534b29f5a0d72683c2859f796275d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 4 Apr 2018 14:39:19 +0800 Subject: [PATCH 559/774] [SPARK-21351][SQL] Update nullability based on children's output ## What changes were proposed in this pull request? This pr added a new optimizer rule `UpdateNullabilityInAttributeReferences ` to update the nullability that `Filter` changes when having `IsNotNull`. In the master, optimized plans do not respect the nullability when `Filter` has `IsNotNull`. This wrongly generates unnecessary code. For example: ``` scala> val df = Seq((Some(1), Some(2))).toDF("a", "b") scala> val bIsNotNull = df.where($"b" =!= 2).select($"b") scala> val targetQuery = bIsNotNull.distinct scala> val targetQuery.queryExecution.optimizedPlan.output(0).nullable res5: Boolean = true scala> targetQuery.debugCodegen Found 2 WholeStageCodegen subtrees. == Subtree 1 / 2 == *HashAggregate(keys=[b#19], functions=[], output=[b#19]) +- Exchange hashpartitioning(b#19, 200) +- *HashAggregate(keys=[b#19], functions=[], output=[b#19]) +- *Project [_2#16 AS b#19] +- *Filter isnotnull(_2#16) +- LocalTableScan [_1#15, _2#16] Generated code: ... /* 124 */ protected void processNext() throws java.io.IOException { ... /* 132 */ // output the result /* 133 */ /* 134 */ while (agg_mapIter.next()) { /* 135 */ wholestagecodegen_numOutputRows.add(1); /* 136 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey(); /* 137 */ UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue(); /* 138 */ /* 139 */ boolean agg_isNull4 = agg_aggKey.isNullAt(0); /* 140 */ int agg_value4 = agg_isNull4 ? -1 : (agg_aggKey.getInt(0)); /* 141 */ agg_rowWriter1.zeroOutNullBytes(); /* 142 */ // We don't need this NULL check because NULL is filtered out in `$"b" =!=2` /* 143 */ if (agg_isNull4) { /* 144 */ agg_rowWriter1.setNullAt(0); /* 145 */ } else { /* 146 */ agg_rowWriter1.write(0, agg_value4); /* 147 */ } /* 148 */ append(agg_result1); /* 149 */ /* 150 */ if (shouldStop()) return; /* 151 */ } /* 152 */ /* 153 */ agg_mapIter.close(); /* 154 */ if (agg_sorter == null) { /* 155 */ agg_hashMap.free(); /* 156 */ } /* 157 */ } /* 158 */ /* 159 */ } ``` In the line 143, we don't need this NULL check because NULL is filtered out in `$"b" =!=2`. This pr could remove this NULL check; ``` scala> val targetQuery.queryExecution.optimizedPlan.output(0).nullable res5: Boolean = false scala> targetQuery.debugCodegen ... Generated code: ... /* 144 */ protected void processNext() throws java.io.IOException { ... /* 152 */ // output the result /* 153 */ /* 154 */ while (agg_mapIter.next()) { /* 155 */ wholestagecodegen_numOutputRows.add(1); /* 156 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey(); /* 157 */ UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue(); /* 158 */ /* 159 */ int agg_value4 = agg_aggKey.getInt(0); /* 160 */ agg_rowWriter1.write(0, agg_value4); /* 161 */ append(agg_result1); /* 162 */ /* 163 */ if (shouldStop()) return; /* 164 */ } /* 165 */ /* 166 */ agg_mapIter.close(); /* 167 */ if (agg_sorter == null) { /* 168 */ agg_hashMap.free(); /* 169 */ } /* 170 */ } ``` ## How was this patch tested? Added `UpdateNullabilityInAttributeReferencesSuite` for unit tests. Author: Takeshi Yamamuro Closes #18576 from maropu/SPARK-21351. --- .../sql/catalyst/optimizer/Optimizer.scala | 19 ++++++- ...ullabilityInAttributeReferencesSuite.scala | 57 +++++++++++++++++++ .../optimizer/complexTypesSuite.scala | 9 --- .../org/apache/spark/sql/DataFrameSuite.scala | 5 -- 4 files changed, 75 insertions(+), 15 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2829d1d81eb1a..9a1bbc675e397 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -153,7 +153,9 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) RewritePredicateSubquery, ColumnPruning, CollapseProject, - RemoveRedundantProject) + RemoveRedundantProject) :+ + Batch("UpdateAttributeReferences", Once, + UpdateNullabilityInAttributeReferences) } /** @@ -1309,3 +1311,18 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { } } } + +/** + * Updates nullability in [[AttributeReference]]s if nullability is different between + * non-leaf plan's expressions and the children output. + */ +object UpdateNullabilityInAttributeReferences extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case p if !p.isInstanceOf[LeafNode] => + val nullabilityMap = AttributeMap(p.children.flatMap(_.output).map { x => x -> x.nullable }) + p transformExpressions { + case ar: AttributeReference if nullabilityMap.contains(ar) => + ar.withNullability(nullabilityMap(ar)) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala new file mode 100644 index 0000000000000..09b11f5aba2a0 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{CreateArray, GetArrayItem} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + + +class UpdateNullabilityInAttributeReferencesSuite extends PlanTest { + + object Optimizer extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Constant Folding", FixedPoint(10), + NullPropagation, + ConstantFolding, + BooleanSimplification, + SimplifyConditionals, + SimplifyBinaryComparison, + SimplifyExtractValueOps) :: + Batch("UpdateAttributeReferences", Once, + UpdateNullabilityInAttributeReferences) :: Nil + } + + test("update nullability in AttributeReference") { + val rel = LocalRelation('a.long.notNull) + // In the 'original' plans below, the Aggregate node produced by groupBy() has a + // nullable AttributeReference to `b`, because both array indexing and map lookup are + // nullable expressions. After optimization, the same attribute is now non-nullable, + // but the AttributeReference is not updated to reflect this. So, we need to update nullability + // by the `UpdateNullabilityInAttributeReferences` rule. + val original = rel + .select(GetArrayItem(CreateArray(Seq('a, 'a + 1L)), 0) as "b") + .groupBy($"b")("1") + val expected = rel.select('a as "b").groupBy($"b")("1").analyze + val optimized = Optimizer.execute(original.analyze) + comparePlans(optimized, expected) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 21ed987627b3b..633d86d495581 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -378,15 +378,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .groupBy($"foo")("1") checkRule(structRel, structExpected) - // These tests must use nullable attributes from the base relation for the following reason: - // in the 'original' plans below, the Aggregate node produced by groupBy() has a - // nullable AttributeReference to a1, because both array indexing and map lookup are - // nullable expressions. After optimization, the same attribute is now non-nullable, - // but the AttributeReference is not updated to reflect this. In the 'expected' plans, - // the grouping expressions have the same nullability as the original attribute in the - // relation. If that attribute is non-nullable, the tests will fail as the plans will - // compare differently, so for these tests we must use a nullable attribute. See - // SPARK-23634. val arrayRel = relation .select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1") .groupBy($"a1")("1") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f7b3393f65cb1..60e84e6ee7504 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2055,11 +2055,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { expr: String, expectedNonNullableColumns: Seq[String]): Unit = { val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr) - // In the logical plan, all the output columns of input dataframe are nullable - dfWithFilter.queryExecution.optimizedPlan.collect { - case e: Filter => assert(e.output.forall(_.nullable)) - } - dfWithFilter.queryExecution.executedPlan.collect { // When the child expression in isnotnull is null-intolerant (i.e. any null input will // result in null output), the involved columns are converted to not nullable; From a35523653cdac039ee2ddff316bc2c25d6514a91 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 4 Apr 2018 18:36:15 +0200 Subject: [PATCH 560/774] [SPARK-23583][SQL] Invoke should support interpreted execution ## What changes were proposed in this pull request? This pr added interpreted execution for `Invoke`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Kazuaki Ishizaki Closes #20797 from kiszk/SPARK-28583. --- .../spark/sql/catalyst/ScalaReflection.scala | 48 +++++++++++++- .../expressions/objects/objects.scala | 56 ++++++++++++++-- .../expressions/ObjectExpressionsSuite.scala | 65 +++++++++++++++++++ 3 files changed, 163 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 9a4bf0075a178..1aae3aea3a31a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -794,6 +794,52 @@ object ScalaReflection extends ScalaReflection { "interface", "long", "native", "new", "null", "package", "private", "protected", "public", "return", "short", "static", "strictfp", "super", "switch", "synchronized", "this", "throw", "throws", "transient", "true", "try", "void", "volatile", "while") + + val typeJavaMapping = Map[DataType, Class[_]]( + BooleanType -> classOf[Boolean], + ByteType -> classOf[Byte], + ShortType -> classOf[Short], + IntegerType -> classOf[Int], + LongType -> classOf[Long], + FloatType -> classOf[Float], + DoubleType -> classOf[Double], + StringType -> classOf[UTF8String], + DateType -> classOf[DateType.InternalType], + TimestampType -> classOf[TimestampType.InternalType], + BinaryType -> classOf[BinaryType.InternalType], + CalendarIntervalType -> classOf[CalendarInterval] + ) + + val typeBoxedJavaMapping = Map[DataType, Class[_]]( + BooleanType -> classOf[java.lang.Boolean], + ByteType -> classOf[java.lang.Byte], + ShortType -> classOf[java.lang.Short], + IntegerType -> classOf[java.lang.Integer], + LongType -> classOf[java.lang.Long], + FloatType -> classOf[java.lang.Float], + DoubleType -> classOf[java.lang.Double], + DateType -> classOf[java.lang.Integer], + TimestampType -> classOf[java.lang.Long] + ) + + def dataTypeJavaClass(dt: DataType): Class[_] = { + dt match { + case _: DecimalType => classOf[Decimal] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayData] + case _: MapType => classOf[MapData] + case ObjectType(cls) => cls + case _ => typeJavaMapping.getOrElse(dt, classOf[java.lang.Object]) + } + } + + def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = { + if (arguments != Nil) { + arguments.map(e => dataTypeJavaClass(e.dataType)) + } else { + Seq.empty + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 0e9d357c19c63..a455c1c821a26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.objects -import java.lang.reflect.Modifier +import java.lang.reflect.{Method, Modifier} import scala.collection.JavaConverters._ import scala.collection.mutable.Builder @@ -28,7 +28,7 @@ import scala.util.Try import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ @@ -104,6 +104,38 @@ trait InvokeLike extends Expression with NonSQLExpression { (argCode, argValues.mkString(", "), resultIsNull) } + + /** + * Evaluate each argument with a given row, invoke a method with a given object and arguments, + * and cast a return value if the return type can be mapped to a Java Boxed type + * + * @param obj the object for the method to be called. If null, perform s static method call + * @param method the method object to be called + * @param arguments the arguments used for the method call + * @param input the row used for evaluating arguments + * @param dataType the data type of the return object + * @return the return object of a method call + */ + def invoke( + obj: Any, + method: Method, + arguments: Seq[Expression], + input: InternalRow, + dataType: DataType): Any = { + val args = arguments.map(e => e.eval(input).asInstanceOf[Object]) + if (needNullCheck && args.exists(_ == null)) { + // return null if one of arguments is null + null + } else { + val ret = method.invoke(obj, args: _*) + val boxedClass = ScalaReflection.typeBoxedJavaMapping.get(dataType) + if (boxedClass.isDefined) { + boxedClass.get.cast(ret) + } else { + ret + } + } + } } /** @@ -264,12 +296,11 @@ case class Invoke( propagateNull: Boolean = true, returnNullable : Boolean = true) extends InvokeLike { + lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) + override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable override def children: Seq[Expression] = targetObject +: arguments - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - private lazy val encodedFunctionName = TermName(functionName).encodedName.toString @transient lazy val method = targetObject.dataType match { @@ -283,6 +314,21 @@ case class Invoke( case _ => None } + override def eval(input: InternalRow): Any = { + val obj = targetObject.eval(input) + if (obj == null) { + // return null if obj is null + null + } else { + val invokeMethod = if (method.isDefined) { + method.get + } else { + obj.getClass.getDeclaredMethod(functionName, argClasses: _*) + } + invoke(obj, invokeMethod, arguments, input, dataType) + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = CodeGenerator.javaType(dataType) val obj = targetObject.genCode(ctx) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 0edd27c8241e8..9bfe2916b0820 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -24,11 +24,23 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +class InvokeTargetClass extends Serializable { + def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0 + def filterPrimitiveInt(e: Int): Boolean = e > 0 + def binOp(e1: Int, e2: Double): Double = e1 + e2 +} + +class InvokeTargetSubClass extends InvokeTargetClass { + override def binOp(e1: Int, e2: Double): Double = e1 - e2 +} class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -81,6 +93,41 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed } + test("SPARK-23583: Invoke should support interpreted execution") { + val targetObject = new InvokeTargetClass + val funcClass = classOf[InvokeTargetClass] + val funcObj = Literal.create(targetObject, ObjectType(funcClass)) + val targetSubObject = new InvokeTargetSubClass + val funcSubObj = Literal.create(targetSubObject, ObjectType(classOf[InvokeTargetSubClass])) + val funcNullObj = Literal.create(null, ObjectType(funcClass)) + + val inputInt = Seq(BoundReference(0, ObjectType(classOf[Any]), true)) + val inputPrimitiveInt = Seq(BoundReference(0, IntegerType, false)) + val inputSum = Seq(BoundReference(0, IntegerType, false), BoundReference(1, DoubleType, false)) + + checkObjectExprEvaluation( + Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt), + java.lang.Boolean.valueOf(true), InternalRow.fromSeq(Seq(Integer.valueOf(1)))) + + checkObjectExprEvaluation( + Invoke(funcObj, "filterPrimitiveInt", BooleanType, inputPrimitiveInt), + false, InternalRow.fromSeq(Seq(-1))) + + checkObjectExprEvaluation( + Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt), + null, InternalRow.fromSeq(Seq(null))) + + checkObjectExprEvaluation( + Invoke(funcNullObj, "filterInt", ObjectType(classOf[Any]), inputInt), + null, InternalRow.fromSeq(Seq(Integer.valueOf(1)))) + + checkObjectExprEvaluation( + Invoke(funcObj, "binOp", DoubleType, inputSum), 1.25, InternalRow.apply(1, 0.25)) + + checkObjectExprEvaluation( + Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25)) + } + test("SPARK-23585: UnwrapOption should support interpreted execution") { val cls = classOf[Option[Int]] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) @@ -105,6 +152,24 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(createExternalRow, Row.fromSeq(Seq(1, "x")), InternalRow.fromSeq(Seq())) } + // by scala values instead of catalyst values. + private def checkObjectExprEvaluation( + expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + val serializer = new JavaSerializer(new SparkConf()).newInstance + val resolver = ResolveTimeZone(new SQLConf) + val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) + checkEvaluationWithoutCodegen(expr, expected, inputRow) + checkEvaluationWithGeneratedMutableProjection(expr, expected, inputRow) + if (GenerateUnsafeProjection.canSupport(expr.dataType)) { + checkEvaluationWithUnsafeProjection( + expr, + expected, + inputRow, + UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + } + checkEvaluationWithOptimization(expr, expected, inputRow) + } + test("SPARK-23594 GetExternalRowField should support interpreted execution") { val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0") From cccaaa14ad775fb981e501452ba2cc06ff5c0f0a Mon Sep 17 00:00:00 2001 From: Andrew Korzhuev Date: Wed, 4 Apr 2018 12:30:52 -0700 Subject: [PATCH 561/774] [SPARK-23668][K8S] Add config option for passing through k8s Pod.spec.imagePullSecrets ## What changes were proposed in this pull request? Pass through the `imagePullSecrets` option to the k8s pod in order to allow user to access private image registries. See https://kubernetes.io/docs/tasks/configure-pod-container/pull-image-private-registry/ ## How was this patch tested? Unit tests + manual testing. Manual testing procedure: 1. Have private image registry. 2. Spark-submit application with no `spark.kubernetes.imagePullSecret` set. Do `kubectl describe pod ...`. See the error message: ``` Error syncing pod, skipping: failed to "StartContainer" for "spark-kubernetes-driver" with ErrImagePull: "rpc error: code = 2 desc = Error: Status 400 trying to pull repository ...: \"{\\n \\\"errors\\\" : [ {\\n \\\"status\\\" : 400,\\n \\\"message\\\" : \\\"Unsupported docker v1 repository request for '...'\\\"\\n } ]\\n}\"" ``` 3. Create secret `kubectl create secret docker-registry ...` 4. Spark-submit with `spark.kubernetes.imagePullSecret` set to the new secret. See that deployment was successful. Author: Andrew Korzhuev Author: Andrew Korzhuev Closes #20811 from andrusha/spark-23668-image-pull-secrets. --- .../org/apache/spark/deploy/k8s/Config.scala | 7 ++++ .../spark/deploy/k8s/KubernetesUtils.scala | 13 +++++++ .../steps/BasicDriverConfigurationStep.scala | 7 +++- .../cluster/k8s/ExecutorPodFactory.scala | 4 +++ .../deploy/k8s/KubernetesUtilsTest.scala | 36 +++++++++++++++++++ .../BasicDriverConfigurationStepSuite.scala | 8 ++++- .../cluster/k8s/ExecutorPodFactorySuite.scala | 5 +++ 7 files changed, 78 insertions(+), 2 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 405ea476351bb..82f6c714f3555 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -54,6 +54,13 @@ private[spark] object Config extends Logging { .checkValues(Set("Always", "Never", "IfNotPresent")) .createWithDefault("IfNotPresent") + val IMAGE_PULL_SECRETS = + ConfigBuilder("spark.kubernetes.container.image.pullSecrets") + .doc("Comma separated list of the Kubernetes secrets used " + + "to access private image registries.") + .stringConf + .createOptional + val KUBERNETES_AUTH_DRIVER_CONF_PREFIX = "spark.kubernetes.authenticate.driver" val KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX = diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 5bc070147d3a8..5b2bb819cdb14 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.deploy.k8s +import io.fabric8.kubernetes.api.model.LocalObjectReference + import org.apache.spark.SparkConf import org.apache.spark.util.Utils @@ -35,6 +37,17 @@ private[spark] object KubernetesUtils { sparkConf.getAllWithPrefix(prefix).toMap } + /** + * Parses comma-separated list of imagePullSecrets into K8s-understandable format + */ + def parseImagePullSecrets(imagePullSecrets: Option[String]): List[LocalObjectReference] = { + imagePullSecrets match { + case Some(secretsCommaSeparated) => + secretsCommaSeparated.split(',').map(_.trim).map(new LocalObjectReference(_)).toList + case None => Nil + } + } + def requireNandDefined(opt1: Option[_], opt2: Option[_], errMessage: String): Unit = { opt1.foreach { _ => require(opt2.isEmpty, errMessage) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index b811db324108c..fcb1db8008053 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.k8s.submit.steps import scala.collection.JavaConverters._ -import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarSourceBuilder, PodBuilder, QuantityBuilder} +import io.fabric8.kubernetes.api.model._ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.deploy.k8s.Config._ @@ -51,6 +51,8 @@ private[spark] class BasicDriverConfigurationStep( .get(DRIVER_CONTAINER_IMAGE) .getOrElse(throw new SparkException("Must specify the driver container image")) + private val imagePullSecrets = sparkConf.get(IMAGE_PULL_SECRETS) + // CPU settings private val driverCpuCores = sparkConf.getOption("spark.driver.cores").getOrElse("1") private val driverLimitCores = sparkConf.get(KUBERNETES_DRIVER_LIMIT_CORES) @@ -129,6 +131,8 @@ private[spark] class BasicDriverConfigurationStep( case _ => driverContainerWithoutArgs.addToArgs(appArgs: _*).build() } + val parsedImagePullSecrets = KubernetesUtils.parseImagePullSecrets(imagePullSecrets) + val baseDriverPod = new PodBuilder(driverSpec.driverPod) .editOrNewMetadata() .withName(driverPodName) @@ -138,6 +142,7 @@ private[spark] class BasicDriverConfigurationStep( .withNewSpec() .withRestartPolicy("Never") .withNodeSelector(nodeSelector.asJava) + .withImagePullSecrets(parsedImagePullSecrets.asJava) .endSpec() .build() diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index 7143f7a6f0b71..8607d6fba3234 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -68,6 +68,7 @@ private[spark] class ExecutorPodFactory( .get(EXECUTOR_CONTAINER_IMAGE) .getOrElse(throw new SparkException("Must specify the executor container image")) private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) + private val imagePullSecrets = sparkConf.get(IMAGE_PULL_SECRETS) private val blockManagerPort = sparkConf .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) @@ -103,6 +104,8 @@ private[spark] class ExecutorPodFactory( nodeToLocalTaskCount: Map[String, Int]): Pod = { val name = s"$executorPodNamePrefix-exec-$executorId" + val parsedImagePullSecrets = KubernetesUtils.parseImagePullSecrets(imagePullSecrets) + // hostname must be no longer than 63 characters, so take the last 63 characters of the pod // name as the hostname. This preserves uniqueness since the end of name contains // executorId @@ -194,6 +197,7 @@ private[spark] class ExecutorPodFactory( .withHostname(hostname) .withRestartPolicy("Never") .withNodeSelector(nodeSelector.asJava) + .withImagePullSecrets(parsedImagePullSecrets.asJava) .endSpec() .build() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala new file mode 100644 index 0000000000000..cf41b22e241af --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.LocalObjectReference + +import org.apache.spark.SparkFunSuite + +class KubernetesUtilsTest extends SparkFunSuite { + + test("testParseImagePullSecrets") { + val noSecrets = KubernetesUtils.parseImagePullSecrets(None) + assert(noSecrets === Nil) + + val oneSecret = KubernetesUtils.parseImagePullSecrets(Some("imagePullSecret")) + assert(oneSecret === new LocalObjectReference("imagePullSecret") :: Nil) + + val commaSeparatedSecrets = KubernetesUtils.parseImagePullSecrets(Some("s1, s2 , s3,s4")) + assert(commaSeparatedSecrets.map(_.getName) === "s1" :: "s2" :: "s3" :: "s4" :: Nil) + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index e59c6d28a8cc2..ee450fff8d376 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -51,6 +51,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { .set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$CUSTOM_ANNOTATION_KEY", CUSTOM_ANNOTATION_VALUE) .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY1", "customDriverEnv1") .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY2", "customDriverEnv2") + .set(IMAGE_PULL_SECRETS, "imagePullSecret1, imagePullSecret2") val submissionStep = new BasicDriverConfigurationStep( APP_ID, @@ -103,7 +104,12 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { CUSTOM_ANNOTATION_KEY -> CUSTOM_ANNOTATION_VALUE, SPARK_APP_NAME_ANNOTATION -> APP_NAME) assert(driverPodMetadata.getAnnotations.asScala === expectedAnnotations) - assert(preparedDriverSpec.driverPod.getSpec.getRestartPolicy === "Never") + + val driverPodSpec = preparedDriverSpec.driverPod.getSpec + assert(driverPodSpec.getRestartPolicy === "Never") + assert(driverPodSpec.getImagePullSecrets.size() === 2) + assert(driverPodSpec.getImagePullSecrets.get(0).getName === "imagePullSecret1") + assert(driverPodSpec.getImagePullSecrets.get(1).getName === "imagePullSecret2") val resolvedSparkConf = preparedDriverSpec.driverSparkConf.getAll.toMap val expectedSparkConf = Map( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index a71a2a1b888bc..d73df20f0f956 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -33,6 +33,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef private val driverPodUid: String = "driver-uid" private val executorPrefix: String = "base" private val executorImage: String = "executor-image" + private val imagePullSecrets: String = "imagePullSecret1, imagePullSecret2" private val driverPod = new PodBuilder() .withNewMetadata() .withName(driverPodName) @@ -54,6 +55,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) .set(CONTAINER_IMAGE, executorImage) .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) + .set(IMAGE_PULL_SECRETS, imagePullSecrets) } test("basic executor pod has reasonable defaults") { @@ -76,6 +78,9 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef .getRequests.get("memory").getAmount === "1408Mi") assert(executor.getSpec.getContainers.get(0).getResources .getLimits.get("memory").getAmount === "1408Mi") + assert(executor.getSpec.getImagePullSecrets.size() === 2) + assert(executor.getSpec.getImagePullSecrets.get(0).getName === "imagePullSecret1") + assert(executor.getSpec.getImagePullSecrets.get(1).getName === "imagePullSecret2") // The pod has no node selector, volumes. assert(executor.getSpec.getNodeSelector.isEmpty) From d8379e5bc3629f4e8233ad42831bdaf68c24cfeb Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 4 Apr 2018 15:43:58 -0700 Subject: [PATCH 562/774] [SPARK-23838][WEBUI] Running SQL query is displayed as "completed" in SQL tab ## What changes were proposed in this pull request? A running SQL query would appear as completed in the Spark UI: ![image1](https://user-images.githubusercontent.com/1097932/38170733-3d7cb00c-35bf-11e8-994c-43f2d4fa285d.png) We can see the query in "Completed queries", while in in the job page we see it's still running Job 132. ![image2](https://user-images.githubusercontent.com/1097932/38170735-48f2c714-35bf-11e8-8a41-6fae23543c46.png) After some time in the query still appears in "Completed queries" (while it's still running), but the "Duration" gets increased. ![image3](https://user-images.githubusercontent.com/1097932/38170737-50f87ea4-35bf-11e8-8b60-000f6f918964.png) To reproduce, we can run a query with multiple jobs. E.g. Run TPCDS q6. The reason is that updates from executions are written into kvstore periodically, and the job start event may be missed. ## How was this patch tested? Manually run the job again and check the SQL Tab. The fix is pretty simple. Author: Gengliang Wang Closes #20955 from gengliangwang/jobCompleted. --- .../apache/spark/sql/execution/ui/AllExecutionsPage.scala | 3 ++- .../spark/sql/execution/ui/SQLAppStatusListener.scala | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index e751ce39cd5d7..582528777f90e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -39,7 +39,8 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L val failed = new mutable.ArrayBuffer[SQLExecutionUIData]() sqlStore.executionsList().foreach { e => - val isRunning = e.jobs.exists { case (_, status) => status == JobExecutionStatus.RUNNING } + val isRunning = e.completionTime.isEmpty || + e.jobs.exists { case (_, status) => status == JobExecutionStatus.RUNNING } val isFailed = e.jobs.exists { case (_, status) => status == JobExecutionStatus.FAILED } if (isRunning) { running += e diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 71e9f93c4566e..2b6bb48467eb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -88,7 +88,7 @@ class SQLAppStatusListener( exec.jobs = exec.jobs + (jobId -> JobExecutionStatus.RUNNING) exec.stages ++= event.stageIds.toSet - update(exec) + update(exec, force = true) } override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = { @@ -308,11 +308,13 @@ class SQLAppStatusListener( }) } - private def update(exec: LiveExecutionData): Unit = { + private def update(exec: LiveExecutionData, force: Boolean = false): Unit = { val now = System.nanoTime() if (exec.endEvents >= exec.jobs.size + 1) { exec.write(kvstore, now) liveExecutions.remove(exec.executionId) + } else if (force) { + exec.write(kvstore, now) } else if (liveUpdatePeriodNs >= 0) { if (now - exec.lastWriteTime > liveUpdatePeriodNs) { exec.write(kvstore, now) From d3bd0435ee4ff3d414f32cce3f58b6b9f67e68bc Mon Sep 17 00:00:00 2001 From: jinxing Date: Wed, 4 Apr 2018 15:51:27 -0700 Subject: [PATCH 563/774] [SPARK-23637][YARN] Yarn might allocate more resource if a same executor is killed multiple times. ## What changes were proposed in this pull request? `YarnAllocator` uses `numExecutorsRunning` to track the number of running executor. `numExecutorsRunning` is used to check if there're executors missing and need to allocate more. In current code, `numExecutorsRunning` can be negative when driver asks to kill a same idle executor multiple times. ## How was this patch tested? UT added Author: jinxing Closes #20781 from jinxing64/SPARK-23637. --- .../spark/deploy/yarn/YarnAllocator.scala | 36 +++++++------- .../deploy/yarn/YarnAllocatorSuite.scala | 48 ++++++++++++++++++- 2 files changed, 66 insertions(+), 18 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index a537243d641cb..ebee3d431744d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -81,7 +81,8 @@ private[yarn] class YarnAllocator( private val releasedContainers = Collections.newSetFromMap[ContainerId]( new ConcurrentHashMap[ContainerId, java.lang.Boolean]) - private val numExecutorsRunning = new AtomicInteger(0) + private val runningExecutors = Collections.newSetFromMap[String]( + new ConcurrentHashMap[String, java.lang.Boolean]()) private val numExecutorsStarting = new AtomicInteger(0) @@ -166,7 +167,7 @@ private[yarn] class YarnAllocator( clock = newClock } - def getNumExecutorsRunning: Int = numExecutorsRunning.get() + def getNumExecutorsRunning: Int = runningExecutors.size() def getNumExecutorsFailed: Int = synchronized { val endTime = clock.getTimeMillis() @@ -242,12 +243,11 @@ private[yarn] class YarnAllocator( * Request that the ResourceManager release the container running the specified executor. */ def killExecutor(executorId: String): Unit = synchronized { - if (executorIdToContainer.contains(executorId)) { - val container = executorIdToContainer.get(executorId).get - internalReleaseContainer(container) - numExecutorsRunning.decrementAndGet() - } else { - logWarning(s"Attempted to kill unknown executor $executorId!") + executorIdToContainer.get(executorId) match { + case Some(container) if !releasedContainers.contains(container.getId) => + internalReleaseContainer(container) + runningExecutors.remove(executorId) + case _ => logWarning(s"Attempted to kill unknown executor $executorId!") } } @@ -274,7 +274,7 @@ private[yarn] class YarnAllocator( "Launching executor count: %d. Cluster resources: %s.") .format( allocatedContainers.size, - numExecutorsRunning.get, + runningExecutors.size, numExecutorsStarting.get, allocateResponse.getAvailableResources)) @@ -286,7 +286,7 @@ private[yarn] class YarnAllocator( logDebug("Completed %d containers".format(completedContainers.size)) processCompletedContainers(completedContainers.asScala) logDebug("Finished processing %d completed containers. Current running executor count: %d." - .format(completedContainers.size, numExecutorsRunning.get)) + .format(completedContainers.size, runningExecutors.size)) } } @@ -300,9 +300,9 @@ private[yarn] class YarnAllocator( val pendingAllocate = getPendingAllocate val numPendingAllocate = pendingAllocate.size val missing = targetNumExecutors - numPendingAllocate - - numExecutorsStarting.get - numExecutorsRunning.get + numExecutorsStarting.get - runningExecutors.size logDebug(s"Updating resource requests, target: $targetNumExecutors, " + - s"pending: $numPendingAllocate, running: ${numExecutorsRunning.get}, " + + s"pending: $numPendingAllocate, running: ${runningExecutors.size}, " + s"executorsStarting: ${numExecutorsStarting.get}") if (missing > 0) { @@ -502,7 +502,7 @@ private[yarn] class YarnAllocator( s"for executor with ID $executorId") def updateInternalState(): Unit = synchronized { - numExecutorsRunning.incrementAndGet() + runningExecutors.add(executorId) numExecutorsStarting.decrementAndGet() executorIdToContainer(executorId) = container containerIdToExecutorId(container.getId) = executorId @@ -513,7 +513,7 @@ private[yarn] class YarnAllocator( allocatedContainerToHostMap.put(containerId, executorHostname) } - if (numExecutorsRunning.get < targetNumExecutors) { + if (runningExecutors.size() < targetNumExecutors) { numExecutorsStarting.incrementAndGet() if (launchContainers) { launcherPool.execute(new Runnable { @@ -554,7 +554,7 @@ private[yarn] class YarnAllocator( } else { logInfo(("Skip launching executorRunnable as running executors count: %d " + "reached target executors count: %d.").format( - numExecutorsRunning.get, targetNumExecutors)) + runningExecutors.size, targetNumExecutors)) } } } @@ -569,7 +569,11 @@ private[yarn] class YarnAllocator( val exitReason = if (!alreadyReleased) { // Decrement the number of executors running. The next iteration of // the ApplicationMaster's reporting thread will take care of allocating. - numExecutorsRunning.decrementAndGet() + containerIdToExecutorId.get(containerId) match { + case Some(executorId) => runningExecutors.remove(executorId) + case None => logWarning(s"Cannot find executorId for container: ${containerId.toString}") + } + logInfo("Completed container %s%s (state: %s, exit status: %s)".format( containerId, onHostStr, diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index cb1e3c5268510..525abb6f2b350 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -251,11 +251,55 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Finished", 0) } handler.updateResourceRequests() - handler.processCompletedContainers(statuses.toSeq) + handler.processCompletedContainers(statuses) handler.getNumExecutorsRunning should be (0) handler.getPendingAllocate.size should be (1) } + test("kill same executor multiple times") { + val handler = createAllocator(2) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getPendingAllocate.size should be (2) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2)) + handler.getNumExecutorsRunning should be (2) + handler.getPendingAllocate.size should be (0) + + val executorToKill = handler.executorIdToContainer.keys.head + handler.killExecutor(executorToKill) + handler.getNumExecutorsRunning should be (1) + handler.killExecutor(executorToKill) + handler.killExecutor(executorToKill) + handler.killExecutor(executorToKill) + handler.getNumExecutorsRunning should be (1) + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map.empty, Set.empty) + handler.updateResourceRequests() + handler.getPendingAllocate.size should be (1) + } + + test("process same completed container multiple times") { + val handler = createAllocator(2) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getPendingAllocate.size should be (2) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2)) + handler.getNumExecutorsRunning should be (2) + handler.getPendingAllocate.size should be (0) + + val statuses = Seq(container1, container1, container2).map { c => + ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Finished", 0) + } + handler.processCompletedContainers(statuses) + handler.getNumExecutorsRunning should be (0) + + } + test("lost executor removed from backend") { val handler = createAllocator(4) handler.updateResourceRequests() @@ -272,7 +316,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Failed", -1) } handler.updateResourceRequests() - handler.processCompletedContainers(statuses.toSeq) + handler.processCompletedContainers(statuses) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) handler.getPendingAllocate.size should be (2) From c5c8b544047a83cb6128a20d31f1d943a15f9260 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 5 Apr 2018 13:39:45 +0200 Subject: [PATCH 564/774] [SPARK-23593][SQL] Add interpreted execution for InitializeJavaBean expression ## What changes were proposed in this pull request? Add interpreted execution for `InitializeJavaBean` expression. ## How was this patch tested? Added unit test. Author: Liang-Chi Hsieh Closes #20756 from viirya/SPARK-23593. --- .../expressions/objects/objects.scala | 47 ++++++++++++++++- .../expressions/ExpressionEvalHelper.scala | 9 ++-- .../expressions/ObjectExpressionsSuite.scala | 52 +++++++++++++++++++ 3 files changed, 103 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index a455c1c821a26..20c4f4c7324fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1410,8 +1410,47 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp override def children: Seq[Expression] = beanInstance +: setters.values.toSeq override def dataType: DataType = beanInstance.dataType - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + private lazy val resolvedSetters = { + assert(beanInstance.dataType.isInstanceOf[ObjectType]) + + val ObjectType(beanClass) = beanInstance.dataType + setters.map { + case (name, expr) => + // Looking for known type mapping. + // But also looking for general `Object`-type parameter for generic methods. + val paramTypes = ScalaReflection.expressionJavaClasses(Seq(expr)) ++ Seq(classOf[Object]) + val methods = paramTypes.flatMap { fieldClass => + try { + Some(beanClass.getDeclaredMethod(name, fieldClass)) + } catch { + case e: NoSuchMethodException => None + } + } + if (methods.isEmpty) { + throw new NoSuchMethodException(s"""A method named "$name" is not declared """ + + "in any enclosing class nor any supertype") + } + methods.head -> expr + } + } + + override def eval(input: InternalRow): Any = { + val instance = beanInstance.eval(input) + if (instance != null) { + val bean = instance.asInstanceOf[Object] + resolvedSetters.foreach { + case (setter, expr) => + val paramVal = expr.eval(input) + if (paramVal == null) { + throw new NullPointerException("The parameter value for setters in " + + "`InitializeJavaBean` can not be null") + } else { + setter.invoke(bean, paramVal.asInstanceOf[AnyRef]) + } + } + } + instance + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val instanceGen = beanInstance.genCode(ctx) @@ -1424,6 +1463,10 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val fieldGen = fieldValue.genCode(ctx) s""" |${fieldGen.code} + |if (${fieldGen.isNull}) { + | throw new NullPointerException("The parameter value for setters in " + + | "`InitializeJavaBean` can not be null"); + |} |$javaBeanInstance.$setterMethod(${fieldGen.value}); """.stripMargin } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 3828f172a15cf..a5ecd1b68fac4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -55,7 +55,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def checkEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val expr = prepareEvaluation(expression) + // Make it as method to obtain fresh expression everytime. + def expr = prepareEvaluation(expression) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) @@ -111,12 +112,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val errMsg = intercept[T] { eval }.getMessage - if (errMsg != expectedErrMsg) { + if (!errMsg.contains(expectedErrMsg)) { fail(s"Expected error message is `$expectedErrMsg`, but `$errMsg` found") } } } - val expr = prepareEvaluation(expression) + + // Make it as method to obtain fresh expression everytime. + def expr = prepareEvaluation(expression) checkException(evaluateWithoutCodegen(expr, inputRow), "non-codegen mode") checkException(evaluateWithGeneratedMutableProjection(expr, inputRow), "codegen mode") if (GenerateUnsafeProjection.canSupport(expr.dataType)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 9bfe2916b0820..44fecd602e854 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -128,6 +128,50 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25)) } + test("SPARK-23593: InitializeJavaBean should support interpreted execution") { + val list = new java.util.LinkedList[Int]() + list.add(1) + + val initializeBean = InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]), + Map("add" -> Literal(1))) + checkEvaluation(initializeBean, list, InternalRow.fromSeq(Seq())) + + val initializeWithNonexistingMethod = InitializeJavaBean( + Literal.fromObject(new java.util.LinkedList[Int]), + Map("nonexisting" -> Literal(1))) + checkExceptionInExpression[Exception](initializeWithNonexistingMethod, + InternalRow.fromSeq(Seq()), + """A method named "nonexisting" is not declared in any enclosing class """ + + "nor any supertype") + + val initializeWithWrongParamType = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setX" -> Literal("1"))) + intercept[Exception] { + evaluateWithoutCodegen(initializeWithWrongParamType, InternalRow.fromSeq(Seq())) + }.getMessage.contains( + """A method named "setX" is not declared in any enclosing class """ + + "nor any supertype") + } + + test("Can not pass in null into setters in InitializeJavaBean") { + val initializeBean = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setNonPrimitive" -> Literal(null))) + intercept[NullPointerException] { + evaluateWithoutCodegen(initializeBean, InternalRow.fromSeq(Seq())) + }.getMessage.contains("The parameter value for setters in `InitializeJavaBean` can not be null") + intercept[NullPointerException] { + evaluateWithGeneratedMutableProjection(initializeBean, InternalRow.fromSeq(Seq())) + }.getMessage.contains("The parameter value for setters in `InitializeJavaBean` can not be null") + + val initializeBean2 = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setNonPrimitive" -> Literal("string"))) + evaluateWithoutCodegen(initializeBean2, InternalRow.fromSeq(Seq())) + evaluateWithGeneratedMutableProjection(initializeBean2, InternalRow.fromSeq(Seq())) + } + test("SPARK-23585: UnwrapOption should support interpreted execution") { val cls = classOf[Option[Int]] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) @@ -278,3 +322,11 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + +class TestBean extends Serializable { + private var x: Int = 0 + + def setX(i: Int): Unit = x = i + def setNonPrimitive(i: AnyRef): Unit = + assert(i != null, "this setter should not be called with null.") +} From 1822ecda51cc9e14bb18050e0b8c270fee47ced7 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 5 Apr 2018 13:47:06 +0200 Subject: [PATCH 565/774] [SPARK-23582][SQL] StaticInvoke should support interpreted execution ## What changes were proposed in this pull request? This pr added interpreted execution for `StaticInvoke`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Kazuaki Ishizaki Closes #20753 from kiszk/SPARK-23582. --- .../expressions/objects/objects.scala | 14 +++- .../expressions/ObjectExpressionsSuite.scala | 66 ++++++++++++++++++- 2 files changed, 77 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 20c4f4c7324fd..9ca0b6137679e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** * Common base class for [[StaticInvoke]], [[Invoke]], and [[NewInstance]]. @@ -217,12 +218,21 @@ case class StaticInvoke( returnNullable: Boolean = true) extends InvokeLike { val objectName = staticObject.getName.stripSuffix("$") + val cls = if (staticObject.getName == objectName) { + staticObject + } else { + Utils.classForName(objectName) + } override def nullable: Boolean = needNullCheck || returnNullable override def children: Seq[Expression] = arguments - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) + @transient lazy val method = cls.getDeclaredMethod(functionName, argClasses : _*) + + override def eval(input: InternalRow): Any = { + invoke(null, method, arguments, input, dataType) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = CodeGenerator.javaType(dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 44fecd602e854..eb89e01b5ff9d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.{Date, Timestamp} + import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -28,9 +30,11 @@ import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{SQLDate, SQLTimestamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class InvokeTargetClass extends Serializable { def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0 @@ -93,6 +97,66 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed } + test("SPARK-23582: StaticInvoke should support interpreted execution") { + Seq((classOf[java.lang.Boolean], "true", true), + (classOf[java.lang.Byte], "1", 1.toByte), + (classOf[java.lang.Short], "257", 257.toShort), + (classOf[java.lang.Integer], "12345", 12345), + (classOf[java.lang.Long], "12345678", 12345678.toLong), + (classOf[java.lang.Float], "12.34", 12.34.toFloat), + (classOf[java.lang.Double], "1.2345678", 1.2345678) + ).foreach { case (cls, arg, expected) => + checkObjectExprEvaluation(StaticInvoke(cls, ObjectType(cls), "valueOf", + Seq(BoundReference(0, ObjectType(classOf[java.lang.String]), true))), + expected, InternalRow.fromSeq(Seq(arg))) + } + + // Return null when null argument is passed with propagateNull = true + val stringCls = classOf[java.lang.String] + checkObjectExprEvaluation(StaticInvoke(stringCls, ObjectType(stringCls), "valueOf", + Seq(BoundReference(0, ObjectType(classOf[Object]), true)), propagateNull = true), + null, InternalRow.fromSeq(Seq(null))) + checkObjectExprEvaluation(StaticInvoke(stringCls, ObjectType(stringCls), "valueOf", + Seq(BoundReference(0, ObjectType(classOf[Object]), true)), propagateNull = false), + "null", InternalRow.fromSeq(Seq(null))) + + // test no argument + val clCls = classOf[java.lang.ClassLoader] + checkObjectExprEvaluation(StaticInvoke(clCls, ObjectType(clCls), "getSystemClassLoader", Nil), + ClassLoader.getSystemClassLoader, InternalRow.empty) + // test more than one argument + val intCls = classOf[java.lang.Integer] + checkObjectExprEvaluation(StaticInvoke(intCls, ObjectType(intCls), "compare", + Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, false))), + 0, InternalRow.fromSeq(Seq(7, 7))) + + Seq((DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", ObjectType(classOf[Timestamp]), + new Timestamp(77777), DateTimeUtils.fromJavaTimestamp(new Timestamp(77777))), + (DateTimeUtils.getClass, DateType, "fromJavaDate", ObjectType(classOf[Date]), + new Date(88888888), DateTimeUtils.fromJavaDate(new Date(88888888))), + (classOf[UTF8String], StringType, "fromString", ObjectType(classOf[String]), + "abc", UTF8String.fromString("abc")), + (Decimal.getClass, DecimalType(38, 0), "fromDecimal", ObjectType(classOf[Any]), + BigInt(88888888), Decimal.fromDecimal(BigInt(88888888))), + (Decimal.getClass, DecimalType.SYSTEM_DEFAULT, + "apply", ObjectType(classOf[java.math.BigInteger]), + new java.math.BigInteger("88888888"), Decimal.apply(new java.math.BigInteger("88888888"))), + (classOf[ArrayData], ArrayType(IntegerType), "toArrayData", ObjectType(classOf[Any]), + Array[Int](1, 2, 3), ArrayData.toArrayData(Array[Int](1, 2, 3))), + (classOf[UnsafeArrayData], ArrayType(IntegerType, false), + "fromPrimitiveArray", ObjectType(classOf[Array[Int]]), + Array[Int](1, 2, 3), UnsafeArrayData.fromPrimitiveArray(Array[Int](1, 2, 3))), + (DateTimeUtils.getClass, ObjectType(classOf[Date]), + "toJavaDate", ObjectType(classOf[SQLDate]), 77777, DateTimeUtils.toJavaDate(77777)), + (DateTimeUtils.getClass, ObjectType(classOf[Timestamp]), + "toJavaTimestamp", ObjectType(classOf[SQLTimestamp]), + 88888888.toLong, DateTimeUtils.toJavaTimestamp(88888888)) + ).foreach { case (cls, dataType, methodName, argType, arg, expected) => + checkObjectExprEvaluation(StaticInvoke(cls, dataType, methodName, + Seq(BoundReference(0, argType, true))), expected, InternalRow.fromSeq(Seq(arg))) + } + } + test("SPARK-23583: Invoke should support interpreted execution") { val targetObject = new InvokeTargetClass val funcClass = classOf[InvokeTargetClass] From b2329fb1fcdc0e93c4bdc39d574cde7328ef6094 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 5 Apr 2018 13:57:41 +0200 Subject: [PATCH 566/774] Revert "[SPARK-23593][SQL] Add interpreted execution for InitializeJavaBean expression" This reverts commit c5c8b544047a83cb6128a20d31f1d943a15f9260. --- .../expressions/objects/objects.scala | 47 +---------------- .../expressions/ExpressionEvalHelper.scala | 9 ++-- .../expressions/ObjectExpressionsSuite.scala | 52 ------------------- 3 files changed, 5 insertions(+), 103 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9ca0b6137679e..3fa91bd36bb60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1420,47 +1420,8 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp override def children: Seq[Expression] = beanInstance +: setters.values.toSeq override def dataType: DataType = beanInstance.dataType - private lazy val resolvedSetters = { - assert(beanInstance.dataType.isInstanceOf[ObjectType]) - - val ObjectType(beanClass) = beanInstance.dataType - setters.map { - case (name, expr) => - // Looking for known type mapping. - // But also looking for general `Object`-type parameter for generic methods. - val paramTypes = ScalaReflection.expressionJavaClasses(Seq(expr)) ++ Seq(classOf[Object]) - val methods = paramTypes.flatMap { fieldClass => - try { - Some(beanClass.getDeclaredMethod(name, fieldClass)) - } catch { - case e: NoSuchMethodException => None - } - } - if (methods.isEmpty) { - throw new NoSuchMethodException(s"""A method named "$name" is not declared """ + - "in any enclosing class nor any supertype") - } - methods.head -> expr - } - } - - override def eval(input: InternalRow): Any = { - val instance = beanInstance.eval(input) - if (instance != null) { - val bean = instance.asInstanceOf[Object] - resolvedSetters.foreach { - case (setter, expr) => - val paramVal = expr.eval(input) - if (paramVal == null) { - throw new NullPointerException("The parameter value for setters in " + - "`InitializeJavaBean` can not be null") - } else { - setter.invoke(bean, paramVal.asInstanceOf[AnyRef]) - } - } - } - instance - } + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val instanceGen = beanInstance.genCode(ctx) @@ -1473,10 +1434,6 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val fieldGen = fieldValue.genCode(ctx) s""" |${fieldGen.code} - |if (${fieldGen.isNull}) { - | throw new NullPointerException("The parameter value for setters in " + - | "`InitializeJavaBean` can not be null"); - |} |$javaBeanInstance.$setterMethod(${fieldGen.value}); """.stripMargin } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a5ecd1b68fac4..3828f172a15cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -55,8 +55,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def checkEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - // Make it as method to obtain fresh expression everytime. - def expr = prepareEvaluation(expression) + val expr = prepareEvaluation(expression) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) @@ -112,14 +111,12 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val errMsg = intercept[T] { eval }.getMessage - if (!errMsg.contains(expectedErrMsg)) { + if (errMsg != expectedErrMsg) { fail(s"Expected error message is `$expectedErrMsg`, but `$errMsg` found") } } } - - // Make it as method to obtain fresh expression everytime. - def expr = prepareEvaluation(expression) + val expr = prepareEvaluation(expression) checkException(evaluateWithoutCodegen(expr, inputRow), "non-codegen mode") checkException(evaluateWithGeneratedMutableProjection(expr, inputRow), "codegen mode") if (GenerateUnsafeProjection.canSupport(expr.dataType)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index eb89e01b5ff9d..1d59b20077fa9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -192,50 +192,6 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25)) } - test("SPARK-23593: InitializeJavaBean should support interpreted execution") { - val list = new java.util.LinkedList[Int]() - list.add(1) - - val initializeBean = InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]), - Map("add" -> Literal(1))) - checkEvaluation(initializeBean, list, InternalRow.fromSeq(Seq())) - - val initializeWithNonexistingMethod = InitializeJavaBean( - Literal.fromObject(new java.util.LinkedList[Int]), - Map("nonexisting" -> Literal(1))) - checkExceptionInExpression[Exception](initializeWithNonexistingMethod, - InternalRow.fromSeq(Seq()), - """A method named "nonexisting" is not declared in any enclosing class """ + - "nor any supertype") - - val initializeWithWrongParamType = InitializeJavaBean( - Literal.fromObject(new TestBean), - Map("setX" -> Literal("1"))) - intercept[Exception] { - evaluateWithoutCodegen(initializeWithWrongParamType, InternalRow.fromSeq(Seq())) - }.getMessage.contains( - """A method named "setX" is not declared in any enclosing class """ + - "nor any supertype") - } - - test("Can not pass in null into setters in InitializeJavaBean") { - val initializeBean = InitializeJavaBean( - Literal.fromObject(new TestBean), - Map("setNonPrimitive" -> Literal(null))) - intercept[NullPointerException] { - evaluateWithoutCodegen(initializeBean, InternalRow.fromSeq(Seq())) - }.getMessage.contains("The parameter value for setters in `InitializeJavaBean` can not be null") - intercept[NullPointerException] { - evaluateWithGeneratedMutableProjection(initializeBean, InternalRow.fromSeq(Seq())) - }.getMessage.contains("The parameter value for setters in `InitializeJavaBean` can not be null") - - val initializeBean2 = InitializeJavaBean( - Literal.fromObject(new TestBean), - Map("setNonPrimitive" -> Literal("string"))) - evaluateWithoutCodegen(initializeBean2, InternalRow.fromSeq(Seq())) - evaluateWithGeneratedMutableProjection(initializeBean2, InternalRow.fromSeq(Seq())) - } - test("SPARK-23585: UnwrapOption should support interpreted execution") { val cls = classOf[Option[Int]] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) @@ -386,11 +342,3 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } - -class TestBean extends Serializable { - private var x: Int = 0 - - def setX(i: Int): Unit = x = i - def setNonPrimitive(i: AnyRef): Unit = - assert(i != null, "this setter should not be called with null.") -} From d9ca1c906bd0571802f2297c36b407e660fcdb64 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 5 Apr 2018 20:43:05 +0200 Subject: [PATCH 567/774] [SPARK-23593][SQL] Add interpreted execution for InitializeJavaBean expression ## What changes were proposed in this pull request? Add interpreted execution for `InitializeJavaBean` expression. ## How was this patch tested? Added unit test. Author: Liang-Chi Hsieh Closes #20985 from viirya/SPARK-23593-2. --- .../expressions/objects/objects.scala | 45 +++++++++++++++-- .../expressions/ExpressionEvalHelper.scala | 9 ++-- .../expressions/ObjectExpressionsSuite.scala | 48 +++++++++++++++++++ 3 files changed, 96 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 3fa91bd36bb60..9252425f86473 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1420,8 +1420,45 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp override def children: Seq[Expression] = beanInstance +: setters.values.toSeq override def dataType: DataType = beanInstance.dataType - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + private lazy val resolvedSetters = { + assert(beanInstance.dataType.isInstanceOf[ObjectType]) + + val ObjectType(beanClass) = beanInstance.dataType + setters.map { + case (name, expr) => + // Looking for known type mapping. + // But also looking for general `Object`-type parameter for generic methods. + val paramTypes = ScalaReflection.expressionJavaClasses(Seq(expr)) ++ Seq(classOf[Object]) + val methods = paramTypes.flatMap { fieldClass => + try { + Some(beanClass.getDeclaredMethod(name, fieldClass)) + } catch { + case e: NoSuchMethodException => None + } + } + if (methods.isEmpty) { + throw new NoSuchMethodException(s"""A method named "$name" is not declared """ + + "in any enclosing class nor any supertype") + } + methods.head -> expr + } + } + + override def eval(input: InternalRow): Any = { + val instance = beanInstance.eval(input) + if (instance != null) { + val bean = instance.asInstanceOf[Object] + resolvedSetters.foreach { + case (setter, expr) => + val paramVal = expr.eval(input) + // We don't call setter if input value is null. + if (paramVal != null) { + setter.invoke(bean, paramVal.asInstanceOf[AnyRef]) + } + } + } + instance + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val instanceGen = beanInstance.genCode(ctx) @@ -1434,7 +1471,9 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val fieldGen = fieldValue.genCode(ctx) s""" |${fieldGen.code} - |$javaBeanInstance.$setterMethod(${fieldGen.value}); + |if (!${fieldGen.isNull}) { + | $javaBeanInstance.$setterMethod(${fieldGen.value}); + |} """.stripMargin } val initializeCode = ctx.splitExpressionsWithCurrentInputs( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 3828f172a15cf..a5ecd1b68fac4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -55,7 +55,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def checkEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val expr = prepareEvaluation(expression) + // Make it as method to obtain fresh expression everytime. + def expr = prepareEvaluation(expression) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) @@ -111,12 +112,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val errMsg = intercept[T] { eval }.getMessage - if (errMsg != expectedErrMsg) { + if (!errMsg.contains(expectedErrMsg)) { fail(s"Expected error message is `$expectedErrMsg`, but `$errMsg` found") } } } - val expr = prepareEvaluation(expression) + + // Make it as method to obtain fresh expression everytime. + def expr = prepareEvaluation(expression) checkException(evaluateWithoutCodegen(expr, inputRow), "non-codegen mode") checkException(evaluateWithGeneratedMutableProjection(expr, inputRow), "codegen mode") if (GenerateUnsafeProjection.canSupport(expr.dataType)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 1d59b20077fa9..b1bc67dfac1b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -192,6 +192,46 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25)) } + test("SPARK-23593: InitializeJavaBean should support interpreted execution") { + val list = new java.util.LinkedList[Int]() + list.add(1) + + val initializeBean = InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]), + Map("add" -> Literal(1))) + checkEvaluation(initializeBean, list, InternalRow.fromSeq(Seq())) + + val initializeWithNonexistingMethod = InitializeJavaBean( + Literal.fromObject(new java.util.LinkedList[Int]), + Map("nonexisting" -> Literal(1))) + checkExceptionInExpression[Exception](initializeWithNonexistingMethod, + InternalRow.fromSeq(Seq()), + """A method named "nonexisting" is not declared in any enclosing class """ + + "nor any supertype") + + val initializeWithWrongParamType = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setX" -> Literal("1"))) + intercept[Exception] { + evaluateWithoutCodegen(initializeWithWrongParamType, InternalRow.fromSeq(Seq())) + }.getMessage.contains( + """A method named "setX" is not declared in any enclosing class """ + + "nor any supertype") + } + + test("InitializeJavaBean doesn't call setters if input in null") { + val initializeBean = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setNonPrimitive" -> Literal(null))) + evaluateWithoutCodegen(initializeBean, InternalRow.fromSeq(Seq())) + evaluateWithGeneratedMutableProjection(initializeBean, InternalRow.fromSeq(Seq())) + + val initializeBean2 = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setNonPrimitive" -> Literal("string"))) + evaluateWithoutCodegen(initializeBean2, InternalRow.fromSeq(Seq())) + evaluateWithGeneratedMutableProjection(initializeBean2, InternalRow.fromSeq(Seq())) + } + test("SPARK-23585: UnwrapOption should support interpreted execution") { val cls = classOf[Option[Int]] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) @@ -342,3 +382,11 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + +class TestBean extends Serializable { + private var x: Int = 0 + + def setX(i: Int): Unit = x = i + def setNonPrimitive(i: AnyRef): Unit = + assert(i != null, "this setter should not be called with null.") +} From 4807d381bb113a5c61e6dad88202f23a8b6dd141 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 6 Apr 2018 10:13:59 +0800 Subject: [PATCH 568/774] [SPARK-10399][CORE][SQL] Introduce multiple MemoryBlocks to choose several types of memory block ## What changes were proposed in this pull request? This PR allows us to use one of several types of `MemoryBlock`, such as byte array, int array, long array, or `java.nio.DirectByteBuffer`. To use `java.nio.DirectByteBuffer` allows to have off heap memory which is automatically deallocated by JVM. `MemoryBlock` class has primitive accessors like `Platform.getInt()`, `Platform.putint()`, or `Platform.copyMemory()`. This PR uses `MemoryBlock` for `OffHeapColumnVector`, `UTF8String`, and other places. This PR can improve performance of operations involving memory accesses (e.g. `UTF8String.trim`) by 1.8x. For now, this PR does not use `MemoryBlock` for `BufferHolder` based on cloud-fan's [suggestion](https://github.com/apache/spark/pull/11494#issuecomment-309694290). Since this PR is a successor of #11494, close #11494. Many codes were ported from #11494. Many efforts were put here. **I think this PR should credit to yzotov.** This PR can achieve **1.1-1.4x performance improvements** for operations in `UTF8String` or `Murmur3_x86_32`. Other operations are almost comparable performances. Without this PR ``` OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-22-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-22-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Hash byte arrays with length 268435487: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Murmur3_x86_32 526 / 536 0.0 131399881.5 1.0X UTF8String benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ hashCode 525 / 552 1022.6 1.0 1.0X substring 414 / 423 1298.0 0.8 1.3X ``` With this PR ``` OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-22-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Hash byte arrays with length 268435487: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Murmur3_x86_32 474 / 488 0.0 118552232.0 1.0X UTF8String benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ hashCode 476 / 480 1127.3 0.9 1.0X substring 287 / 291 1869.9 0.5 1.7X ``` Benchmark program ``` test("benchmark Murmur3_x86_32") { val length = 8192 * 32768 + 31 val seed = 42L val iters = 1 << 2 val random = new Random(seed) val arrays = Array.fill[MemoryBlock](numArrays) { val bytes = new Array[Byte](length) random.nextBytes(bytes) new ByteArrayMemoryBlock(bytes, Platform.BYTE_ARRAY_OFFSET, length) } val benchmark = new Benchmark("Hash byte arrays with length " + length, iters * numArrays, minNumIters = 20) benchmark.addCase("HiveHasher") { _: Int => var sum = 0L for (_ <- 0L until iters) { sum += HiveHasher.hashUnsafeBytesBlock( arrays(i), Platform.BYTE_ARRAY_OFFSET, length) } } benchmark.run() } test("benchmark UTF8String") { val N = 512 * 1024 * 1024 val iters = 2 val benchmark = new Benchmark("UTF8String benchmark", N, minNumIters = 20) val str0 = new java.io.StringWriter() { { for (i <- 0 until N) { write(" ") } } }.toString val s0 = UTF8String.fromString(str0) benchmark.addCase("hashCode") { _: Int => var h: Int = 0 for (_ <- 0L until iters) { h += s0.hashCode } } benchmark.addCase("substring") { _: Int => var s: UTF8String = null for (_ <- 0L until iters) { s = s0.substring(N / 2 - 5, N / 2 + 5) } } benchmark.run() } ``` I run [this benchmark program](https://gist.github.com/kiszk/94f75b506c93a663bbbc372ffe8f05de) using [the commit](https://github.com/apache/spark/pull/19222/commits/ee5a79861c18725fb1cd9b518cdfd2489c05b81d6). I got the following results: ``` OpenJDK 64-Bit Server VM 1.8.0_151-8u151-b12-0ubuntu0.16.04.2-b12 on Linux 4.4.0-66-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Memory access benchmarks: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ ByteArrayMemoryBlock get/putInt() 220 / 221 609.3 1.6 1.0X Platform get/putInt(byte[]) 220 / 236 610.9 1.6 1.0X Platform get/putInt(Object) 492 / 494 272.8 3.7 0.4X OnHeapMemoryBlock get/putLong() 322 / 323 416.5 2.4 0.7X long[] 221 / 221 608.0 1.6 1.0X Platform get/putLong(long[]) 321 / 321 418.7 2.4 0.7X Platform get/putLong(Object) 561 / 563 239.2 4.2 0.4X ``` I also run [this benchmark program](https://gist.github.com/kiszk/5fdb4e03733a5d110421177e289d1fb5) for comparing performance of `Platform.copyMemory()`. ``` OpenJDK 64-Bit Server VM 1.8.0_151-8u151-b12-0ubuntu0.16.04.2-b12 on Linux 4.4.0-66-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Platform copyMemory: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Object to Object 1961 / 1967 8.6 116.9 1.0X System.arraycopy Object to Object 1917 / 1921 8.8 114.3 1.0X byte array to byte array 1961 / 1968 8.6 116.9 1.0X System.arraycopy byte array to byte array 1909 / 1937 8.8 113.8 1.0X int array to int array 1921 / 1990 8.7 114.5 1.0X double array to double array 1918 / 1923 8.7 114.3 1.0X Object to byte array 1961 / 1967 8.6 116.9 1.0X Object to short array 1965 / 1972 8.5 117.1 1.0X Object to int array 1910 / 1915 8.8 113.9 1.0X Object to float array 1971 / 1978 8.5 117.5 1.0X Object to double array 1919 / 1944 8.7 114.4 1.0X byte array to Object 1959 / 1967 8.6 116.8 1.0X int array to Object 1961 / 1970 8.6 116.9 1.0X double array to Object 1917 / 1924 8.8 114.3 1.0X ``` These results show three facts: 1. According to the second/third or sixth/seventh results in the first experiment, if we use `Platform.get/putInt(Object)`, we achieve more than 2x worse performance than `Platform.get/putInt(byte[])` with concrete type (i.e. `byte[]`). 2. According to the second/third or fourth/fifth/sixth results in the first experiment, the fastest way to access an array element on Java heap is `array[]`. **Cons of `array[]` is that it is not possible to support unaligned-8byte access.** 3. According to the first/second/third or fourth/sixth/seventh results in the first experiment, `getInt()/putInt() or getLong()/putLong()` in subclasses of `MemoryBlock` can achieve comparable performance to `Platform.get/putInt()` or `Platform.get/putLong()` with concrete type (second or sixth result). There is no overhead regarding virtual call. 4. According to results in the second experiment, for `Platform.copy()`, to pass `Object` can achieve the same performance as to pass any type of primitive array as source or destination. 5. According to second/fourth results in the second experiment, `Platform.copy()` can achieve the same performance as `System.arrayCopy`. **It would be good to use `Platform.copy()` since `Platform.copy()` can take any types for src and dst.** We are incrementally replace `Platform.get/putXXX` with `MemoryBlock.get/putXXX`. This is because we have two advantages. 1) Achieve better performance due to having a concrete type for an array. 2) Use simple OO design instead of passing `Object` It is easy to use `MemoryBlock` in `InternalRow`, `BufferHolder`, `TaskMemoryManager`, and others that are already abstracted. It is not easy to use `MemoryBlock` in utility classes related to hashing or others. Other candidates are - UnsafeRow, UnsafeArrayData, UnsafeMapData, SpecificUnsafeRowJoiner - UTF8StringBuffer - BufferHolder - TaskMemoryManager - OnHeapColumnVector - BytesToBytesMap - CachedBatch - classes for hash - others. ## How was this patch tested? Added `UnsafeMemoryAllocator` Author: Kazuaki Ishizaki Closes #19222 from kiszk/SPARK-10399. --- .../sql/catalyst/expressions/HiveHasher.java | 12 +- .../org/apache/spark/unsafe/Platform.java | 2 +- .../spark/unsafe/array/ByteArrayMethods.java | 13 +- .../apache/spark/unsafe/array/LongArray.java | 17 +- .../spark/unsafe/hash/Murmur3_x86_32.java | 45 +++-- .../unsafe/memory/ByteArrayMemoryBlock.java | 128 +++++++++++++ .../unsafe/memory/HeapMemoryAllocator.java | 19 +- .../spark/unsafe/memory/MemoryAllocator.java | 4 +- .../spark/unsafe/memory/MemoryBlock.java | 157 ++++++++++++++-- .../spark/unsafe/memory/MemoryLocation.java | 54 ------ .../unsafe/memory/OffHeapMemoryBlock.java | 105 +++++++++++ .../unsafe/memory/OnHeapMemoryBlock.java | 132 +++++++++++++ .../unsafe/memory/UnsafeMemoryAllocator.java | 21 ++- .../apache/spark/unsafe/types/UTF8String.java | 148 +++++++-------- .../spark/unsafe/PlatformUtilSuite.java | 4 +- .../spark/unsafe/array/LongArraySuite.java | 5 +- .../unsafe/hash/Murmur3_x86_32Suite.java | 18 ++ .../spark/unsafe/memory/MemoryBlockSuite.java | 175 ++++++++++++++++++ .../spark/unsafe/types/UTF8StringSuite.java | 29 +-- .../spark/memory/TaskMemoryManager.java | 22 +-- .../shuffle/sort/ShuffleInMemorySorter.java | 14 +- .../shuffle/sort/ShuffleSortDataFormat.java | 11 +- .../unsafe/sort/UnsafeExternalSorter.java | 2 +- .../unsafe/sort/UnsafeInMemorySorter.java | 13 +- .../spark/memory/TaskMemoryManagerSuite.java | 2 +- .../util/collection/ExternalSorterSuite.scala | 7 +- .../unsafe/sort/RadixSortSuite.scala | 10 +- .../spark/ml/feature/FeatureHasher.scala | 5 +- .../spark/mllib/feature/HashingTF.scala | 2 +- .../catalyst/expressions/UnsafeArrayData.java | 4 +- .../sql/catalyst/expressions/UnsafeRow.java | 4 +- .../spark/sql/catalyst/expressions/XXH64.java | 46 +++-- .../spark/sql/catalyst/expressions/hash.scala | 39 ++-- .../catalyst/expressions/HiveHasherSuite.java | 20 +- .../sql/catalyst/expressions/XXH64Suite.java | 18 +- .../vectorized/OffHeapColumnVector.java | 3 +- .../sql/vectorized/ArrowColumnVector.java | 6 +- .../execution/benchmark/SortBenchmark.scala | 16 +- .../sql/execution/python/RowQueueSuite.scala | 4 +- 39 files changed, 1002 insertions(+), 334 deletions(-) create mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java delete mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java create mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java create mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java create mode 100644 common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java index 73577437ac506..5d905943a3aa7 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; /** * Simulates Hive's hashing function from Hive v1.2.1 @@ -38,12 +39,17 @@ public static int hashLong(long input) { return (int) ((input >>> 32) ^ input); } - public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) { + public static int hashUnsafeBytesBlock(MemoryBlock mb) { + long lengthInBytes = mb.size(); assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int result = 0; - for (int i = 0; i < lengthInBytes; i++) { - result = (result * 31) + (int) Platform.getByte(base, offset + i); + for (long i = 0; i < lengthInBytes; i++) { + result = (result * 31) + (int) mb.getByte(i); } return result; } + + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) { + return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes)); + } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index aca6fca00c48b..54dcadf3a7754 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -187,7 +187,7 @@ public static void setMemory(long address, byte value, long size) { } public static void copyMemory( - Object src, long srcOffset, Object dst, long dstOffset, long length) { + Object src, long srcOffset, Object dst, long dstOffset, long length) { // Check if dstOffset is before or after srcOffset to determine if we should copy // forward or backwards. This is necessary in case src and dst overlap. if (dstOffset < srcOffset) { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index a6b1f7a16d605..c334c9651cf6b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -18,6 +18,7 @@ package org.apache.spark.unsafe.array; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; public class ByteArrayMethods { @@ -48,6 +49,16 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { public static int MAX_ROUNDED_ARRAY_LENGTH = Integer.MAX_VALUE - 15; private static final boolean unaligned = Platform.unaligned(); + /** + * MemoryBlock equality check for MemoryBlocks. + * @return true if the arrays are equal, false otherwise + */ + public static boolean arrayEqualsBlock( + MemoryBlock leftBase, long leftOffset, MemoryBlock rightBase, long rightOffset, final long length) { + return arrayEquals(leftBase.getBaseObject(), leftBase.getBaseOffset() + leftOffset, + rightBase.getBaseObject(), rightBase.getBaseOffset() + rightOffset, length); + } + /** * Optimized byte array equality check for byte arrays. * @return true if the arrays are equal, false otherwise @@ -56,7 +67,7 @@ public static boolean arrayEquals( Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) { int i = 0; - // check if stars align and we can get both offsets to be aligned + // check if starts align and we can get both offsets to be aligned if ((leftOffset % 8) == (rightOffset % 8)) { while ((leftOffset + i) % 8 != 0 && i < length) { if (Platform.getByte(leftBase, leftOffset + i) != diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java index 2cd39bd60c2ac..b74d2de0691d5 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -17,7 +17,6 @@ package org.apache.spark.unsafe.array; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; /** @@ -33,16 +32,12 @@ public final class LongArray { private static final long WIDTH = 8; private final MemoryBlock memory; - private final Object baseObj; - private final long baseOffset; private final long length; public LongArray(MemoryBlock memory) { assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size >= Integer.MAX_VALUE elements"; this.memory = memory; - this.baseObj = memory.getBaseObject(); - this.baseOffset = memory.getBaseOffset(); this.length = memory.size() / WIDTH; } @@ -51,11 +46,11 @@ public MemoryBlock memoryBlock() { } public Object getBaseObject() { - return baseObj; + return memory.getBaseObject(); } public long getBaseOffset() { - return baseOffset; + return memory.getBaseOffset(); } /** @@ -69,8 +64,8 @@ public long size() { * Fill this all with 0L. */ public void zeroOut() { - for (long off = baseOffset; off < baseOffset + length * WIDTH; off += WIDTH) { - Platform.putLong(baseObj, off, 0); + for (long off = 0; off < length * WIDTH; off += WIDTH) { + memory.putLong(off, 0); } } @@ -80,7 +75,7 @@ public void zeroOut() { public void set(int index, long value) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - Platform.putLong(baseObj, baseOffset + index * WIDTH, value); + memory.putLong(index * WIDTH, value); } /** @@ -89,6 +84,6 @@ public void set(int index, long value) { public long get(int index) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - return Platform.getLong(baseObj, baseOffset + index * WIDTH); + return memory.getLong(index * WIDTH); } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index d239de6083ad0..f372b19fac119 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -17,7 +17,9 @@ package org.apache.spark.unsafe.hash; -import org.apache.spark.unsafe.Platform; +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.memory.MemoryBlock; /** * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. @@ -49,49 +51,66 @@ public static int hashInt(int input, int seed) { } public int hashUnsafeWords(Object base, long offset, int lengthInBytes) { - return hashUnsafeWords(base, offset, lengthInBytes, seed); + return hashUnsafeWordsBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); } - public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { + public static int hashUnsafeWordsBlock(MemoryBlock base, int seed) { // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. + int lengthInBytes = Ints.checkedCast(base.size()); assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; - int h1 = hashBytesByInt(base, offset, lengthInBytes, seed); + int h1 = hashBytesByIntBlock(base, seed); return fmix(h1, lengthInBytes); } - public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { + // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. + return hashUnsafeWordsBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); + } + + public static int hashUnsafeBytesBlock(MemoryBlock base, int seed) { // This is not compatible with original and another implementations. // But remain it for backward compatibility for the components existing before 2.3. + int lengthInBytes = Ints.checkedCast(base.size()); assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; - int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + int h1 = hashBytesByIntBlock(base.subBlock(0, lengthAligned), seed); for (int i = lengthAligned; i < lengthInBytes; i++) { - int halfWord = Platform.getByte(base, offset + i); + int halfWord = base.getByte(i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } return fmix(h1, lengthInBytes); } + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); + } + public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) { + return hashUnsafeBytes2Block(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); + } + + public static int hashUnsafeBytes2Block(MemoryBlock base, int seed) { // This is compatible with original and another implementations. // Use this method for new components after Spark 2.3. - assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; + int lengthInBytes = Ints.checkedCast(base.size()); + assert (lengthInBytes >= 0) : "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; - int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + int h1 = hashBytesByIntBlock(base.subBlock(0, lengthAligned), seed); int k1 = 0; for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) { - k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift; + k1 ^= (base.getByte(i) & 0xFF) << shift; } h1 ^= mixK1(k1); return fmix(h1, lengthInBytes); } - private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { + private static int hashBytesByIntBlock(MemoryBlock base, int seed) { + long lengthInBytes = base.size(); assert (lengthInBytes % 4 == 0); int h1 = seed; - for (int i = 0; i < lengthInBytes; i += 4) { - int halfWord = Platform.getInt(base, offset + i); + for (long i = 0; i < lengthInBytes; i += 4) { + int halfWord = base.getInt(i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java new file mode 100644 index 0000000000000..99a9868a49a79 --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.Platform; + +/** + * A consecutive block of memory with a byte array on Java heap. + */ +public final class ByteArrayMemoryBlock extends MemoryBlock { + + private final byte[] array; + + public ByteArrayMemoryBlock(byte[] obj, long offset, long size) { + super(obj, offset, size); + this.array = obj; + assert(offset + size <= Platform.BYTE_ARRAY_OFFSET + obj.length) : + "The sum of size " + size + " and offset " + offset + " should not be larger than " + + "the size of the given memory space " + (obj.length + Platform.BYTE_ARRAY_OFFSET); + } + + public ByteArrayMemoryBlock(long length) { + this(new byte[Ints.checkedCast(length)], Platform.BYTE_ARRAY_OFFSET, length); + } + + @Override + public MemoryBlock subBlock(long offset, long size) { + checkSubBlockRange(offset, size); + if (offset == 0 && size == this.size()) return this; + return new ByteArrayMemoryBlock(array, this.offset + offset, size); + } + + public byte[] getByteArray() { return array; } + + /** + * Creates a memory block pointing to the memory used by the byte array. + */ + public static ByteArrayMemoryBlock fromArray(final byte[] array) { + return new ByteArrayMemoryBlock(array, Platform.BYTE_ARRAY_OFFSET, array.length); + } + + @Override + public final int getInt(long offset) { + return Platform.getInt(array, this.offset + offset); + } + + @Override + public final void putInt(long offset, int value) { + Platform.putInt(array, this.offset + offset, value); + } + + @Override + public final boolean getBoolean(long offset) { + return Platform.getBoolean(array, this.offset + offset); + } + + @Override + public final void putBoolean(long offset, boolean value) { + Platform.putBoolean(array, this.offset + offset, value); + } + + @Override + public final byte getByte(long offset) { + return array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)]; + } + + @Override + public final void putByte(long offset, byte value) { + array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)] = value; + } + + @Override + public final short getShort(long offset) { + return Platform.getShort(array, this.offset + offset); + } + + @Override + public final void putShort(long offset, short value) { + Platform.putShort(array, this.offset + offset, value); + } + + @Override + public final long getLong(long offset) { + return Platform.getLong(array, this.offset + offset); + } + + @Override + public final void putLong(long offset, long value) { + Platform.putLong(array, this.offset + offset, value); + } + + @Override + public final float getFloat(long offset) { + return Platform.getFloat(array, this.offset + offset); + } + + @Override + public final void putFloat(long offset, float value) { + Platform.putFloat(array, this.offset + offset, value); + } + + @Override + public final double getDouble(long offset) { + return Platform.getDouble(array, this.offset + offset); + } + + @Override + public final void putDouble(long offset, double value) { + Platform.putDouble(array, this.offset + offset, value); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index 2733760dd19ef..acf28fd7ee59b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -58,7 +58,7 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { final long[] array = arrayReference.get(); if (array != null) { assert (array.length * 8L >= size); - MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); + MemoryBlock memory = OnHeapMemoryBlock.fromArray(array, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -70,7 +70,7 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { } } long[] array = new long[numWords]; - MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); + MemoryBlock memory = OnHeapMemoryBlock.fromArray(array, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -79,12 +79,13 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { @Override public void free(MemoryBlock memory) { - assert (memory.obj != null) : + assert(memory instanceof OnHeapMemoryBlock); + assert (memory.getBaseObject() != null) : "baseObject was null; are you trying to use the on-heap allocator to free off-heap memory?"; - assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + assert (memory.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : "page has already been freed"; - assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) - || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + assert ((memory.getPageNumber() == MemoryBlock.NO_PAGE_NUMBER) + || (memory.getPageNumber() == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator " + "free()"; @@ -94,12 +95,12 @@ public void free(MemoryBlock memory) { } // Mark the page as freed (so we can detect double-frees). - memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; + memory.setPageNumber(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER); // As an additional layer of defense against use-after-free bugs, we mutate the // MemoryBlock to null out its reference to the long[] array. - long[] array = (long[]) memory.obj; - memory.setObjAndOffset(null, 0); + long[] array = ((OnHeapMemoryBlock)memory).getLongArray(); + memory.resetObjAndOffset(); long alignedSize = ((size + 7) / 8) * 8; if (shouldPool(alignedSize)) { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java index 7b588681d9790..38315fb97b46a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java @@ -38,7 +38,7 @@ public interface MemoryAllocator { void free(MemoryBlock memory); - MemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); + UnsafeMemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); - MemoryAllocator HEAP = new HeapMemoryAllocator(); + HeapMemoryAllocator HEAP = new HeapMemoryAllocator(); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index c333857358d30..b086941108522 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -22,10 +22,10 @@ import org.apache.spark.unsafe.Platform; /** - * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. + * A representation of a consecutive memory block in Spark. It defines the common interfaces + * for memory accessing and mutating. */ -public class MemoryBlock extends MemoryLocation { - +public abstract class MemoryBlock { /** Special `pageNumber` value for pages which were not allocated by TaskMemoryManagers */ public static final int NO_PAGE_NUMBER = -1; @@ -45,38 +45,163 @@ public class MemoryBlock extends MemoryLocation { */ public static final int FREED_IN_ALLOCATOR_PAGE_NUMBER = -3; - private final long length; + @Nullable + protected Object obj; + + protected long offset; + + protected long length; /** * Optional page number; used when this MemoryBlock represents a page allocated by a - * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager, - * which lives in a different package. + * TaskMemoryManager. This field can be updated using setPageNumber method so that + * this can be modified by the TaskMemoryManager, which lives in a different package. */ - public int pageNumber = NO_PAGE_NUMBER; + private int pageNumber = NO_PAGE_NUMBER; - public MemoryBlock(@Nullable Object obj, long offset, long length) { - super(obj, offset); + protected MemoryBlock(@Nullable Object obj, long offset, long length) { + if (offset < 0 || length < 0) { + throw new IllegalArgumentException( + "Length " + length + " and offset " + offset + "must be non-negative"); + } + this.obj = obj; + this.offset = offset; this.length = length; } + protected MemoryBlock() { + this(null, 0, 0); + } + + public final Object getBaseObject() { + return obj; + } + + public final long getBaseOffset() { + return offset; + } + + public void resetObjAndOffset() { + this.obj = null; + this.offset = 0; + } + /** * Returns the size of the memory block. */ - public long size() { + public final long size() { return length; } - /** - * Creates a memory block pointing to the memory used by the long array. - */ - public static MemoryBlock fromLongArray(final long[] array) { - return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8L); + public final void setPageNumber(int pageNum) { + pageNumber = pageNum; + } + + public final int getPageNumber() { + return pageNumber; } /** * Fills the memory block with the specified byte value. */ - public void fill(byte value) { + public final void fill(byte value) { Platform.setMemory(obj, offset, length, value); } + + /** + * Instantiate MemoryBlock for given object type with new offset + */ + public final static MemoryBlock allocateFromObject(Object obj, long offset, long length) { + MemoryBlock mb = null; + if (obj instanceof byte[]) { + byte[] array = (byte[])obj; + mb = new ByteArrayMemoryBlock(array, offset, length); + } else if (obj instanceof long[]) { + long[] array = (long[])obj; + mb = new OnHeapMemoryBlock(array, offset, length); + } else if (obj == null) { + // we assume that to pass null pointer means off-heap + mb = new OffHeapMemoryBlock(offset, length); + } else { + throw new UnsupportedOperationException( + "Instantiate MemoryBlock for type " + obj.getClass() + " is not supported now"); + } + return mb; + } + + /** + * Just instantiate the sub-block with the same type of MemoryBlock with the new size and relative + * offset from the original offset. The data is not copied. + * If parameters are invalid, an exception is thrown. + */ + public abstract MemoryBlock subBlock(long offset, long size); + + protected void checkSubBlockRange(long offset, long size) { + if (offset < 0 || size < 0) { + throw new ArrayIndexOutOfBoundsException( + "Size " + size + " and offset " + offset + " must be non-negative"); + } + if (offset + size > length) { + throw new ArrayIndexOutOfBoundsException("The sum of size " + size + " and offset " + + offset + " should not be larger than the length " + length + " in the MemoryBlock"); + } + } + + /** + * getXXX/putXXX does not ensure guarantee behavior if the offset is invalid. e.g cause illegal + * memory access, throw an exception, or etc. + * getXXX/putXXX uses an index based on this.offset that includes the size of metadata such as + * JVM object header. The offset is 0-based and is expected as an logical offset in the memory + * block. + */ + public abstract int getInt(long offset); + + public abstract void putInt(long offset, int value); + + public abstract boolean getBoolean(long offset); + + public abstract void putBoolean(long offset, boolean value); + + public abstract byte getByte(long offset); + + public abstract void putByte(long offset, byte value); + + public abstract short getShort(long offset); + + public abstract void putShort(long offset, short value); + + public abstract long getLong(long offset); + + public abstract void putLong(long offset, long value); + + public abstract float getFloat(long offset); + + public abstract void putFloat(long offset, float value); + + public abstract double getDouble(long offset); + + public abstract void putDouble(long offset, double value); + + public static final void copyMemory( + MemoryBlock src, long srcOffset, MemoryBlock dst, long dstOffset, long length) { + assert(srcOffset + length <= src.length && dstOffset + length <= dst.length); + Platform.copyMemory(src.getBaseObject(), src.getBaseOffset() + srcOffset, + dst.getBaseObject(), dst.getBaseOffset() + dstOffset, length); + } + + public static final void copyMemory(MemoryBlock src, MemoryBlock dst, long length) { + assert(length <= src.length && length <= dst.length); + Platform.copyMemory(src.getBaseObject(), src.getBaseOffset(), + dst.getBaseObject(), dst.getBaseOffset(), length); + } + + public final void copyFrom(Object src, long srcOffset, long dstOffset, long length) { + assert(length <= this.length - srcOffset); + Platform.copyMemory(src, srcOffset, obj, offset + dstOffset, length); + } + + public final void writeTo(long srcOffset, Object dst, long dstOffset, long length) { + assert(length <= this.length - srcOffset); + Platform.copyMemory(obj, offset + srcOffset, dst, dstOffset, length); + } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java deleted file mode 100644 index 74ebc87dc978c..0000000000000 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.memory; - -import javax.annotation.Nullable; - -/** - * A memory location. Tracked either by a memory address (with off-heap allocation), - * or by an offset from a JVM object (in-heap allocation). - */ -public class MemoryLocation { - - @Nullable - Object obj; - - long offset; - - public MemoryLocation(@Nullable Object obj, long offset) { - this.obj = obj; - this.offset = offset; - } - - public MemoryLocation() { - this(null, 0); - } - - public void setObjAndOffset(Object newObj, long newOffset) { - this.obj = newObj; - this.offset = newOffset; - } - - public final Object getBaseObject() { - return obj; - } - - public final long getBaseOffset() { - return offset; - } -} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java new file mode 100644 index 0000000000000..f90f62bf21dcb --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.apache.spark.unsafe.Platform; + +public class OffHeapMemoryBlock extends MemoryBlock { + static public final OffHeapMemoryBlock NULL = new OffHeapMemoryBlock(0, 0); + + public OffHeapMemoryBlock(long address, long size) { + super(null, address, size); + } + + @Override + public MemoryBlock subBlock(long offset, long size) { + checkSubBlockRange(offset, size); + if (offset == 0 && size == this.size()) return this; + return new OffHeapMemoryBlock(this.offset + offset, size); + } + + @Override + public final int getInt(long offset) { + return Platform.getInt(null, this.offset + offset); + } + + @Override + public final void putInt(long offset, int value) { + Platform.putInt(null, this.offset + offset, value); + } + + @Override + public final boolean getBoolean(long offset) { + return Platform.getBoolean(null, this.offset + offset); + } + + @Override + public final void putBoolean(long offset, boolean value) { + Platform.putBoolean(null, this.offset + offset, value); + } + + @Override + public final byte getByte(long offset) { + return Platform.getByte(null, this.offset + offset); + } + + @Override + public final void putByte(long offset, byte value) { + Platform.putByte(null, this.offset + offset, value); + } + + @Override + public final short getShort(long offset) { + return Platform.getShort(null, this.offset + offset); + } + + @Override + public final void putShort(long offset, short value) { + Platform.putShort(null, this.offset + offset, value); + } + + @Override + public final long getLong(long offset) { + return Platform.getLong(null, this.offset + offset); + } + + @Override + public final void putLong(long offset, long value) { + Platform.putLong(null, this.offset + offset, value); + } + + @Override + public final float getFloat(long offset) { + return Platform.getFloat(null, this.offset + offset); + } + + @Override + public final void putFloat(long offset, float value) { + Platform.putFloat(null, this.offset + offset, value); + } + + @Override + public final double getDouble(long offset) { + return Platform.getDouble(null, this.offset + offset); + } + + @Override + public final void putDouble(long offset, double value) { + Platform.putDouble(null, this.offset + offset, value); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java new file mode 100644 index 0000000000000..12f67c7bd593e --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.Platform; + +/** + * A consecutive block of memory with a long array on Java heap. + */ +public final class OnHeapMemoryBlock extends MemoryBlock { + + private final long[] array; + + public OnHeapMemoryBlock(long[] obj, long offset, long size) { + super(obj, offset, size); + this.array = obj; + assert(offset + size <= obj.length * 8L + Platform.LONG_ARRAY_OFFSET) : + "The sum of size " + size + " and offset " + offset + " should not be larger than " + + "the size of the given memory space " + (obj.length * 8L + Platform.LONG_ARRAY_OFFSET); + } + + public OnHeapMemoryBlock(long size) { + this(new long[Ints.checkedCast((size + 7) / 8)], Platform.LONG_ARRAY_OFFSET, size); + } + + @Override + public MemoryBlock subBlock(long offset, long size) { + checkSubBlockRange(offset, size); + if (offset == 0 && size == this.size()) return this; + return new OnHeapMemoryBlock(array, this.offset + offset, size); + } + + public long[] getLongArray() { return array; } + + /** + * Creates a memory block pointing to the memory used by the long array. + */ + public static OnHeapMemoryBlock fromArray(final long[] array) { + return new OnHeapMemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8L); + } + + public static OnHeapMemoryBlock fromArray(final long[] array, long size) { + return new OnHeapMemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); + } + + @Override + public final int getInt(long offset) { + return Platform.getInt(array, this.offset + offset); + } + + @Override + public final void putInt(long offset, int value) { + Platform.putInt(array, this.offset + offset, value); + } + + @Override + public final boolean getBoolean(long offset) { + return Platform.getBoolean(array, this.offset + offset); + } + + @Override + public final void putBoolean(long offset, boolean value) { + Platform.putBoolean(array, this.offset + offset, value); + } + + @Override + public final byte getByte(long offset) { + return Platform.getByte(array, this.offset + offset); + } + + @Override + public final void putByte(long offset, byte value) { + Platform.putByte(array, this.offset + offset, value); + } + + @Override + public final short getShort(long offset) { + return Platform.getShort(array, this.offset + offset); + } + + @Override + public final void putShort(long offset, short value) { + Platform.putShort(array, this.offset + offset, value); + } + + @Override + public final long getLong(long offset) { + return Platform.getLong(array, this.offset + offset); + } + + @Override + public final void putLong(long offset, long value) { + Platform.putLong(array, this.offset + offset, value); + } + + @Override + public final float getFloat(long offset) { + return Platform.getFloat(array, this.offset + offset); + } + + @Override + public final void putFloat(long offset, float value) { + Platform.putFloat(array, this.offset + offset, value); + } + + @Override + public final double getDouble(long offset) { + return Platform.getDouble(array, this.offset + offset); + } + + @Override + public final void putDouble(long offset, double value) { + Platform.putDouble(array, this.offset + offset, value); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java index 4368fb615ba1e..5310bdf2779a9 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -25,9 +25,9 @@ public class UnsafeMemoryAllocator implements MemoryAllocator { @Override - public MemoryBlock allocate(long size) throws OutOfMemoryError { + public OffHeapMemoryBlock allocate(long size) throws OutOfMemoryError { long address = Platform.allocateMemory(size); - MemoryBlock memory = new MemoryBlock(null, address, size); + OffHeapMemoryBlock memory = new OffHeapMemoryBlock(address, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -36,22 +36,25 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { @Override public void free(MemoryBlock memory) { - assert (memory.obj == null) : - "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; - assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + assert(memory instanceof OffHeapMemoryBlock) : + "UnsafeMemoryAllocator can only free OffHeapMemoryBlock."; + if (memory == OffHeapMemoryBlock.NULL) return; + assert (memory.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : "page has already been freed"; - assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) - || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + assert ((memory.getPageNumber() == MemoryBlock.NO_PAGE_NUMBER) + || (memory.getPageNumber() == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : "TMM-allocated pages must be freed via TMM.freePage(), not directly in allocator free()"; if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); } + Platform.freeMemory(memory.offset); + // As an additional layer of defense against use-after-free bugs, we mutate the // MemoryBlock to reset its pointer. - memory.offset = 0; + memory.resetObjAndOffset(); // Mark the page as freed (so we can detect double-frees). - memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; + memory.setPageNumber(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER); } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 5d468aed42337..e9b3d9b045af5 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -30,9 +30,12 @@ import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; +import com.google.common.primitives.Ints; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import static org.apache.spark.unsafe.Platform.*; @@ -50,12 +53,13 @@ public final class UTF8String implements Comparable, Externalizable, // These are only updated by readExternal() or read() @Nonnull - private Object base; - private long offset; + private MemoryBlock base; + // While numBytes has the same value as base.size(), to keep as int avoids cast from long to int private int numBytes; - public Object getBaseObject() { return base; } - public long getBaseOffset() { return offset; } + public MemoryBlock getMemoryBlock() { return base; } + public Object getBaseObject() { return base.getBaseObject(); } + public long getBaseOffset() { return base.getBaseOffset(); } /** * A char in UTF-8 encoding can take 1-4 bytes depending on the first byte which @@ -108,7 +112,8 @@ public final class UTF8String implements Comparable, Externalizable, */ public static UTF8String fromBytes(byte[] bytes) { if (bytes != null) { - return new UTF8String(bytes, BYTE_ARRAY_OFFSET, bytes.length); + return new UTF8String( + new ByteArrayMemoryBlock(bytes, BYTE_ARRAY_OFFSET, bytes.length)); } else { return null; } @@ -121,19 +126,13 @@ public static UTF8String fromBytes(byte[] bytes) { */ public static UTF8String fromBytes(byte[] bytes, int offset, int numBytes) { if (bytes != null) { - return new UTF8String(bytes, BYTE_ARRAY_OFFSET + offset, numBytes); + return new UTF8String( + new ByteArrayMemoryBlock(bytes, BYTE_ARRAY_OFFSET + offset, numBytes)); } else { return null; } } - /** - * Creates an UTF8String from given address (base and offset) and length. - */ - public static UTF8String fromAddress(Object base, long offset, int numBytes) { - return new UTF8String(base, offset, numBytes); - } - /** * Creates an UTF8String from String. */ @@ -150,16 +149,13 @@ public static UTF8String blankString(int length) { return fromBytes(spaces); } - protected UTF8String(Object base, long offset, int numBytes) { + public UTF8String(MemoryBlock base) { this.base = base; - this.offset = offset; - this.numBytes = numBytes; + this.numBytes = Ints.checkedCast(base.size()); } // for serialization - public UTF8String() { - this(null, 0, 0); - } + public UTF8String() {} /** * Writes the content of this string into a memory address, identified by an object and an offset. @@ -167,7 +163,7 @@ public UTF8String() { * bytes in this string. */ public void writeToMemory(Object target, long targetOffset) { - Platform.copyMemory(base, offset, target, targetOffset, numBytes); + base.writeTo(0, target, targetOffset, numBytes); } public void writeTo(ByteBuffer buffer) { @@ -187,8 +183,9 @@ public void writeTo(ByteBuffer buffer) { */ @Nonnull public ByteBuffer getByteBuffer() { - if (base instanceof byte[] && offset >= BYTE_ARRAY_OFFSET) { - final byte[] bytes = (byte[]) base; + long offset = base.getBaseOffset(); + if (base instanceof ByteArrayMemoryBlock && offset >= BYTE_ARRAY_OFFSET) { + final byte[] bytes = ((ByteArrayMemoryBlock) base).getByteArray(); // the offset includes an object header... this is only needed for unsafe copies final long arrayOffset = offset - BYTE_ARRAY_OFFSET; @@ -255,12 +252,12 @@ public long getPrefix() { long mask = 0; if (IS_LITTLE_ENDIAN) { if (numBytes >= 8) { - p = Platform.getLong(base, offset); + p = base.getLong(0); } else if (numBytes > 4) { - p = Platform.getLong(base, offset); + p = base.getLong(0); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = (long) Platform.getInt(base, offset); + p = (long) base.getInt(0); mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -269,12 +266,12 @@ public long getPrefix() { } else { // byteOrder == ByteOrder.BIG_ENDIAN if (numBytes >= 8) { - p = Platform.getLong(base, offset); + p = base.getLong(0); } else if (numBytes > 4) { - p = Platform.getLong(base, offset); + p = base.getLong(0); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = ((long) Platform.getInt(base, offset)) << 32; + p = ((long) base.getInt(0)) << 32; mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -289,12 +286,13 @@ public long getPrefix() { */ public byte[] getBytes() { // avoid copy if `base` is `byte[]` - if (offset == BYTE_ARRAY_OFFSET && base instanceof byte[] - && ((byte[]) base).length == numBytes) { - return (byte[]) base; + long offset = base.getBaseOffset(); + if (offset == BYTE_ARRAY_OFFSET && base instanceof ByteArrayMemoryBlock + && (((ByteArrayMemoryBlock) base).getByteArray()).length == numBytes) { + return ((ByteArrayMemoryBlock) base).getByteArray(); } else { byte[] bytes = new byte[numBytes]; - copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); + base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, numBytes); return bytes; } } @@ -324,7 +322,7 @@ public UTF8String substring(final int start, final int until) { if (i > j) { byte[] bytes = new byte[i - j]; - copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j); + base.writeTo(j, bytes, BYTE_ARRAY_OFFSET, i - j); return fromBytes(bytes); } else { return EMPTY_UTF8; @@ -365,14 +363,14 @@ public boolean contains(final UTF8String substring) { * Returns the byte at position `i`. */ private byte getByte(int i) { - return Platform.getByte(base, offset + i); + return base.getByte(i); } private boolean matchAt(final UTF8String s, int pos) { if (s.numBytes + pos > numBytes || pos < 0) { return false; } - return ByteArrayMethods.arrayEquals(base, offset + pos, s.base, s.offset, s.numBytes); + return ByteArrayMethods.arrayEqualsBlock(base, pos, s.base, 0, s.numBytes); } public boolean startsWith(final UTF8String prefix) { @@ -499,8 +497,7 @@ public int findInSet(UTF8String match) { for (int i = 0; i < numBytes; i++) { if (getByte(i) == (byte) ',') { if (i - (lastComma + 1) == match.numBytes && - ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, - match.numBytes)) { + ByteArrayMethods.arrayEqualsBlock(base, lastComma + 1, match.base, 0, match.numBytes)) { return n; } lastComma = i; @@ -508,8 +505,7 @@ public int findInSet(UTF8String match) { } } if (numBytes - (lastComma + 1) == match.numBytes && - ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, - match.numBytes)) { + ByteArrayMethods.arrayEqualsBlock(base, lastComma + 1, match.base, 0, match.numBytes)) { return n; } return 0; @@ -524,7 +520,7 @@ public int findInSet(UTF8String match) { private UTF8String copyUTF8String(int start, int end) { int len = end - start + 1; byte[] newBytes = new byte[len]; - copyMemory(base, offset + start, newBytes, BYTE_ARRAY_OFFSET, len); + base.writeTo(start, newBytes, BYTE_ARRAY_OFFSET, len); return UTF8String.fromBytes(newBytes); } @@ -671,8 +667,7 @@ public UTF8String reverse() { int i = 0; // position in byte while (i < numBytes) { int len = numBytesForFirstByte(getByte(i)); - copyMemory(this.base, this.offset + i, result, - BYTE_ARRAY_OFFSET + result.length - i - len, len); + base.writeTo(i, result, BYTE_ARRAY_OFFSET + result.length - i - len, len); i += len; } @@ -686,7 +681,7 @@ public UTF8String repeat(int times) { } byte[] newBytes = new byte[numBytes * times]; - copyMemory(this.base, this.offset, newBytes, BYTE_ARRAY_OFFSET, numBytes); + base.writeTo(0, newBytes, BYTE_ARRAY_OFFSET, numBytes); int copied = 1; while (copied < times) { @@ -723,7 +718,7 @@ public int indexOf(UTF8String v, int start) { if (i + v.numBytes > numBytes) { return -1; } - if (ByteArrayMethods.arrayEquals(base, offset + i, v.base, v.offset, v.numBytes)) { + if (ByteArrayMethods.arrayEqualsBlock(base, i, v.base, 0, v.numBytes)) { return c; } i += numBytesForFirstByte(getByte(i)); @@ -739,7 +734,7 @@ public int indexOf(UTF8String v, int start) { private int find(UTF8String str, int start) { assert (str.numBytes > 0); while (start <= numBytes - str.numBytes) { - if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { + if (ByteArrayMethods.arrayEqualsBlock(base, start, str.base, 0, str.numBytes)) { return start; } start += 1; @@ -753,7 +748,7 @@ private int find(UTF8String str, int start) { private int rfind(UTF8String str, int start) { assert (str.numBytes > 0); while (start >= 0) { - if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { + if (ByteArrayMethods.arrayEqualsBlock(base, start, str.base, 0, str.numBytes)) { return start; } start -= 1; @@ -786,7 +781,7 @@ public UTF8String subStringIndex(UTF8String delim, int count) { return EMPTY_UTF8; } byte[] bytes = new byte[idx]; - copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, idx); + base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, idx); return fromBytes(bytes); } else { @@ -806,7 +801,7 @@ public UTF8String subStringIndex(UTF8String delim, int count) { } int size = numBytes - delim.numBytes - idx; byte[] bytes = new byte[size]; - copyMemory(base, offset + idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size); + base.writeTo(idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size); return fromBytes(bytes); } } @@ -829,15 +824,15 @@ public UTF8String rpad(int len, UTF8String pad) { UTF8String remain = pad.substring(0, spaces - padChars * count); byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; - copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET, this.numBytes); + base.writeTo(0, data, BYTE_ARRAY_OFFSET, this.numBytes); int offset = this.numBytes; int idx = 0; while (idx < count) { - copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); + pad.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++ idx; offset += pad.numBytes; } - copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); + remain.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); return UTF8String.fromBytes(data); } @@ -865,13 +860,13 @@ public UTF8String lpad(int len, UTF8String pad) { int offset = 0; int idx = 0; while (idx < count) { - copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); + pad.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++ idx; offset += pad.numBytes; } - copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); + remain.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); offset += remain.numBytes; - copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET + offset, numBytes()); + base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, numBytes()); return UTF8String.fromBytes(data); } @@ -896,8 +891,8 @@ public static UTF8String concat(UTF8String... inputs) { int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].numBytes; - copyMemory( - inputs[i].base, inputs[i].offset, + inputs[i].base.writeTo( + 0, result, BYTE_ARRAY_OFFSET + offset, len); offset += len; @@ -936,8 +931,8 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { for (int i = 0, j = 0; i < inputs.length; i++) { if (inputs[i] != null) { int len = inputs[i].numBytes; - copyMemory( - inputs[i].base, inputs[i].offset, + inputs[i].base.writeTo( + 0, result, BYTE_ARRAY_OFFSET + offset, len); offset += len; @@ -945,8 +940,8 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { j++; // Add separator if this is not the last input. if (j < numInputs) { - copyMemory( - separator.base, separator.offset, + separator.base.writeTo( + 0, result, BYTE_ARRAY_OFFSET + offset, separator.numBytes); offset += separator.numBytes; @@ -1220,7 +1215,7 @@ public UTF8String clone() { public UTF8String copy() { byte[] bytes = new byte[numBytes]; - copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); + base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, numBytes); return fromBytes(bytes); } @@ -1228,11 +1223,10 @@ public UTF8String copy() { public int compareTo(@Nonnull final UTF8String other) { int len = Math.min(numBytes, other.numBytes); int wordMax = (len / 8) * 8; - long roffset = other.offset; - Object rbase = other.base; + MemoryBlock rbase = other.base; for (int i = 0; i < wordMax; i += 8) { - long left = getLong(base, offset + i); - long right = getLong(rbase, roffset + i); + long left = base.getLong(i); + long right = rbase.getLong(i); if (left != right) { if (IS_LITTLE_ENDIAN) { return Long.compareUnsigned(Long.reverseBytes(left), Long.reverseBytes(right)); @@ -1243,7 +1237,7 @@ public int compareTo(@Nonnull final UTF8String other) { } for (int i = wordMax; i < len; i++) { // In UTF-8, the byte should be unsigned, so we should compare them as unsigned int. - int res = (getByte(i) & 0xFF) - (Platform.getByte(rbase, roffset + i) & 0xFF); + int res = (getByte(i) & 0xFF) - (rbase.getByte(i) & 0xFF); if (res != 0) { return res; } @@ -1262,7 +1256,7 @@ public boolean equals(final Object other) { if (numBytes != o.numBytes) { return false; } - return ByteArrayMethods.arrayEquals(base, offset, o.base, o.offset, numBytes); + return ByteArrayMethods.arrayEqualsBlock(base, 0, o.base, 0, numBytes); } else { return false; } @@ -1318,8 +1312,8 @@ public int levenshteinDistance(UTF8String other) { num_bytes_j != numBytesForFirstByte(s.getByte(i_bytes))) { cost = 1; } else { - cost = (ByteArrayMethods.arrayEquals(t.base, t.offset + j_bytes, s.base, - s.offset + i_bytes, num_bytes_j)) ? 0 : 1; + cost = (ByteArrayMethods.arrayEqualsBlock(t.base, j_bytes, s.base, + i_bytes, num_bytes_j)) ? 0 : 1; } d[i + 1] = Math.min(Math.min(d[i] + 1, p[i + 1] + 1), p[i] + cost); } @@ -1334,7 +1328,7 @@ public int levenshteinDistance(UTF8String other) { @Override public int hashCode() { - return Murmur3_x86_32.hashUnsafeBytes(base, offset, numBytes, 42); + return Murmur3_x86_32.hashUnsafeBytesBlock(base,42); } /** @@ -1397,10 +1391,10 @@ public void writeExternal(ObjectOutput out) throws IOException { } public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - offset = BYTE_ARRAY_OFFSET; numBytes = in.readInt(); - base = new byte[numBytes]; - in.readFully((byte[]) base); + byte[] bytes = new byte[numBytes]; + in.readFully(bytes); + base = ByteArrayMemoryBlock.fromArray(bytes); } @Override @@ -1412,10 +1406,10 @@ public void write(Kryo kryo, Output out) { @Override public void read(Kryo kryo, Input in) { - this.offset = BYTE_ARRAY_OFFSET; - this.numBytes = in.readInt(); - this.base = new byte[numBytes]; - in.read((byte[]) base); + numBytes = in.readInt(); + byte[] bytes = new byte[numBytes]; + in.read(bytes); + base = ByteArrayMemoryBlock.fromArray(bytes); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 3ad9ac7b4de9c..583a148b3845d 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -81,7 +81,7 @@ public void freeingOnHeapMemoryBlockResetsBaseObjectAndOffset() { MemoryAllocator.HEAP.free(block); Assert.assertNull(block.getBaseObject()); Assert.assertEquals(0, block.getBaseOffset()); - Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.getPageNumber()); } @Test @@ -92,7 +92,7 @@ public void freeingOffHeapMemoryBlockResetsOffset() { MemoryAllocator.UNSAFE.free(block); Assert.assertNull(block.getBaseObject()); Assert.assertEquals(0, block.getBaseOffset()); - Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.getPageNumber()); } @Test(expected = AssertionError.class) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java index fb8e53b3348f3..8c2e98c2bfc54 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java @@ -20,14 +20,13 @@ import org.junit.Assert; import org.junit.Test; -import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; public class LongArraySuite { @Test public void basicTest() { - long[] bytes = new long[2]; - LongArray arr = new LongArray(MemoryBlock.fromLongArray(bytes)); + LongArray arr = new LongArray(new OnHeapMemoryBlock(16)); arr.set(0, 1L); arr.set(1, 2L); arr.set(1, 3L); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java index 6348a73bf3895..d7ed005db1891 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -70,6 +70,24 @@ public void testKnownBytesInputs() { Murmur3_x86_32.hashUnsafeBytes2(tes, Platform.BYTE_ARRAY_OFFSET, tes.length, 0)); } + @Test + public void testKnownWordsInputs() { + byte[] bytes = new byte[16]; + long offset = Platform.BYTE_ARRAY_OFFSET; + for (int i = 0; i < 16; i++) { + bytes[i] = 0; + } + Assert.assertEquals(-300363099, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + for (int i = 0; i < 16; i++) { + bytes[i] = -1; + } + Assert.assertEquals(-1210324667, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + for (int i = 0; i < 16; i++) { + bytes[i] = (byte)i; + } + Assert.assertEquals(-634919701, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + } + @Test public void randomizedStressTest() { int size = 65536; diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java new file mode 100644 index 0000000000000..47f05c928f2e5 --- /dev/null +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.apache.spark.unsafe.Platform; +import org.junit.Assert; +import org.junit.Test; + +import java.nio.ByteOrder; + +import static org.hamcrest.core.StringContains.containsString; + +public class MemoryBlockSuite { + private static final boolean bigEndianPlatform = + ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); + + private void check(MemoryBlock memory, Object obj, long offset, int length) { + memory.setPageNumber(1); + memory.fill((byte)-1); + memory.putBoolean(0, true); + memory.putByte(1, (byte)127); + memory.putShort(2, (short)257); + memory.putInt(4, 0x20000002); + memory.putLong(8, 0x1234567089ABCDEFL); + memory.putFloat(16, 1.0F); + memory.putLong(20, 0x1234567089ABCDEFL); + memory.putDouble(28, 2.0); + MemoryBlock.copyMemory(memory, 0L, memory, 36, 4); + int[] a = new int[2]; + a[0] = 0x12345678; + a[1] = 0x13579BDF; + memory.copyFrom(a, Platform.INT_ARRAY_OFFSET, 40, 8); + byte[] b = new byte[8]; + memory.writeTo(40, b, Platform.BYTE_ARRAY_OFFSET, 8); + + Assert.assertEquals(obj, memory.getBaseObject()); + Assert.assertEquals(offset, memory.getBaseOffset()); + Assert.assertEquals(length, memory.size()); + Assert.assertEquals(1, memory.getPageNumber()); + Assert.assertEquals(true, memory.getBoolean(0)); + Assert.assertEquals((byte)127, memory.getByte(1 )); + Assert.assertEquals((short)257, memory.getShort(2)); + Assert.assertEquals(0x20000002, memory.getInt(4)); + Assert.assertEquals(0x1234567089ABCDEFL, memory.getLong(8)); + Assert.assertEquals(1.0F, memory.getFloat(16), 0); + Assert.assertEquals(0x1234567089ABCDEFL, memory.getLong(20)); + Assert.assertEquals(2.0, memory.getDouble(28), 0); + Assert.assertEquals(true, memory.getBoolean(36)); + Assert.assertEquals((byte)127, memory.getByte(37 )); + Assert.assertEquals((short)257, memory.getShort(38)); + Assert.assertEquals(a[0], memory.getInt(40)); + Assert.assertEquals(a[1], memory.getInt(44)); + if (bigEndianPlatform) { + Assert.assertEquals(a[0], + ((int)b[0] & 0xff) << 24 | ((int)b[1] & 0xff) << 16 | + ((int)b[2] & 0xff) << 8 | ((int)b[3] & 0xff)); + Assert.assertEquals(a[1], + ((int)b[4] & 0xff) << 24 | ((int)b[5] & 0xff) << 16 | + ((int)b[6] & 0xff) << 8 | ((int)b[7] & 0xff)); + } else { + Assert.assertEquals(a[0], + ((int)b[3] & 0xff) << 24 | ((int)b[2] & 0xff) << 16 | + ((int)b[1] & 0xff) << 8 | ((int)b[0] & 0xff)); + Assert.assertEquals(a[1], + ((int)b[7] & 0xff) << 24 | ((int)b[6] & 0xff) << 16 | + ((int)b[5] & 0xff) << 8 | ((int)b[4] & 0xff)); + } + for (int i = 48; i < memory.size(); i++) { + Assert.assertEquals((byte) -1, memory.getByte(i)); + } + + assert(memory.subBlock(0, memory.size()) == memory); + + try { + memory.subBlock(-8, 8); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("non-negative")); + } + + try { + memory.subBlock(0, -8); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("non-negative")); + } + + try { + memory.subBlock(0, length + 8); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); + } + + try { + memory.subBlock(8, length - 4); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); + } + + try { + memory.subBlock(length + 8, 4); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); + } + } + + @Test + public void ByteArrayMemoryBlockTest() { + byte[] obj = new byte[56]; + long offset = Platform.BYTE_ARRAY_OFFSET; + int length = obj.length; + + MemoryBlock memory = new ByteArrayMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + + memory = ByteArrayMemoryBlock.fromArray(obj); + check(memory, obj, offset, length); + + obj = new byte[112]; + memory = new ByteArrayMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + } + + @Test + public void OnHeapMemoryBlockTest() { + long[] obj = new long[7]; + long offset = Platform.LONG_ARRAY_OFFSET; + int length = obj.length * 8; + + MemoryBlock memory = new OnHeapMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + + memory = OnHeapMemoryBlock.fromArray(obj); + check(memory, obj, offset, length); + + obj = new long[14]; + memory = new OnHeapMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + } + + @Test + public void OffHeapArrayMemoryBlockTest() { + MemoryAllocator memoryAllocator = new UnsafeMemoryAllocator(); + MemoryBlock memory = memoryAllocator.allocate(56); + Object obj = memory.getBaseObject(); + long offset = memory.getBaseOffset(); + int length = 56; + + check(memory, obj, offset, length); + + long address = Platform.allocateMemory(112); + memory = new OffHeapMemoryBlock(address, length); + obj = memory.getBaseObject(); + offset = memory.getBaseOffset(); + check(memory, obj, offset, length); + } +} diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 7c34d419574ef..bad908fcaf136 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -26,6 +26,9 @@ import com.google.common.collect.ImmutableMap; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; import org.junit.Test; import static org.junit.Assert.*; @@ -519,7 +522,8 @@ public void writeToOutputStreamUnderflow() throws IOException { final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8); for (int i = 1; i <= Platform.BYTE_ARRAY_OFFSET; ++i) { - UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i) + new UTF8String( + new ByteArrayMemoryBlock(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i)) .writeTo(outputStream); final ByteBuffer buffer = ByteBuffer.wrap(outputStream.toByteArray(), i, test.length); assertEquals("01234567", StandardCharsets.UTF_8.decode(buffer).toString()); @@ -534,7 +538,7 @@ public void writeToOutputStreamSlice() throws IOException { for (int i = 0; i < test.length; ++i) { for (int j = 0; j < test.length - i; ++j) { - UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET + i, j) + new UTF8String(ByteArrayMemoryBlock.fromArray(test).subBlock(i, j)) .writeTo(outputStream); assertArrayEquals(Arrays.copyOfRange(test, i, i + j), outputStream.toByteArray()); @@ -565,7 +569,7 @@ public void writeToOutputStreamOverflow() throws IOException { for (final long offset : offsets) { try { - fromAddress(test, BYTE_ARRAY_OFFSET + offset, test.length) + new UTF8String(ByteArrayMemoryBlock.fromArray(test).subBlock(offset, test.length)) .writeTo(outputStream); throw new IllegalStateException(Long.toString(offset)); @@ -592,26 +596,25 @@ public void writeToOutputStream() throws IOException { } @Test - public void writeToOutputStreamIntArray() throws IOException { + public void writeToOutputStreamLongArray() throws IOException { // verify that writes work on objects that are not byte arrays - final ByteBuffer buffer = StandardCharsets.UTF_8.encode("大千世界"); + final ByteBuffer buffer = StandardCharsets.UTF_8.encode("3千大千世界"); buffer.position(0); buffer.order(ByteOrder.nativeOrder()); final int length = buffer.limit(); - assertEquals(12, length); + assertEquals(16, length); - final int ints = length / 4; - final int[] array = new int[ints]; + final int longs = length / 8; + final long[] array = new long[longs]; - for (int i = 0; i < ints; ++i) { - array[i] = buffer.getInt(); + for (int i = 0; i < longs; ++i) { + array[i] = buffer.getLong(); } final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - fromAddress(array, Platform.INT_ARRAY_OFFSET, length) - .writeTo(outputStream); - assertEquals("大千世界", outputStream.toString("UTF-8")); + new UTF8String(OnHeapMemoryBlock.fromArray(array)).writeTo(outputStream); + assertEquals("3千大千世界", outputStream.toString("UTF-8")); } @Test diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index d07faf1da1248..8651a639c07f7 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -311,7 +311,7 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { // this could trigger spilling to free some pages. return allocatePage(size, consumer); } - page.pageNumber = pageNumber; + page.setPageNumber(pageNumber); pageTable[pageNumber] = page; if (logger.isTraceEnabled()) { logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired); @@ -323,25 +323,25 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}. */ public void freePage(MemoryBlock page, MemoryConsumer consumer) { - assert (page.pageNumber != MemoryBlock.NO_PAGE_NUMBER) : + assert (page.getPageNumber() != MemoryBlock.NO_PAGE_NUMBER) : "Called freePage() on memory that wasn't allocated with allocatePage()"; - assert (page.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + assert (page.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : "Called freePage() on a memory block that has already been freed"; - assert (page.pageNumber != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) : + assert (page.getPageNumber() != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) : "Called freePage() on a memory block that has already been freed"; - assert(allocatedPages.get(page.pageNumber)); - pageTable[page.pageNumber] = null; + assert(allocatedPages.get(page.getPageNumber())); + pageTable[page.getPageNumber()] = null; synchronized (this) { - allocatedPages.clear(page.pageNumber); + allocatedPages.clear(page.getPageNumber()); } if (logger.isTraceEnabled()) { - logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); + logger.trace("Freed page number {} ({} bytes)", page.getPageNumber(), page.size()); } long pageSize = page.size(); // Clear the page number before passing the block to the MemoryAllocator's free(). // Doing this allows the MemoryAllocator to detect when a TaskMemoryManager-managed // page has been inappropriately directly freed without calling TMM.freePage(). - page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; + page.setPageNumber(MemoryBlock.FREED_IN_TMM_PAGE_NUMBER); memoryManager.tungstenMemoryAllocator().free(page); releaseExecutionMemory(pageSize, consumer); } @@ -363,7 +363,7 @@ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { // relative to the page's base offset; this relative offset will fit in 51 bits. offsetInPage -= page.getBaseOffset(); } - return encodePageNumberAndOffset(page.pageNumber, offsetInPage); + return encodePageNumberAndOffset(page.getPageNumber(), offsetInPage); } @VisibleForTesting @@ -434,7 +434,7 @@ public long cleanUpAllAllocatedMemory() { for (MemoryBlock page : pageTable) { if (page != null) { logger.debug("unreleased page: " + page + " in task " + taskAttemptId); - page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; + page.setPageNumber(MemoryBlock.FREED_IN_TMM_PAGE_NUMBER); memoryManager.tungstenMemoryAllocator().free(page); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index dc36809d8911f..8f49859746b89 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -20,7 +20,6 @@ import java.util.Comparator; import org.apache.spark.memory.MemoryConsumer; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.Sorter; @@ -105,13 +104,7 @@ public void reset() { public void expandPointerArray(LongArray newArray) { assert(newArray.size() > array.size()); - Platform.copyMemory( - array.getBaseObject(), - array.getBaseOffset(), - newArray.getBaseObject(), - newArray.getBaseOffset(), - pos * 8L - ); + MemoryBlock.copyMemory(array.memoryBlock(), newArray.memoryBlock(), pos * 8L); consumer.freeArray(array); array = newArray; usableCapacity = getUsableCapacity(); @@ -180,10 +173,7 @@ public ShuffleSorterIterator getSortedIterator() { PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX, PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false); } else { - MemoryBlock unused = new MemoryBlock( - array.getBaseObject(), - array.getBaseOffset() + pos * 8L, - (array.size() - pos) * 8L); + MemoryBlock unused = array.memoryBlock().subBlock(pos * 8L, (array.size() - pos) * 8L); LongArray buffer = new LongArray(unused); Sorter sorter = new Sorter<>(new ShuffleSortDataFormat(buffer)); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java index 717bdd79d47ef..254449e95443e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java @@ -17,8 +17,8 @@ package org.apache.spark.shuffle.sort; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.SortDataFormat; final class ShuffleSortDataFormat extends SortDataFormat { @@ -60,13 +60,8 @@ public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) { @Override public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) { - Platform.copyMemory( - src.getBaseObject(), - src.getBaseOffset() + srcPos * 8L, - dst.getBaseObject(), - dst.getBaseOffset() + dstPos * 8L, - length * 8L - ); + MemoryBlock.copyMemory(src.memoryBlock(), srcPos * 8L, + dst.memoryBlock(),dstPos * 8L,length * 8L); } @Override diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 66118f454159b..4fc19b1721518 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -544,7 +544,7 @@ public long spill() throws IOException { // is accessing the current record. We free this page in that caller's next loadNext() // call. for (MemoryBlock page : allocatedPages) { - if (!loaded || page.pageNumber != + if (!loaded || page.getPageNumber() != ((UnsafeInMemorySorter.SortedIterator)upstream).getCurrentPageNumber()) { released += page.size(); freePage(page); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index b3c27d83da172..20a7a8b267438 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -26,7 +26,6 @@ import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.UnsafeAlignedOffset; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -216,12 +215,7 @@ public void expandPointerArray(LongArray newArray) { if (newArray.size() < array.size()) { throw new SparkOutOfMemoryError("Not enough memory to grow pointer array"); } - Platform.copyMemory( - array.getBaseObject(), - array.getBaseOffset(), - newArray.getBaseObject(), - newArray.getBaseOffset(), - pos * 8L); + MemoryBlock.copyMemory(array.memoryBlock(), newArray.memoryBlock(), pos * 8L); consumer.freeArray(array); array = newArray; usableCapacity = getUsableCapacity(); @@ -348,10 +342,7 @@ public UnsafeSorterIterator getSortedIterator() { array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); } else { - MemoryBlock unused = new MemoryBlock( - array.getBaseObject(), - array.getBaseOffset() + pos * 8L, - (array.size() - pos) * 8L); + MemoryBlock unused = array.memoryBlock().subBlock(pos * 8L, (array.size() - pos) * 8L); LongArray buffer = new LongArray(unused); Sorter sorter = new Sorter<>(new UnsafeSortDataFormat(buffer)); diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index a0664b30d6cc2..d7d2d0b012bd3 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -76,7 +76,7 @@ public void freeingPageSetsPageNumberToSpecialConstant() { final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); final MemoryBlock dataPage = manager.allocatePage(256, c); c.freePage(dataPage); - Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.pageNumber); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.getPageNumber()); } @Test(expected = AssertionError.class) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 47173b89e91e2..3e56db5ea116a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark._ import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordPointerAndKeyPrefix, UnsafeSortDataFormat} class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { @@ -105,9 +105,8 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { // the form [150000000, 150000001, 150000002, ...., 300000000, 0, 1, 2, ..., 149999999] // that can trigger copyRange() in TimSort.mergeLo() or TimSort.mergeHi() val ref = Array.tabulate[Long](size) { i => if (i < size / 2) size / 2 + i else i } - val buf = new LongArray(MemoryBlock.fromLongArray(ref)) - val tmp = new Array[Long](size/2) - val tmpBuf = new LongArray(MemoryBlock.fromLongArray(tmp)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(ref)) + val tmpBuf = new LongArray(new OnHeapMemoryBlock((size/2) * 8L)) new Sorter(new UnsafeSortDataFormat(tmpBuf)).sort( buf, 0, size, new Comparator[RecordPointerAndKeyPrefix] { diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala index d5956ea32096a..ddf3740e76a7a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala @@ -27,7 +27,7 @@ import com.google.common.primitives.Ints import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.collection.Sorter import org.apache.spark.util.random.XORShiftRandom @@ -78,14 +78,14 @@ class RadixSortSuite extends SparkFunSuite with Logging { private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = { val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand } val extended = ref ++ Array.fill[Long](Ints.checkedCast(size))(0) - (ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended))) + (ref.map(i => new JLong(i)), new LongArray(OnHeapMemoryBlock.fromArray(extended))) } private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = { val ref = Array.tabulate[Long](Ints.checkedCast(size * 2)) { i => rand } val extended = ref ++ Array.fill[Long](Ints.checkedCast(size * 2))(0) - (new LongArray(MemoryBlock.fromLongArray(ref)), - new LongArray(MemoryBlock.fromLongArray(extended))) + (new LongArray(OnHeapMemoryBlock.fromArray(ref)), + new LongArray(OnHeapMemoryBlock.fromArray(extended))) } private def collectToArray(array: LongArray, offset: Int, length: Long): Array[Long] = { @@ -110,7 +110,7 @@ class RadixSortSuite extends SparkFunSuite with Logging { } private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) { - val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) + val sortBuffer = new LongArray(new OnHeapMemoryBlock(buf.size() * 8L)) new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( buf, Ints.checkedCast(lo), Ints.checkedCast(hi), new Comparator[RecordPointerAndKeyPrefix] { override def compare( diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index c78f61ac3ef71..d67e4819b161a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -29,7 +29,7 @@ import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2} +import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2Block} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashMap @@ -243,8 +243,7 @@ object FeatureHasher extends DefaultParamsReadable[FeatureHasher] { case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed) case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) case s: String => - val utf8 = UTF8String.fromString(s) - hashUnsafeBytes2(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed) + hashUnsafeBytes2Block(UTF8String.fromString(s).getMemoryBlock, seed) case _ => throw new SparkException("FeatureHasher with murmur3 algorithm does not " + s"support type ${term.getClass.getCanonicalName} of input data.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index 8935c8496cdbb..7b73b286fb91c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -160,7 +160,7 @@ object HashingTF { case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) case s: String => val utf8 = UTF8String.fromString(s) - hashUnsafeBytes(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed) + hashUnsafeBytesBlock(utf8.getMemoryBlock(), seed) case _ => throw new SparkException("HashingTF with murmur3 algorithm does not " + s"support type ${term.getClass.getCanonicalName} of input data.") } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index d18542b188f71..8546c28335536 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -27,6 +27,7 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -230,7 +231,8 @@ public UTF8String getUTF8String(int ordinal) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) offsetAndSize; - return UTF8String.fromAddress(baseObject, baseOffset + offset, size); + MemoryBlock mb = MemoryBlock.allocateFromObject(baseObject, baseOffset + offset, size); + return new UTF8String(mb); } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 71c086029cc5b..29a1411241cf6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -37,6 +37,7 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -414,7 +415,8 @@ public UTF8String getUTF8String(int ordinal) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) offsetAndSize; - return UTF8String.fromAddress(baseObject, baseOffset + offset, size); + MemoryBlock mb = MemoryBlock.allocateFromObject(baseObject, baseOffset + offset, size); + return new UTF8String(mb); } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java index f37ef83ad92b4..883748932ad33 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java @@ -16,7 +16,10 @@ */ package org.apache.spark.sql.catalyst.expressions; +import com.google.common.primitives.Ints; + import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; // scalastyle: off /** @@ -71,13 +74,13 @@ public static long hashLong(long input, long seed) { return fmix(hash); } - public long hashUnsafeWords(Object base, long offset, int length) { - return hashUnsafeWords(base, offset, length, seed); + public long hashUnsafeWordsBlock(MemoryBlock mb) { + return hashUnsafeWordsBlock(mb, seed); } - public static long hashUnsafeWords(Object base, long offset, int length, long seed) { - assert (length % 8 == 0) : "lengthInBytes must be a multiple of 8 (word-aligned)"; - long hash = hashBytesByWords(base, offset, length, seed); + public static long hashUnsafeWordsBlock(MemoryBlock mb, long seed) { + assert (mb.size() % 8 == 0) : "lengthInBytes must be a multiple of 8 (word-aligned)"; + long hash = hashBytesByWordsBlock(mb, seed); return fmix(hash); } @@ -85,26 +88,32 @@ public long hashUnsafeBytes(Object base, long offset, int length) { return hashUnsafeBytes(base, offset, length, seed); } - public static long hashUnsafeBytes(Object base, long offset, int length, long seed) { + public static long hashUnsafeBytesBlock(MemoryBlock mb, long seed) { + long offset = 0; + long length = mb.size(); assert (length >= 0) : "lengthInBytes cannot be negative"; - long hash = hashBytesByWords(base, offset, length, seed); + long hash = hashBytesByWordsBlock(mb, seed); long end = offset + length; offset += length & -8; if (offset + 4L <= end) { - hash ^= (Platform.getInt(base, offset) & 0xFFFFFFFFL) * PRIME64_1; + hash ^= (mb.getInt(offset) & 0xFFFFFFFFL) * PRIME64_1; hash = Long.rotateLeft(hash, 23) * PRIME64_2 + PRIME64_3; offset += 4L; } while (offset < end) { - hash ^= (Platform.getByte(base, offset) & 0xFFL) * PRIME64_5; + hash ^= (mb.getByte(offset) & 0xFFL) * PRIME64_5; hash = Long.rotateLeft(hash, 11) * PRIME64_1; offset++; } return fmix(hash); } + public static long hashUnsafeBytes(Object base, long offset, int length, long seed) { + return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, length), seed); + } + private static long fmix(long hash) { hash ^= hash >>> 33; hash *= PRIME64_2; @@ -114,30 +123,31 @@ private static long fmix(long hash) { return hash; } - private static long hashBytesByWords(Object base, long offset, int length, long seed) { - long end = offset + length; + private static long hashBytesByWordsBlock(MemoryBlock mb, long seed) { + long offset = 0; + long length = mb.size(); long hash; if (length >= 32) { - long limit = end - 32; + long limit = length - 32; long v1 = seed + PRIME64_1 + PRIME64_2; long v2 = seed + PRIME64_2; long v3 = seed; long v4 = seed - PRIME64_1; do { - v1 += Platform.getLong(base, offset) * PRIME64_2; + v1 += mb.getLong(offset) * PRIME64_2; v1 = Long.rotateLeft(v1, 31); v1 *= PRIME64_1; - v2 += Platform.getLong(base, offset + 8) * PRIME64_2; + v2 += mb.getLong(offset + 8) * PRIME64_2; v2 = Long.rotateLeft(v2, 31); v2 *= PRIME64_1; - v3 += Platform.getLong(base, offset + 16) * PRIME64_2; + v3 += mb.getLong(offset + 16) * PRIME64_2; v3 = Long.rotateLeft(v3, 31); v3 *= PRIME64_1; - v4 += Platform.getLong(base, offset + 24) * PRIME64_2; + v4 += mb.getLong(offset + 24) * PRIME64_2; v4 = Long.rotateLeft(v4, 31); v4 *= PRIME64_1; @@ -178,9 +188,9 @@ private static long hashBytesByWords(Object base, long offset, int length, long hash += length; - long limit = end - 8; + long limit = length - 8; while (offset <= limit) { - long k1 = Platform.getLong(base, offset); + long k1 = mb.getLong(offset); hash ^= Long.rotateLeft(k1 * PRIME64_2, 31) * PRIME64_1; hash = Long.rotateLeft(hash, 27) * PRIME64_1 + PRIME64_4; offset += 8L; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index b702422ed7a1d..b76b64ab5096f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 +import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -360,10 +361,8 @@ abstract class HashExpression[E] extends Expression { } protected def genHashString(input: String, result: String): String = { - val baseObject = s"$input.getBaseObject()" - val baseOffset = s"$input.getBaseOffset()" - val numBytes = s"$input.numBytes()" - s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" + val mb = s"$input.getMemoryBlock()" + s"$result = $hasherClassName.hashUnsafeBytesBlock($mb, $result);" } protected def genHashForMap( @@ -465,6 +464,8 @@ abstract class InterpretedHashFunction { protected def hashUnsafeBytes(base: AnyRef, offset: Long, length: Int, seed: Long): Long + protected def hashUnsafeBytesBlock(base: MemoryBlock, seed: Long): Long + /** * Computes hash of a given `value` of type `dataType`. The caller needs to check the validity * of input `value`. @@ -490,8 +491,7 @@ abstract class InterpretedHashFunction { case c: CalendarInterval => hashInt(c.months, hashLong(c.microseconds, seed)) case a: Array[Byte] => hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed) - case s: UTF8String => - hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) + case s: UTF8String => hashUnsafeBytesBlock(s.getMemoryBlock(), seed) case array: ArrayData => val elementType = dataType match { @@ -578,9 +578,15 @@ object Murmur3HashFunction extends InterpretedHashFunction { Murmur3_x86_32.hashLong(l, seed.toInt) } - override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes( + base: AnyRef, offset: Long, len: Int, seed: Long): Long = { Murmur3_x86_32.hashUnsafeBytes(base, offset, len, seed.toInt) } + + override protected def hashUnsafeBytesBlock( + base: MemoryBlock, seed: Long): Long = { + Murmur3_x86_32.hashUnsafeBytesBlock(base, seed.toInt) + } } /** @@ -605,9 +611,14 @@ object XxHash64Function extends InterpretedHashFunction { override protected def hashLong(l: Long, seed: Long): Long = XXH64.hashLong(l, seed) - override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes( + base: AnyRef, offset: Long, len: Int, seed: Long): Long = { XXH64.hashUnsafeBytes(base, offset, len, seed) } + + override protected def hashUnsafeBytesBlock(base: MemoryBlock, seed: Long): Long = { + XXH64.hashUnsafeBytesBlock(base, seed) + } } /** @@ -714,10 +725,8 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """ override protected def genHashString(input: String, result: String): String = { - val baseObject = s"$input.getBaseObject()" - val baseOffset = s"$input.getBaseOffset()" - val numBytes = s"$input.numBytes()" - s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes);" + val mb = s"$input.getMemoryBlock()" + s"$result = $hasherClassName.hashUnsafeBytesBlock($mb);" } override protected def genHashForArray( @@ -805,10 +814,14 @@ object HiveHashFunction extends InterpretedHashFunction { HiveHasher.hashLong(l) } - override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes( + base: AnyRef, offset: Long, len: Int, seed: Long): Long = { HiveHasher.hashUnsafeBytes(base, offset, len) } + override protected def hashUnsafeBytesBlock( + base: MemoryBlock, seed: Long): Long = HiveHasher.hashUnsafeBytesBlock(base) + private val HIVE_DECIMAL_MAX_PRECISION = 38 private val HIVE_DECIMAL_MAX_SCALE = 38 diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java index b67c6f3e6e85e..8ffc1d7c24d61 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java @@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.UTF8String; import org.junit.Assert; import org.junit.Test; @@ -53,7 +55,7 @@ public void testKnownStringAndIntInputs() { for (int i = 0; i < inputs.length; i++) { UTF8String s = UTF8String.fromString("val_" + inputs[i]); - int hash = HiveHasher.hashUnsafeBytes(s.getBaseObject(), s.getBaseOffset(), s.numBytes()); + int hash = HiveHasher.hashUnsafeBytesBlock(s.getMemoryBlock()); Assert.assertEquals(expected[i], ((31 * inputs[i]) + hash)); } } @@ -89,13 +91,13 @@ public void randomizedStressTestBytes() { int byteArrSize = rand.nextInt(100) * 8; byte[] bytes = new byte[byteArrSize]; rand.nextBytes(bytes); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(bytes); Assert.assertEquals( - HiveHasher.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - HiveHasher.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + HiveHasher.hashUnsafeBytesBlock(mb), + HiveHasher.hashUnsafeBytesBlock(mb)); - hashcodes.add(HiveHasher.hashUnsafeBytes( - bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(HiveHasher.hashUnsafeBytesBlock(mb)); } // A very loose bound. @@ -112,13 +114,13 @@ public void randomizedStressTestPaddedStrings() { byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8); byte[] paddedBytes = new byte[byteArrSize]; System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(paddedBytes); Assert.assertEquals( - HiveHasher.hashUnsafeBytes(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - HiveHasher.hashUnsafeBytes(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + HiveHasher.hashUnsafeBytesBlock(mb), + HiveHasher.hashUnsafeBytesBlock(mb)); - hashcodes.add(HiveHasher.hashUnsafeBytes( - paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(HiveHasher.hashUnsafeBytesBlock(mb)); } // A very loose bound. diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java index 1baee91b3439c..cd8bce623c5df 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java @@ -24,6 +24,8 @@ import java.util.Set; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.junit.Assert; import org.junit.Test; @@ -142,13 +144,13 @@ public void randomizedStressTestBytes() { int byteArrSize = rand.nextInt(100) * 8; byte[] bytes = new byte[byteArrSize]; rand.nextBytes(bytes); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(bytes); Assert.assertEquals( - hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWordsBlock(mb), + hasher.hashUnsafeWordsBlock(mb)); - hashcodes.add(hasher.hashUnsafeWords( - bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(hasher.hashUnsafeWordsBlock(mb)); } // A very loose bound. @@ -165,13 +167,13 @@ public void randomizedStressTestPaddedStrings() { byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8); byte[] paddedBytes = new byte[byteArrSize]; System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(paddedBytes); Assert.assertEquals( - hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWordsBlock(mb), + hasher.hashUnsafeWordsBlock(mb)); - hashcodes.add(hasher.hashUnsafeWords( - paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(hasher.hashUnsafeWordsBlock(mb)); } // A very loose bound. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 754c26579ff08..4733f36174f42 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.OffHeapMemoryBlock; import org.apache.spark.unsafe.types.UTF8String; /** @@ -206,7 +207,7 @@ public byte[] getBytes(int rowId, int count) { @Override protected UTF8String getBytesAsUTF8String(int rowId, int count) { - return UTF8String.fromAddress(null, data + rowId, count); + return new UTF8String(new OffHeapMemoryBlock(data + rowId, count)); } // diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index f8e37e995a17f..227a16f7e69e9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -25,6 +25,7 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.arrow.ArrowUtils; import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.memory.OffHeapMemoryBlock; import org.apache.spark.unsafe.types.UTF8String; /** @@ -377,9 +378,10 @@ final UTF8String getUTF8String(int rowId) { if (stringResult.isSet == 0) { return null; } else { - return UTF8String.fromAddress(null, + return new UTF8String(new OffHeapMemoryBlock( stringResult.buffer.memoryAddress() + stringResult.start, - stringResult.end - stringResult.start); + stringResult.end - stringResult.start + )); } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala index 50ae26a3ff9d9..470b93efd1974 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.benchmark import java.util.{Arrays, Comparator} import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.Benchmark import org.apache.spark.util.collection.Sorter import org.apache.spark.util.collection.unsafe.sort._ @@ -36,7 +36,7 @@ import org.apache.spark.util.random.XORShiftRandom class SortBenchmark extends BenchmarkBase { private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) { - val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) + val sortBuffer = new LongArray(new OnHeapMemoryBlock(buf.size() * 8L)) new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] { override def compare( @@ -50,8 +50,8 @@ class SortBenchmark extends BenchmarkBase { private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = { val ref = Array.tabulate[Long](size * 2) { i => rand } val extended = ref ++ Array.fill[Long](size * 2)(0) - (new LongArray(MemoryBlock.fromLongArray(ref)), - new LongArray(MemoryBlock.fromLongArray(extended))) + (new LongArray(OnHeapMemoryBlock.fromArray(ref)), + new LongArray(OnHeapMemoryBlock.fromArray(extended))) } ignore("sort") { @@ -60,7 +60,7 @@ class SortBenchmark extends BenchmarkBase { val benchmark = new Benchmark("radix sort " + size, size) benchmark.addTimerCase("reference TimSort key prefix array") { timer => val array = Array.tabulate[Long](size * 2) { i => rand.nextLong } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() referenceKeyPrefixSort(buf, 0, size, PrefixComparators.BINARY) timer.stopTiming() @@ -78,7 +78,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong & 0xff i += 1 } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() @@ -90,7 +90,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong & 0xffff i += 1 } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() @@ -102,7 +102,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong i += 1 } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala index ffda33cf906c5..25ee95daa034c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala @@ -22,13 +22,13 @@ import java.io.File import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.Utils class RowQueueSuite extends SparkFunSuite { test("in-memory queue") { - val page = MemoryBlock.fromLongArray(new Array[Long](1<<10)) + val page = new OnHeapMemoryBlock((1<<10) * 8L) val queue = new InMemoryRowQueue(page, 1) { override def close() {} } From f2ac0879561cde63ed4eb759f5efa0a5ce393a22 Mon Sep 17 00:00:00 2001 From: Yogesh Garg Date: Thu, 5 Apr 2018 19:55:42 -0700 Subject: [PATCH 569/774] [SPARK-23870][ML] Forward RFormula handleInvalid Param to VectorAssembler to handle invalid values in non-string columns ## What changes were proposed in this pull request? `handleInvalid` Param was forwarded to the VectorAssembler used by RFormula. ## How was this patch tested? added a test and ran all tests for RFormula and VectorAssembler Author: Yogesh Garg Closes #20970 from yogeshg/spark_23562. --- .../apache/spark/ml/feature/RFormula.scala | 1 + .../spark/ml/feature/RFormulaSuite.scala | 23 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 22e7b8bbf1ff5..e214765e3307f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -278,6 +278,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) encoderStages += new VectorAssembler(uid) .setInputCols(encodedTerms.toArray) .setOutputCol($(featuresCol)) + .setHandleInvalid($(handleInvalid)) encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap) encoderStages += new ColumnPruner(tempColumns.toSet) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 27d570f0b68ad..a250331efeb1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.SparkException import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite @@ -592,4 +593,26 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { assert(features.toArray === a +: b.toArray) } } + + test("SPARK-23562 RFormula handleInvalid should handle invalid values in non-string columns.") { + val d1 = Seq( + (1001L, "a"), + (1002L, "b")).toDF("id1", "c1") + val d2 = Seq[(java.lang.Long, String)]( + (20001L, "x"), + (20002L, "y"), + (null, null)).toDF("id2", "c2") + val dataset = d1.crossJoin(d2) + + def get_output(mode: String): DataFrame = { + val formula = new RFormula().setFormula("c1 ~ id2").setHandleInvalid(mode) + formula.fit(dataset).transform(dataset).select("features", "label") + } + + assert(intercept[SparkException](get_output("error").collect()) + .getMessage.contains("Encountered null while assembling a row")) + assert(get_output("skip").count() == 4) + assert(get_output("keep").count() == 6) + } + } From d65e531b44a388fed25d3cbf28fdce5a2d0598e6 Mon Sep 17 00:00:00 2001 From: JiahuiJiang Date: Thu, 5 Apr 2018 20:06:08 -0700 Subject: [PATCH 570/774] [SPARK-23823][SQL] Keep origin in transformExpression Fixes https://issues.apache.org/jira/browse/SPARK-23823 Keep origin for all the methods using transformExpression ## What changes were proposed in this pull request? Keep origin in transformExpression ## How was this patch tested? Manually tested that this fixes https://issues.apache.org/jira/browse/SPARK-23823 and columns have correct origins after Analyzer.analyze Author: JiahuiJiang Author: Jiahui Jiang Closes #20961 from JiahuiJiang/jj/keep-origin. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 6 ++- .../sql/catalyst/plans/QueryPlanSuite.scala | 42 +++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index ddf2cbf2ab911..64cb8c726772f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} @@ -103,7 +103,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT var changed = false @inline def transformExpression(e: Expression): Expression = { - val newE = f(e) + val newE = CurrentOrigin.withOrigin(e.origin) { + f(e) + } if (newE.fastEquals(e)) { e } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala new file mode 100644 index 0000000000000..27914ef5565c0 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.plans +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.types.IntegerType + +class QueryPlanSuite extends SparkFunSuite { + + test("origin remains the same after mapExpressions (SPARK-23823)") { + CurrentOrigin.setPosition(0, 0) + val column = AttributeReference("column", IntegerType)(NamedExpression.newExprId) + val query = plans.DslLogicalPlan(plans.table("table")).select(column) + CurrentOrigin.reset() + + val mappedQuery = query mapExpressions { + case _: Expression => Literal(1) + } + + val mappedOrigin = mappedQuery.expressions.apply(0).origin + assert(mappedOrigin == Origin.apply(Some(0), Some(0))) + } + +} From 249007e37f51f00d14e596692aeac87fbc10b520 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 5 Apr 2018 20:19:25 -0700 Subject: [PATCH 571/774] [SPARK-19724][SQL] create a managed table with an existed default table should throw an exception ## What changes were proposed in this pull request? This PR is to finish https://github.com/apache/spark/pull/17272 This JIRA is a follow up work after SPARK-19583 As we discussed in that PR The following DDL for a managed table with an existed default location should throw an exception: CREATE TABLE ... (PARTITIONED BY ...) AS SELECT ... CREATE TABLE ... (PARTITIONED BY ...) Currently there are some situations which are not consist with above logic: CREATE TABLE ... (PARTITIONED BY ...) succeed with an existed default location situation: for both hive/datasource(with HiveExternalCatalog/InMemoryCatalog) CREATE TABLE ... (PARTITIONED BY ...) AS SELECT ... situation: hive table succeed with an existed default location This PR is going to make above two situations consist with the logic that it should throw an exception with an existed default location. ## How was this patch tested? unit test added Author: Gengliang Wang Closes #20886 from gengliangwang/pr-17272. --- docs/sql-programming-guide.md | 1 + .../sql/catalyst/catalog/SessionCatalog.scala | 23 ++++++- .../apache/spark/sql/internal/SQLConf.scala | 11 +++ .../command/createDataSourceTables.scala | 5 +- .../spark/sql/StatisticsCollectionSuite.scala | 7 ++ .../sql/execution/command/DDLSuite.scala | 67 +++++++++++++++++++ 6 files changed, 110 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 2b393f30d1435..9822d669050d5 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1809,6 +1809,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 64e7ca11270b4..52ed89ef8d781 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -289,6 +289,7 @@ class SessionCatalog( def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { val db = formatDatabaseName(tableDefinition.identifier.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableDefinition.identifier.table) + val tableIdentifier = TableIdentifier(table, Some(db)) validateName(table) val newTableDefinition = if (tableDefinition.storage.locationUri.isDefined @@ -298,15 +299,33 @@ class SessionCatalog( makeQualifiedPath(tableDefinition.storage.locationUri.get) tableDefinition.copy( storage = tableDefinition.storage.copy(locationUri = Some(qualifiedTableLocation)), - identifier = TableIdentifier(table, Some(db))) + identifier = tableIdentifier) } else { - tableDefinition.copy(identifier = TableIdentifier(table, Some(db))) + tableDefinition.copy(identifier = tableIdentifier) } requireDbExists(db) + if (!ignoreIfExists) { + validateTableLocation(newTableDefinition) + } externalCatalog.createTable(newTableDefinition, ignoreIfExists) } + def validateTableLocation(table: CatalogTable): Unit = { + // SPARK-19724: the default location of a managed table should be non-existent or empty. + if (table.tableType == CatalogTableType.MANAGED && + !conf.allowCreatingManagedTableUsingNonemptyLocation) { + val tableLocation = + new Path(table.storage.locationUri.getOrElse(defaultTablePath(table.identifier))) + val fs = tableLocation.getFileSystem(hadoopConf) + + if (fs.exists(tableLocation) && fs.listStatus(tableLocation).nonEmpty) { + throw new AnalysisException(s"Can not create the managed table('${table.identifier}')" + + s". The associated location('${tableLocation.toString}') already exists.") + } + } + } + /** * Alter the metadata of an existing metastore table identified by `tableDefinition`. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 13f31a6b2eb93..1c8ab9c62623e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1159,6 +1159,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION = + buildConf("spark.sql.allowCreatingManagedTableUsingNonemptyLocation") + .internal() + .doc("When this option is set to true, creating managed tables with nonempty location " + + "is allowed. Otherwise, an analysis exception is thrown. ") + .booleanConf + .createWithDefault(false) + val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE = buildConf("spark.sql.streaming.continuous.executorQueueSize") .internal() @@ -1581,6 +1589,9 @@ class SQLConf extends Serializable with Logging { def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) + def allowCreatingManagedTableUsingNonemptyLocation: Boolean = + getConf(ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION) + def partitionOverwriteMode: PartitionOverwriteMode.Value = PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index e9747769dfcfc..f7c3e9b019258 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -167,7 +167,7 @@ case class CreateDataSourceTableAsSelectCommand( sparkSession, table, table.storage.locationUri, child, SaveMode.Append, tableExists = true) } else { assert(table.schema.isEmpty) - + sparkSession.sessionState.catalog.validateTableLocation(table) val tableLocation = if (table.tableType == CatalogTableType.MANAGED) { Some(sessionState.catalog.defaultTablePath(table.identifier)) } else { @@ -181,7 +181,8 @@ case class CreateDataSourceTableAsSelectCommand( // the schema of df). It is important since the nullability may be changed by the relation // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). schema = result.schema) - sessionState.catalog.createTable(newTable, ignoreIfExists = false) + // Table location is already validated. No need to check it again during table creation. + sessionState.catalog.createTable(newTable, ignoreIfExists = true) result match { case fs: HadoopFsRelation if table.partitionColumnNames.nonEmpty && diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index ed4ea0231f1a7..14a565863d66c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.io.File + import scala.collection.mutable import org.apache.spark.sql.catalyst.TableIdentifier @@ -26,6 +28,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.ArrayData import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** @@ -242,6 +245,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("change stats after set location command") { val table = "change_stats_set_location_table" + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier(table))) Seq(false, true).foreach { autoUpdate => withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { withTable(table) { @@ -269,6 +273,9 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared assert(fetched3.get.sizeInBytes == fetched1.get.sizeInBytes) } else { checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) + // SPARK-19724: clean up the previous table location. + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 4df8fbfe1c0db..4304d0b6f6b16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -180,6 +180,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { private val escapedIdentifier = "`(.+)`".r + private def dataSource: String = { + if (isUsingHiveMetastore) { + "HIVE" + } else { + "PARQUET" + } + } protected def normalizeCatalogTable(table: CatalogTable): CatalogTable = table private def normalizeSerdeProp(props: Map[String, String]): Map[String, String] = { @@ -365,6 +372,66 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("CTAS a managed table with the existing empty directory") { + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab1"))) + try { + tableLoc.mkdir() + withTable("tab1") { + sql(s"CREATE TABLE tab1 USING ${dataSource} AS SELECT 1, 'a'") + checkAnswer(spark.table("tab1"), Row(1, "a")) + } + } finally { + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) + } + } + + test("create a managed table with the existing empty directory") { + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab1"))) + try { + tableLoc.mkdir() + withTable("tab1") { + sql(s"CREATE TABLE tab1 (col1 int, col2 string) USING ${dataSource}") + sql("INSERT INTO tab1 VALUES (1, 'a')") + checkAnswer(spark.table("tab1"), Row(1, "a")) + } + } finally { + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) + } + } + + test("create a managed table with the existing non-empty directory") { + withTable("tab1") { + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab1"))) + try { + // create an empty hidden file + tableLoc.mkdir() + val hiddenGarbageFile = new File(tableLoc.getCanonicalPath, ".garbage") + hiddenGarbageFile.createNewFile() + val exMsg = "Can not create the managed table('`tab1`'). The associated location" + val exMsgWithDefaultDB = + "Can not create the managed table('`default`.`tab1`'). The associated location" + var ex = intercept[AnalysisException] { + sql(s"CREATE TABLE tab1 USING ${dataSource} AS SELECT 1, 'a'") + }.getMessage + if (isUsingHiveMetastore) { + assert(ex.contains(exMsgWithDefaultDB)) + } else { + assert(ex.contains(exMsg)) + } + + ex = intercept[AnalysisException] { + sql(s"CREATE TABLE tab1 (col1 int, col2 string) USING ${dataSource}") + }.getMessage + assert(ex.contains(exMsgWithDefaultDB)) + } finally { + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) + } + } + } + private def checkSchemaInCreatedDataSourceTable( path: File, userSpecifiedSchema: Option[String], From 6ade5cbb498f6c6ea38779b97f2325d5cf5013f2 Mon Sep 17 00:00:00 2001 From: Daniel Sakuma Date: Fri, 6 Apr 2018 13:37:08 +0800 Subject: [PATCH 572/774] [MINOR][DOC] Fix some typos and grammar issues ## What changes were proposed in this pull request? Easy fix in the documentation. ## How was this patch tested? N/A Closes #20948 Author: Daniel Sakuma Closes #20928 from dsakuma/fix_typo_configuration_docs. --- docs/README.md | 2 +- docs/_plugins/include_example.rb | 2 +- docs/building-spark.md | 2 +- docs/cloud-integration.md | 4 +-- docs/configuration.md | 20 ++++++------ docs/css/pygments-default.css | 2 +- docs/graphx-programming-guide.md | 4 +-- docs/job-scheduling.md | 4 +-- docs/ml-advanced.md | 2 +- docs/ml-classification-regression.md | 6 ++-- docs/ml-collaborative-filtering.md | 2 +- docs/ml-features.md | 2 +- docs/ml-migration-guides.md | 2 +- docs/ml-tuning.md | 2 +- docs/mllib-clustering.md | 2 +- docs/mllib-collaborative-filtering.md | 4 +-- docs/mllib-data-types.md | 2 +- docs/mllib-dimensionality-reduction.md | 2 +- docs/mllib-evaluation-metrics.md | 2 +- docs/mllib-feature-extraction.md | 2 +- docs/mllib-isotonic-regression.md | 4 +-- docs/mllib-linear-methods.md | 2 +- docs/mllib-optimization.md | 4 +-- docs/monitoring.md | 4 +-- docs/quick-start.md | 6 ++-- docs/rdd-programming-guide.md | 2 +- docs/running-on-kubernetes.md | 4 +-- docs/running-on-mesos.md | 12 +++---- docs/running-on-yarn.md | 2 +- docs/security.md | 2 +- docs/spark-standalone.md | 2 +- docs/sparkr.md | 6 ++-- docs/sql-programming-guide.md | 32 +++++++++---------- docs/storage-openstack-swift.md | 2 +- docs/streaming-flume-integration.md | 6 ++-- docs/streaming-kafka-0-8-integration.md | 10 +++--- docs/streaming-programming-guide.md | 26 +++++++-------- .../structured-streaming-kafka-integration.md | 2 +- .../structured-streaming-programming-guide.md | 8 ++--- docs/submitting-applications.md | 2 +- docs/tuning.md | 2 +- python/README.md | 2 +- sql/README.md | 2 +- 43 files changed, 107 insertions(+), 107 deletions(-) diff --git a/docs/README.md b/docs/README.md index 225bb1b2040de..9eac4ba35c458 100644 --- a/docs/README.md +++ b/docs/README.md @@ -5,7 +5,7 @@ here with the Spark source code. You can also find documentation specific to rel Spark at http://spark.apache.org/documentation.html. Read on to learn more about viewing documentation in plain text (i.e., markdown) or building the -documentation yourself. Why build it yourself? So that you have the docs that corresponds to +documentation yourself. Why build it yourself? So that you have the docs that correspond to whichever version of Spark you currently have checked out of revision control. ## Prerequisites diff --git a/docs/_plugins/include_example.rb b/docs/_plugins/include_example.rb index 6ea1d438f529e..1e91f12518e0b 100644 --- a/docs/_plugins/include_example.rb +++ b/docs/_plugins/include_example.rb @@ -48,7 +48,7 @@ def render(context) begin code = File.open(@file).read.encode("UTF-8") rescue => e - # We need to explicitly exit on execptions here because Jekyll will silently swallow + # We need to explicitly exit on exceptions here because Jekyll will silently swallow # them, leading to silent build failures (see https://github.com/jekyll/jekyll/issues/5104) puts(e) puts(e.backtrace) diff --git a/docs/building-spark.md b/docs/building-spark.md index c391255a91596..0236bb05849ad 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -113,7 +113,7 @@ Note: Flume support is deprecated as of Spark 2.3.0. ## Building submodules individually -It's possible to build Spark sub-modules using the `mvn -pl` option. +It's possible to build Spark submodules using the `mvn -pl` option. For instance, you can build the Spark Streaming module using: diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md index c150d9efc06ff..ac1c336988930 100644 --- a/docs/cloud-integration.md +++ b/docs/cloud-integration.md @@ -27,13 +27,13 @@ description: Introduction to cloud storage support in Apache Spark SPARK_VERSION All major cloud providers offer persistent data storage in *object stores*. These are not classic "POSIX" file systems. In order to store hundreds of petabytes of data without any single points of failure, -object stores replace the classic filesystem directory tree +object stores replace the classic file system directory tree with a simpler model of `object-name => data`. To enable remote access, operations on objects are usually offered as (slow) HTTP REST operations. Spark can read and write data in object stores through filesystem connectors implemented in Hadoop or provided by the infrastructure suppliers themselves. -These connectors make the object stores look *almost* like filesystems, with directories and files +These connectors make the object stores look *almost* like file systems, with directories and files and the classic operations on them such as list, delete and rename. diff --git a/docs/configuration.md b/docs/configuration.md index 2eb6a77434ea6..4d4d0c58dd07d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -558,7 +558,7 @@ Apart from these, the following properties are also available, and may be useful @@ -1288,7 +1288,7 @@ Apart from these, the following properties are also available, and may be useful @@ -1513,7 +1513,7 @@ Apart from these, the following properties are also available, and may be useful @@ -1722,7 +1722,7 @@ Apart from these, the following properties are also available, and may be useful When spark.task.reaper.enabled = true, this setting specifies a timeout after which the executor JVM will kill itself if a killed task has not stopped running. The default value, -1, disables this mechanism and prevents the executor from self-destructing. The purpose - of this setting is to act as a safety-net to prevent runaway uncancellable tasks from rendering + of this setting is to act as a safety-net to prevent runaway noncancellable tasks from rendering an executor unusable. @@ -1915,8 +1915,8 @@ showDF(properties, numRows = 200, truncate = FALSE) @@ -1971,7 +1971,7 @@ showDF(properties, numRows = 200, truncate = FALSE) @@ -1980,7 +1980,7 @@ showDF(properties, numRows = 200, truncate = FALSE) @@ -2178,7 +2178,7 @@ Spark's classpath for each application. In a Spark cluster running on YARN, thes files are set cluster-wide, and cannot safely be changed by the application. The better choice is to use spark hadoop properties in the form of `spark.hadoop.*`. -They can be considered as same as normal spark properties which can be set in `$SPARK_HOME/conf/spark-defalut.conf` +They can be considered as same as normal spark properties which can be set in `$SPARK_HOME/conf/spark-default.conf` In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, Spark allows you to simply create an empty conf and set spark/spark hadoop properties. diff --git a/docs/css/pygments-default.css b/docs/css/pygments-default.css index 6247cd8396cf1..a4d583b366603 100644 --- a/docs/css/pygments-default.css +++ b/docs/css/pygments-default.css @@ -5,7 +5,7 @@ To generate this, I had to run But first I had to install pygments via easy_install pygments I had to override the conflicting bootstrap style rules by linking to -this stylesheet lower in the html than the bootstap css. +this stylesheet lower in the html than the bootstrap css. Also, I was thrown off for a while at first when I was using markdown code block inside my {% highlight scala %} ... {% endhighlight %} tags diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 5c97a248df4bc..35293348e3f3d 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -491,7 +491,7 @@ val joinedGraph = graph.joinVertices(uniqueCosts)( The more general [`outerJoinVertices`][Graph.outerJoinVertices] behaves similarly to `joinVertices` except that the user defined `map` function is applied to all vertices and can change the vertex property type. Because not all vertices may have a matching value in the input RDD the `map` -function takes an `Option` type. For example, we can setup a graph for PageRank by initializing +function takes an `Option` type. For example, we can set up a graph for PageRank by initializing vertex properties with their `outDegree`. @@ -969,7 +969,7 @@ A vertex is part of a triangle when it has two adjacent vertices with an edge be # Examples Suppose I want to build a graph from some text files, restrict the graph -to important relationships and users, run page-rank on the sub-graph, and +to important relationships and users, run page-rank on the subgraph, and then finally return attributes associated with the top users. I can do all of this in just a few lines with GraphX: diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index e6d881639a13b..da90342406c84 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -23,7 +23,7 @@ run tasks and store data for that application. If multiple users need to share y different options to manage allocation, depending on the cluster manager. The simplest option, available on all cluster managers, is _static partitioning_ of resources. With -this approach, each application is given a maximum amount of resources it can use, and holds onto them +this approach, each application is given a maximum amount of resources it can use and holds onto them for its whole duration. This is the approach used in Spark's [standalone](spark-standalone.html) and [YARN](running-on-yarn.html) modes, as well as the [coarse-grained Mesos mode](running-on-mesos.html#mesos-run-modes). @@ -230,7 +230,7 @@ properties: * `minShare`: Apart from an overall weight, each pool can be given a _minimum shares_ (as a number of CPU cores) that the administrator would like it to have. The fair scheduler always attempts to meet all active pools' minimum shares before redistributing extra resources according to the weights. - The `minShare` property can therefore be another way to ensure that a pool can always get up to a + The `minShare` property can, therefore, be another way to ensure that a pool can always get up to a certain number of resources (e.g. 10 cores) quickly without giving it a high priority for the rest of the cluster. By default, each pool's `minShare` is 0. diff --git a/docs/ml-advanced.md b/docs/ml-advanced.md index 2747f2df7cb10..375957e92cc4c 100644 --- a/docs/ml-advanced.md +++ b/docs/ml-advanced.md @@ -77,7 +77,7 @@ Quasi-Newton methods in this case. This fallback is currently always enabled for L1 regularization is applied (i.e. $\alpha = 0$), there exists an analytical solution and either Cholesky or Quasi-Newton solver may be used. When $\alpha > 0$ no analytical solution exists and we instead use the Quasi-Newton solver to find the coefficients iteratively. -In order to make the normal equation approach efficient, `WeightedLeastSquares` requires that the number of features be no more than 4096. For larger problems, use L-BFGS instead. +In order to make the normal equation approach efficient, `WeightedLeastSquares` requires that the number of features is no more than 4096. For larger problems, use L-BFGS instead. ## Iteratively reweighted least squares (IRLS) diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index ddd2f4b49ca07..d660655e193eb 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -420,7 +420,7 @@ Refer to the [R API docs](api/R/spark.svmLinear.html) for more details. [OneVsRest](http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. It is also known as "One-vs-All." -`OneVsRest` is implemented as an `Estimator`. For the base classifier it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. +`OneVsRest` is implemented as an `Estimator`. For the base classifier, it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label. @@ -908,7 +908,7 @@ Refer to the [R API docs](api/R/spark.survreg.html) for more details. belongs to the family of regression algorithms. Formally isotonic regression is a problem where given a finite set of real numbers `$Y = {y_1, y_2, ..., y_n}$` representing observed responses and `$X = {x_1, x_2, ..., x_n}$` the unknown response values to be fitted -finding a function that minimises +finding a function that minimizes `\begin{equation} f(x) = \sum_{i=1}^n w_i (y_i - x_i)^2 @@ -927,7 +927,7 @@ We implement a which uses an approach to [parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). The training input is a DataFrame which contains three columns -label, features and weight. Additionally IsotonicRegression algorithm has one +label, features and weight. Additionally, IsotonicRegression algorithm has one optional parameter called $isotonic$ defaulting to true. This argument specifies if the isotonic regression is isotonic (monotonically increasing) or antitonic (monotonically decreasing). diff --git a/docs/ml-collaborative-filtering.md b/docs/ml-collaborative-filtering.md index 58f2d4b531e70..8b0f287dc39ad 100644 --- a/docs/ml-collaborative-filtering.md +++ b/docs/ml-collaborative-filtering.md @@ -35,7 +35,7 @@ but the ids must be within the integer value range. ### Explicit vs. implicit feedback -The standard approach to matrix factorization based collaborative filtering treats +The standard approach to matrix factorization-based collaborative filtering treats the entries in the user-item matrix as *explicit* preferences given by the user to the item, for example, users giving ratings to movies. diff --git a/docs/ml-features.md b/docs/ml-features.md index 3370eb3893272..7aed2341584fc 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1174,7 +1174,7 @@ for more details on the API. ## SQLTransformer `SQLTransformer` implements the transformations which are defined by SQL statement. -Currently we only support SQL syntax like `"SELECT ... FROM __THIS__ ..."` +Currently, we only support SQL syntax like `"SELECT ... FROM __THIS__ ..."` where `"__THIS__"` represents the underlying table of the input dataset. The select clause specifies the fields, constants, and expressions to display in the output, and can be any select clause that Spark SQL supports. Users can also diff --git a/docs/ml-migration-guides.md b/docs/ml-migration-guides.md index f4b0df58cf63b..e4736411fb5fe 100644 --- a/docs/ml-migration-guides.md +++ b/docs/ml-migration-guides.md @@ -347,7 +347,7 @@ rather than using the old parameter class `Strategy`. These new training method separate classification and regression, and they replace specialized parameter types with simple `String` types. -Examples of the new, recommended `trainClassifier` and `trainRegressor` are given in the +Examples of the new recommended `trainClassifier` and `trainRegressor` are given in the [Decision Trees Guide](mllib-decision-tree.html#examples). ## From 0.9 to 1.0 diff --git a/docs/ml-tuning.md b/docs/ml-tuning.md index 54d9cd21909df..028bfec465bab 100644 --- a/docs/ml-tuning.md +++ b/docs/ml-tuning.md @@ -103,7 +103,7 @@ Refer to the [`CrossValidator` Python docs](api/python/pyspark.ml.html#pyspark.m In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning. `TrainValidationSplit` only evaluates each combination of parameters once, as opposed to k times in - the case of `CrossValidator`. It is therefore less expensive, + the case of `CrossValidator`. It is, therefore, less expensive, but will not produce as reliable results when the training dataset is not sufficiently large. Unlike `CrossValidator`, `TrainValidationSplit` creates a single (training, test) dataset pair. diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index df2be92d860e4..dc6b095f5d59b 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -42,7 +42,7 @@ The following code snippets can be executed in `spark-shell`. In the following example after loading and parsing data, we use the [`KMeans`](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) object to cluster the data into two clusters. The number of desired clusters is passed to the algorithm. We then compute Within -Set Sum of Squared Error (WSSSE). You can reduce this error measure by increasing *k*. In fact the +Set Sum of Squared Error (WSSSE). You can reduce this error measure by increasing *k*. In fact, the optimal *k* is usually one where there is an "elbow" in the WSSSE graph. Refer to the [`KMeans` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) and [`KMeansModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.KMeansModel) for details on the API. diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 76a00f18b3b90..b2300028e151b 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -31,7 +31,7 @@ following parameters: ### Explicit vs. implicit feedback -The standard approach to matrix factorization based collaborative filtering treats +The standard approach to matrix factorization-based collaborative filtering treats the entries in the user-item matrix as *explicit* preferences given by the user to the item, for example, users giving ratings to movies. @@ -60,7 +60,7 @@ best parameter learned from a sampled subset to the full dataset and expect simi
    -In the following example we load rating data. Each row consists of a user, a product and a rating. +In the following example, we load rating data. Each row consists of a user, a product and a rating. We use the default [ALS.train()](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS$) method which assumes ratings are explicit. We evaluate the recommendation model by measuring the Mean Squared Error of rating prediction. diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 35cee3275e3b5..5066bb29387dc 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -350,7 +350,7 @@ which is a tuple of `(Int, Int, Matrix)`. ***Note*** The underlying RDDs of a distributed matrix must be deterministic, because we cache the matrix size. -In general the use of non-deterministic RDDs can lead to errors. +In general, the use of non-deterministic RDDs can lead to errors. ### RowMatrix diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index a72680d52a26c..4e6b4530942f1 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -91,7 +91,7 @@ The same code applies to `IndexedRowMatrix` if `U` is defined as an [Principal component analysis (PCA)](http://en.wikipedia.org/wiki/Principal_component_analysis) is a statistical method to find a rotation such that the first coordinate has the largest variance -possible, and each succeeding coordinate in turn has the largest variance possible. The columns of +possible, and each succeeding coordinate, in turn, has the largest variance possible. The columns of the rotation matrix are called principal components. PCA is used widely in dimensionality reduction. `spark.mllib` supports PCA for tall-and-skinny matrices stored in row-oriented format and any Vectors. diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index 7f277543d2e9a..d9dbbab4840a3 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -13,7 +13,7 @@ of the model on some criteria, which depends on the application and its requirem suite of metrics for the purpose of evaluating the performance of machine learning models. Specific machine learning algorithms fall under broader types of machine learning applications like classification, -regression, clustering, etc. Each of these types have well established metrics for performance evaluation and those +regression, clustering, etc. Each of these types have well-established metrics for performance evaluation and those metrics that are currently available in `spark.mllib` are detailed in this section. ## Classification model evaluation diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 8b89296b14cdd..bb29f65c0322f 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -105,7 +105,7 @@ p(w_i | w_j ) = \frac{\exp(u_{w_i}^{\top}v_{w_j})}{\sum_{l=1}^{V} \exp(u_l^{\top \]` where $V$ is the vocabulary size. -The skip-gram model with softmax is expensive because the cost of computing $\log p(w_i | w_j)$ +The skip-gram model with softmax is expensive because the cost of computing $\log p(w_i | w_j)$ is proportional to $V$, which can be easily in order of millions. To speed up training of Word2Vec, we used hierarchical softmax, which reduced the complexity of computing of $\log p(w_i | w_j)$ to $O(\log(V))$ diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index ca84551506b2b..99cab98c690c6 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -9,7 +9,7 @@ displayTitle: Regression - RDD-based API belongs to the family of regression algorithms. Formally isotonic regression is a problem where given a finite set of real numbers `$Y = {y_1, y_2, ..., y_n}$` representing observed responses and `$X = {x_1, x_2, ..., x_n}$` the unknown response values to be fitted -finding a function that minimises +finding a function that minimizes `\begin{equation} f(x) = \sum_{i=1}^n w_i (y_i - x_i)^2 @@ -28,7 +28,7 @@ best fitting the original data points. which uses an approach to [parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). The training input is an RDD of tuples of three double values that represent -label, feature and weight in this order. Additionally IsotonicRegression algorithm has one +label, feature and weight in this order. Additionally, IsotonicRegression algorithm has one optional parameter called $isotonic$ defaulting to true. This argument specifies if the isotonic regression is isotonic (monotonically increasing) or antitonic (monotonically decreasing). diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 034e89e25000e..73f6e206ca543 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -425,7 +425,7 @@ We create our model by initializing the weights to zero and register the streams testing then start the job. Printing predictions alongside true labels lets us easily see the result. -Finally we can save text files with data to the training or testing folders. +Finally, we can save text files with data to the training or testing folders. Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label and `x1,x2,x3` are the features. Anytime a text file is placed in `args(0)` the model will update. Anytime a text file is placed in `args(1)` you will see predictions. diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index 14d76a6e41e23..04758903da89c 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -121,7 +121,7 @@ computation of the sum of the partial results from each worker machine is perfor standard spark routines. If the fraction of points `miniBatchFraction` is set to 1 (default), then the resulting step in -each iteration is exact (sub)gradient descent. In this case there is no randomness and no +each iteration is exact (sub)gradient descent. In this case, there is no randomness and no variance in the used step directions. On the other extreme, if `miniBatchFraction` is chosen very small, such that only a single point is sampled, i.e. `$|S|=$ miniBatchFraction $\cdot n = 1$`, then the algorithm is equivalent to @@ -135,7 +135,7 @@ algorithm in the family of quasi-Newton methods to solve the optimization proble quadratic without evaluating the second partial derivatives of the objective function to construct the Hessian matrix. The Hessian matrix is approximated by previous gradient evaluations, so there is no vertical scalability issue (the number of training features) when computing the Hessian matrix -explicitly in Newton's method. As a result, L-BFGS often achieves rapider convergence compared with +explicitly in Newton's method. As a result, L-BFGS often achieves more rapid convergence compared with other first-order optimization. ### Choosing an Optimization Method diff --git a/docs/monitoring.md b/docs/monitoring.md index 01736c77b0979..6eaf33135744d 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -214,7 +214,7 @@ incomplete attempt or the final successful attempt. 2. Incomplete applications are only updated intermittently. The time between updates is defined by the interval between checks for changed files (`spark.history.fs.update.interval`). -On larger clusters the update interval may be set to large values. +On larger clusters, the update interval may be set to large values. The way to view a running application is actually to view its own web UI. 3. Applications which exited without registering themselves as completed will be listed @@ -422,7 +422,7 @@ configuration property. If, say, users wanted to set the metrics namespace to the name of the application, they can set the `spark.metrics.namespace` property to a value like `${spark.app.name}`. This value is then expanded appropriately by Spark and is used as the root namespace of the metrics system. -Non driver and executor metrics are never prefixed with `spark.app.id`, nor does the +Non-driver and executor metrics are never prefixed with `spark.app.id`, nor does the `spark.metrics.namespace` property have any such affect on such metrics. Spark's metrics are decoupled into different diff --git a/docs/quick-start.md b/docs/quick-start.md index 07c520cbee6be..f1a2096cd4dbd 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -11,11 +11,11 @@ This tutorial provides a quick introduction to using Spark. We will first introd interactive shell (in Python or Scala), then show how to write applications in Java, Scala, and Python. -To follow along with this guide, first download a packaged release of Spark from the +To follow along with this guide, first, download a packaged release of Spark from the [Spark website](http://spark.apache.org/downloads.html). Since we won't be using HDFS, you can download a package for any version of Hadoop. -Note that, before Spark 2.0, the main programming interface of Spark was the Resilient Distributed Dataset (RDD). After Spark 2.0, RDDs are replaced by Dataset, which is strongly-typed like an RDD, but with richer optimizations under the hood. The RDD interface is still supported, and you can get a more complete reference at the [RDD programming guide](rdd-programming-guide.html). However, we highly recommend you to switch to use Dataset, which has better performance than RDD. See the [SQL programming guide](sql-programming-guide.html) to get more information about Dataset. +Note that, before Spark 2.0, the main programming interface of Spark was the Resilient Distributed Dataset (RDD). After Spark 2.0, RDDs are replaced by Dataset, which is strongly-typed like an RDD, but with richer optimizations under the hood. The RDD interface is still supported, and you can get a more detailed reference at the [RDD programming guide](rdd-programming-guide.html). However, we highly recommend you to switch to use Dataset, which has better performance than RDD. See the [SQL programming guide](sql-programming-guide.html) to get more information about Dataset. # Interactive Analysis with the Spark Shell @@ -47,7 +47,7 @@ scala> textFile.first() // First item in this Dataset res1: String = # Apache Spark {% endhighlight %} -Now let's transform this Dataset to a new one. We call `filter` to return a new Dataset with a subset of the items in the file. +Now let's transform this Dataset into a new one. We call `filter` to return a new Dataset with a subset of the items in the file. {% highlight scala %} scala> val linesWithSpark = textFile.filter(line => line.contains("Spark")) diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index 2e29aef7f21a2..b6424090d2fea 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -818,7 +818,7 @@ The behavior of the above code is undefined, and may not work as intended. To ex The variables within the closure sent to each executor are now copies and thus, when **counter** is referenced within the `foreach` function, it's no longer the **counter** on the driver node. There is still a **counter** in the memory of the driver node but this is no longer visible to the executors! The executors only see the copy from the serialized closure. Thus, the final value of **counter** will still be zero since all operations on **counter** were referencing the value within the serialized closure. -In local mode, in some circumstances the `foreach` function will actually execute within the same JVM as the driver and will reference the same original **counter**, and may actually update it. +In local mode, in some circumstances, the `foreach` function will actually execute within the same JVM as the driver and will reference the same original **counter**, and may actually update it. To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#accumulators). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 9c4644947c911..e9e1f3e280609 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -17,7 +17,7 @@ container images and entrypoints.** * A runnable distribution of Spark 2.3 or above. * A running Kubernetes cluster at version >= 1.6 with access configured to it using [kubectl](https://kubernetes.io/docs/user-guide/prereqs/). If you do not already have a working Kubernetes cluster, -you may setup a test cluster on your local machine using +you may set up a test cluster on your local machine using [minikube](https://kubernetes.io/docs/getting-started-guides/minikube/). * We recommend using the latest release of minikube with the DNS addon enabled. * Be aware that the default minikube configuration is not enough for running Spark applications. @@ -221,7 +221,7 @@ that allows driver pods to create pods and services under the default Kubernetes [RBAC](https://kubernetes.io/docs/admin/authorization/rbac/) policies. Sometimes users may need to specify a custom service account that has the right role granted. Spark on Kubernetes supports specifying a custom service account to be used by the driver pod through the configuration property -`spark.kubernetes.authenticate.driver.serviceAccountName=`. For example to make the driver pod +`spark.kubernetes.authenticate.driver.serviceAccountName=`. For example, to make the driver pod use the `spark` service account, a user simply adds the following option to the `spark-submit` command: ``` diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 8e58892e2689f..3c2a1501ca692 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -90,7 +90,7 @@ Depending on your deployment environment you may wish to create a single set of Framework credentials may be specified in a variety of ways depending on your deployment environment and security requirements. The most simple way is to specify the `spark.mesos.principal` and `spark.mesos.secret` values directly in your Spark configuration. Alternatively you may specify these values indirectly by instead specifying `spark.mesos.principal.file` and `spark.mesos.secret.file`, these settings point to files containing the principal and secret. These files must be plaintext files in UTF-8 encoding. Combined with appropriate file ownership and mode/ACLs this provides a more secure way to specify these credentials. -Additionally if you prefer to use environment variables you can specify all of the above via environment variables instead, the environment variable names are simply the configuration settings uppercased with `.` replaced with `_` e.g. `SPARK_MESOS_PRINCIPAL`. +Additionally, if you prefer to use environment variables you can specify all of the above via environment variables instead, the environment variable names are simply the configuration settings uppercased with `.` replaced with `_` e.g. `SPARK_MESOS_PRINCIPAL`. ### Credential Specification Preference Order @@ -225,7 +225,7 @@ details and default values. Executors are brought up eagerly when the application starts, until `spark.cores.max` is reached. If you don't set `spark.cores.max`, the Spark application will consume all resources offered to it by Mesos, -so we of course urge you to set this variable in any sort of +so we, of course, urge you to set this variable in any sort of multi-tenant cluster, including one which runs multiple concurrent Spark applications. @@ -233,14 +233,14 @@ The scheduler will start executors round-robin on the offers Mesos gives it, but there are no spread guarantees, as Mesos does not provide such guarantees on the offer stream. -In this mode spark executors will honor port allocation if such is -provided from the user. Specifically if the user defines +In this mode Spark executors will honor port allocation if such is +provided from the user. Specifically, if the user defines `spark.blockManager.port` in Spark configuration, the mesos scheduler will check the available offers for a valid port range containing the port numbers. If no such range is available it will not launch any task. If no restriction is imposed on port numbers by the user, ephemeral ports are used as usual. This port honouring implementation -implies one task per host if the user defines a port. In the future network +implies one task per host if the user defines a port. In the future network, isolation shall be supported. The benefit of coarse-grained mode is much lower startup overhead, but @@ -486,7 +486,7 @@ See the [configuration page](configuration.html) for information on Spark config
    - + @@ -1797,6 +1798,23 @@ Apart from these, the following properties are also available, and may be useful Lower bound for the number of executors if dynamic allocation is enabled. + + + + + From 2a24c481da3f30b510deb62e5cf21c9463cf250c Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Tue, 24 Apr 2018 09:25:41 -0700 Subject: [PATCH 680/774] [SPARK-23975][ML] Allow Clustering to take Arrays of Double as input features ## What changes were proposed in this pull request? - Multiple possible input types is added in validateAndTransformSchema() and computeCost() while checking column type - Add if statement in transform() to support array type as featuresCol - Add the case statement in fit() while selecting columns from dataset These changes will be applied to KMeans first, then to other clustering method ## How was this patch tested? unit test is added Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21081 from ludatabricks/SPARK-23975. --- .../apache/spark/ml/clustering/KMeans.scala | 32 +++++++--- .../apache/spark/ml/util/DatasetUtils.scala | 63 +++++++++++++++++++ .../spark/ml/clustering/KMeansSuite.scala | 38 +++++++++++ 3 files changed, 126 insertions(+), 7 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 1ad157a695a7d..d475c726e6f08 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -33,8 +33,8 @@ import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} -import org.apache.spark.sql.functions.{col, udf} -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.VersionUtils.majorVersion @@ -86,13 +86,24 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") def getInitSteps: Int = $(initSteps) + /** + * Validates the input schema. + * @param schema input schema + */ + private[clustering] def validateSchema(schema: StructType): Unit = { + val typeCandidates = List( new VectorUDT, + new ArrayType(DoubleType, false), + new ArrayType(FloatType, false)) + + SchemaUtils.checkColumnTypes(schema, $(featuresCol), typeCandidates) + } /** * Validates and transforms the input schema. * @param schema input schema * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + validateSchema(schema) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } } @@ -125,8 +136,11 @@ class KMeansModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) + val predictUDF = udf((vector: Vector) => predict(vector)) - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + + dataset.withColumn($(predictionCol), + predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } @Since("1.5.0") @@ -146,8 +160,10 @@ class KMeansModel private[ml] ( // TODO: Replace the temp fix when we have proper evaluators defined for clustering. @Since("2.0.0") def computeCost(dataset: Dataset[_]): Double = { - SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) - val data: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { + validateSchema(dataset.schema) + + val data: RDD[OldVector] = dataset.select(DatasetUtils.columnToVector(dataset, getFeaturesCol)) + .rdd.map { case Row(point: Vector) => OldVectors.fromML(point) } parentModel.computeCost(data) @@ -335,7 +351,9 @@ class KMeans @Since("1.5.0") ( transformSchema(dataset.schema, logging = true) val handlePersistence = dataset.storageLevel == StorageLevel.NONE - val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { + val instances: RDD[OldVector] = dataset.select( + DatasetUtils.columnToVector(dataset, getFeaturesCol)) + .rdd.map { case Row(point: Vector) => OldVectors.fromML(point) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala new file mode 100644 index 0000000000000..52619cb65489a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.ml.linalg.{Vectors, VectorUDT} +import org.apache.spark.sql.{Column, Dataset} +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType} + + +private[spark] object DatasetUtils { + + /** + * Cast a column in a Dataset to Vector type. + * + * The supported data types of the input column are + * - Vector + * - float/double type Array. + * + * Note: The returned column does not have Metadata. + * + * @param dataset input DataFrame + * @param colName column name. + * @return Vector column + */ + def columnToVector(dataset: Dataset[_], colName: String): Column = { + val columnDataType = dataset.schema(colName).dataType + columnDataType match { + case _: VectorUDT => col(colName) + case fdt: ArrayType => + val transferUDF = fdt.elementType match { + case _: FloatType => udf(f = (vector: Seq[Float]) => { + val inputArray = Array.fill[Double](vector.size)(0.0) + vector.indices.foreach(idx => inputArray(idx) = vector(idx).toDouble) + Vectors.dense(inputArray) + }) + case _: DoubleType => udf((vector: Seq[Double]) => { + Vectors.dense(vector.toArray) + }) + case other => + throw new IllegalArgumentException(s"Array[$other] column cannot be cast to Vector") + } + transferUDF(col(colName)) + case other => + throw new IllegalArgumentException(s"$other column cannot be cast to Vector") + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 77c9d482d95b6..5445ebe5c95eb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -30,6 +30,8 @@ import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType} private[clustering] case class TestRow(features: Vector) @@ -199,6 +201,42 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(e.getCause.getMessage.contains("Cosine distance is not defined")) } + test("KMean with Array input") { + val featuresColNameD = "array_double_features" + val featuresColNameF = "array_float_features" + + val doubleUDF = udf { (features: Vector) => + val featureArray = Array.fill[Double](features.size)(0.0) + features.foreachActive((idx, value) => featureArray(idx) = value.toFloat) + featureArray + } + val floatUDF = udf { (features: Vector) => + val featureArray = Array.fill[Float](features.size)(0.0f) + features.foreachActive((idx, value) => featureArray(idx) = value.toFloat) + featureArray + } + + val newdatasetD = dataset.withColumn(featuresColNameD, doubleUDF(col("features"))) + .drop("features") + val newdatasetF = dataset.withColumn(featuresColNameF, floatUDF(col("features"))) + .drop("features") + assert(newdatasetD.schema(featuresColNameD).dataType.equals(new ArrayType(DoubleType, false))) + assert(newdatasetF.schema(featuresColNameF).dataType.equals(new ArrayType(FloatType, false))) + + val kmeansD = new KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameD).setSeed(1) + val kmeansF = new KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameF).setSeed(1) + val modelD = kmeansD.fit(newdatasetD) + val modelF = kmeansF.fit(newdatasetF) + val transformedD = modelD.transform(newdatasetD) + val transformedF = modelF.transform(newdatasetF) + + val predictDifference = transformedD.select("prediction") + .except(transformedF.select("prediction")) + assert(predictDifference.count() == 0) + assert(modelD.computeCost(newdatasetD) == modelF.computeCost(newdatasetF) ) + } + + test("read/write") { def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = { assert(model.clusterCenters === model2.clusterCenters) From ce7ba2e98e0a3b038e881c271b5905058c43155b Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Tue, 24 Apr 2018 09:57:09 -0700 Subject: [PATCH 681/774] [SPARK-23807][BUILD] Add Hadoop 3.1 profile with relevant POM fix ups MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? 1. Adds a `hadoop-3.1` profile build depending on the hadoop-3.1 artifacts. 1. In the hadoop-cloud module, adds an explicit hadoop-3.1 profile which switches from explicitly pulling in cloud connectors (hadoop-openstack, hadoop-aws, hadoop-azure) to depending on the hadoop-cloudstorage POM artifact, which pulls these in, has pre-excluded things like hadoop-common, and stays up to date with new connectors (hadoop-azuredatalake, hadoop-allyun). Goal: it becomes the Hadoop projects homework of keeping this clean, and the spark project doesn't need to handle new hadoop releases adding more dependencies. 1. the hadoop-cloud/hadoop-3.1 profile also declares support for jetty-ajax and jetty-util to ensure that these jars get into the distribution jar directory when needed by unshaded libraries. 1. Increases the curator and zookeeper versions to match those in hadoop-3, fixing spark core to build in sbt with the hadoop-3 dependencies. ## How was this patch tested? * Everything this has been built and tested against both ASF Hadoop branch-3.1 and hadoop trunk. * spark-shell was used to create connectors to all the stores and verify that file IO could take place. The spark hive-1.2.1 JAR has problems here, as it's version check logic fails for Hadoop versions > 2. This can be avoided with either of * The hadoop JARs built to declare their version as Hadoop 2.11 `mvn install -DskipTests -DskipShade -Ddeclared.hadoop.version=2.11` . This is safe for local test runs, not for deployment (HDFS is very strict about cross-version deployment). * A modified version of spark hive whose version check switch statement is happy with hadoop 3. I've done both, with maven and SBT. Three issues surfaced 1. A spark-core test failure —fixed in SPARK-23787. 1. SBT only: Zookeeper not being found in spark-core. Somehow curator 2.12.0 triggers some slightly different dependency resolution logic from previous versions, and Ivy was missing zookeeper.jar entirely. This patch adds the explicit declaration for all spark profiles, setting the ZK version = 3.4.9 for hadoop-3.1 1. Marking jetty-utils as provided in spark was stopping hadoop-azure from being able to instantiate the azure wasb:// client; it was using jetty-util-ajax, which could then not find a class in jetty-util. Author: Steve Loughran Closes #20923 from steveloughran/cloud/SPARK-23807-hadoop-31. --- assembly/pom.xml | 8 ++ core/pom.xml | 6 + dev/deps/spark-deps-hadoop-3.1 | 221 +++++++++++++++++++++++++++++++++ dev/test-dependencies.sh | 1 + hadoop-cloud/pom.xml | 83 ++++++++++++- pom.xml | 9 ++ 6 files changed, 327 insertions(+), 1 deletion(-) create mode 100644 dev/deps/spark-deps-hadoop-3.1 diff --git a/assembly/pom.xml b/assembly/pom.xml index a207dae5a74ff..9608c96fd5369 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -254,6 +254,14 @@ spark-hadoop-cloud_${scala.binary.version} ${project.version} + + + org.eclipse.jetty + jetty-util + ${hadoop.deps.scope} + diff --git a/core/pom.xml b/core/pom.xml index 9258a856028a0..093a9869b6dd7 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -95,6 +95,12 @@ org.apache.curator curator-recipes + + + org.apache.zookeeper + zookeeper + diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 new file mode 100644 index 0000000000000..97ad65a4096cb --- /dev/null +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -0,0 +1,221 @@ +HikariCP-java7-2.4.12.jar +JavaEWAH-0.3.2.jar +RoaringBitmap-0.5.11.jar +ST4-4.0.4.jar +accessors-smart-1.2.jar +activation-1.1.1.jar +aircompressor-0.8.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar +antlr4-runtime-4.7.jar +aopalliance-1.0.jar +aopalliance-repackaged-2.4.0-b34.jar +apache-log4j-extras-1.2.17.jar +arpack_combined_all-0.1.jar +arrow-format-0.8.0.jar +arrow-memory-0.8.0.jar +arrow-vector-0.8.0.jar +automaton-1.11-8.jar +avro-1.7.7.jar +avro-ipc-1.7.7.jar +avro-mapred-1.7.7-hadoop2.jar +base64-2.3.8.jar +bcprov-jdk15on-1.58.jar +bonecp-0.8.0.RELEASE.jar +breeze-macros_2.11-0.13.2.jar +breeze_2.11-0.13.2.jar +calcite-avatica-1.2.0-incubating.jar +calcite-core-1.2.0-incubating.jar +calcite-linq4j-1.2.0-incubating.jar +chill-java-0.8.4.jar +chill_2.11-0.8.4.jar +commons-beanutils-1.9.3.jar +commons-cli-1.2.jar +commons-codec-1.10.jar +commons-collections-3.2.2.jar +commons-compiler-3.0.8.jar +commons-compress-1.4.1.jar +commons-configuration2-2.1.1.jar +commons-crypto-1.0.0.jar +commons-daemon-1.0.13.jar +commons-dbcp-1.4.jar +commons-httpclient-3.1.jar +commons-io-2.4.jar +commons-lang-2.6.jar +commons-lang3-3.5.jar +commons-logging-1.1.3.jar +commons-math3-3.4.1.jar +commons-net-3.1.jar +commons-pool-1.5.4.jar +compress-lzf-1.0.3.jar +core-1.1.2.jar +curator-client-2.12.0.jar +curator-framework-2.12.0.jar +curator-recipes-2.12.0.jar +datanucleus-api-jdo-3.2.6.jar +datanucleus-core-3.2.10.jar +datanucleus-rdbms-3.2.9.jar +derby-10.12.1.1.jar +dnsjava-2.1.7.jar +ehcache-3.3.1.jar +eigenbase-properties-1.1.5.jar +flatbuffers-1.2.0-3f79e055.jar +generex-1.0.1.jar +geronimo-jcache_1.0_spec-1.0-alpha-1.jar +gson-2.2.4.jar +guava-14.0.1.jar +guice-4.0.jar +guice-servlet-4.0.jar +hadoop-annotations-3.1.0.jar +hadoop-auth-3.1.0.jar +hadoop-client-3.1.0.jar +hadoop-common-3.1.0.jar +hadoop-hdfs-client-3.1.0.jar +hadoop-mapreduce-client-common-3.1.0.jar +hadoop-mapreduce-client-core-3.1.0.jar +hadoop-mapreduce-client-jobclient-3.1.0.jar +hadoop-yarn-api-3.1.0.jar +hadoop-yarn-client-3.1.0.jar +hadoop-yarn-common-3.1.0.jar +hadoop-yarn-registry-3.1.0.jar +hadoop-yarn-server-common-3.1.0.jar +hadoop-yarn-server-web-proxy-3.1.0.jar +hk2-api-2.4.0-b34.jar +hk2-locator-2.4.0-b34.jar +hk2-utils-2.4.0-b34.jar +hppc-0.7.2.jar +htrace-core4-4.1.0-incubating.jar +httpclient-4.5.4.jar +httpcore-4.4.8.jar +ivy-2.4.0.jar +jackson-annotations-2.6.7.jar +jackson-core-2.6.7.jar +jackson-core-asl-1.9.13.jar +jackson-databind-2.6.7.1.jar +jackson-dataformat-yaml-2.6.7.jar +jackson-jaxrs-base-2.7.8.jar +jackson-jaxrs-json-provider-2.7.8.jar +jackson-mapper-asl-1.9.13.jar +jackson-module-jaxb-annotations-2.6.7.jar +jackson-module-paranamer-2.7.9.jar +jackson-module-scala_2.11-2.6.7.1.jar +janino-3.0.8.jar +java-xmlbuilder-1.1.jar +javassist-3.18.1-GA.jar +javax.annotation-api-1.2.jar +javax.inject-1.jar +javax.inject-2.4.0-b34.jar +javax.servlet-api-3.1.0.jar +javax.ws.rs-api-2.0.1.jar +javolution-5.5.1.jar +jaxb-api-2.2.11.jar +jcip-annotations-1.0-1.jar +jcl-over-slf4j-1.7.16.jar +jdo-api-3.0.1.jar +jersey-client-2.22.2.jar +jersey-common-2.22.2.jar +jersey-container-servlet-2.22.2.jar +jersey-container-servlet-core-2.22.2.jar +jersey-guava-2.22.2.jar +jersey-media-jaxb-2.22.2.jar +jersey-server-2.22.2.jar +jets3t-0.9.4.jar +jetty-webapp-9.3.20.v20170531.jar +jetty-xml-9.3.20.v20170531.jar +jline-2.12.1.jar +joda-time-2.9.3.jar +jodd-core-3.5.2.jar +jpam-1.1.jar +json-smart-2.3.jar +json4s-ast_2.11-3.5.3.jar +json4s-core_2.11-3.5.3.jar +json4s-jackson_2.11-3.5.3.jar +json4s-scalap_2.11-3.5.3.jar +jsp-api-2.1.jar +jsr305-1.3.9.jar +jta-1.1.jar +jtransforms-2.4.0.jar +jul-to-slf4j-1.7.16.jar +kerb-admin-1.0.1.jar +kerb-client-1.0.1.jar +kerb-common-1.0.1.jar +kerb-core-1.0.1.jar +kerb-crypto-1.0.1.jar +kerb-identity-1.0.1.jar +kerb-server-1.0.1.jar +kerb-simplekdc-1.0.1.jar +kerb-util-1.0.1.jar +kerby-asn1-1.0.1.jar +kerby-config-1.0.1.jar +kerby-pkix-1.0.1.jar +kerby-util-1.0.1.jar +kerby-xdr-1.0.1.jar +kryo-shaded-3.0.3.jar +kubernetes-client-3.0.0.jar +kubernetes-model-2.0.0.jar +leveldbjni-all-1.8.jar +libfb303-0.9.3.jar +libthrift-0.9.3.jar +log4j-1.2.17.jar +logging-interceptor-3.8.1.jar +lz4-java-1.4.0.jar +machinist_2.11-0.6.1.jar +macro-compat_2.11-1.1.1.jar +mesos-1.4.0-shaded-protobuf.jar +metrics-core-3.1.5.jar +metrics-graphite-3.1.5.jar +metrics-json-3.1.5.jar +metrics-jvm-3.1.5.jar +minlog-1.3.0.jar +mssql-jdbc-6.2.1.jre7.jar +netty-3.9.9.Final.jar +netty-all-4.1.17.Final.jar +nimbus-jose-jwt-4.41.1.jar +objenesis-2.1.jar +okhttp-2.7.5.jar +okhttp-3.8.1.jar +okio-1.13.0.jar +opencsv-2.3.jar +orc-core-1.4.3-nohive.jar +orc-mapreduce-1.4.3-nohive.jar +oro-2.0.8.jar +osgi-resource-locator-1.0.1.jar +paranamer-2.8.jar +parquet-column-1.8.2.jar +parquet-common-1.8.2.jar +parquet-encoding-1.8.2.jar +parquet-format-2.3.1.jar +parquet-hadoop-1.8.2.jar +parquet-hadoop-bundle-1.6.0.jar +parquet-jackson-1.8.2.jar +protobuf-java-2.5.0.jar +py4j-0.10.6.jar +pyrolite-4.13.jar +re2j-1.1.jar +scala-compiler-2.11.8.jar +scala-library-2.11.8.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.8.jar +scala-xml_2.11-1.0.5.jar +shapeless_2.11-2.3.2.jar +slf4j-api-1.7.16.jar +slf4j-log4j12-1.7.16.jar +snakeyaml-1.15.jar +snappy-0.2.jar +snappy-java-1.1.7.1.jar +spire-macros_2.11-0.13.0.jar +spire_2.11-0.13.0.jar +stax-api-1.0.1.jar +stax2-api-3.1.4.jar +stream-2.7.0.jar +stringtemplate-3.2.1.jar +super-csv-2.2.0.jar +token-provider-1.0.1.jar +univocity-parsers-2.5.9.jar +validation-api-1.1.0.Final.jar +woodstox-core-5.0.3.jar +xbean-asm5-shaded-4.4.jar +xz-1.0.jar +zjsonpatch-0.3.0.jar +zookeeper-3.4.9.jar +zstd-jni-1.3.2-2.jar diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 3bf7618e1ea96..2fbd6b5e98f7f 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -34,6 +34,7 @@ MVN="build/mvn" HADOOP_PROFILES=( hadoop-2.6 hadoop-2.7 + hadoop-3.1 ) # We'll switch the version to a temp. one, publish POMs using that new version, then switch back to diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index 8e424b1c50236..2c39a7df0146e 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -38,7 +38,32 @@ hadoop-cloud + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.hadoop + hadoop-client + ${hadoop.version} + provided + + + + hadoop-3.1 + + + + org.apache.hadoop + hadoop-cloud-storage + ${hadoop.version} + ${hadoop.deps.scope} + + + org.apache.hadoop + hadoop-common + + + org.codehaus.jackson + jackson-mapper-asl + + + com.fasterxml.jackson.core + jackson-core + + + com.google.guava + guava + + + + + + org.eclipse.jetty + jetty-util + ${hadoop.deps.scope} + + + org.eclipse.jetty + jetty-util-ajax + ${jetty.version} + ${hadoop.deps.scope} + + + + diff --git a/pom.xml b/pom.xml index 0a711f287a53f..88e77ff874748 100644 --- a/pom.xml +++ b/pom.xml @@ -2671,6 +2671,15 @@ + + hadoop-3.1 + + 3.1.0 + 2.12.0 + 3.4.9 + + + yarn From 83013752e3cfcbc3edeef249439ac20b143eeabc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 24 Apr 2018 10:40:25 -0700 Subject: [PATCH 682/774] [SPARK-23455][ML] Default Params in ML should be saved separately in metadata ## What changes were proposed in this pull request? We save ML's user-supplied params and default params as one entity in metadata. During loading the saved models, we set all the loaded params into created ML model instances as user-supplied params. It causes some problems, e.g., if we strictly disallow some params to be set at the same time, a default param can fail the param check because it is treated as user-supplied param after loading. The loaded default params should not be set as user-supplied params. We should save ML default params separately in metadata. For backward compatibility, when loading metadata, if it is a metadata file from previous Spark, we shouldn't raise error if we can't find the default param field. ## How was this patch tested? Pass existing tests and added tests. Author: Liang-Chi Hsieh Closes #20633 from viirya/save-ml-default-params. --- .../DecisionTreeClassifier.scala | 2 +- .../ml/classification/GBTClassifier.scala | 4 +- .../spark/ml/classification/LinearSVC.scala | 2 +- .../classification/LogisticRegression.scala | 2 +- .../MultilayerPerceptronClassifier.scala | 2 +- .../spark/ml/classification/NaiveBayes.scala | 2 +- .../spark/ml/classification/OneVsRest.scala | 4 +- .../RandomForestClassifier.scala | 4 +- .../spark/ml/clustering/BisectingKMeans.scala | 2 +- .../spark/ml/clustering/GaussianMixture.scala | 2 +- .../apache/spark/ml/clustering/KMeans.scala | 2 +- .../org/apache/spark/ml/clustering/LDA.scala | 4 +- .../feature/BucketedRandomProjectionLSH.scala | 2 +- .../apache/spark/ml/feature/Bucketizer.scala | 24 ---- .../spark/ml/feature/ChiSqSelector.scala | 2 +- .../spark/ml/feature/CountVectorizer.scala | 2 +- .../org/apache/spark/ml/feature/IDF.scala | 2 +- .../org/apache/spark/ml/feature/Imputer.scala | 2 +- .../spark/ml/feature/MaxAbsScaler.scala | 2 +- .../apache/spark/ml/feature/MinHashLSH.scala | 2 +- .../spark/ml/feature/MinMaxScaler.scala | 2 +- .../ml/feature/OneHotEncoderEstimator.scala | 2 +- .../org/apache/spark/ml/feature/PCA.scala | 2 +- .../ml/feature/QuantileDiscretizer.scala | 24 ---- .../apache/spark/ml/feature/RFormula.scala | 6 +- .../spark/ml/feature/StandardScaler.scala | 2 +- .../spark/ml/feature/StringIndexer.scala | 2 +- .../spark/ml/feature/VectorIndexer.scala | 2 +- .../apache/spark/ml/feature/Word2Vec.scala | 2 +- .../org/apache/spark/ml/fpm/FPGrowth.scala | 2 +- .../org/apache/spark/ml/param/params.scala | 13 +- .../apache/spark/ml/recommendation/ALS.scala | 2 +- .../ml/regression/AFTSurvivalRegression.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../spark/ml/regression/GBTRegressor.scala | 4 +- .../GeneralizedLinearRegression.scala | 2 +- .../ml/regression/IsotonicRegression.scala | 2 +- .../ml/regression/LinearRegression.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 4 +- .../spark/ml/tuning/CrossValidator.scala | 6 +- .../ml/tuning/TrainValidationSplit.scala | 6 +- .../org/apache/spark/ml/util/ReadWrite.scala | 130 ++++++++++++------ .../spark/ml/util/DefaultReadWriteTest.scala | 73 +++++++++- project/MimaExcludes.scala | 6 + 44 files changed, 223 insertions(+), 147 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 771cd4fe91dcf..57797d1cc4978 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -279,7 +279,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica val root = loadTreeNodes(path, metadata, sparkSession, isClassification = true) val model = new DecisionTreeClassificationModel(metadata.uid, root.asInstanceOf[ClassificationNode], numFeatures, numClasses) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index c0255103bc313..0aa24f0a3cfcc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -379,14 +379,14 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { case (treeMetadata, root) => val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 8f950cd28c3aa..80c537e1e0eb2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -377,7 +377,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] { val Row(coefficients: Vector, intercept: Double) = data.select("coefficients", "intercept").head() val model = new LinearSVCModel(metadata.uid, coefficients, intercept) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index ee4b01058c75c..e426263910f26 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1270,7 +1270,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { numClasses, isMultinomial) } - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index af2e4699924e5..57ba47e596a97 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -388,7 +388,7 @@ object MultilayerPerceptronClassificationModel val weights = data.getAs[Vector](1) val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 0293e03d47435..45fb585ed2262 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -407,7 +407,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { .head() val model = new NaiveBayesModel(metadata.uid, pi, theta) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 5348d882cfd67..7df53a6b8ad10 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -289,7 +289,7 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] { DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc) } val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models) - DefaultParamsReader.getAndSetParams(ovrModel, metadata) + metadata.getAndSetParams(ovrModel) ovrModel.set("classifier", classifier) ovrModel } @@ -484,7 +484,7 @@ object OneVsRest extends MLReadable[OneVsRest] { override def load(path: String): OneVsRest = { val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className) val ovr = new OneVsRest(metadata.uid) - DefaultParamsReader.getAndSetParams(ovr, metadata) + metadata.getAndSetParams(ovr) ovr.setClassifier(classifier) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index bb972e9706fc1..f1ef26a07d3f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -319,14 +319,14 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica case (treeMetadata, root) => val tree = new DecisionTreeClassificationModel(treeMetadata.uid, root.asInstanceOf[ClassificationNode], numFeatures, numClasses) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } require(numTrees == trees.length, s"RandomForestClassificationModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") val model = new RandomForestClassificationModel(metadata.uid, trees, numFeatures, numClasses) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index f7c422dc0faea..addc12ac52ec1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -193,7 +193,7 @@ object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] { val dataPath = new Path(path, "data").toString val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath) val model = new BisectingKMeansModel(metadata.uid, mllibModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index f19ad7a5a6938..b5804900c0358 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -233,7 +233,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { } val model = new GaussianMixtureModel(metadata.uid, weights, gaussians) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index d475c726e6f08..de61c9c089a36 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -280,7 +280,7 @@ object KMeansModel extends MLReadable[KMeansModel] { sparkSession.read.parquet(dataPath).as[OldData].head().clusterCenters } val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 4bab670cc159f..47077230fac0a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -366,7 +366,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM private object LDAParams { /** - * Equivalent to [[DefaultParamsReader.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]] + * Equivalent to [[Metadata.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]] * formats saved with Spark 1.6, which differ from the formats in Spark 2.0+. * * @param model [[LDA]] or [[LDAModel]] instance. This instance will be modified with @@ -391,7 +391,7 @@ private object LDAParams { s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") } case _ => // 2.0+ - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index 41eaaf9679914..a906e954fecd5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -238,7 +238,7 @@ object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProject val model = new BucketedRandomProjectionLSHModel(metadata.uid, randUnitVectors.rowIter.toArray) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index f49c410cbcfe2..f99649f7fa164 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -217,8 +217,6 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String override def copy(extra: ParamMap): Bucketizer = { defaultCopy[Bucketizer](extra).setParent(parent) } - - override def write: MLWriter = new Bucketizer.BucketizerWriter(this) } @Since("1.6.0") @@ -296,28 +294,6 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { } } - - private[Bucketizer] class BucketizerWriter(instance: Bucketizer) extends MLWriter { - - override protected def saveImpl(path: String): Unit = { - // SPARK-23377: The default params will be saved and loaded as user-supplied params. - // Once `inputCols` is set, the default value of `outputCol` param causes the error - // when checking exclusive params. As a temporary to fix it, we skip the default value - // of `outputCol` if `inputCols` is set when saving the metadata. - // TODO: If we modify the persistence mechanism later to better handle default params, - // we can get rid of this. - var paramWithoutOutputCol: Option[JValue] = None - if (instance.isSet(instance.inputCols)) { - val params = instance.extractParamMap().toSeq - val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) => - p.name -> parse(p.jsonEncode(v)) - }.toList - paramWithoutOutputCol = Some(render(jsonParams)) - } - DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol) - } - } - @Since("1.6.0") override def load(path: String): Bucketizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 16abc4949dea3..dbfb199ccd58f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -334,7 +334,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] { val selectedFeatures = data.getAs[Seq[Int]](0).toArray val oldModel = new feature.ChiSqSelectorModel(selectedFeatures) val model = new ChiSqSelectorModel(metadata.uid, oldModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 9e0ed437e7bfc..10c48c3f52085 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -363,7 +363,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] { .head() val vocabulary = data.getAs[Seq[String]](0).toArray val model = new CountVectorizerModel(metadata.uid, vocabulary) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 46a0730f5ddb8..58897cca4e5c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -182,7 +182,7 @@ object IDFModel extends MLReadable[IDFModel] { .select("idf") .head() val model = new IDFModel(metadata.uid, new feature.IDFModel(OldVectors.fromML(idf))) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 730ee9fc08db8..1c074e204ad99 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -262,7 +262,7 @@ object ImputerModel extends MLReadable[ImputerModel] { val dataPath = new Path(path, "data").toString val surrogateDF = sqlContext.read.parquet(dataPath) val model = new ImputerModel(metadata.uid, surrogateDF) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index 85f9732f79f67..90eceb0d61b40 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -172,7 +172,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] { .select("maxAbs") .head() val model = new MaxAbsScalerModel(metadata.uid, maxAbs) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index 556848e45532d..a67a3b0abbc1f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -205,7 +205,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] { .map(tuple => (tuple(0), tuple(1))).toArray val model = new MinHashLSHModel(metadata.uid, randCoefficients) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index f648deced54cd..2e0ae4af66f06 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -243,7 +243,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { .select("originalMin", "originalMax") .head() val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index bd1e3426c8780..4a44f3186538d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -386,7 +386,7 @@ object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] { .head() val categorySizes = data.getAs[Seq[Int]](0).toArray val model = new OneHotEncoderModel(metadata.uid, categorySizes) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 4143d864d7930..8172491a517d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -220,7 +220,7 @@ object PCAModel extends MLReadable[PCAModel] { new PCAModel(metadata.uid, pc.asML, Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]) } - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 3b4c25478fb1d..56e2c543d100a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -253,35 +253,11 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("1.6.0") override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra) - - override def write: MLWriter = new QuantileDiscretizer.QuantileDiscretizerWriter(this) } @Since("1.6.0") object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { - private[QuantileDiscretizer] - class QuantileDiscretizerWriter(instance: QuantileDiscretizer) extends MLWriter { - - override protected def saveImpl(path: String): Unit = { - // SPARK-23377: The default params will be saved and loaded as user-supplied params. - // Once `inputCols` is set, the default value of `outputCol` param causes the error - // when checking exclusive params. As a temporary to fix it, we skip the default value - // of `outputCol` if `inputCols` is set when saving the metadata. - // TODO: If we modify the persistence mechanism later to better handle default params, - // we can get rid of this. - var paramWithoutOutputCol: Option[JValue] = None - if (instance.isSet(instance.inputCols)) { - val params = instance.extractParamMap().toSeq - val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) => - p.name -> parse(p.jsonEncode(v)) - }.toList - paramWithoutOutputCol = Some(render(jsonParams)) - } - DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol) - } - } - @Since("1.6.0") override def load(path: String): QuantileDiscretizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index e214765e3307f..55e595eee6ffb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -446,7 +446,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] { val model = new RFormulaModel(metadata.uid, resolvedRFormula, pipelineModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } @@ -510,7 +510,7 @@ private object ColumnPruner extends MLReadable[ColumnPruner] { val columnsToPrune = data.getAs[Seq[String]](0).toSet val pruner = new ColumnPruner(metadata.uid, columnsToPrune) - DefaultParamsReader.getAndSetParams(pruner, metadata) + metadata.getAndSetParams(pruner) pruner } } @@ -602,7 +602,7 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite val prefixesToRewrite = data.getAs[Map[String, String]](1) val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite) - DefaultParamsReader.getAndSetParams(rewriter, metadata) + metadata.getAndSetParams(rewriter) rewriter } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 8f125d8fd51d2..91b0707dec3f3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -212,7 +212,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { .select("std", "mean") .head() val model = new StandardScalerModel(metadata.uid, std, mean) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 67cdb097217a2..a833d8b270cf1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -315,7 +315,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { .head() val labels = data.getAs[Seq[String]](0).toArray val model = new StringIndexerModel(metadata.uid, labels) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index e6ec4e2e36ff0..0e7396a621dbd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -537,7 +537,7 @@ object VectorIndexerModel extends MLReadable[VectorIndexerModel] { val numFeatures = data.getAs[Int](0) val categoryMaps = data.getAs[Map[Int, Map[Double, Int]]](1) val model = new VectorIndexerModel(metadata.uid, numFeatures, categoryMaps) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index fe3306e1e50d6..fc9996d69ba72 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -410,7 +410,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] { } val model = new Word2VecModel(metadata.uid, oldModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 3d041fc80eb7f..0bf405d9abf9d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -335,7 +335,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { val dataPath = new Path(path, "data").toString val frequentItems = sparkSession.read.parquet(dataPath) val model = new FPGrowthModel(metadata.uid, frequentItems) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 9a83a5882ce29..e6c347ed17c15 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -865,10 +865,10 @@ trait Params extends Identifiable with Serializable { } /** Internal param map for user-supplied values. */ - private val paramMap: ParamMap = ParamMap.empty + private[ml] val paramMap: ParamMap = ParamMap.empty /** Internal param map for default values. */ - private val defaultParamMap: ParamMap = ParamMap.empty + private[ml] val defaultParamMap: ParamMap = ParamMap.empty /** Validates that the input param belongs to this instance. */ private def shouldOwn(param: Param[_]): Unit = { @@ -905,6 +905,15 @@ trait Params extends Identifiable with Serializable { } } +private[ml] object Params { + /** + * Sets a default param value for a `Params`. + */ + private[ml] final def setDefault[T](params: Params, param: Param[T], value: T): Unit = { + params.defaultParamMap.put(param -> value) + } +} + /** * :: DeveloperApi :: * Java-friendly wrapper for [[Params]]. diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 81a8f50761e0e..a23f9552b9e5f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -529,7 +529,7 @@ object ALSModel extends MLReadable[ALSModel] { val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 4b46c3831d75f..7c6ec2a8419fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -423,7 +423,7 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] .head() val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 5cef5c9f21f1e..8bcf0793a64c1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -282,7 +282,7 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode val root = loadTreeNodes(path, metadata, sparkSession, isClassification = false) val model = new DecisionTreeRegressionModel(metadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 834aaa0e362d1..8598e808c4946 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -311,7 +311,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { case (treeMetadata, root) => val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } @@ -319,7 +319,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { s" trees based on metadata but found ${trees.length} trees.") val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 4c3f1431d5077..e030a40cb19be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1146,7 +1146,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr val model = new GeneralizedLinearRegressionModel(metadata.uid, coefficients, intercept) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 8faab52ea474b..b046897ab2b7e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -308,7 +308,7 @@ object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { val model = new IsotonicRegressionModel( metadata.uid, new MLlibIsotonicRegressionModel(boundaries, predictions, isotonic)) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 9cdd3a051e719..f1d9a4453deaa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -799,7 +799,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { new LinearRegressionModel(metadata.uid, coefficients, intercept, scale) } - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 7f77398ba2a22..4509f85aafd12 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -276,14 +276,14 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } require(numTrees == trees.length, s"RandomForestRegressionModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") val model = new RandomForestRegressionModel(metadata.uid, trees, numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index c2826dcc08634..5e916cc4a9fdd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -234,8 +234,7 @@ object CrossValidator extends MLReadable[CrossValidator] { .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - DefaultParamsReader.getAndSetParams(cv, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(cv, skipParams = Option(List("estimatorParamMaps"))) cv } } @@ -424,8 +423,7 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - DefaultParamsReader.getAndSetParams(model, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(model, skipParams = Option(List("estimatorParamMaps"))) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 8d1b9a8ddab59..13369c4df7180 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -228,8 +228,7 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] { .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - DefaultParamsReader.getAndSetParams(tvs, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(tvs, skipParams = Option(List("estimatorParamMaps"))) tvs } } @@ -407,8 +406,7 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - DefaultParamsReader.getAndSetParams(model, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(model, skipParams = Option(List("estimatorParamMaps"))) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 7edcd498678cc..72a60e04360d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -39,7 +39,7 @@ import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param.{ParamPair, Params} import org.apache.spark.ml.tuning.ValidatorParams import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.util.Utils +import org.apache.spark.util.{Utils, VersionUtils} /** * Trait for `MLWriter` and `MLReader`. @@ -421,6 +421,7 @@ private[ml] object DefaultParamsWriter { * - timestamp * - sparkVersion * - uid + * - defaultParamMap * - paramMap * - (optionally, extra metadata) * @@ -453,15 +454,20 @@ private[ml] object DefaultParamsWriter { paramMap: Option[JValue] = None): String = { val uid = instance.uid val cls = instance.getClass.getName - val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] + val params = instance.paramMap.toSeq + val defaultParams = instance.defaultParamMap.toSeq val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }.toList)) + val jsonDefaultParams = render(defaultParams.map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList) val basicMetadata = ("class" -> cls) ~ ("timestamp" -> System.currentTimeMillis()) ~ ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ - ("paramMap" -> jsonParams) + ("paramMap" -> jsonParams) ~ + ("defaultParamMap" -> jsonDefaultParams) val metadata = extraMetadata match { case Some(jObject) => basicMetadata ~ jObject @@ -488,7 +494,7 @@ private[ml] class DefaultParamsReader[T] extends MLReader[T] { val cls = Utils.classForName(metadata.className) val instance = cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params] - DefaultParamsReader.getAndSetParams(instance, metadata) + metadata.getAndSetParams(instance) instance.asInstanceOf[T] } } @@ -499,6 +505,8 @@ private[ml] object DefaultParamsReader { * All info from metadata file. * * @param params paramMap, as a `JValue` + * @param defaultParams defaultParamMap, as a `JValue`. For metadata file prior to Spark 2.4, + * this is `JNothing`. * @param metadata All metadata, including the other fields * @param metadataJson Full metadata file String (for debugging) */ @@ -508,27 +516,90 @@ private[ml] object DefaultParamsReader { timestamp: Long, sparkVersion: String, params: JValue, + defaultParams: JValue, metadata: JValue, metadataJson: String) { + + private def getValueFromParams(params: JValue): Seq[(String, JValue)] = { + params match { + case JObject(pairs) => pairs + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: $metadataJson.") + } + } + /** * Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name. * This can be useful for getting a Param value before an instance of `Params` - * is available. + * is available. This will look up `params` first, if not existing then looking up + * `defaultParams`. */ def getParamValue(paramName: String): JValue = { implicit val format = DefaultFormats - params match { + + // Looking up for `params` first. + var pairs = getValueFromParams(params) + var foundPairs = pairs.filter { case (pName, jsonValue) => + pName == paramName + } + if (foundPairs.length == 0) { + // Looking up for `defaultParams` then. + pairs = getValueFromParams(defaultParams) + foundPairs = pairs.filter { case (pName, jsonValue) => + pName == paramName + } + } + assert(foundPairs.length == 1, s"Expected one instance of Param '$paramName' but found" + + s" ${foundPairs.length} in JSON Params: " + pairs.map(_.toString).mkString(", ")) + + foundPairs.map(_._2).head + } + + /** + * Extract Params from metadata, and set them in the instance. + * This works if all Params (except params included by `skipParams` list) implement + * [[org.apache.spark.ml.param.Param.jsonDecode()]]. + * + * @param skipParams The params included in `skipParams` won't be set. This is useful if some + * params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]] + * and need special handling. + */ + def getAndSetParams( + instance: Params, + skipParams: Option[List[String]] = None): Unit = { + setParams(instance, skipParams, isDefault = false) + + // For metadata file prior to Spark 2.4, there is no default section. + val (major, minor) = VersionUtils.majorMinorVersion(sparkVersion) + if (major > 2 || (major == 2 && minor >= 4)) { + setParams(instance, skipParams, isDefault = true) + } + } + + private def setParams( + instance: Params, + skipParams: Option[List[String]], + isDefault: Boolean): Unit = { + implicit val format = DefaultFormats + val paramsToSet = if (isDefault) defaultParams else params + paramsToSet match { case JObject(pairs) => - val values = pairs.filter { case (pName, jsonValue) => - pName == paramName - }.map(_._2) - assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" + - s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", ")) - values.head + pairs.foreach { case (paramName, jsonValue) => + if (skipParams == None || !skipParams.get.contains(paramName)) { + val param = instance.getParam(paramName) + val value = param.jsonDecode(compact(render(jsonValue))) + if (isDefault) { + Params.setDefault(instance, param, value) + } else { + instance.set(param, value) + } + } + } case _ => throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: $metadataJson.") + s"Cannot recognize JSON metadata: ${metadataJson}.") } } } @@ -561,43 +632,14 @@ private[ml] object DefaultParamsReader { val uid = (metadata \ "uid").extract[String] val timestamp = (metadata \ "timestamp").extract[Long] val sparkVersion = (metadata \ "sparkVersion").extract[String] + val defaultParams = metadata \ "defaultParamMap" val params = metadata \ "paramMap" if (expectedClassName.nonEmpty) { require(className == expectedClassName, s"Error loading metadata: Expected class name" + s" $expectedClassName but found class name $className") } - Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr) - } - - /** - * Extract Params from metadata, and set them in the instance. - * This works if all Params (except params included by `skipParams` list) implement - * [[org.apache.spark.ml.param.Param.jsonDecode()]]. - * - * @param skipParams The params included in `skipParams` won't be set. This is useful if some - * params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]] - * and need special handling. - * TODO: Move to [[Metadata]] method - */ - def getAndSetParams( - instance: Params, - metadata: Metadata, - skipParams: Option[List[String]] = None): Unit = { - implicit val format = DefaultFormats - metadata.params match { - case JObject(pairs) => - pairs.foreach { case (paramName, jsonValue) => - if (skipParams == None || !skipParams.get.contains(paramName)) { - val param = instance.getParam(paramName) - val value = param.jsonDecode(compact(render(jsonValue))) - instance.set(param, value) - } - } - case _ => - throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") - } + Metadata(className, uid, timestamp, sparkVersion, params, defaultParams, metadata, metadataStr) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 4da95e74434ee..4d9e664850c12 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -19,9 +19,10 @@ package org.apache.spark.ml.util import java.io.{File, IOException} +import org.json4s.JNothing import org.scalatest.Suite -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -129,6 +130,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => class MyParams(override val uid: String) extends Params with MLWritable { final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc") + final val shouldNotSetIfSetintParamWithDefault: IntParam = + new IntParam(this, "shouldNotSetIfSetintParamWithDefault", "doc") final val intParam: IntParam = new IntParam(this, "intParam", "doc") final val floatParam: FloatParam = new FloatParam(this, "floatParam", "doc") final val doubleParam: DoubleParam = new DoubleParam(this, "doubleParam", "doc") @@ -150,6 +153,13 @@ class MyParams(override val uid: String) extends Params with MLWritable { set(doubleArrayParam -> Array(8.0, 9.0)) set(stringArrayParam -> Array("10", "11")) + def checkExclusiveParams(): Unit = { + if (isSet(shouldNotSetIfSetintParamWithDefault) && isSet(intParamWithDefault)) { + throw new SparkException("intParamWithDefault and shouldNotSetIfSetintParamWithDefault " + + "shouldn't be set at the same time") + } + } + override def copy(extra: ParamMap): Params = defaultCopy(extra) override def write: MLWriter = new DefaultParamsWriter(this) @@ -169,4 +179,65 @@ class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext val myParams = new MyParams("my_params") testDefaultReadWrite(myParams) } + + test("default param shouldn't become user-supplied param after persistence") { + val myParams = new MyParams("my_params") + myParams.set(myParams.shouldNotSetIfSetintParamWithDefault, 1) + myParams.checkExclusiveParams() + val loadedMyParams = testDefaultReadWrite(myParams) + loadedMyParams.checkExclusiveParams() + assert(loadedMyParams.getDefault(loadedMyParams.intParamWithDefault) == + myParams.getDefault(myParams.intParamWithDefault)) + + loadedMyParams.set(myParams.intParamWithDefault, 1) + intercept[SparkException] { + loadedMyParams.checkExclusiveParams() + } + } + + test("User-supplied value for default param should be kept after persistence") { + val myParams = new MyParams("my_params") + myParams.set(myParams.intParamWithDefault, 100) + val loadedMyParams = testDefaultReadWrite(myParams) + assert(loadedMyParams.get(myParams.intParamWithDefault).get == 100) + } + + test("Read metadata without default field prior to 2.4") { + // default params are saved in `paramMap` field in metadata file prior to Spark 2.4. + val metadata = """{"class":"org.apache.spark.ml.util.MyParams", + |"timestamp":1518852502761,"sparkVersion":"2.3.0", + |"uid":"my_params", + |"paramMap":{"intParamWithDefault":0}}""".stripMargin + val parsedMetadata = DefaultParamsReader.parseMetadata(metadata) + val myParams = new MyParams("my_params") + assert(!myParams.isSet(myParams.intParamWithDefault)) + parsedMetadata.getAndSetParams(myParams) + + // The behavior prior to Spark 2.4, default params are set in loaded ML instance. + assert(myParams.isSet(myParams.intParamWithDefault)) + } + + test("Should raise error when read metadata without default field after Spark 2.4") { + val myParams = new MyParams("my_params") + + val metadata1 = """{"class":"org.apache.spark.ml.util.MyParams", + |"timestamp":1518852502761,"sparkVersion":"2.4.0", + |"uid":"my_params", + |"paramMap":{"intParamWithDefault":0}}""".stripMargin + val parsedMetadata1 = DefaultParamsReader.parseMetadata(metadata1) + val err1 = intercept[IllegalArgumentException] { + parsedMetadata1.getAndSetParams(myParams) + } + assert(err1.getMessage().contains("Cannot recognize JSON metadata")) + + val metadata2 = """{"class":"org.apache.spark.ml.util.MyParams", + |"timestamp":1518852502761,"sparkVersion":"3.0.0", + |"uid":"my_params", + |"paramMap":{"intParamWithDefault":0}}""".stripMargin + val parsedMetadata2 = DefaultParamsReader.parseMetadata(metadata2) + val err2 = intercept[IllegalArgumentException] { + parsedMetadata2.getAndSetParams(myParams) + } + assert(err2.getMessage().contains("Cannot recognize JSON metadata")) + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a87fa68422c34..7d0e88ee20c3f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -62,6 +62,12 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.cacheSize"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel"), + // [SPARK-23455][ML] Default Params in ML should be saved separately in metadata + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.paramMap"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$paramMap_="), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.defaultParamMap"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$defaultParamMap_="), + // [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.LeafNode"), ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"), From 379bffa0525a4343f8c10e51ed192031922f9874 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 24 Apr 2018 11:02:22 -0700 Subject: [PATCH 683/774] [SPARK-23990][ML] Instruments logging improvements - ML regression package ## What changes were proposed in this pull request? Instruments logging improvements - ML regression package I add an `OptionalInstrument` class which used in `WeightLeastSquares` and `IterativelyReweightedLeastSquares`. ## How was this patch tested? N/A Author: WeichenXu Closes #21078 from WeichenXu123/inst_reg. --- .../classification/LogisticRegression.scala | 4 +- .../IterativelyReweightedLeastSquares.scala | 18 +++-- .../spark/ml/optim/WeightedLeastSquares.scala | 32 +++++---- .../ml/regression/AFTSurvivalRegression.scala | 2 +- .../GeneralizedLinearRegression.scala | 14 ++-- .../ml/regression/LinearRegression.scala | 22 +++--- .../spark/ml/tree/impl/RandomForest.scala | 2 + .../spark/ml/util/Instrumentation.scala | 68 ++++++++++++++++++- 8 files changed, 125 insertions(+), 37 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index e426263910f26..06ca37bc75146 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -500,7 +500,7 @@ class LogisticRegression @Since("1.2.0") ( if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(regParam, elasticNetParam, standardization, threshold, maxIter, tol, fitIntercept) @@ -816,7 +816,7 @@ class LogisticRegression @Since("1.2.0") ( if (state == null) { val msg = s"${optimizer.getClass.getName} failed." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala index 6961b45f55e4d..572b8cf0051b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.optim -import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.util.OptionalInstrumentation import org.apache.spark.rdd.RDD /** @@ -61,9 +61,12 @@ private[ml] class IterativelyReweightedLeastSquares( val fitIntercept: Boolean, val regParam: Double, val maxIter: Int, - val tol: Double) extends Logging with Serializable { + val tol: Double) extends Serializable { - def fit(instances: RDD[OffsetInstance]): IterativelyReweightedLeastSquaresModel = { + def fit( + instances: RDD[OffsetInstance], + instr: OptionalInstrumentation = OptionalInstrumentation.create( + classOf[IterativelyReweightedLeastSquares])): IterativelyReweightedLeastSquaresModel = { var converged = false var iter = 0 @@ -83,7 +86,8 @@ private[ml] class IterativelyReweightedLeastSquares( // Estimate new model model = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, - standardizeFeatures = false, standardizeLabel = false).fit(newInstances) + standardizeFeatures = false, standardizeLabel = false) + .fit(newInstances, instr = instr) // Check convergence val oldCoefficients = oldModel.coefficients @@ -96,14 +100,14 @@ private[ml] class IterativelyReweightedLeastSquares( if (maxTol < tol) { converged = true - logInfo(s"IRLS converged in $iter iterations.") + instr.logInfo(s"IRLS converged in $iter iterations.") } - logInfo(s"Iteration $iter : relative tolerance = $maxTol") + instr.logInfo(s"Iteration $iter : relative tolerance = $maxTol") iter = iter + 1 if (iter == maxIter) { - logInfo(s"IRLS reached the max number of iterations: $maxIter.") + instr.logInfo(s"IRLS reached the max number of iterations: $maxIter.") } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index c5c9c8eb2bd29..1b7c15f1f0a8c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.optim -import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.util.OptionalInstrumentation import org.apache.spark.rdd.RDD /** @@ -81,13 +81,11 @@ private[ml] class WeightedLeastSquares( val standardizeLabel: Boolean, val solverType: WeightedLeastSquares.Solver = WeightedLeastSquares.Auto, val maxIter: Int = 100, - val tol: Double = 1e-6) extends Logging with Serializable { + val tol: Double = 1e-6 + ) extends Serializable { import WeightedLeastSquares._ require(regParam >= 0.0, s"regParam cannot be negative: $regParam") - if (regParam == 0.0) { - logWarning("regParam is zero, which might cause numerical instability and overfitting.") - } require(elasticNetParam >= 0.0 && elasticNetParam <= 1.0, s"elasticNetParam must be in [0, 1]: $elasticNetParam") require(maxIter >= 0, s"maxIter must be a positive integer: $maxIter") @@ -96,10 +94,17 @@ private[ml] class WeightedLeastSquares( /** * Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s. */ - def fit(instances: RDD[Instance]): WeightedLeastSquaresModel = { + def fit( + instances: RDD[Instance], + instr: OptionalInstrumentation = OptionalInstrumentation.create(classOf[WeightedLeastSquares]) + ): WeightedLeastSquaresModel = { + if (regParam == 0.0) { + instr.logWarning("regParam is zero, which might cause numerical instability and overfitting.") + } + val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_)) summary.validate() - logInfo(s"Number of instances: ${summary.count}.") + instr.logInfo(s"Number of instances: ${summary.count}.") val k = if (fitIntercept) summary.k + 1 else summary.k val numFeatures = summary.k val triK = summary.triK @@ -114,11 +119,12 @@ private[ml] class WeightedLeastSquares( if (rawBStd == 0) { if (fitIntercept || rawBBar == 0.0) { if (rawBBar == 0.0) { - logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + - s"and the intercept will all be zero; as a result, training is not needed.") + instr.logWarning(s"Mean and standard deviation of the label are zero, so the " + + s"coefficients and the intercept will all be zero; as a result, training is not " + + s"needed.") } else { - logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + - s"zeros and the intercept will be the mean of the label; as a result, " + + instr.logWarning(s"The standard deviation of the label is zero, so the coefficients " + + s"will be zeros and the intercept will be the mean of the label; as a result, " + s"training is not needed.") } val coefficients = new DenseVector(Array.ofDim(numFeatures)) @@ -128,7 +134,7 @@ private[ml] class WeightedLeastSquares( } else { require(!(regParam > 0.0 && standardizeLabel), "The standard deviation of the label is " + "zero. Model cannot be regularized with standardization=true") - logWarning(s"The standard deviation of the label is zero. Consider setting " + + instr.logWarning(s"The standard deviation of the label is zero. Consider setting " + s"fitIntercept=true.") } } @@ -256,7 +262,7 @@ private[ml] class WeightedLeastSquares( // if Auto solver is used and Cholesky fails due to singular AtA, then fall back to // Quasi-Newton solver. case _: SingularMatrixException if solverType == WeightedLeastSquares.Auto => - logWarning("Cholesky solver failed due to singular covariance matrix. " + + instr.logWarning("Cholesky solver failed due to singular covariance matrix. " + "Retrying with Quasi-Newton solver.") // ab and aa were modified in place, so reconstruct them val _aa = getAtA(aaBarValues, aBarValues) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 7c6ec2a8419fd..e27a96e1f5dfc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -237,7 +237,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S if (!$(fitIntercept) && (0 until numFeatures).exists { i => featuresStd(i) == 0.0 && featuresSummarizer.mean(i) != 0.0 }) { - logWarning("Fitting AFTSurvivalRegressionModel without intercept on dataset with " + + instr.logWarning("Fitting AFTSurvivalRegressionModel without intercept on dataset with " + "constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero " + "columns. This behavior is different from R survival::survreg.") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index e030a40cb19be..143c8a3548b1f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -404,7 +404,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val } val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true) - val wlsModel = optimizer.fit(instances) + val wlsModel = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) val model = copyValues( new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept) .setParent(this)) @@ -418,10 +418,11 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val OffsetInstance(label, weight, offset, features) } // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). - val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam)) + val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam), + instr = OptionalInstrumentation.create(instr)) val optimizer = new IterativelyReweightedLeastSquares(initialModel, familyAndLink.reweightFunc, $(fitIntercept), $(regParam), $(maxIter), $(tol)) - val irlsModel = optimizer.fit(instances) + val irlsModel = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) val model = copyValues( new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept) .setParent(this)) @@ -492,7 +493,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine def initialize( instances: RDD[OffsetInstance], fitIntercept: Boolean, - regParam: Double): WeightedLeastSquaresModel = { + regParam: Double, + instr: OptionalInstrumentation = OptionalInstrumentation.create( + classOf[GeneralizedLinearRegression]) + ): WeightedLeastSquaresModel = { val newInstances = instances.map { instance => val mu = family.initialize(instance.label, instance.weight) val eta = predict(mu) - instance.offset @@ -501,7 +505,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine // TODO: Make standardizeFeatures and standardizeLabel configurable. val initialModel = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true) - .fit(newInstances) + .fit(newInstances, instr) initialModel } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f1d9a4453deaa..c45ade94a4e33 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -339,7 +339,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = $(elasticNetParam), $(standardization), true, solverType = WeightedLeastSquares.Auto, maxIter = $(maxIter), tol = $(tol)) - val model = optimizer.fit(instances) + val model = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) // When it is trained by WeightedLeastSquares, training summary does not // attach returned model. val lrModel = copyValues(new LinearRegressionModel(uid, model.coefficients, model.intercept)) @@ -378,6 +378,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val yMean = ySummarizer.mean(0) val rawYStd = math.sqrt(ySummarizer.variance(0)) + + instr.logNumExamples(ySummarizer.count) + instr.logNamedValue(Instrumentation.loggerTags.meanOfLabels, yMean) + instr.logNamedValue(Instrumentation.loggerTags.varianceOfLabels, rawYStd) + if (rawYStd == 0.0) { if ($(fitIntercept) || yMean == 0.0) { // If the rawYStd==0 and fitIntercept==true, then the intercept is yMean with @@ -385,11 +390,12 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String // Also, if yMean==0 and rawYStd==0, all the coefficients are zero regardless of // the fitIntercept. if (yMean == 0.0) { - logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + - s"and the intercept will all be zero; as a result, training is not needed.") + instr.logWarning(s"Mean and standard deviation of the label are zero, so the " + + s"coefficients and the intercept will all be zero; as a result, training is not " + + s"needed.") } else { - logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + - s"zeros and the intercept will be the mean of the label; as a result, " + + instr.logWarning(s"The standard deviation of the label is zero, so the coefficients " + + s"will be zeros and the intercept will be the mean of the label; as a result, " + s"training is not needed.") } if (handlePersistence) instances.unpersist() @@ -415,7 +421,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } else { require($(regParam) == 0.0, "The standard deviation of the label is zero. " + "Model cannot be regularized.") - logWarning(s"The standard deviation of the label is zero. " + + instr.logWarning(s"The standard deviation of the label is zero. " + "Consider setting fitIntercept=true.") } } @@ -430,7 +436,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String if (!$(fitIntercept) && (0 until numFeatures).exists { i => featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) { - logWarning("Fitting LinearRegressionModel without intercept on dataset with " + + instr.logWarning("Fitting LinearRegressionModel without intercept on dataset with " + "constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero " + "columns. This behavior is the same as R glmnet but different from LIBSVM.") } @@ -522,7 +528,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } if (state == null) { val msg = s"${optimizer.getClass.getName} failed." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 056a94b351f79..905870178e549 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -108,9 +108,11 @@ private[spark] object RandomForest extends Logging { case Some(instrumentation) => instrumentation.logNumFeatures(metadata.numFeatures) instrumentation.logNumClasses(metadata.numClasses) + instrumentation.logNumExamples(metadata.numExamples) case None => logInfo("numFeatures: " + metadata.numFeatures) logInfo("numClasses: " + metadata.numClasses) + logInfo("numExamples: " + metadata.numExamples) } // Find the splits and the corresponding bins (interval between the splits) using a sample diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index e694bc27b2f1e..3247c394dfa64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -19,6 +19,8 @@ package org.apache.spark.ml.util import java.util.UUID +import scala.reflect.ClassTag + import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -40,7 +42,8 @@ import org.apache.spark.sql.Dataset * @tparam E the type of the estimator */ private[spark] class Instrumentation[E <: Estimator[_]] private ( - estimator: E, dataset: RDD[_]) extends Logging { + val estimator: E, + val dataset: RDD[_]) extends Logging { private val id = UUID.randomUUID() private val prefix = { @@ -103,6 +106,10 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( logNamedValue(Instrumentation.loggerTags.numClasses, num) } + def logNumExamples(num: Long): Unit = { + logNamedValue(Instrumentation.loggerTags.numExamples, num) + } + /** * Logs the value with customized name field. */ @@ -114,6 +121,10 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( log(compact(render(name -> value))) } + def logNamedValue(name: String, value: Double): Unit = { + log(compact(render(name -> value))) + } + /** * Logs the successful completion of the training session. */ @@ -131,6 +142,8 @@ private[spark] object Instrumentation { val numFeatures = "numFeatures" val numClasses = "numClasses" val numExamples = "numExamples" + val meanOfLabels = "meanOfLabels" + val varianceOfLabels = "varianceOfLabels" } /** @@ -150,3 +163,56 @@ private[spark] object Instrumentation { } } + +/** + * A small wrapper that contains an optional `Instrumentation` object. + * Provide some log methods, if the containing `Instrumentation` object is defined, + * will log via it, otherwise will log via common logger. + */ +private[spark] class OptionalInstrumentation private( + val instrumentation: Option[Instrumentation[_ <: Estimator[_]]], + val className: String) extends Logging { + + protected override def logName: String = className + + override def logInfo(msg: => String) { + instrumentation match { + case Some(instr) => instr.logInfo(msg) + case None => super.logInfo(msg) + } + } + + override def logWarning(msg: => String) { + instrumentation match { + case Some(instr) => instr.logWarning(msg) + case None => super.logWarning(msg) + } + } + + override def logError(msg: => String) { + instrumentation match { + case Some(instr) => instr.logError(msg) + case None => super.logError(msg) + } + } +} + +private[spark] object OptionalInstrumentation { + + /** + * Creates an `OptionalInstrumentation` object from an existing `Instrumentation` object. + */ + def create[E <: Estimator[_]](instr: Instrumentation[E]): OptionalInstrumentation = { + new OptionalInstrumentation(Some(instr), + instr.estimator.getClass.getName.stripSuffix("$")) + } + + /** + * Creates an `OptionalInstrumentation` object from a `Class` object. + * The created `OptionalInstrumentation` object will log messages via common logger and use the + * specified class name as logger name. + */ + def create(clazz: Class[_]): OptionalInstrumentation = { + new OptionalInstrumentation(None, clazz.getName.stripSuffix("$")) + } +} From 7b1e6523af3c96043aa8d2763e5f18b6e2781c3d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 24 Apr 2018 14:33:33 -0700 Subject: [PATCH 684/774] [SPARK-24056][SS] Make consumer creation lazy in Kafka source for Structured streaming ## What changes were proposed in this pull request? Currently, the driver side of the Kafka source (i.e. KafkaMicroBatchReader) eagerly creates a consumer as soon as the Kafk aMicroBatchReader is created. However, we create dummy KafkaMicroBatchReader to get the schema and immediately stop it. Its better to make the consumer creation lazy, it will be created on the first attempt to fetch offsets using the KafkaOffsetReader. ## How was this patch tested? Existing unit tests Author: Tathagata Das Closes #21134 from tdas/SPARK-24056. --- .../sql/kafka010/KafkaOffsetReader.scala | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 551641cfdbca8..82066697cb95a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -75,7 +75,17 @@ private[kafka010] class KafkaOffsetReader( * A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the * offsets and never commits them. */ - protected var consumer = createConsumer() + @volatile protected var _consumer: Consumer[Array[Byte], Array[Byte]] = null + + protected def consumer: Consumer[Array[Byte], Array[Byte]] = synchronized { + assert(Thread.currentThread().isInstanceOf[UninterruptibleThread]) + if (_consumer == null) { + val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams) + newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId()) + _consumer = consumerStrategy.createConsumer(newKafkaParams) + } + _consumer + } private val maxOffsetFetchAttempts = readerOptions.getOrElse("fetchOffset.numRetries", "3").toInt @@ -95,9 +105,7 @@ private[kafka010] class KafkaOffsetReader( * Closes the connection to Kafka, and cleans up state. */ def close(): Unit = { - runUninterruptibly { - consumer.close() - } + if (_consumer != null) runUninterruptibly { stopConsumer() } kafkaReaderThread.shutdown() } @@ -304,19 +312,14 @@ private[kafka010] class KafkaOffsetReader( } } - /** - * Create a consumer using the new generated group id. We always use a new consumer to avoid - * just using a broken consumer to retry on Kafka errors, which likely will fail again. - */ - private def createConsumer(): Consumer[Array[Byte], Array[Byte]] = synchronized { - val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams) - newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId()) - consumerStrategy.createConsumer(newKafkaParams) + private def stopConsumer(): Unit = synchronized { + assert(Thread.currentThread().isInstanceOf[UninterruptibleThread]) + if (_consumer != null) _consumer.close() } private def resetConsumer(): Unit = synchronized { - consumer.close() - consumer = createConsumer() + stopConsumer() + _consumer = null // will automatically get reinitialized again } } From d6c26d1c9a8f747a3e0d281a27ea9eb4d92102e5 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 24 Apr 2018 17:06:03 -0700 Subject: [PATCH 685/774] [SPARK-24038][SS] Refactor continuous writing to its own class ## What changes were proposed in this pull request? Refactor continuous writing to its own class. See WIP https://github.com/jose-torres/spark/pull/13 for the overall direction this is going, but I think this PR is very isolated and necessary anyway. ## How was this patch tested? existing unit tests - refactoring only Author: Jose Torres Closes #21116 from jose-torres/SPARK-24038. --- .../datasources/v2/DataSourceV2Strategy.scala | 4 + .../datasources/v2/WriteToDataSourceV2.scala | 74 +---------- .../continuous/ContinuousExecution.scala | 2 +- .../WriteToContinuousDataSource.scala | 31 +++++ .../WriteToContinuousDataSourceExec.scala | 124 ++++++++++++++++++ 5 files changed, 165 insertions(+), 70 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 1ac9572de6412..c2a31442d2be5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { @@ -32,6 +33,9 @@ object DataSourceV2Strategy extends Strategy { case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil + case WriteToContinuousDataSource(writer, query) => + WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index e80b44c1cdc66..ea283ed77efda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -65,25 +65,10 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e s"The input RDD has ${messages.length} partitions.") try { - val runTask = writer match { - // This case means that we're doing continuous processing. In microbatch streaming, the - // StreamWriter is wrapped in a MicroBatchWriter, which is executed as a normal batch. - case w: StreamWriter => - EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), - sparkContext.env) - .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) - - (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.runContinuous(writeTask, context, iter) - case _ => - (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator) - } - sparkContext.runJob( rdd, - runTask, + (context: TaskContext, iter: Iterator[InternalRow]) => + DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator), rdd.partitions.indices, (index, message: WriterCommitMessage) => { messages(index) = message @@ -91,14 +76,10 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e } ) - if (!writer.isInstanceOf[StreamWriter]) { - logInfo(s"Data source writer $writer is committing.") - writer.commit(messages) - logInfo(s"Data source writer $writer committed.") - } + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") } catch { - case _: InterruptedException if writer.isInstanceOf[StreamWriter] => - // Interruption is how continuous queries are ended, so accept and ignore the exception. case cause: Throwable => logError(s"Data source writer $writer is aborting.") try { @@ -111,8 +92,6 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e } logError(s"Data source writer $writer aborted.") cause match { - // Do not wrap interruption exceptions that will be handled by streaming specially. - case _ if StreamExecution.isInterruptionException(cause) => throw cause // Only wrap non fatal exceptions. case NonFatal(e) => throw new SparkException("Writing job aborted.", e) case _ => throw cause @@ -168,49 +147,6 @@ object DataWritingSparkTask extends Logging { logError(s"Writer for stage $stageId, task $partId.$attemptId aborted.") }) } - - def runContinuous( - writeTask: DataWriterFactory[InternalRow], - context: TaskContext, - iter: Iterator[InternalRow]): WriterCommitMessage = { - val epochCoordinator = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), - SparkEnv.get) - val currentMsg: WriterCommitMessage = null - var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - do { - var dataWriter: DataWriter[InternalRow] = null - // write the data and commit this writer. - Utils.tryWithSafeFinallyAndFailureCallbacks(block = { - try { - dataWriter = writeTask.createDataWriter( - context.partitionId(), context.attemptNumber(), currentEpoch) - while (iter.hasNext) { - dataWriter.write(iter.next()) - } - logInfo(s"Writer for partition ${context.partitionId()} is committing.") - val msg = dataWriter.commit() - logInfo(s"Writer for partition ${context.partitionId()} committed.") - epochCoordinator.send( - CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) - ) - currentEpoch += 1 - } catch { - case _: InterruptedException => - // Continuous shutdown always involves an interrupt. Just finish the task. - } - })(catchBlock = { - // If there is an error, abort this writer. We enter this callback in the middle of - // rethrowing an exception, so runContinuous will stop executing at this point. - logError(s"Writer for partition ${context.partitionId()} is aborting.") - if (dataWriter != null) dataWriter.abort() - logError(s"Writer for partition ${context.partitionId()} aborted.") - }) - } while (!context.isInterrupted()) - - currentMsg - } } class InternalRowDataWriterFactory( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 951d694355ec5..f58146ac42398 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -199,7 +199,7 @@ class ContinuousExecution( triggerLogicalPlan.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) - val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) + val withSink = WriteToContinuousDataSource(writer, triggerLogicalPlan) val reader = withSink.collect { case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala new file mode 100644 index 0000000000000..943c731a70529 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter + +/** + * The logical plan for writing data in a continuous stream. + */ +case class WriteToContinuousDataSource( + writer: StreamWriter, query: LogicalPlan) extends LogicalPlan { + override def children: Seq[LogicalPlan] = Seq(query) + override def output: Seq[Attribute] = Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala new file mode 100644 index 0000000000000..ba88ae1af469a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import scala.util.control.NonFatal + +import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.v2.{DataWritingSparkTask, InternalRowDataWriterFactory} +import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo} +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.util.Utils + +/** + * The physical plan for writing data into a continuous processing [[StreamWriter]]. + */ +case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPlan) + extends SparkPlan with Logging { + override def children: Seq[SparkPlan] = Seq(query) + override def output: Seq[Attribute] = Nil + + override protected def doExecute(): RDD[InternalRow] = { + val writerFactory = writer match { + case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() + case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) + } + + val rdd = query.execute() + + logInfo(s"Start processing data source writer: $writer. " + + s"The input RDD has ${rdd.getNumPartitions} partitions.") + // Let the epoch coordinator know how many partitions the write RDD has. + EpochCoordinatorRef.get( + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + sparkContext.env) + .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) + + try { + // Force the RDD to run so continuous processing starts; no data is actually being collected + // to the driver, as ContinuousWriteRDD outputs nothing. + sparkContext.runJob( + rdd, + (context: TaskContext, iter: Iterator[InternalRow]) => + WriteToContinuousDataSourceExec.run(writerFactory, context, iter), + rdd.partitions.indices) + } catch { + case _: InterruptedException => + // Interruption is how continuous queries are ended, so accept and ignore the exception. + case cause: Throwable => + cause match { + // Do not wrap interruption exceptions that will be handled by streaming specially. + case _ if StreamExecution.isInterruptionException(cause) => throw cause + // Only wrap non fatal exceptions. + case NonFatal(e) => throw new SparkException("Writing job aborted.", e) + case _ => throw cause + } + } + + sparkContext.emptyRDD + } +} + +object WriteToContinuousDataSourceExec extends Logging { + def run( + writeTask: DataWriterFactory[InternalRow], + context: TaskContext, + iter: Iterator[InternalRow]): Unit = { + val epochCoordinator = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + SparkEnv.get) + var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + do { + var dataWriter: DataWriter[InternalRow] = null + // write the data and commit this writer. + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + try { + dataWriter = writeTask.createDataWriter( + context.partitionId(), context.attemptNumber(), currentEpoch) + while (iter.hasNext) { + dataWriter.write(iter.next()) + } + logInfo(s"Writer for partition ${context.partitionId()} is committing.") + val msg = dataWriter.commit() + logInfo(s"Writer for partition ${context.partitionId()} committed.") + epochCoordinator.send( + CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) + ) + currentEpoch += 1 + } catch { + case _: InterruptedException => + // Continuous shutdown always involves an interrupt. Just finish the task. + } + })(catchBlock = { + // If there is an error, abort this writer. We enter this callback in the middle of + // rethrowing an exception, so runContinuous will stop executing at this point. + logError(s"Writer for partition ${context.partitionId()} is aborting.") + if (dataWriter != null) dataWriter.abort() + logError(s"Writer for partition ${context.partitionId()} aborted.") + }) + } while (!context.isInterrupted()) + } +} From 5fea17b3befc50aef59b799711d03b9552f21b19 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Wed, 25 Apr 2018 11:19:08 +0900 Subject: [PATCH 686/774] [SPARK-23821][SQL] Collection function: flatten ## What changes were proposed in this pull request? This PR adds a new collection function that transforms an array of arrays into a single array. The PR comprises: - An expression for flattening array structure - Flatten function - A wrapper for PySpark ## How was this patch tested? New tests added into: - CollectionExpressionsSuite - DataFrameFunctionsSuite ## Codegen examples ### Primitive type ``` val df = Seq( Seq(Seq(1, 2), Seq(4, 5)), Seq(null, Seq(1)) ).toDF("i") df.filter($"i".isNotNull || $"i".isNull).select(flatten($"i")).debugCodegen ``` Result: ``` /* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 034 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 035 */ null : (inputadapter_row.getArray(0)); /* 036 */ /* 037 */ boolean filter_value = true; /* 038 */ /* 039 */ if (!(!inputadapter_isNull)) { /* 040 */ filter_value = inputadapter_isNull; /* 041 */ } /* 042 */ if (!filter_value) continue; /* 043 */ /* 044 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 045 */ /* 046 */ boolean project_isNull = inputadapter_isNull; /* 047 */ ArrayData project_value = null; /* 048 */ /* 049 */ if (!inputadapter_isNull) { /* 050 */ for (int z = 0; !project_isNull && z < inputadapter_value.numElements(); z++) { /* 051 */ project_isNull |= inputadapter_value.isNullAt(z); /* 052 */ } /* 053 */ if (!project_isNull) { /* 054 */ long project_numElements = 0; /* 055 */ for (int z = 0; z < inputadapter_value.numElements(); z++) { /* 056 */ project_numElements += inputadapter_value.getArray(z).numElements(); /* 057 */ } /* 058 */ if (project_numElements > 2147483632) { /* 059 */ throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + /* 060 */ project_numElements + " elements due to exceeding the array size limit 2147483632."); /* 061 */ } /* 062 */ /* 063 */ long project_size = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( /* 064 */ project_numElements, /* 065 */ 4); /* 066 */ if (project_size > 2147483632) { /* 067 */ throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + /* 068 */ project_size + " bytes of data due to exceeding the limit 2147483632" + /* 069 */ " bytes for UnsafeArrayData."); /* 070 */ } /* 071 */ /* 072 */ byte[] project_array = new byte[(int)project_size]; /* 073 */ UnsafeArrayData project_tempArrayData = new UnsafeArrayData(); /* 074 */ Platform.putLong(project_array, 16, project_numElements); /* 075 */ project_tempArrayData.pointTo(project_array, 16, (int)project_size); /* 076 */ int project_counter = 0; /* 077 */ for (int k = 0; k < inputadapter_value.numElements(); k++) { /* 078 */ ArrayData arr = inputadapter_value.getArray(k); /* 079 */ for (int l = 0; l < arr.numElements(); l++) { /* 080 */ if (arr.isNullAt(l)) { /* 081 */ project_tempArrayData.setNullAt(project_counter); /* 082 */ } else { /* 083 */ project_tempArrayData.setInt( /* 084 */ project_counter, /* 085 */ arr.getInt(l) /* 086 */ ); /* 087 */ } /* 088 */ project_counter++; /* 089 */ } /* 090 */ } /* 091 */ project_value = project_tempArrayData; /* 092 */ /* 093 */ } /* 094 */ /* 095 */ } ``` ### Non-primitive type ``` val df = Seq( Seq(Seq("a", "b"), Seq(null, "d")), Seq(null, Seq("a")) ).toDF("s") df.filter($"s".isNotNull || $"s".isNull).select(flatten($"s")).debugCodegen ``` Result: ``` /* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 034 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 035 */ null : (inputadapter_row.getArray(0)); /* 036 */ /* 037 */ boolean filter_value = true; /* 038 */ /* 039 */ if (!(!inputadapter_isNull)) { /* 040 */ filter_value = inputadapter_isNull; /* 041 */ } /* 042 */ if (!filter_value) continue; /* 043 */ /* 044 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 045 */ /* 046 */ boolean project_isNull = inputadapter_isNull; /* 047 */ ArrayData project_value = null; /* 048 */ /* 049 */ if (!inputadapter_isNull) { /* 050 */ for (int z = 0; !project_isNull && z < inputadapter_value.numElements(); z++) { /* 051 */ project_isNull |= inputadapter_value.isNullAt(z); /* 052 */ } /* 053 */ if (!project_isNull) { /* 054 */ long project_numElements = 0; /* 055 */ for (int z = 0; z < inputadapter_value.numElements(); z++) { /* 056 */ project_numElements += inputadapter_value.getArray(z).numElements(); /* 057 */ } /* 058 */ if (project_numElements > 2147483632) { /* 059 */ throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + /* 060 */ project_numElements + " elements due to exceeding the array size limit 2147483632."); /* 061 */ } /* 062 */ /* 063 */ Object[] project_arrayObject = new Object[(int)project_numElements]; /* 064 */ int project_counter = 0; /* 065 */ for (int k = 0; k < inputadapter_value.numElements(); k++) { /* 066 */ ArrayData arr = inputadapter_value.getArray(k); /* 067 */ for (int l = 0; l < arr.numElements(); l++) { /* 068 */ project_arrayObject[project_counter] = arr.getUTF8String(l); /* 069 */ project_counter++; /* 070 */ } /* 071 */ } /* 072 */ project_value = new org.apache.spark.sql.catalyst.util.GenericArrayData(project_arrayObject); /* 073 */ /* 074 */ } /* 075 */ /* 076 */ } ``` Author: mn-mikke Closes #20938 from mn-mikke/feature/array-api-flatten-to-master. --- python/pyspark/sql/functions.py | 17 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 176 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 95 ++++++++++ .../org/apache/spark/sql/functions.scala | 8 + .../spark/sql/DataFrameFunctionsSuite.scala | 79 ++++++++ 6 files changed, 376 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index da32ab25cad0c..de53b48b6f3b4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2191,6 +2191,23 @@ def reverse(col): return Column(sc._jvm.functions.reverse(_to_java_column(col))) +@since(2.4) +def flatten(col): + """ + Collection function: creates a single array from an array of arrays. + If a structure of nested arrays is deeper than two levels, + only one level of nesting is removed. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data']) + >>> df.select(flatten(df.data).alias('r')).collect() + [Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.flatten(_to_java_column(col))) + + @since(2.3) def map_keys(col): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c41f16c61d7a2..6afcf309bd690 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -413,6 +413,7 @@ object FunctionRegistry { expression[ArrayMax]("array_max"), expression[Reverse]("reverse"), expression[Concat]("concat"), + expression[Flatten]("flatten"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c16793bda028e..bc71b5f34ce4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -883,3 +883,179 @@ case class Concat(children: Seq[Expression]) extends Expression { override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" } + +/** + * Transforms an array of arrays into a single array. + */ +@ExpressionDescription( + usage = "_FUNC_(arrayOfArrays) - Transforms an array of arrays into a single array.", + examples = """ + Examples: + > SELECT _FUNC_(array(array(1, 2), array(3, 4)); + [1,2,3,4] + """, + since = "2.4.0") +case class Flatten(child: Expression) extends UnaryExpression { + + private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] + + override def nullable: Boolean = child.nullable || childDataType.containsNull + + override def dataType: DataType = childDataType.elementType + + lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case ArrayType(_: ArrayType, _) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure( + s"The argument should be an array of arrays, " + + s"but '${child.sql}' is of ${child.dataType.simpleString} type." + ) + } + + override def nullSafeEval(child: Any): Any = { + val elements = child.asInstanceOf[ArrayData].toObjectArray(dataType) + + if (elements.contains(null)) { + null + } else { + val arrayData = elements.map(_.asInstanceOf[ArrayData]) + val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements()) + if (numberOfElements > MAX_ARRAY_LENGTH) { + throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + + s"$numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + } + val flattenedData = new Array(numberOfElements.toInt) + var position = 0 + for (ad <- arrayData) { + val arr = ad.toObjectArray(elementType) + Array.copy(arr, 0, flattenedData, position, arr.length) + position += arr.length + } + new GenericArrayData(flattenedData) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val code = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForFlattenOfPrimitiveElements(ctx, c, ev.value) + } else { + genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value) + } + if (childDataType.containsNull) nullElementsProtection(ev, c, code) else code + }) + } + + private def nullElementsProtection( + ev: ExprCode, + childVariableName: String, + coreLogic: String): String = { + s""" + |for (int z = 0; !${ev.isNull} && z < $childVariableName.numElements(); z++) { + | ${ev.isNull} |= $childVariableName.isNullAt(z); + |} + |if (!${ev.isNull}) { + | $coreLogic + |} + """.stripMargin + } + + private def genCodeForNumberOfElements( + ctx: CodegenContext, + childVariableName: String) : (String, String) = { + val variableName = ctx.freshName("numElements") + val code = s""" + |long $variableName = 0; + |for (int z = 0; z < $childVariableName.numElements(); z++) { + | $variableName += $childVariableName.getArray(z).numElements(); + |} + |if ($variableName > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + + | $variableName + " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |} + """.stripMargin + (code, variableName) + } + + private def genCodeForFlattenOfPrimitiveElements( + ctx: CodegenContext, + childVariableName: String, + arrayDataName: String): String = { + val arrayName = ctx.freshName("array") + val arraySizeName = ctx.freshName("size") + val counter = ctx.freshName("counter") + val tempArrayDataName = ctx.freshName("tempArrayData") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) + + val unsafeArraySizeInBytes = s""" + |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElemName, + | ${elementType.defaultSize}); + |if ($arraySizeName > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + + | $arraySizeName + " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH" + + | " bytes for UnsafeArrayData."); + |} + """.stripMargin + val baseOffset = Platform.BYTE_ARRAY_OFFSET + + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + + s""" + |$numElemCode + |$unsafeArraySizeInBytes + |byte[] $arrayName = new byte[(int)$arraySizeName]; + |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData(); + |Platform.putLong($arrayName, $baseOffset, $numElemName); + |$tempArrayDataName.pointTo($arrayName, $baseOffset, (int)$arraySizeName); + |int $counter = 0; + |for (int k = 0; k < $childVariableName.numElements(); k++) { + | ArrayData arr = $childVariableName.getArray(k); + | for (int l = 0; l < arr.numElements(); l++) { + | if (arr.isNullAt(l)) { + | $tempArrayDataName.setNullAt($counter); + | } else { + | $tempArrayDataName.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue("arr", elementType, "l")} + | ); + | } + | $counter++; + | } + |} + |$arrayDataName = $tempArrayDataName; + """.stripMargin + } + + private def genCodeForFlattenOfNonPrimitiveElements( + ctx: CodegenContext, + childVariableName: String, + arrayDataName: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayName = ctx.freshName("arrayObject") + val counter = ctx.freshName("counter") + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) + + s""" + |$numElemCode + |Object[] $arrayName = new Object[(int)$numElemName]; + |int $counter = 0; + |for (int k = 0; k < $childVariableName.numElements(); k++) { + | ArrayData arr = $childVariableName.getArray(k); + | for (int l = 0; l < arr.numElements(); l++) { + | $arrayName[$counter] = ${CodeGenerator.getValue("arr", elementType, "l")}; + | $counter++; + | } + |} + |$arrayDataName = new $genericArrayClass($arrayName); + """.stripMargin + } + + override def prettyName: String = "flatten" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 43c5dda2e4a48..b49fa76b2a781 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -280,4 +280,99 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f"))) } + + test("Flatten") { + // Primitive-type test cases + val intArrayType = ArrayType(ArrayType(IntegerType)) + + // Main test cases (primitive type) + val aim1 = Literal.create(Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6)), intArrayType) + val aim2 = Literal.create(Seq(Seq(1, 2, 3)), intArrayType) + + checkEvaluation(Flatten(aim1), Seq(1, 2, 3, 4, 5, 6)) + checkEvaluation(Flatten(aim2), Seq(1, 2, 3)) + + // Test cases with an empty array (primitive type) + val aie1 = Literal.create(Seq(Seq.empty, Seq(1, 2), Seq(3, 4)), intArrayType) + val aie2 = Literal.create(Seq(Seq(1, 2), Seq.empty, Seq(3, 4)), intArrayType) + val aie3 = Literal.create(Seq(Seq(1, 2), Seq(3, 4), Seq.empty), intArrayType) + val aie4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), intArrayType) + val aie5 = Literal.create(Seq(Seq.empty), intArrayType) + val aie6 = Literal.create(Seq.empty, intArrayType) + + checkEvaluation(Flatten(aie1), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie2), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie3), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie4), Seq.empty) + checkEvaluation(Flatten(aie5), Seq.empty) + checkEvaluation(Flatten(aie6), Seq.empty) + + // Test cases with null elements (primitive type) + val ain1 = Literal.create(Seq(Seq(null, null, null), Seq(4, null)), intArrayType) + val ain2 = Literal.create(Seq(Seq(null, 2, null), Seq(null, null)), intArrayType) + val ain3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), intArrayType) + + checkEvaluation(Flatten(ain1), Seq(null, null, null, 4, null)) + checkEvaluation(Flatten(ain2), Seq(null, 2, null, null, null)) + checkEvaluation(Flatten(ain3), Seq(null, null, null, null)) + + // Test cases with a null array (primitive type) + val aia1 = Literal.create(Seq(null, Seq(1, 2)), intArrayType) + val aia2 = Literal.create(Seq(Seq(1, 2), null), intArrayType) + val aia3 = Literal.create(Seq(null), intArrayType) + val aia4 = Literal.create(null, intArrayType) + + checkEvaluation(Flatten(aia1), null) + checkEvaluation(Flatten(aia2), null) + checkEvaluation(Flatten(aia3), null) + checkEvaluation(Flatten(aia4), null) + + // Non-primitive-type test cases + val strArrayType = ArrayType(ArrayType(StringType)) + val arrArrayType = ArrayType(ArrayType(ArrayType(StringType))) + + // Main test cases (non-primitive type) + val asm1 = Literal.create(Seq(Seq("a"), Seq("b", "c"), Seq("d", "e", "f")), strArrayType) + val asm2 = Literal.create(Seq(Seq("a", "b")), strArrayType) + val asm3 = Literal.create(Seq(Seq(Seq("a", "b"), Seq("c")), Seq(Seq("d", "e"))), arrArrayType) + + checkEvaluation(Flatten(asm1), Seq("a", "b", "c", "d", "e", "f")) + checkEvaluation(Flatten(asm2), Seq("a", "b")) + checkEvaluation(Flatten(asm3), Seq(Seq("a", "b"), Seq("c"), Seq("d", "e"))) + + // Test cases with an empty array (non-primitive type) + val ase1 = Literal.create(Seq(Seq.empty, Seq("a", "b"), Seq("c", "d")), strArrayType) + val ase2 = Literal.create(Seq(Seq("a", "b"), Seq.empty, Seq("c", "d")), strArrayType) + val ase3 = Literal.create(Seq(Seq("a", "b"), Seq("c", "d"), Seq.empty), strArrayType) + val ase4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), strArrayType) + val ase5 = Literal.create(Seq(Seq.empty), strArrayType) + val ase6 = Literal.create(Seq.empty, strArrayType) + + checkEvaluation(Flatten(ase1), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase2), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase3), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase4), Seq.empty) + checkEvaluation(Flatten(ase5), Seq.empty) + checkEvaluation(Flatten(ase6), Seq.empty) + + // Test cases with null elements (non-primitive type) + val asn1 = Literal.create(Seq(Seq(null, null, "c"), Seq(null, null)), strArrayType) + val asn2 = Literal.create(Seq(Seq(null, null, null), Seq("d", null)), strArrayType) + val asn3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), strArrayType) + + checkEvaluation(Flatten(asn1), Seq(null, null, "c", null, null)) + checkEvaluation(Flatten(asn2), Seq(null, null, null, "d", null)) + checkEvaluation(Flatten(asn3), Seq(null, null, null, null)) + + // Test cases with a null array (non-primitive type) + val asa1 = Literal.create(Seq(null, Seq("a", "b")), strArrayType) + val asa2 = Literal.create(Seq(Seq("a", "b"), null), strArrayType) + val asa3 = Literal.create(Seq(null), strArrayType) + val asa4 = Literal.create(null, strArrayType) + + checkEvaluation(Flatten(asa1), null) + checkEvaluation(Flatten(asa2), null) + checkEvaluation(Flatten(asa3), null) + checkEvaluation(Flatten(asa4), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index bea8c0e445002..d2f057310f89b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3340,6 +3340,14 @@ object functions { */ def reverse(e: Column): Column = withExpr { Reverse(e.expr) } + /** + * Creates a single array from an array of arrays. If a structure of nested arrays is deeper than + * two levels, only one level of nesting is removed. + * @group collection_funcs + * @since 2.4.0 + */ + def flatten(e: Column): Column = withExpr { Flatten(e.expr) } + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 25e5cd60dd236..03605c30036a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -691,6 +691,85 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } + test("flatten function") { + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch codeGen on + val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + + // Test cases with a primitive type + val intDF = Seq( + (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6))), + (Seq(Seq(1, 2))), + (Seq(Seq(1), Seq.empty)), + (Seq(Seq.empty, Seq(1))), + (Seq(Seq.empty, Seq.empty)), + (Seq(Seq(1), null)), + (Seq(null, Seq(1))), + (Seq(null, null)) + ).toDF("i") + + val intDFResult = Seq( + Row(Seq(1, 2, 3, 4, 5, 6)), + Row(Seq(1, 2)), + Row(Seq(1)), + Row(Seq(1)), + Row(Seq.empty), + Row(null), + Row(null), + Row(null)) + + checkAnswer(intDF.select(flatten($"i")), intDFResult) + checkAnswer(intDF.filter(dummyFilter($"i"))select(flatten($"i")), intDFResult) + checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult) + checkAnswer( + oneRowDF.selectExpr("flatten(array(arr, array(null, 5), array(6, null)))"), + Seq(Row(Seq(1, 2, 3, null, 5, 6, null)))) + + // Test cases with non-primitive types + val strDF = Seq( + (Seq(Seq("a", "b"), Seq("c"), Seq("d", "e", "f"))), + (Seq(Seq("a", "b"))), + (Seq(Seq("a", null), Seq(null, "b"), Seq(null, null))), + (Seq(Seq("a"), Seq.empty)), + (Seq(Seq.empty, Seq("a"))), + (Seq(Seq.empty, Seq.empty)), + (Seq(Seq("a"), null)), + (Seq(null, Seq("a"))), + (Seq(null, null)) + ).toDF("s") + + val strDFResult = Seq( + Row(Seq("a", "b", "c", "d", "e", "f")), + Row(Seq("a", "b")), + Row(Seq("a", null, null, "b", null, null)), + Row(Seq("a")), + Row(Seq("a")), + Row(Seq.empty), + Row(null), + Row(null), + Row(null)) + + checkAnswer(strDF.select(flatten($"s")), strDFResult) + checkAnswer(strDF.filter(dummyFilter($"s")).select(flatten($"s")), strDFResult) + checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult) + checkAnswer( + oneRowDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"), + Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3))))) + + // Error test cases + intercept[AnalysisException] { + oneRowDF.select(flatten($"arr")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"i")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"s")) + } + intercept[AnalysisException] { + oneRowDF.selectExpr("flatten(null)") + } + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 64e8408e6fa2d74929601b01a29771738f6d8c65 Mon Sep 17 00:00:00 2001 From: liutang123 Date: Wed, 25 Apr 2018 18:10:51 +0800 Subject: [PATCH 687/774] [SPARK-24012][SQL] Union of map and other compatible column ## What changes were proposed in this pull request? Union of map and other compatible column result in unresolved operator 'Union; exception Reproduction `spark-sql>select map(1,2), 'str' union all select map(1,2,3,null), 1` Output: ``` Error in query: unresolved operator 'Union;; 'Union :- Project [map(1, 2) AS map(1, 2)#106, str AS str#107] : +- OneRowRelation$ +- Project [map(1, cast(2 as int), 3, cast(null as int)) AS map(1, CAST(2 AS INT), 3, CAST(NULL AS INT))#109, 1 AS 1#108] +- OneRowRelation$ ``` So, we should cast part of columns to be compatible when appropriate. ## How was this patch tested? Added a test (query union of map and other columns) to SQLQueryTestSuite's union.sql. Author: liutang123 Closes #21100 from liutang123/SPARK-24012. --- .../sql/catalyst/analysis/TypeCoercion.scala | 8 ++++ .../test/resources/sql-tests/inputs/union.sql | 11 +++++ .../resources/sql-tests/results/union.sql.out | 42 ++++++++++++++----- 3 files changed, 51 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index cfcbd8db559a3..25bad28a2a209 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -112,6 +112,14 @@ object TypeCoercion { StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) })) + case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) => + findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2)) + + case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) => + val keyType = findTightestCommonType(kt1, kt2) + val valueType = findTightestCommonType(vt1, vt2) + Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2)) + case _ => None } diff --git a/sql/core/src/test/resources/sql-tests/inputs/union.sql b/sql/core/src/test/resources/sql-tests/inputs/union.sql index e57d69eaad033..6da1b9b49b226 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/union.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/union.sql @@ -35,6 +35,17 @@ FROM (SELECT col AS col SELECT col FROM p3) T1) T2; +-- SPARK-24012 Union of map and other compatible columns. +SELECT map(1, 2), 'str' +UNION ALL +SELECT map(1, 2, 3, NULL), 1; + +-- SPARK-24012 Union of array and other compatible columns. +SELECT array(1, 2), 'str' +UNION ALL +SELECT array(1, 2, 3, NULL), 1; + + -- Clean-up DROP VIEW IF EXISTS t1; DROP VIEW IF EXISTS t2; diff --git a/sql/core/src/test/resources/sql-tests/results/union.sql.out b/sql/core/src/test/resources/sql-tests/results/union.sql.out index d123b7fdbe0cf..b023df825d814 100644 --- a/sql/core/src/test/resources/sql-tests/results/union.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/union.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 14 +-- Number of queries: 16 -- !query 0 @@ -105,23 +105,29 @@ struct -- !query 9 -DROP VIEW IF EXISTS t1 +SELECT map(1, 2), 'str' +UNION ALL +SELECT map(1, 2, 3, NULL), 1 -- !query 9 schema -struct<> +struct,str:string> -- !query 9 output - +{1:2,3:null} 1 +{1:2} str -- !query 10 -DROP VIEW IF EXISTS t2 +SELECT array(1, 2), 'str' +UNION ALL +SELECT array(1, 2, 3, NULL), 1 -- !query 10 schema -struct<> +struct,str:string> -- !query 10 output - +[1,2,3,null] 1 +[1,2] str -- !query 11 -DROP VIEW IF EXISTS p1 +DROP VIEW IF EXISTS t1 -- !query 11 schema struct<> -- !query 11 output @@ -129,7 +135,7 @@ struct<> -- !query 12 -DROP VIEW IF EXISTS p2 +DROP VIEW IF EXISTS t2 -- !query 12 schema struct<> -- !query 12 output @@ -137,8 +143,24 @@ struct<> -- !query 13 -DROP VIEW IF EXISTS p3 +DROP VIEW IF EXISTS p1 -- !query 13 schema struct<> -- !query 13 output + + +-- !query 14 +DROP VIEW IF EXISTS p2 +-- !query 14 schema +struct<> +-- !query 14 output + + + +-- !query 15 +DROP VIEW IF EXISTS p3 +-- !query 15 schema +struct<> +-- !query 15 output + From 20ca208bcda6f22fe7d9fb54144de435b4237536 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 25 Apr 2018 19:06:18 +0800 Subject: [PATCH 688/774] [SPARK-23880][SQL] Do not trigger any jobs for caching data ## What changes were proposed in this pull request? This pr fixed code so that `cache` could prevent any jobs from being triggered. For example, in the current master, an operation below triggers a actual job; ``` val df = spark.range(10000000000L) .filter('id > 1000) .orderBy('id.desc) .cache() ``` This triggers a job while the cache should be lazy. The problem is that, when creating `InMemoryRelation`, we build the RDD, which calls `SparkPlan.execute` and may trigger jobs, like sampling job for range partitioner, or broadcast job. This pr removed the code to build a cached `RDD` in the constructor of `InMemoryRelation` and added `CachedRDDBuilder` to lazily build the `RDD` in `InMemoryRelation`. Then, the first call of `CachedRDDBuilder.cachedColumnBuffers` triggers a job to materialize the cache in `InMemoryTableScanExec` . ## How was this patch tested? Added tests in `CachedTableSuite`. Author: Takeshi Yamamuro Closes #21018 from maropu/SPARK-23880. --- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../spark/sql/execution/CacheManager.scala | 14 +- .../execution/columnar/InMemoryRelation.scala | 155 ++++++++++-------- .../columnar/InMemoryTableScanExec.scala | 10 +- .../apache/spark/sql/CachedTableSuite.scala | 36 +++- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../columnar/InMemoryColumnarQuerySuite.scala | 6 +- .../spark/sql/hive/CachedTableSuite.scala | 2 +- 8 files changed, 133 insertions(+), 94 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 917168162b236..cd4def71e6f3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2933,7 +2933,7 @@ class Dataset[T] private[sql]( */ def storageLevel: StorageLevel = { sparkSession.sharedState.cacheManager.lookupCachedData(this).map { cachedData => - cachedData.cachedRepresentation.storageLevel + cachedData.cachedRepresentation.cacheBuilder.storageLevel }.getOrElse(StorageLevel.NONE) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index a8794be7280c7..93bf91e56f1bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -71,7 +71,7 @@ class CacheManager extends Logging { /** Clears all cached tables. */ def clearCache(): Unit = writeLock { - cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) + cachedData.asScala.foreach(_.cachedRepresentation.cacheBuilder.clearCache()) cachedData.clear() } @@ -119,7 +119,7 @@ class CacheManager extends Logging { while (it.hasNext) { val cd = it.next() if (cd.plan.find(_.sameResult(plan)).isDefined) { - cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + cd.cachedRepresentation.cacheBuilder.clearCache(blocking) it.remove() } } @@ -138,16 +138,14 @@ class CacheManager extends Logging { while (it.hasNext) { val cd = it.next() if (condition(cd.plan)) { - cd.cachedRepresentation.cachedColumnBuffers.unpersist() + cd.cachedRepresentation.cacheBuilder.clearCache() // Remove the cache entry before we create a new one, so that we can have a different // physical plan. it.remove() + val plan = spark.sessionState.executePlan(cd.plan).executedPlan val newCache = InMemoryRelation( - useCompression = cd.cachedRepresentation.useCompression, - batchSize = cd.cachedRepresentation.batchSize, - storageLevel = cd.cachedRepresentation.storageLevel, - child = spark.sessionState.executePlan(cd.plan).executedPlan, - tableName = cd.cachedRepresentation.tableName, + cacheBuilder = cd.cachedRepresentation + .cacheBuilder.copy(cachedPlan = plan)(_cachedColumnBuffers = null), logicalPlan = cd.plan) needToRecache += cd.copy(cachedRepresentation = newCache) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index a7ba9b86a176f..da35a4734e65a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -32,19 +32,6 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator -object InMemoryRelation { - def apply( - useCompression: Boolean, - batchSize: Int, - storageLevel: StorageLevel, - child: SparkPlan, - tableName: Option[String], - logicalPlan: LogicalPlan): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)( - statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) -} - - /** * CachedBatch is a cached batch of rows. * @@ -55,58 +42,41 @@ object InMemoryRelation { private[columnar] case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) -case class InMemoryRelation( - output: Seq[Attribute], +case class CachedRDDBuilder( useCompression: Boolean, batchSize: Int, storageLevel: StorageLevel, - @transient child: SparkPlan, + @transient cachedPlan: SparkPlan, tableName: Option[String])( - @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, - val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics, - override val outputOrdering: Seq[SortOrder]) - extends logical.LeafNode with MultiInstanceRelation { - - override protected def innerChildren: Seq[SparkPlan] = Seq(child) - - override def doCanonicalize(): logical.LogicalPlan = - copy(output = output.map(QueryPlan.normalizeExprId(_, child.output)), - storageLevel = StorageLevel.NONE, - child = child.canonicalized, - tableName = None)( - _cachedColumnBuffers, - sizeInBytesStats, - statsOfPlanToCache, - outputOrdering) + @transient private var _cachedColumnBuffers: RDD[CachedBatch] = null) { - override def producedAttributes: AttributeSet = outputSet - - @transient val partitionStatistics = new PartitionStatistics(output) + val sizeInBytesStats: LongAccumulator = cachedPlan.sqlContext.sparkContext.longAccumulator - override def computeStats(): Statistics = { - if (sizeInBytesStats.value == 0L) { - // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. - // Note that we should drop the hint info here. We may cache a plan whose root node is a hint - // node. When we lookup the cache with a semantically same plan without hint info, the plan - // returned by cache lookup should not have hint info. If we lookup the cache with a - // semantically same plan with a different hint info, `CacheManager.useCachedData` will take - // care of it and retain the hint info in the lookup input plan. - statsOfPlanToCache.copy(hints = HintInfo()) - } else { - Statistics(sizeInBytes = sizeInBytesStats.value.longValue) + def cachedColumnBuffers: RDD[CachedBatch] = { + if (_cachedColumnBuffers == null) { + synchronized { + if (_cachedColumnBuffers == null) { + _cachedColumnBuffers = buildBuffers() + } + } } + _cachedColumnBuffers } - // If the cached column buffers were not passed in, we calculate them in the constructor. - // As in Spark, the actual work of caching is lazy. - if (_cachedColumnBuffers == null) { - buildBuffers() + def clearCache(blocking: Boolean = true): Unit = { + if (_cachedColumnBuffers != null) { + synchronized { + if (_cachedColumnBuffers != null) { + _cachedColumnBuffers.unpersist(blocking) + _cachedColumnBuffers = null + } + } + } } - private def buildBuffers(): Unit = { - val output = child.output - val cached = child.execute().mapPartitionsInternal { rowIterator => + private def buildBuffers(): RDD[CachedBatch] = { + val output = cachedPlan.output + val cached = cachedPlan.execute().mapPartitionsInternal { rowIterator => new Iterator[CachedBatch] { def next(): CachedBatch = { val columnBuilders = output.map { attribute => @@ -154,32 +124,77 @@ case class InMemoryRelation( cached.setName( tableName.map(n => s"In-memory table $n") - .getOrElse(StringUtils.abbreviate(child.toString, 1024))) - _cachedColumnBuffers = cached + .getOrElse(StringUtils.abbreviate(cachedPlan.toString, 1024))) + cached + } +} + +object InMemoryRelation { + + def apply( + useCompression: Boolean, + batchSize: Int, + storageLevel: StorageLevel, + child: SparkPlan, + tableName: Option[String], + logicalPlan: LogicalPlan): InMemoryRelation = { + val cacheBuilder = CachedRDDBuilder(useCompression, batchSize, storageLevel, child, tableName)() + new InMemoryRelation(child.output, cacheBuilder)( + statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) + } + + def apply(cacheBuilder: CachedRDDBuilder, logicalPlan: LogicalPlan): InMemoryRelation = { + new InMemoryRelation(cacheBuilder.cachedPlan.output, cacheBuilder)( + statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) + } +} + +case class InMemoryRelation( + output: Seq[Attribute], + @transient cacheBuilder: CachedRDDBuilder)( + statsOfPlanToCache: Statistics, + override val outputOrdering: Seq[SortOrder]) + extends logical.LeafNode with MultiInstanceRelation { + + override protected def innerChildren: Seq[SparkPlan] = Seq(cachedPlan) + + override def doCanonicalize(): logical.LogicalPlan = + copy(output = output.map(QueryPlan.normalizeExprId(_, cachedPlan.output)), + cacheBuilder)( + statsOfPlanToCache, + outputOrdering) + + override def producedAttributes: AttributeSet = outputSet + + @transient val partitionStatistics = new PartitionStatistics(output) + + def cachedPlan: SparkPlan = cacheBuilder.cachedPlan + + override def computeStats(): Statistics = { + if (cacheBuilder.sizeInBytesStats.value == 0L) { + // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. + // Note that we should drop the hint info here. We may cache a plan whose root node is a hint + // node. When we lookup the cache with a semantically same plan without hint info, the plan + // returned by cache lookup should not have hint info. If we lookup the cache with a + // semantically same plan with a different hint info, `CacheManager.useCachedData` will take + // care of it and retain the hint info in the lookup input plan. + statsOfPlanToCache.copy(hints = HintInfo()) + } else { + Statistics(sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue) + } } def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { - InMemoryRelation( - newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache, outputOrdering) + InMemoryRelation(newOutput, cacheBuilder)(statsOfPlanToCache, outputOrdering) } override def newInstance(): this.type = { new InMemoryRelation( output.map(_.newInstance()), - useCompression, - batchSize, - storageLevel, - child, - tableName)( - _cachedColumnBuffers, - sizeInBytesStats, + cacheBuilder)( statsOfPlanToCache, outputOrdering).asInstanceOf[this.type] } - def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers - - override protected def otherCopyArgs: Seq[AnyRef] = - Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache) + override protected def otherCopyArgs: Seq[AnyRef] = Seq(statsOfPlanToCache) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index e73e1378d52e3..ea315fb71617c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -154,7 +154,7 @@ case class InMemoryTableScanExec( private def updateAttribute(expr: Expression): Expression = { // attributes can be pruned so using relation's output. // E.g., relation.output is [id, item] but this scan's output can be [item] only. - val attrMap = AttributeMap(relation.child.output.zip(relation.output)) + val attrMap = AttributeMap(relation.cachedPlan.output.zip(relation.output)) expr.transform { case attr: Attribute => attrMap.getOrElse(attr, attr) } @@ -163,16 +163,16 @@ case class InMemoryTableScanExec( // The cached version does not change the outputPartitioning of the original SparkPlan. // But the cached version could alias output, so we need to replace output. override def outputPartitioning: Partitioning = { - relation.child.outputPartitioning match { + relation.cachedPlan.outputPartitioning match { case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning] - case _ => relation.child.outputPartitioning + case _ => relation.cachedPlan.outputPartitioning } } // The cached version does not change the outputOrdering of the original SparkPlan. // But the cached version could alias output, so we need to replace output. override def outputOrdering: Seq[SortOrder] = - relation.child.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) + relation.cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) // Keeps relation's partition statistics because we don't serialize relation. private val stats = relation.partitionStatistics @@ -252,7 +252,7 @@ case class InMemoryTableScanExec( // within the map Partitions closure. val schema = stats.schema val schemaIndex = schema.zipWithIndex - val buffers = relation.cachedColumnBuffers + val buffers = relation.cacheBuilder.cachedColumnBuffers buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) => val partitionFilter = newPredicate( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 669e5f2bf4e65..81b7e18773f81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -22,6 +22,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.spark.CleanerListener +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} @@ -52,7 +53,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val plan = spark.table(tableName).queryExecution.sparkPlan plan.collect { case InMemoryTableScanExec(_, _, relation) => - relation.cachedColumnBuffers.id + relation.cacheBuilder.cachedColumnBuffers.id case _ => fail(s"Table $tableName is not cached\n" + plan) }.head @@ -78,7 +79,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext private def getNumInMemoryTablesRecursively(plan: SparkPlan): Int = { plan.collect { case InMemoryTableScanExec(_, _, relation) => - getNumInMemoryTablesRecursively(relation.child) + 1 + getNumInMemoryTablesRecursively(relation.cachedPlan) + 1 }.sum } @@ -200,7 +201,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext spark.catalog.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { spark.table("testData").queryExecution.withCachedData.collect { - case r @ InMemoryRelation(_, _, _, _, _: InMemoryTableScanExec, _) => r + case r: InMemoryRelation if r.cachedPlan.isInstanceOf[InMemoryTableScanExec] => r }.size } @@ -367,12 +368,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val toBeCleanedAccIds = new HashSet[Long] val accId1 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.sizeInBytesStats.id + case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id }.head toBeCleanedAccIds += accId1 val accId2 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.sizeInBytesStats.id + case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id }.head toBeCleanedAccIds += accId2 @@ -794,4 +795,29 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } } } + + private def checkIfNoJobTriggered[T](f: => T): T = { + var numJobTrigered = 0 + val jobListener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + numJobTrigered += 1 + } + } + sparkContext.addSparkListener(jobListener) + try { + val result = f + sparkContext.listenerBus.waitUntilEmpty(10000L) + assert(numJobTrigered === 0) + result + } finally { + sparkContext.removeSparkListener(jobListener) + } + } + + test("SPARK-23880 table cache should be lazy and don't trigger any jobs") { + val cachedData = checkIfNoJobTriggered { + spark.range(1002).filter('id > 1000).orderBy('id.desc).cache() + } + assert(cachedData.collect === Seq(1001)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 40915a102bab0..f0dfe6b76f7ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -194,7 +194,7 @@ class PlannerSuite extends SharedSQLContext { test("CollectLimit can appear in the middle of a plan when caching is used") { val query = testData.select('key, 'value).limit(2).cache() val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation] - assert(planned.child.isInstanceOf[CollectLimitExec]) + assert(planned.cachedPlan.isInstanceOf[CollectLimitExec]) } test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 9b7b316211d30..863703b15f4f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -45,8 +45,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None, data.logicalPlan) - assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel) - inMemoryRelation.cachedColumnBuffers.collect().head match { + assert(inMemoryRelation.cacheBuilder.cachedColumnBuffers.getStorageLevel == storageLevel) + inMemoryRelation.cacheBuilder.cachedColumnBuffers.collect().head match { case _: CachedBatch => case other => fail(s"Unexpected cached batch type: ${other.getClass.getName}") } @@ -337,7 +337,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(cached, expectedAnswer) // Check that the right size was calculated. - assert(cached.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize) + assert(cached.cacheBuilder.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize) } test("access primitive-type columns in CachedBatch without whole stage codegen") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 48ab4eb9a6178..569f00c053e5f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -38,7 +38,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto val plan = table(tableName).queryExecution.sparkPlan plan.collect { case InMemoryTableScanExec(_, _, relation) => - relation.cachedColumnBuffers.id + relation.cacheBuilder.cachedColumnBuffers.id case _ => fail(s"Table $tableName is not cached\n" + plan) }.head From 396938ef02c70468e1695872f96b1e9aff28b7ea Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 25 Apr 2018 12:21:55 -0700 Subject: [PATCH 689/774] [SPARK-24050][SS] Calculate input / processing rates correctly for DataSourceV2 streaming sources ## What changes were proposed in this pull request? In some streaming queries, the input and processing rates are not calculated at all (shows up as zero) because MicroBatchExecution fails to associated metrics from the executed plan of a trigger with the sources in the logical plan of the trigger. The way this executed-plan-leaf-to-logical-source attribution works is as follows. With V1 sources, there was no way to identify which execution plan leaves were generated by a streaming source. So did a best-effort attempt to match logical and execution plan leaves when the number of leaves were same. In cases where the number of leaves is different, we just give up and report zero rates. An example where this may happen is as follows. ``` val cachedStaticDF = someStaticDF.union(anotherStaticDF).cache() val streamingInputDF = ... val query = streamingInputDF.join(cachedStaticDF).writeStream.... ``` In this case, the `cachedStaticDF` has multiple logical leaves, but in the trigger's execution plan it only has leaf because a cached subplan is represented as a single InMemoryTableScanExec leaf. This leads to a mismatch in the number of leaves causing the input rates to be computed as zero. With DataSourceV2, all inputs are represented in the executed plan using `DataSourceV2ScanExec`, each of which has a reference to the associated logical `DataSource` and `DataSourceReader`. So its easy to associate the metrics to the original streaming sources. In this PR, the solution is as follows. If all the streaming sources in a streaming query as v2 sources, then use a new code path where the execution-metrics-to-source mapping is done directly. Otherwise we fall back to existing mapping logic. ## How was this patch tested? - New unit tests using V2 memory source - Existing unit tests using V1 source Author: Tathagata Das Closes #21126 from tdas/SPARK-24050. --- .../kafka010/KafkaMicroBatchSourceSuite.scala | 9 +- .../streaming/ProgressReporter.scala | 146 +++++++++++++----- .../sql/streaming/StreamingQuerySuite.scala | 134 +++++++++++++++- 3 files changed, 245 insertions(+), 44 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index e017fd9b84d21..d2d04b68de6ab 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -563,7 +563,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { ) } - test("ensure stream-stream self-join generates only one offset in offset log") { + test("ensure stream-stream self-join generates only one offset in log and correct metrics") { val topic = newTopic() testUtils.createTopic(topic, partitions = 2) require(testUtils.getLatestOffsets(Set(topic)).size === 2) @@ -587,7 +587,12 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { AddKafkaData(Set(topic), 1, 2), CheckAnswer((1, 1, 1), (2, 2, 2)), AddKafkaData(Set(topic), 6, 3), - CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 6, 6)) + CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 6, 6)), + AssertOnQuery { q => + assert(q.availableOffsets.iterator.size == 1) + assert(q.recentProgress.map(_.numInputRows).sum == 4) + true + } ) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index d1e5be9c12762..16ad3ef9a3d4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent import org.apache.spark.util.Clock @@ -141,7 +143,7 @@ trait ProgressReporter extends Logging { } logDebug(s"Execution stats: $executionStats") - val sourceProgress = sources.map { source => + val sourceProgress = sources.distinct.map { source => val numRecords = executionStats.inputRows.getOrElse(source, 0L) new SourceProgress( description = source.toString, @@ -207,62 +209,126 @@ trait ProgressReporter extends Logging { return ExecutionStats(Map.empty, stateOperators, watermarkTimestamp) } - // We want to associate execution plan leaves to sources that generate them, so that we match - // the their metrics (e.g. numOutputRows) to the sources. To do this we do the following. - // Consider the translation from the streaming logical plan to the final executed plan. - // - // streaming logical plan (with sources) <==> trigger's logical plan <==> executed plan - // - // 1. We keep track of streaming sources associated with each leaf in the trigger's logical plan - // - Each logical plan leaf will be associated with a single streaming source. - // - There can be multiple logical plan leaves associated with a streaming source. - // - There can be leaves not associated with any streaming source, because they were - // generated from a batch source (e.g. stream-batch joins) - // - // 2. Assuming that the executed plan has same number of leaves in the same order as that of - // the trigger logical plan, we associate executed plan leaves with corresponding - // streaming sources. - // - // 3. For each source, we sum the metrics of the associated execution plan leaves. - // - val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) => - logicalPlan.collectLeaves().map { leaf => leaf -> source } + val numInputRows = extractSourceToNumInputRows() + + val eventTimeStats = lastExecution.executedPlan.collect { + case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 => + val stats = e.eventTimeStats.value + Map( + "max" -> stats.max, + "min" -> stats.min, + "avg" -> stats.avg.toLong).mapValues(formatTimestamp) + }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp + + ExecutionStats(numInputRows, stateOperators, eventTimeStats) + } + + /** Extract number of input sources for each streaming source in plan */ + private def extractSourceToNumInputRows(): Map[BaseStreamingSource, Long] = { + + import java.util.IdentityHashMap + import scala.collection.JavaConverters._ + + def sumRows(tuples: Seq[(BaseStreamingSource, Long)]): Map[BaseStreamingSource, Long] = { + tuples.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source } - val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming - val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() - val numInputRows: Map[BaseStreamingSource, Long] = + + val onlyDataSourceV2Sources = { + // Check whether the streaming query's logical plan has only V2 data sources + val allStreamingLeaves = + logicalPlan.collect { case s: StreamingExecutionRelation => s } + allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReader] } + } + + if (onlyDataSourceV2Sources) { + // DataSourceV2ScanExec is the execution plan leaf that is responsible for reading data + // from a V2 source and has a direct reference to the V2 source that generated it. Each + // DataSourceV2ScanExec records the number of rows it has read using SQLMetrics. However, + // just collecting all DataSourceV2ScanExec nodes and getting the metric is not correct as + // a DataSourceV2ScanExec instance may be referred to in the execution plan from two (or + // even multiple times) points and considering it twice will leads to double counting. We + // can't dedup them using their hashcode either because two different instances of + // DataSourceV2ScanExec can have the same hashcode but account for separate sets of + // records read, and deduping them to consider only one of them would be undercounting the + // records read. Therefore the right way to do this is to consider the unique instances of + // DataSourceV2ScanExec (using their identity hash codes) and get metrics from them. + // Hence we calculate in the following way. + // + // 1. Collect all the unique DataSourceV2ScanExec instances using IdentityHashMap. + // + // 2. Extract the source and the number of rows read from the DataSourceV2ScanExec instanes. + // + // 3. Multiple DataSourceV2ScanExec instance may refer to the same source (can happen with + // self-unions or self-joins). Add up the number of rows for each unique source. + val uniqueStreamingExecLeavesMap = + new IdentityHashMap[DataSourceV2ScanExec, DataSourceV2ScanExec]() + + lastExecution.executedPlan.collectLeaves().foreach { + case s: DataSourceV2ScanExec if s.reader.isInstanceOf[BaseStreamingSource] => + uniqueStreamingExecLeavesMap.put(s, s) + case _ => + } + + val sourceToInputRowsTuples = + uniqueStreamingExecLeavesMap.values.asScala.map { execLeaf => + val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) + val source = execLeaf.reader.asInstanceOf[BaseStreamingSource] + source -> numRows + }.toSeq + logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t")) + sumRows(sourceToInputRowsTuples) + } else { + + // Since V1 source do not generate execution plan leaves that directly link with source that + // generated it, we can only do a best-effort association between execution plan leaves to the + // sources. This is known to fail in a few cases, see SPARK-24050. + // + // We want to associate execution plan leaves to sources that generate them, so that we match + // the their metrics (e.g. numOutputRows) to the sources. To do this we do the following. + // Consider the translation from the streaming logical plan to the final executed plan. + // + // streaming logical plan (with sources) <==> trigger's logical plan <==> executed plan + // + // 1. We keep track of streaming sources associated with each leaf in trigger's logical plan + // - Each logical plan leaf will be associated with a single streaming source. + // - There can be multiple logical plan leaves associated with a streaming source. + // - There can be leaves not associated with any streaming source, because they were + // generated from a batch source (e.g. stream-batch joins) + // + // 2. Assuming that the executed plan has same number of leaves in the same order as that of + // the trigger logical plan, we associate executed plan leaves with corresponding + // streaming sources. + // + // 3. For each source, we sum the metrics of the associated execution plan leaves. + // + val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) => + logicalPlan.collectLeaves().map { leaf => leaf -> source } + } + val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming + val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() if (allLogicalPlanLeaves.size == allExecPlanLeaves.size) { val execLeafToSource = allLogicalPlanLeaves.zip(allExecPlanLeaves).flatMap { case (lp, ep) => logicalPlanLeafToSource.get(lp).map { source => ep -> source } } - val sourceToNumInputRows = execLeafToSource.map { case (execLeaf, source) => + val sourceToInputRowsTuples = execLeafToSource.map { case (execLeaf, source) => val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) source -> numRows } - sourceToNumInputRows.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source + sumRows(sourceToInputRowsTuples) } else { if (!metricWarningLogged) { def toString[T](seq: Seq[T]): String = s"(size = ${seq.size}), ${seq.mkString(", ")}" + logWarning( "Could not report metrics as number leaves in trigger logical plan did not match that" + - s" of the execution plan:\n" + - s"logical plan leaves: ${toString(allLogicalPlanLeaves)}\n" + - s"execution plan leaves: ${toString(allExecPlanLeaves)}\n") + s" of the execution plan:\n" + + s"logical plan leaves: ${toString(allLogicalPlanLeaves)}\n" + + s"execution plan leaves: ${toString(allExecPlanLeaves)}\n") metricWarningLogged = true } Map.empty } - - val eventTimeStats = lastExecution.executedPlan.collect { - case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 => - val stats = e.eventTimeStats.value - Map( - "max" -> stats.max, - "min" -> stats.min, - "avg" -> stats.avg.toLong).mapValues(formatTimestamp) - }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp - - ExecutionStats(numInputRows, stateOperators, eventTimeStats) + } } /** Records the duration of running `body` for the next query progress update. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 20942ed93897c..390d67d1feb27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -466,7 +466,17 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } - test("input row calculation with mixed batch and streaming sources") { + test("input row calculation with same V1 source used twice in self-join") { + val streamingTriggerDF = spark.createDataset(1 to 10).toDF + val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") + + val progress = getFirstProgress(streamingInputDF.join(streamingInputDF, "value")) + assert(progress.numInputRows === 20) // data is read multiple times in self-joins + assert(progress.sources.size === 1) + assert(progress.sources(0).numInputRows === 20) + } + + test("input row calculation with mixed batch and streaming V1 sources") { val streamingTriggerDF = spark.createDataset(1 to 10).toDF val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") val staticInputDF = spark.createDataFrame(Seq(1 -> "1", 2 -> "2")).toDF("value", "anotherValue") @@ -479,7 +489,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.sources(0).numInputRows === 10) } - test("input row calculation with trigger input DF having multiple leaves") { + test("input row calculation with trigger input DF having multiple leaves in V1 source") { val streamingTriggerDF = spark.createDataset(1 to 5).toDF.union(spark.createDataset(6 to 10).toDF) require(streamingTriggerDF.logicalPlan.collectLeaves().size > 1) @@ -492,6 +502,121 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.sources(0).numInputRows === 10) } + test("input row calculation with same V2 source used twice in self-union") { + val streamInput = MemoryStream[Int] + + testStream(streamInput.toDF().union(streamInput.toDF()), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + CheckAnswer(1, 1, 2, 2, 3, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 6) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 6) + true + } + ) + } + + test("input row calculation with same V2 source used twice in self-join") { + val streamInput = MemoryStream[Int] + val df = streamInput.toDF() + testStream(df.join(df, "value"), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + CheckAnswer(1, 2, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 6) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 6) + true + } + ) + } + + test("input row calculation with trigger having data for only one of two V2 sources") { + val streamInput1 = MemoryStream[Int] + val streamInput2 = MemoryStream[Int] + + testStream(streamInput1.toDF().union(streamInput2.toDF()), useV2Sink = true)( + AddData(streamInput1, 1, 2, 3), + CheckLastBatch(1, 2, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sources.length == 2) + assert(lastProgress.get.sources(0).numInputRows == 3) + assert(lastProgress.get.sources(1).numInputRows == 0) + true + }, + AddData(streamInput2, 4, 5), + CheckLastBatch(4, 5), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 2) + assert(lastProgress.get.sources.length == 2) + assert(lastProgress.get.sources(0).numInputRows == 0) + assert(lastProgress.get.sources(1).numInputRows == 2) + true + } + ) + } + + test("input row calculation with mixed batch and streaming V2 sources") { + + val streamInput = MemoryStream[Int] + val staticInputDF = spark.createDataFrame(Seq(1 -> "1", 2 -> "2")).toDF("value", "anotherValue") + + testStream(streamInput.toDF().join(staticInputDF, "value"), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + AssertOnQuery { q => + q.processAllAvailable() + + // The number of leaves in the trigger's logical plan should be same as the executed plan. + require( + q.lastExecution.logical.collectLeaves().length == + q.lastExecution.executedPlan.collectLeaves().length) + + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 3) + true + } + ) + + val streamInput2 = MemoryStream[Int] + val staticInputDF2 = staticInputDF.union(staticInputDF).cache() + + testStream(streamInput2.toDF().join(staticInputDF2, "value"), useV2Sink = true)( + AddData(streamInput2, 1, 2, 3), + AssertOnQuery { q => + q.processAllAvailable() + // The number of leaves in the trigger's logical plan should be different from + // the executed plan. The static input will have two leaves in the logical plan + // (due to the union), but will be converted to a single leaf in the executed plan + // (due to the caching, the cached subplan is replaced by a single InMemoryTableScanExec). + require( + q.lastExecution.logical.collectLeaves().length != + q.lastExecution.executedPlan.collectLeaves().length) + + // Despite the mismatch in total number of leaves in the logical and executed plans, + // we should be able to attribute streaming input metrics to the streaming sources. + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 3) + true + } + ) + } + testQuietly("StreamExecution metadata garbage collection") { val inputData = MemoryStream[Int] val mapped = inputData.toDS().map(6 / _) @@ -733,6 +858,11 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + /** Returns the last query progress from query.recentProgress where numInputRows is positive */ + def getLastProgressWithData(q: StreamingQuery): Option[StreamingQueryProgress] = { + q.recentProgress.filter(_.numInputRows > 0).lastOption + } + /** * A [[StreamAction]] to test the behavior of `StreamingQuery.awaitTermination()`. * From ac4ca7c4dd3ff666ec70aeb26ac84cffa557ee12 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 25 Apr 2018 13:42:44 -0700 Subject: [PATCH 690/774] [SPARK-24012][SQL][TEST][FOLLOWUP] add unit test ## What changes were proposed in this pull request? a followup of https://github.com/apache/spark/pull/21100 ## How was this patch tested? N/A Author: Wenchen Fan Closes #21154 from cloud-fan/test. --- .../catalyst/analysis/TypeCoercionSuite.scala | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index fd6a3121663ed..1cc431aaf0a60 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -429,6 +429,24 @@ class TypeCoercionSuite extends AnalysisTest { Some(StructType(Seq(StructField("a", IntegerType), StructField("B", IntegerType)))), isSymmetric = false) } + + widenTest( + ArrayType(IntegerType, containsNull = true), + ArrayType(IntegerType, containsNull = false), + Some(ArrayType(IntegerType, containsNull = true))) + + widenTest( + MapType(IntegerType, StringType, valueContainsNull = true), + MapType(IntegerType, StringType, valueContainsNull = false), + Some(MapType(IntegerType, StringType, valueContainsNull = true))) + + widenTest( + new StructType() + .add("arr", ArrayType(IntegerType, containsNull = true), nullable = false), + new StructType() + .add("arr", ArrayType(IntegerType, containsNull = false), nullable = true), + Some(new StructType() + .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true))) } test("wider common type for decimal and array") { From 95a651339ec39d5753e849e578ad715be0d7c83e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 26 Apr 2018 09:12:38 +0800 Subject: [PATCH 691/774] [SPARK-24069][R] Add array_min / array_max functions ## What changes were proposed in this pull request? This PR proposes to add array_max and array_min in R side too. array_max: ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_array(df$gear, df$am, df$carb)) head(select(mutated, array_max(mutated$v1))) ``` ``` array_max(v1) 1 4 2 4 3 4 4 3 5 3 6 3 ``` array_min: ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) head(select(mutated, array_min(mutated$v1))) ``` ``` array_min(v1) 1 6 2 6 3 4 4 6 5 8 6 6 ``` ## How was this patch tested? Unit tests were added in `R/pkg/tests/fulltests/test_sparkSQL.R` and manually tested. Documentation was manually built and verified. Author: hyukjinkwon Closes #21142 from HyukjinKwon/sparkr_array_min_array_max. --- R/pkg/NAMESPACE | 2 ++ R/pkg/R/functions.R | 27 +++++++++++++++++++++++++++ R/pkg/R/generics.R | 8 ++++++++ R/pkg/tests/fulltests/test_sparkSQL.R | 9 ++++++++- 4 files changed, 45 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 55dec177ea853..f36d462a83cb0 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -201,6 +201,8 @@ exportMethods("%<=>%", "approxCountDistinct", "approxQuantile", "array_contains", + "array_max", + "array_min", "array_position", "asc", "ascii", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 7b3aa05074563..ec4bd4e73c7e5 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -206,6 +206,7 @@ NULL #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) +#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) #' head(select(tmp, array_position(tmp$v1, 21))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) @@ -2992,6 +2993,32 @@ setMethod("array_contains", column(jc) }) +#' @details +#' \code{array_max}: Returns the maximum value of the array. +#' +#' @rdname column_collection_functions +#' @aliases array_max array_max,Column-method +#' @note array_max since 2.4.0 +setMethod("array_max", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_max", x@jc) + column(jc) + }) + +#' @details +#' \code{array_min}: Returns the minimum value of the array. +#' +#' @rdname column_collection_functions +#' @aliases array_min array_min,Column-method +#' @note array_min since 2.4.0 +setMethod("array_min", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_min", x@jc) + column(jc) + }) + #' @details #' \code{array_position}: Locates the position of the first occurrence of the given value #' in the given array. Returns NA if either of the arguments are NA. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index f30ac9e4295e4..562d3399ee9c8 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -757,6 +757,14 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_max", function(x) { standardGeneric("array_max") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_min", function(x) { standardGeneric("array_min") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index a384997830276..8cc2db7a140f9 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1479,11 +1479,18 @@ test_that("column functions", { df5 <- createDataFrame(list(list(a = "010101"))) expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") - # Test array_contains(), array_position(), element_at() and sort_array() + # Test array_contains(), array_max(), array_min(), array_position(), element_at() + # and sort_array() df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] expect_equal(result, c(TRUE, FALSE)) + result <- collect(select(df, array_max(df[[1]])))[[1]] + expect_equal(result, c(3, 6)) + + result <- collect(select(df, array_min(df[[1]])))[[1]] + expect_equal(result, c(1, 4)) + result <- collect(select(df, array_position(df[[1]], 1L)))[[1]] expect_equal(result, c(1, 0)) From 3f1e999d3d215bb3b867bcd83ec5c799448ec730 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 26 Apr 2018 09:14:24 +0800 Subject: [PATCH 692/774] [SPARK-23849][SQL] Tests for samplingRatio of json datasource ## What changes were proposed in this pull request? Added the `samplingRatio` option to the `json()` method of PySpark DataFrame Reader. Improving existing tests for Scala API according to review of the PR: https://github.com/apache/spark/pull/20959 ## How was this patch tested? Added new test for PySpark, updated 2 existing tests according to reviews of https://github.com/apache/spark/pull/20959 and added new negative test Author: Maxim Gekk Closes #21056 from MaxGekk/json-sampling. --- python/pyspark/sql/readwriter.py | 7 ++- python/pyspark/sql/tests.py | 8 +++ .../apache/spark/sql/DataFrameReader.scala | 2 + .../datasources/json/JsonSuite.scala | 63 ++++++++++--------- .../datasources/json/TestJsonData.scala | 12 ++++ 5 files changed, 61 insertions(+), 31 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 6bd79bc2f43e5..df176c579fc8b 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -176,7 +176,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None, lineSep=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -239,6 +239,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, including tab and line feed characters) or not. :param lineSep: defines the line separator that should be used for parsing. If None is set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. + :param samplingRatio: defines fraction of input JSON objects used for schema inferring. + If None is set, it uses the default value, ``1.0``. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes @@ -256,7 +258,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, + samplingRatio=samplingRatio) if isinstance(path, basestring): path = [path] if type(path) == list: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4e99c8e3c6b10..98fa1b54b0a17 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3018,6 +3018,14 @@ def test_sort_with_nulls_order(self): df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(), [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)]) + def test_json_sampling_ratio(self): + rdd = self.spark.sparkContext.range(0, 100, 1, 1) \ + .map(lambda x: '{"a":0.1}' if x == 1 else '{"a":%s}' % str(x)) + schema = self.spark.read.option('inferSchema', True) \ + .option('samplingRatio', 0.5) \ + .json(rdd).schema + self.assertEquals(schema, StructType([StructField("a", LongType(), True)])) + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index d640fdc530ce2..b44552f0eb17b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -374,6 +374,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * per file *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator * that should be used for parsing.
  • + *
  • `samplingRatio` (default is 1.0): defines fraction of input JSON objects used + * for schema inferring.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 70aee561ff0f6..a58dff827b92d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2128,38 +2128,43 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } } - test("SPARK-23849: schema inferring touches less data if samplingRation < 1.0") { - val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, - 57, 62, 68, 72) - withTempPath { path => - val writer = Files.newBufferedWriter(Paths.get(path.getAbsolutePath), - StandardCharsets.UTF_8, StandardOpenOption.CREATE_NEW) - for (i <- 0 until 100) { - if (predefinedSample.contains(i)) { - writer.write(s"""{"f1":${i.toString}}""" + "\n") - } else { - writer.write(s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n") - } - } - writer.close() + test("SPARK-23849: schema inferring touches less data if samplingRatio < 1.0") { + // Set default values for the DataSource parameters to make sure + // that whole test file is mapped to only one partition. This will guarantee + // reliable sampling of the input file. + withSQLConf( + "spark.sql.files.maxPartitionBytes" -> (128 * 1024 * 1024).toString, + "spark.sql.files.openCostInBytes" -> (4 * 1024 * 1024).toString + )(withTempPath { path => + val ds = sampledTestData.coalesce(1) + ds.write.text(path.getAbsolutePath) + val readback = spark.read.option("samplingRatio", 0.1).json(path.getCanonicalPath) + + assert(readback.schema == new StructType().add("f1", LongType)) + }) + } - val ds = spark.read.option("samplingRatio", 0.1).json(path.getCanonicalPath) - assert(ds.schema == new StructType().add("f1", LongType)) - } + test("SPARK-23849: usage of samplingRatio while parsing a dataset of strings") { + val ds = sampledTestData.coalesce(1) + val readback = spark.read.option("samplingRatio", 0.1).json(ds) + + assert(readback.schema == new StructType().add("f1", LongType)) } - test("SPARK-23849: usage of samplingRation while parsing of dataset of strings") { - val dstr = spark.sparkContext.parallelize(0 until 100, 1).map { i => - val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, - 57, 62, 68, 72) - if (predefinedSample.contains(i)) { - s"""{"f1":${i.toString}}""" + "\n" - } else { - s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n" - } - }.toDS() - val ds = spark.read.option("samplingRatio", 0.1).json(dstr) + test("SPARK-23849: samplingRatio is out of the range (0, 1.0]") { + val ds = spark.range(0, 100, 1, 1).map(_.toString) + + val errorMsg0 = intercept[IllegalArgumentException] { + spark.read.option("samplingRatio", -1).json(ds) + }.getMessage + assert(errorMsg0.contains("samplingRatio (-1.0) should be greater than 0")) + + val errorMsg1 = intercept[IllegalArgumentException] { + spark.read.option("samplingRatio", 0).json(ds) + }.getMessage + assert(errorMsg1.contains("samplingRatio (0.0) should be greater than 0")) - assert(ds.schema == new StructType().add("f1", LongType)) + val sampled = spark.read.option("samplingRatio", 1.0).json(ds) + assert(sampled.count() == ds.count()) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index 13084ba4a7f04..6e9559edf8ec2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -233,4 +233,16 @@ private[json] trait TestJsonData { spark.createDataset(spark.sparkContext.parallelize("""{"a":123}""" :: Nil))(Encoders.STRING) def empty: Dataset[String] = spark.emptyDataset(Encoders.STRING) + + def sampledTestData: Dataset[String] = { + spark.range(0, 100, 1).map { index => + val predefinedSample = Set[Long](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, + 57, 62, 68, 72) + if (predefinedSample.contains(index)) { + s"""{"f1":${index.toString}}""" + } else { + s"""{"f1":${(index.toDouble + 0.1).toString}}""" + } + }(Encoders.STRING) + } } From 58c55cb4a6d72d72df908e37aa63f617b3cc5587 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 26 Apr 2018 12:19:20 +0900 Subject: [PATCH 693/774] [SPARK-23902][SQL] Add roundOff flag to months_between ## What changes were proposed in this pull request? HIVE-15511 introduced the `roundOff` flag in order to disable the rounding to 8 digits which is performed in `months_between`. Since this can be a computational intensive operation, skipping it may improve performances when the rounding is not needed. ## How was this patch tested? modified existing UT Author: Marco Gaido Closes #21008 from mgaido91/SPARK-23902. --- python/pyspark/sql/functions.py | 10 +++- .../expressions/datetimeExpressions.scala | 33 +++++++---- .../sql/catalyst/util/DateTimeUtils.scala | 32 ++++------ .../expressions/DateExpressionsSuite.scala | 59 +++++++++++-------- .../catalyst/util/DateTimeUtilsSuite.scala | 30 +++++++--- .../org/apache/spark/sql/functions.scala | 13 +++- .../apache/spark/sql/DateFunctionsSuite.scala | 7 +++ 7 files changed, 118 insertions(+), 66 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index de53b48b6f3b4..38ae41a5dafe6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1088,16 +1088,20 @@ def add_months(start, months): @since(1.5) -def months_between(date1, date2): +def months_between(date1, date2, roundOff=True): """ Returns the number of months between date1 and date2. + Unless `roundOff` is set to `False`, the result is rounded off to 8 digits. >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1', 'date2']) >>> df.select(months_between(df.date1, df.date2).alias('months')).collect() - [Row(months=3.9495967...)] + [Row(months=3.94959677)] + >>> df.select(months_between(df.date1, df.date2, False).alias('months')).collect() + [Row(months=3.9495967741935485)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2))) + return Column(sc._jvm.functions.months_between( + _to_java_column(date1), _to_java_column(date2), roundOff)) @since(2.2) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index b9b2cd5bdb9f0..d882d06cfd625 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1156,38 +1156,49 @@ case class AddMonths(startDate: Expression, numMonths: Expression) */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(timestamp1, timestamp2) - Returns number of months between `timestamp1` and `timestamp2`.", + usage = """ + _FUNC_(timestamp1, timestamp2[, roundOff]) - Returns number of months between `timestamp1` and `timestamp2`. + The result is rounded to 8 decimal places by default. Set roundOff=false otherwise."""", examples = """ Examples: > SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30'); 3.94959677 + > SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30', false); + 3.9495967741935485 """, since = "1.5.0") // scalastyle:on line.size.limit -case class MonthsBetween(date1: Expression, date2: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { +case class MonthsBetween( + date1: Expression, + date2: Expression, + roundOff: Expression, + timeZoneId: Option[String] = None) + extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + + def this(date1: Expression, date2: Expression) = this(date1, date2, Literal.TrueLiteral, None) - def this(date1: Expression, date2: Expression) = this(date1, date2, None) + def this(date1: Expression, date2: Expression, roundOff: Expression) = + this(date1, date2, roundOff, None) - override def left: Expression = date1 - override def right: Expression = date2 + override def children: Seq[Expression] = Seq(date1, date2, roundOff) - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType) + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType, BooleanType) override def dataType: DataType = DoubleType override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) - override def nullSafeEval(t1: Any, t2: Any): Any = { - DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long], timeZone) + override def nullSafeEval(t1: Any, t2: Any, roundOff: Any): Any = { + DateTimeUtils.monthsBetween( + t1.asInstanceOf[Long], t2.asInstanceOf[Long], roundOff.asInstanceOf[Boolean], timeZone) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (l, r) => { - s"""$dtu.monthsBetween($l, $r, $tz)""" + defineCodeGen(ctx, ev, (d1, d2, roundOff) => { + s"""$dtu.monthsBetween($d1, $d2, $roundOff, $tz)""" }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index fa69b8af62c85..4b00a61c6cf91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -870,24 +870,14 @@ object DateTimeUtils { * If time1 and time2 having the same day of month, or both are the last day of month, * it returns an integer (time under a day will be ignored). * - * Otherwise, the difference is calculated based on 31 days per month, and rounding to - * 8 digits. + * Otherwise, the difference is calculated based on 31 days per month. + * If `roundOff` is set to true, the result is rounded to 8 decimal places. */ - def monthsBetween(time1: SQLTimestamp, time2: SQLTimestamp): Double = { - monthsBetween(time1, time2, defaultTimeZone()) - } - - /** - * Returns number of months between time1 and time2. time1 and time2 are expressed in - * microseconds since 1.1.1970. - * - * If time1 and time2 having the same day of month, or both are the last day of month, - * it returns an integer (time under a day will be ignored). - * - * Otherwise, the difference is calculated based on 31 days per month, and rounding to - * 8 digits. - */ - def monthsBetween(time1: SQLTimestamp, time2: SQLTimestamp, timeZone: TimeZone): Double = { + def monthsBetween( + time1: SQLTimestamp, + time2: SQLTimestamp, + roundOff: Boolean, + timeZone: TimeZone): Double = { val millis1 = time1 / 1000L val millis2 = time2 / 1000L val date1 = millisToDays(millis1, timeZone) @@ -906,8 +896,12 @@ object DateTimeUtils { val timeInDay2 = millis2 - daysToMillis(date2, timeZone) val timesBetween = (timeInDay1 - timeInDay2).toDouble / MILLIS_PER_DAY val diff = (months1 - months2).toDouble + (dayInMonth1 - dayInMonth2 + timesBetween) / 31.0 - // rounding to 8 digits - math.round(diff * 1e8) / 1e8 + if (roundOff) { + // rounding to 8 digits + math.round(diff * 1e8) / 1e8 + } else { + diff + } } // Thursday = 0 since 1970/Jan/01 => Thursday diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 080ec487cfa6a..63b24fb9eb13a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -464,34 +464,47 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { MonthsBetween( Literal(new Timestamp(sdf.parse("1997-02-28 10:30:00").getTime)), Literal(new Timestamp(sdf.parse("1996-10-30 00:00:00").getTime)), - timeZoneId), - 3.94959677) + Literal.TrueLiteral, + timeZoneId = timeZoneId), 3.94959677) checkEvaluation( MonthsBetween( - Literal(new Timestamp(sdf.parse("2015-01-30 11:52:00").getTime)), - Literal(new Timestamp(sdf.parse("2015-01-30 11:50:00").getTime)), - timeZoneId), - 0.0) - checkEvaluation( - MonthsBetween( - Literal(new Timestamp(sdf.parse("2015-01-31 00:00:00").getTime)), - Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), - timeZoneId), - -2.0) - checkEvaluation( - MonthsBetween( - Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), - Literal(new Timestamp(sdf.parse("2015-02-28 00:00:00").getTime)), - timeZoneId), - 1.0) + Literal(new Timestamp(sdf.parse("1997-02-28 10:30:00").getTime)), + Literal(new Timestamp(sdf.parse("1996-10-30 00:00:00").getTime)), + Literal.FalseLiteral, + timeZoneId = timeZoneId), 3.9495967741935485) + + Seq(Literal.FalseLiteral, Literal.TrueLiteral). foreach { roundOff => + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-01-30 11:52:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-01-30 11:50:00").getTime)), + roundOff, + timeZoneId = timeZoneId), 0.0) + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-01-31 00:00:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), + roundOff, + timeZoneId = timeZoneId), -2.0) + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-02-28 00:00:00").getTime)), + roundOff, + timeZoneId = timeZoneId), 1.0) + } val t = Literal(Timestamp.valueOf("2015-03-31 22:00:00")) val tnull = Literal.create(null, TimestampType) - checkEvaluation(MonthsBetween(t, tnull, timeZoneId), null) - checkEvaluation(MonthsBetween(tnull, t, timeZoneId), null) - checkEvaluation(MonthsBetween(tnull, tnull, timeZoneId), null) + checkEvaluation(MonthsBetween(t, tnull, Literal.TrueLiteral, timeZoneId = timeZoneId), null) + checkEvaluation(MonthsBetween(tnull, t, Literal.TrueLiteral, timeZoneId = timeZoneId), null) + checkEvaluation( + MonthsBetween(tnull, tnull, Literal.TrueLiteral, timeZoneId = timeZoneId), null) + checkEvaluation( + MonthsBetween(t, t, Literal.create(null, BooleanType), timeZoneId = timeZoneId), null) checkConsistencyBetweenInterpretedAndCodegen( - (time1: Expression, time2: Expression) => MonthsBetween(time1, time2, timeZoneId), - TimestampType, TimestampType) + (time1: Expression, time2: Expression, roundOff: Expression) => + MonthsBetween(time1, time2, roundOff, timeZoneId = timeZoneId), + TimestampType, TimestampType, BooleanType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 625ff38943fa3..cbf6106697f30 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -490,24 +490,36 @@ class DateTimeUtilsSuite extends SparkFunSuite { c1.set(1997, 1, 28, 10, 30, 0) val c2 = Calendar.getInstance() c2.set(1996, 9, 30, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 3.94959677) - c2.set(2000, 1, 28, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) - c2.set(2000, 1, 29, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) - c2.set(1996, 2, 31, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 11) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, true, c1.getTimeZone) === 3.94959677) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, false, c1.getTimeZone) + === 3.9495967741935485) + Seq(true, false).foreach { roundOff => + c2.set(2000, 1, 28, 0, 0, 0) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === -36) + c2.set(2000, 1, 29, 0, 0, 0) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === -36) + c2.set(1996, 2, 31, 0, 0, 0) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === 11) + } val c3 = Calendar.getInstance(TimeZonePST) c3.set(2000, 1, 28, 16, 0, 0) val c4 = Calendar.getInstance(TimeZonePST) c4.set(1997, 1, 28, 16, 0, 0) assert( - monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, TimeZonePST) + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, true, TimeZonePST) === 36.0) assert( - monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, TimeZoneGMT) + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, true, TimeZoneGMT) === 35.90322581) + assert( + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, false, TimeZoneGMT) + === 35.903225806451616) } test("from UTC timestamp") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d2f057310f89b..f1587cd032adc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2691,11 +2691,22 @@ object functions { /** * Returns number of months between dates `date1` and `date2`. + * The result is rounded off to 8 digits. * @group datetime_funcs * @since 1.5.0 */ def months_between(date1: Column, date2: Column): Column = withExpr { - MonthsBetween(date1.expr, date2.expr) + new MonthsBetween(date1.expr, date2.expr) + } + + /** + * Returns number of months between dates `date1` and `date2`. If `roundOff` is set to true, the + * result is rounded off to 8 digits; it is not rounded otherwise. + * @group datetime_funcs + * @since 2.4.0 + */ + def months_between(date1: Column, date2: Column, roundOff: Boolean): Column = withExpr { + MonthsBetween(date1.expr, date2.expr, lit(roundOff).expr) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 6bbf38516cdf6..f712baa7a9134 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -327,6 +327,13 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { val df = Seq((t1, d1, s1), (t2, d2, s2)).toDF("t", "d", "s") checkAnswer(df.select(months_between(col("t"), col("d"))), Seq(Row(-10.0), Row(7.0))) checkAnswer(df.selectExpr("months_between(t, s)"), Seq(Row(0.5), Row(-0.5))) + checkAnswer(df.selectExpr("months_between(t, s, true)"), Seq(Row(0.5), Row(-0.5))) + Seq(true, false).foreach { roundOff => + checkAnswer(df.select(months_between(col("t"), col("d"), roundOff)), + Seq(Row(-10.0), Row(7.0))) + checkAnswer(df.withColumn("r", lit(false)).selectExpr("months_between(t, s, r)"), + Seq(Row(0.5), Row(-0.5))) + } } test("function last_day") { From cd10f9df8284ee8a5d287b2cd204c70b8ba87f5e Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 26 Apr 2018 13:37:13 +0900 Subject: [PATCH 694/774] [SPARK-23916][SQL] Add array_join function ## What changes were proposed in this pull request? The PR adds the SQL function `array_join`. The behavior of the function is based on Presto's one. The function accepts an `array` of `string` which is to be joined, a `string` which is the delimiter to use between the items of the first argument and optionally a `string` which is used to replace `null` values. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21011 from mgaido91/SPARK-23916. --- python/pyspark/sql/functions.py | 21 +++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 169 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 35 ++++ .../org/apache/spark/sql/functions.scala | 19 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 23 +++ 6 files changed, 268 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 38ae41a5dafe6..ad4bd6f5089e9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1834,6 +1834,27 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@ignore_unicode_prefix +@since(2.4) +def array_join(col, delimiter, null_replacement=None): + """ + Concatenates the elements of `column` using the `delimiter`. Null values are replaced with + `null_replacement` if set, otherwise they are ignored. + + >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data']) + >>> df.select(array_join(df.data, ",").alias("joined")).collect() + [Row(joined=u'a,b,c'), Row(joined=u'a')] + >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect() + [Row(joined=u'a,b,c'), Row(joined=u'a,NULL')] + """ + sc = SparkContext._active_spark_context + if null_replacement is None: + return Column(sc._jvm.functions.array_join(_to_java_column(col), delimiter)) + else: + return Column(sc._jvm.functions.array_join( + _to_java_column(col), delimiter, null_replacement)) + + @since(1.5) @ignore_unicode_prefix def concat(*cols): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6afcf309bd690..6bc7b4e4f7cb3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -401,6 +401,7 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index bc71b5f34ce4a..90223b9126555 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -378,6 +378,175 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Creates a String containing all the elements of the input array separated by the delimiter. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, delimiter[, nullReplacement]) - Concatenates the elements of the given array + using the delimiter and an optional string to replace nulls. If no value is set for + nullReplacement, any null value is filtered.""", + examples = """ + Examples: + > SELECT _FUNC_(array('hello', 'world'), ' '); + hello world + > SELECT _FUNC_(array('hello', null ,'world'), ' '); + hello world + > SELECT _FUNC_(array('hello', null ,'world'), ' ', ','); + hello , world + """, since = "2.4.0") +case class ArrayJoin( + array: Expression, + delimiter: Expression, + nullReplacement: Option[Expression]) extends Expression with ExpectsInputTypes { + + def this(array: Expression, delimiter: Expression) = this(array, delimiter, None) + + def this(array: Expression, delimiter: Expression, nullReplacement: Expression) = + this(array, delimiter, Some(nullReplacement)) + + override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) { + Seq(ArrayType(StringType), StringType, StringType) + } else { + Seq(ArrayType(StringType), StringType) + } + + override def children: Seq[Expression] = if (nullReplacement.isDefined) { + Seq(array, delimiter, nullReplacement.get) + } else { + Seq(array, delimiter) + } + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = { + val arrayEval = array.eval(input) + if (arrayEval == null) return null + val delimiterEval = delimiter.eval(input) + if (delimiterEval == null) return null + val nullReplacementEval = nullReplacement.map(_.eval(input)) + if (nullReplacementEval.contains(null)) return null + + val buffer = new UTF8StringBuilder() + var firstItem = true + val nullHandling = nullReplacementEval match { + case Some(rep) => (prependDelimiter: Boolean) => { + if (!prependDelimiter) { + buffer.append(delimiterEval.asInstanceOf[UTF8String]) + } + buffer.append(rep.asInstanceOf[UTF8String]) + true + } + case None => (_: Boolean) => false + } + arrayEval.asInstanceOf[ArrayData].foreach(StringType, (_, item) => { + if (item == null) { + if (nullHandling(firstItem)) { + firstItem = false + } + } else { + if (!firstItem) { + buffer.append(delimiterEval.asInstanceOf[UTF8String]) + } + buffer.append(item.asInstanceOf[UTF8String]) + firstItem = false + } + }) + buffer.build() + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val code = nullReplacement match { + case Some(replacement) => + val replacementGen = replacement.genCode(ctx) + val nullHandling = (buffer: String, delimiter: String, firstItem: String) => { + s""" + |if (!$firstItem) { + | $buffer.append($delimiter); + |} + |$buffer.append(${replacementGen.value}); + |$firstItem = false; + """.stripMargin + } + val execCode = if (replacement.nullable) { + ctx.nullSafeExec(replacement.nullable, replacementGen.isNull) { + genCodeForArrayAndDelimiter(ctx, ev, nullHandling) + } + } else { + genCodeForArrayAndDelimiter(ctx, ev, nullHandling) + } + s""" + |${replacementGen.code} + |$execCode + """.stripMargin + case None => genCodeForArrayAndDelimiter(ctx, ev, + (_: String, _: String, _: String) => "// nulls are ignored") + } + if (nullable) { + ev.copy( + s""" + |boolean ${ev.isNull} = true; + |UTF8String ${ev.value} = null; + |$code + """.stripMargin) + } else { + ev.copy( + s""" + |UTF8String ${ev.value} = null; + |$code + """.stripMargin, FalseLiteral) + } + } + + private def genCodeForArrayAndDelimiter( + ctx: CodegenContext, + ev: ExprCode, + nullEval: (String, String, String) => String): String = { + val arrayGen = array.genCode(ctx) + val delimiterGen = delimiter.genCode(ctx) + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val i = ctx.freshName("i") + val firstItem = ctx.freshName("firstItem") + val resultCode = + s""" + |$bufferClass $buffer = new $bufferClass(); + |boolean $firstItem = true; + |for (int $i = 0; $i < ${arrayGen.value}.numElements(); $i ++) { + | if (${arrayGen.value}.isNullAt($i)) { + | ${nullEval(buffer, delimiterGen.value, firstItem)} + | } else { + | if (!$firstItem) { + | $buffer.append(${delimiterGen.value}); + | } + | $buffer.append(${CodeGenerator.getValue(arrayGen.value, StringType, i)}); + | $firstItem = false; + | } + |} + |${ev.value} = $buffer.build();""".stripMargin + + if (array.nullable || delimiter.nullable) { + arrayGen.code + ctx.nullSafeExec(array.nullable, arrayGen.isNull) { + delimiterGen.code + ctx.nullSafeExec(delimiter.nullable, delimiterGen.isNull) { + s""" + |${ev.isNull} = false; + |$resultCode""".stripMargin + } + } + } else { + s""" + |${arrayGen.code} + |${delimiterGen.code} + |$resultCode""".stripMargin + } + } + + override def dataType: DataType = StringType + +} + /** * Returns the minimum value in the array. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index b49fa76b2a781..7048d93fd5649 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -106,6 +106,41 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + test("ArrayJoin") { + def testArrays( + arrays: Seq[Expression], + nullReplacement: Option[Expression], + expected: Seq[String]): Unit = { + assert(arrays.length == expected.length) + arrays.zip(expected).foreach { case (arr, exp) => + checkEvaluation(ArrayJoin(arr, Literal(","), nullReplacement), exp) + } + } + + val arrays = Seq(Literal.create(Seq[String]("a", "b"), ArrayType(StringType)), + Literal.create(Seq[String]("a", null, "b"), ArrayType(StringType)), + Literal.create(Seq[String](null), ArrayType(StringType)), + Literal.create(Seq[String]("a", "b", null), ArrayType(StringType)), + Literal.create(Seq[String](null, "a", "b"), ArrayType(StringType)), + Literal.create(Seq[String]("a"), ArrayType(StringType))) + + val withoutNullReplacement = Seq("a,b", "a,b", "", "a,b", "a,b", "a") + val withNullReplacement = Seq("a,b", "a,NULL,b", "NULL", "a,b,NULL", "NULL,a,b", "a") + testArrays(arrays, None, withoutNullReplacement) + testArrays(arrays, Some(Literal("NULL")), withNullReplacement) + + checkEvaluation(ArrayJoin( + Literal.create(null, ArrayType(StringType)), Literal(","), None), null) + checkEvaluation(ArrayJoin( + Literal.create(Seq[String](null), ArrayType(StringType)), + Literal.create(null, StringType), + None), null) + checkEvaluation(ArrayJoin( + Literal.create(Seq[String](null), ArrayType(StringType)), + Literal(","), + Some(Literal.create(null, StringType))), null) + } + test("Array Min") { checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11) checkEvaluation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f1587cd032adc..25afaacc38d6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3039,6 +3039,25 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Concatenates the elements of `column` using the `delimiter`. Null values are replaced with + * `nullReplacement`. + * @group collection_funcs + * @since 2.4.0 + */ + def array_join(column: Column, delimiter: String, nullReplacement: String): Column = withExpr { + ArrayJoin(column.expr, Literal(delimiter), Some(Literal(nullReplacement))) + } + + /** + * Concatenates the elements of `column` using the `delimiter`. + * @group collection_funcs + * @since 2.4.0 + */ + def array_join(column: Column, delimiter: String): Column = withExpr { + ArrayJoin(column.expr, Literal(delimiter), None) + } + /** * Concatenates multiple input columns together into a single column. * The function works with strings, binary and compatible array columns. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 03605c30036a3..c216d1322a06c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,6 +413,29 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("array_join function") { + val df = Seq( + (Seq[String]("a", "b"), ","), + (Seq[String]("a", null, "b"), ","), + (Seq.empty[String], ",") + ).toDF("x", "delimiter") + + checkAnswer( + df.select(array_join(df("x"), ";")), + Seq(Row("a;b"), Row("a;b"), Row("")) + ) + checkAnswer( + df.select(array_join(df("x"), ";", "NULL")), + Seq(Row("a;b"), Row("a;NULL;b"), Row("")) + ) + checkAnswer( + df.selectExpr("array_join(x, delimiter)"), + Seq(Row("a,b"), Row("a,b"), Row(""))) + checkAnswer( + df.selectExpr("array_join(x, delimiter, 'NULL')"), + Seq(Row("a,b"), Row("a,NULL,b"), Row(""))) + } + test("array_min function") { val df = Seq( Seq[Option[Int]](Some(1), Some(3), Some(2)), From ffaf0f9fd407aeba7006f3d785ea8a0e51187357 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 26 Apr 2018 13:27:33 +0800 Subject: [PATCH 695/774] [SPARK-24062][THRIFT SERVER] Fix SASL encryption cannot enabled issue in thrift server ## What changes were proposed in this pull request? For the details of the exception please see [SPARK-24062](https://issues.apache.org/jira/browse/SPARK-24062). The issue is: Spark on Yarn stores SASL secret in current UGI's credentials, this credentials will be distributed to AM and executors, so that executors and drive share the same secret to communicate. But STS/Hive library code will refresh the current UGI by UGI's loginFromKeytab() after Spark application is started, this will create a new UGI in the current driver's context with empty tokens and secret keys, so secret key is lost in the current context's UGI, that's why Spark driver throws secret key not found exception. In Spark 2.2 code, Spark also stores this secret key in SecurityManager's class variable, so even UGI is refreshed, the secret is still existed in the object, so STS with SASL can still be worked in Spark 2.2. But in Spark 2.3, we always search key from current UGI, which makes it fail to work in Spark 2.3. To fix this issue, there're two possible solutions: 1. Fix in STS/Hive library, when a new UGI is refreshed, copy the secret key from original UGI to the new one. The difficulty is that some codes to refresh the UGI is existed in Hive library, which makes us hard to change the code. 2. Roll back the logics in SecurityManager to match Spark 2.2, so that this issue can be fixed. 2nd solution seems a simple one. So I will propose a PR with 2nd solution. ## How was this patch tested? Verified in local cluster. CC vanzin tgravescs please help to review. Thanks! Author: jerryshao Closes #21138 from jerryshao/SPARK-24062. --- .../main/scala/org/apache/spark/SecurityManager.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 09ec8932353a0..dbfd5a514c189 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -89,6 +89,7 @@ private[spark] class SecurityManager( setViewAclsGroups(sparkConf.get("spark.ui.view.acls.groups", "")); setModifyAclsGroups(sparkConf.get("spark.modify.acls.groups", "")); + private var secretKey: String = _ logInfo("SecurityManager: authentication " + (if (authOn) "enabled" else "disabled") + "; ui acls " + (if (aclsOn) "enabled" else "disabled") + "; users with view permissions: " + viewAcls.toString() + @@ -321,6 +322,12 @@ private[spark] class SecurityManager( val creds = UserGroupInformation.getCurrentUser().getCredentials() Option(creds.getSecretKey(SECRET_LOOKUP_KEY)) .map { bytes => new String(bytes, UTF_8) } + // Secret key may not be found in current UGI's credentials. + // This happens when UGI is refreshed in the driver side by UGI's loginFromKeytab but not + // copy secret key from original UGI to the new one. This exists in ThriftServer's Hive + // logic. So as a workaround, storing secret key in a local variable to make it visible + // in different context. + .orElse(Option(secretKey)) .orElse(Option(sparkConf.getenv(ENV_AUTH_SECRET))) .orElse(sparkConf.getOption(SPARK_AUTH_SECRET_CONF)) .getOrElse { @@ -364,8 +371,8 @@ private[spark] class SecurityManager( rnd.nextBytes(secretBytes) val creds = new Credentials() - val secretStr = HashCodes.fromBytes(secretBytes).toString() - creds.addSecretKey(SECRET_LOOKUP_KEY, secretStr.getBytes(UTF_8)) + secretKey = HashCodes.fromBytes(secretBytes).toString() + creds.addSecretKey(SECRET_LOOKUP_KEY, secretKey.getBytes(UTF_8)) UserGroupInformation.getCurrentUser().addCredentials(creds) } From d1eb8d3ddc877958512194cc8f5dd8119b41bed0 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 25 Apr 2018 23:24:05 -0700 Subject: [PATCH 696/774] [SPARK-24094][SS][MINOR] Change description strings of v2 streaming sources to reflect the change ## What changes were proposed in this pull request? This makes it easy to understand at runtime which version is running. Great for debugging production issues. ## How was this patch tested? Not necessary. Author: Tathagata Das Closes #21160 from tdas/SPARK-24094. --- .../org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala | 2 +- .../streaming/sources/RateStreamMicroBatchReader.scala | 2 +- .../apache/spark/sql/execution/streaming/sources/socket.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index 2ed49ba3f5495..cbe655f9bff1f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -169,7 +169,7 @@ private[kafka010] class KafkaMicroBatchReader( kafkaOffsetReader.close() } - override def toString(): String = s"Kafka[$kafkaOffsetReader]" + override def toString(): String = s"KafkaV2[$kafkaOffsetReader]" /** * Read initial partition offsets from the checkpoint, or decide the offsets and write them to diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index 6cf8520fc544f..f54291bea6678 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -177,7 +177,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: override def stop(): Unit = {} - override def toString: String = s"MicroBatchRateSource[rowsPerSecond=$rowsPerSecond, " + + override def toString: String = s"RateStreamV2[rowsPerSecond=$rowsPerSecond, " + s"rampUpTimeSeconds=$rampUpTimeSeconds, " + s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 5aae46b463398..90f4a5ba4234d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -214,7 +214,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR } } - override def toString: String = s"TextSocket[host: $host, port: $port]" + override def toString: String = s"TextSocketV2[host: $host, port: $port]" } class TextSocketSourceProvider extends DataSourceV2 From ce2f919f8df1b794ceaa23e1a59d5d541ed47bf5 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 26 Apr 2018 19:07:13 +0800 Subject: [PATCH 697/774] [SPARK-23799][SQL][FOLLOW-UP] FilterEstimation.evaluateInSet produces wrong stats for STRING ## What changes were proposed in this pull request? `colStat.min` AND `colStat.max` are empty for string type. Thus, `evaluateInSet` should not return zero when either `colStat.min` or `colStat.max`. ## How was this patch tested? Added a test case. Author: gatorsmile Closes #21147 from gatorsmile/cached. --- .../logical/statsEstimation/FilterEstimation.scala | 12 ++++++++---- .../statsEstimation/FilterEstimationSuite.scala | 12 ++++++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 263c9ba60d145..5a3eeefaedb18 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -392,13 +392,13 @@ case class FilterEstimation(plan: Filter) extends Logging { val dataType = attr.dataType var newNdv = ndv - if (ndv.toDouble == 0 || colStat.min.isEmpty || colStat.max.isEmpty) { - return Some(0.0) - } - // use [min, max] to filter the original hSet dataType match { case _: NumericType | BooleanType | DateType | TimestampType => + if (ndv.toDouble == 0 || colStat.min.isEmpty || colStat.max.isEmpty) { + return Some(0.0) + } + val statsInterval = ValueInterval(colStat.min, colStat.max, dataType).asInstanceOf[NumericValueInterval] val validQuerySet = hSet.filter { v => @@ -422,6 +422,10 @@ case class FilterEstimation(plan: Filter) extends Logging { // We assume the whole set since there is no min/max information for String/Binary type case StringType | BinaryType => + if (ndv.toDouble == 0) { + return Some(0.0) + } + newNdv = ndv.min(BigInt(hSet.size)) if (update) { val newStats = colStat.copy(distinctCount = Some(newNdv), nullCount = Some(0)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 16cb5d032cf57..47bfa62569583 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -368,6 +368,18 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 0) } + test("evaluateInSet with string") { + validateEstimatedStats( + Filter(InSet(attrString, Set("A0")), + StatsTestPlan(Seq(attrString), 10, + AttributeMap(Seq(attrString -> + ColumnStat(distinctCount = Some(10), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)))))), + Seq(attrString -> ColumnStat(distinctCount = Some(1), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))), + expectedRowCount = 1) + } + test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), From 4f1e38649ebc7710850b7c40e6fb355775e7bb7f Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 26 Apr 2018 14:21:22 -0700 Subject: [PATCH 698/774] [SPARK-24057][PYTHON] put the real data type in the AssertionError message ## What changes were proposed in this pull request? Print out the data type in the AssertionError message to make it more meaningful. ## How was this patch tested? I manually tested the changed code on my local, but didn't add any test. Author: Huaxin Gao Closes #21159 from huaxingao/spark-24057. --- python/pyspark/sql/types.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 1f6534836d64a..3cd7a2ef115af 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -289,7 +289,8 @@ def __init__(self, elementType, containsNull=True): >>> ArrayType(StringType(), False) == ArrayType(StringType()) False """ - assert isinstance(elementType, DataType), "elementType should be DataType" + assert isinstance(elementType, DataType),\ + "elementType %s should be an instance of %s" % (elementType, DataType) self.elementType = elementType self.containsNull = containsNull @@ -343,8 +344,10 @@ def __init__(self, keyType, valueType, valueContainsNull=True): ... == MapType(StringType(), FloatType())) False """ - assert isinstance(keyType, DataType), "keyType should be DataType" - assert isinstance(valueType, DataType), "valueType should be DataType" + assert isinstance(keyType, DataType),\ + "keyType %s should be an instance of %s" % (keyType, DataType) + assert isinstance(valueType, DataType),\ + "valueType %s should be an instance of %s" % (valueType, DataType) self.keyType = keyType self.valueType = valueType self.valueContainsNull = valueContainsNull @@ -402,8 +405,9 @@ def __init__(self, name, dataType, nullable=True, metadata=None): ... == StructField("f2", StringType(), True)) False """ - assert isinstance(dataType, DataType), "dataType should be DataType" - assert isinstance(name, basestring), "field name should be string" + assert isinstance(dataType, DataType),\ + "dataType %s should be an instance of %s" % (dataType, DataType) + assert isinstance(name, basestring), "field name %s should be string" % (name) if not isinstance(name, str): name = name.encode('utf-8') self.name = name From f7435bec6a9348cfbbe26b13c230c08545d16067 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 26 Apr 2018 15:11:42 -0700 Subject: [PATCH 699/774] [SPARK-24044][PYTHON] Explicitly print out skipped tests from unittest module ## What changes were proposed in this pull request? This PR proposes to remove duplicated dependency checking logics and also print out skipped tests from unittests. For example, as below: ``` Skipped tests in pyspark.sql.tests with pypy: test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' ... Skipped tests in pyspark.sql.tests with python3: test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' ... ``` Currently, it's not printed out in the console. I think we should better print out skipped tests in the console. ## How was this patch tested? Manually tested. Also, fortunately, Jenkins has good environment to test the skipped output. Author: hyukjinkwon Closes #21107 from HyukjinKwon/skipped-tests-print. --- python/pyspark/ml/tests.py | 16 +++-- python/pyspark/mllib/tests.py | 4 +- python/pyspark/sql/tests.py | 51 +++++++------ python/pyspark/streaming/tests.py | 4 +- python/pyspark/tests.py | 12 +--- python/run-tests.py | 115 +++++++++++++----------------- 6 files changed, 98 insertions(+), 104 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 2ec0be60e9fa9..093593132e56d 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -2136,17 +2136,23 @@ class ImageReaderTest2(PySparkTestCase): @classmethod def setUpClass(cls): super(ImageReaderTest2, cls).setUpClass() + cls.hive_available = True # Note that here we enable Hive's support. cls.spark = None try: cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() except py4j.protocol.Py4JError: cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") + cls.hive_available = False except TypeError: cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") - cls.spark = HiveContext._createForTesting(cls.sc) + cls.hive_available = False + if cls.hive_available: + cls.spark = HiveContext._createForTesting(cls.sc) + + def setUp(self): + if not self.hive_available: + self.skipTest("Hive is not available.") @classmethod def tearDownClass(cls): @@ -2662,6 +2668,6 @@ def testDefaultFitMultiple(self): if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() + unittest.main(verbosity=2) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 1037bab7f1088..14d788b0bef60 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -1767,9 +1767,9 @@ def test_pca(self): if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() + unittest.main(verbosity=2) if not _have_scipy: print("NOTE: SciPy tests were skipped as it does not seem to be installed") sc.stop() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 98fa1b54b0a17..6b28c557a803e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3096,23 +3096,28 @@ def setUpClass(cls): filename_pattern = ( "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" "TestQueryExecutionListener.class") - if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): - raise unittest.SkipTest( + cls.has_listener = bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern))) + + if cls.has_listener: + # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. + cls.spark = SparkSession.builder \ + .master("local[4]") \ + .appName(cls.__name__) \ + .config( + "spark.sql.queryExecutionListeners", + "org.apache.spark.sql.TestQueryExecutionListener") \ + .getOrCreate() + + def setUp(self): + if not self.has_listener: + raise self.skipTest( "'org.apache.spark.sql.TestQueryExecutionListener' is not " "available. Will skip the related tests.") - # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. - cls.spark = SparkSession.builder \ - .master("local[4]") \ - .appName(cls.__name__) \ - .config( - "spark.sql.queryExecutionListeners", - "org.apache.spark.sql.TestQueryExecutionListener") \ - .getOrCreate() - @classmethod def tearDownClass(cls): - cls.spark.stop() + if hasattr(cls, "spark"): + cls.spark.stop() def tearDown(self): self.spark._jvm.OnSuccessCall.clear() @@ -3196,18 +3201,22 @@ class HiveContextSQLTests(ReusedPySparkTestCase): def setUpClass(cls): ReusedPySparkTestCase.setUpClass() cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + cls.hive_available = True try: cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() except py4j.protocol.Py4JError: - cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") + cls.hive_available = False except TypeError: - cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") + cls.hive_available = False os.unlink(cls.tempdir.name) - cls.spark = HiveContext._createForTesting(cls.sc) - cls.testData = [Row(key=i, value=str(i)) for i in range(100)] - cls.df = cls.sc.parallelize(cls.testData).toDF() + if cls.hive_available: + cls.spark = HiveContext._createForTesting(cls.sc) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + cls.df = cls.sc.parallelize(cls.testData).toDF() + + def setUp(self): + if not self.hive_available: + self.skipTest("Hive is not available.") @classmethod def tearDownClass(cls): @@ -5316,6 +5325,6 @@ def test_invalid_args(self): if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() + unittest.main(verbosity=2) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 103940923dd4d..d77f1baa1f344 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1590,11 +1590,11 @@ def search_kinesis_asl_assembly_jar(): sys.stderr.write("[Running %s]\n" % (testcase)) tests = unittest.TestLoader().loadTestsFromTestCase(testcase) if xmlrunner: - result = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=3).run(tests) + result = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2).run(tests) if not result.wasSuccessful(): failed = True else: - result = unittest.TextTestRunner(verbosity=3).run(tests) + result = unittest.TextTestRunner(verbosity=2).run(tests) if not result.wasSuccessful(): failed = True sys.exit(failed) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 9111dbbed5929..8392d7f29af53 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2353,15 +2353,7 @@ def test_statcounter_array(self): if __name__ == "__main__": from pyspark.tests import * - if not _have_scipy: - print("NOTE: Skipping SciPy tests as it does not seem to be installed") - if not _have_numpy: - print("NOTE: Skipping NumPy tests as it does not seem to be installed") if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() - if not _have_scipy: - print("NOTE: SciPy tests were skipped as it does not seem to be installed") - if not _have_numpy: - print("NOTE: NumPy tests were skipped as it does not seem to be installed") + unittest.main(verbosity=2) diff --git a/python/run-tests.py b/python/run-tests.py index 6b41b5ee22814..f408fc5082b3d 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -32,6 +32,7 @@ else: import queue as Queue from distutils.version import LooseVersion +from multiprocessing import Manager # Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module @@ -50,6 +51,7 @@ def print_red(text): print('\033[31m' + text + '\033[0m') +SKIPPED_TESTS = Manager().dict() LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log") FAILURE_REPORTING_LOCK = Lock() LOGGER = logging.getLogger() @@ -109,8 +111,34 @@ def run_individual_python_test(test_name, pyspark_python): # this code is invoked from a thread other than the main thread. os._exit(-1) else: - per_test_output.close() - LOGGER.info("Finished test(%s): %s (%is)", pyspark_python, test_name, duration) + skipped_counts = 0 + try: + per_test_output.seek(0) + # Here expects skipped test output from unittest when verbosity level is + # 2 (or --verbose option is enabled). + decoded_lines = map(lambda line: line.decode(), iter(per_test_output)) + skipped_tests = list(filter( + lambda line: re.search('test_.* \(pyspark\..*\) ... skipped ', line), + decoded_lines)) + skipped_counts = len(skipped_tests) + if skipped_counts > 0: + key = (pyspark_python, test_name) + SKIPPED_TESTS[key] = skipped_tests + per_test_output.close() + except: + import traceback + print_red("\nGot an exception while trying to store " + "skipped test output:\n%s" % traceback.format_exc()) + # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if + # this code is invoked from a thread other than the main thread. + os._exit(-1) + if skipped_counts != 0: + LOGGER.info( + "Finished test(%s): %s (%is) ... %s tests were skipped", pyspark_python, test_name, + duration, skipped_counts) + else: + LOGGER.info( + "Finished test(%s): %s (%is)", pyspark_python, test_name, duration) def get_default_python_executables(): @@ -152,65 +180,17 @@ def parse_opts(): return opts -def _check_dependencies(python_exec, modules_to_test): - if "COVERAGE_PROCESS_START" in os.environ: - # Make sure if coverage is installed. - try: - subprocess_check_output( - [python_exec, "-c", "import coverage"], - stderr=open(os.devnull, 'w')) - except: - print_red("Coverage is not installed in Python executable '%s' " - "but 'COVERAGE_PROCESS_START' environment variable is set, " - "exiting." % python_exec) - sys.exit(-1) - - # If we should test 'pyspark-sql', it checks if PyArrow and Pandas are installed and - # explicitly prints out. See SPARK-23300. - if pyspark_sql in modules_to_test: - # TODO(HyukjinKwon): Relocate and deduplicate these version specifications. - minimum_pyarrow_version = '0.8.0' - minimum_pandas_version = '0.19.2' - - try: - pyarrow_version = subprocess_check_output( - [python_exec, "-c", "import pyarrow; print(pyarrow.__version__)"], - universal_newlines=True, - stderr=open(os.devnull, 'w')).strip() - if LooseVersion(pyarrow_version) >= LooseVersion(minimum_pyarrow_version): - LOGGER.info("Will test PyArrow related features against Python executable " - "'%s' in '%s' module." % (python_exec, pyspark_sql.name)) - else: - LOGGER.warning( - "Will skip PyArrow related features against Python executable " - "'%s' in '%s' module. PyArrow >= %s is required; however, PyArrow " - "%s was found." % ( - python_exec, pyspark_sql.name, minimum_pyarrow_version, pyarrow_version)) - except: - LOGGER.warning( - "Will skip PyArrow related features against Python executable " - "'%s' in '%s' module. PyArrow >= %s is required; however, PyArrow " - "was not found." % (python_exec, pyspark_sql.name, minimum_pyarrow_version)) - - try: - pandas_version = subprocess_check_output( - [python_exec, "-c", "import pandas; print(pandas.__version__)"], - universal_newlines=True, - stderr=open(os.devnull, 'w')).strip() - if LooseVersion(pandas_version) >= LooseVersion(minimum_pandas_version): - LOGGER.info("Will test Pandas related features against Python executable " - "'%s' in '%s' module." % (python_exec, pyspark_sql.name)) - else: - LOGGER.warning( - "Will skip Pandas related features against Python executable " - "'%s' in '%s' module. Pandas >= %s is required; however, Pandas " - "%s was found." % ( - python_exec, pyspark_sql.name, minimum_pandas_version, pandas_version)) - except: - LOGGER.warning( - "Will skip Pandas related features against Python executable " - "'%s' in '%s' module. Pandas >= %s is required; however, Pandas " - "was not found." % (python_exec, pyspark_sql.name, minimum_pandas_version)) +def _check_coverage(python_exec): + # Make sure if coverage is installed. + try: + subprocess_check_output( + [python_exec, "-c", "import coverage"], + stderr=open(os.devnull, 'w')) + except: + print_red("Coverage is not installed in Python executable '%s' " + "but 'COVERAGE_PROCESS_START' environment variable is set, " + "exiting." % python_exec) + sys.exit(-1) def main(): @@ -237,9 +217,10 @@ def main(): task_queue = Queue.PriorityQueue() for python_exec in python_execs: - # Check if the python executable has proper dependencies installed to run tests - # for given modules properly. - _check_dependencies(python_exec, modules_to_test) + # Check if the python executable has coverage installed when 'COVERAGE_PROCESS_START' + # environmental variable is set. + if "COVERAGE_PROCESS_START" in os.environ: + _check_coverage(python_exec) python_implementation = subprocess_check_output( [python_exec, "-c", "import platform; print(platform.python_implementation())"], @@ -281,6 +262,12 @@ def process_queue(task_queue): total_duration = time.time() - start_time LOGGER.info("Tests passed in %i seconds", total_duration) + for key, lines in sorted(SKIPPED_TESTS.items()): + pyspark_python, test_name = key + LOGGER.info("\nSkipped tests in %s with %s:" % (test_name, pyspark_python)) + for line in lines: + LOGGER.info(" %s" % line.rstrip()) + if __name__ == "__main__": main() From 9ee9fcf5223efdf7543161b7bc99131111876b92 Mon Sep 17 00:00:00 2001 From: zhoukang Date: Thu, 26 Apr 2018 15:38:11 -0700 Subject: [PATCH 700/774] [SPARK-24083][YARN] Log stacktrace for uncaught exception ## What changes were proposed in this pull request? Log stacktrace for uncaught exception ## How was this patch tested? UT and manually test Author: zhoukang Closes #21151 from caneGuy/zhoukang/log-stacktrace. --- .../scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index d04989e138f83..650840045361c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -308,7 +308,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends logError("Uncaught exception: ", e) finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION, - "Uncaught exception: " + e) + "Uncaught exception: " + StringUtils.stringifyException(e)) } } From 8aa1d7b0ede5115297541d29eab4ce5f4fe905cb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 27 Apr 2018 11:00:41 +0800 Subject: [PATCH 701/774] [SPARK-23355][SQL] convertMetastore should not ignore table properties ## What changes were proposed in this pull request? Previously, SPARK-22158 fixed for `USING hive` syntax. This PR aims to fix for `STORED AS` syntax. Although the test case covers ORC part, the patch considers both `convertMetastoreOrc` and `convertMetastoreParquet`. ## How was this patch tested? Pass newly added test cases. Author: Dongjoon Hyun Closes #20522 from dongjoon-hyun/SPARK-22158-2. --- .../spark/sql/hive/HiveStrategies.scala | 17 +++- .../sql/hive/CompressionCodecSuite.scala | 7 +- .../sql/hive/execution/HiveDDLSuite.scala | 81 +++++++++++++++++++ 3 files changed, 97 insertions(+), 8 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 8df05cbb20361..a0c197b06ddab 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -186,15 +186,28 @@ case class RelationConversions( serde.contains("orc") && conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) } + // Return true for Apache ORC and Hive ORC-related configuration names. + // Note that Spark doesn't support configurations like `hive.merge.orcfile.stripe.level`. + private def isOrcProperty(key: String) = + key.startsWith("orc.") || key.contains(".orc.") + + private def isParquetProperty(key: String) = + key.startsWith("parquet.") || key.contains(".parquet.") + private def convert(relation: HiveTableRelation): LogicalRelation = { val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + + // Consider table and storage properties. For properties existing in both sides, storage + // properties will supersede table properties. if (serde.contains("parquet")) { - val options = relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> + val options = relation.tableMeta.properties.filterKeys(isParquetProperty) ++ + relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) sessionCatalog.metastoreCatalog .convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet") } else { - val options = relation.tableMeta.storage.properties + val options = relation.tableMeta.properties.filterKeys(isOrcProperty) ++ + relation.tableMeta.storage.properties if (conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "native") { sessionCatalog.metastoreCatalog.convertToLogicalRelation( relation, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala index d10a6f25c64fc..4550d350f6db2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -268,12 +268,7 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo compressionCodecs = compressCodecs, tableCompressionCodecs = compressCodecs) { case (tableCodec, sessionCodec, realCodec, tableSize) => - // For non-partitioned table and when convertMetastore is true, Expect session-level - // take effect, and in other cases expect table-level take effect - // TODO: It should always be table-level taking effect when the bug(SPARK-22926) - // is fixed - val expectCodec = - if (convertMetastore && !isPartitioned) sessionCodec else tableCodec.get + val expectCodec = tableCodec.get assert(expectCodec == realCodec) assert(checkTableSize( format, expectCodec, isPartitioned, convertMetastore, usingCTAS, tableSize)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index c85db78c732de..daac6af9b557f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.hive.HiveUtils.{CONVERT_METASTORE_ORC, CONVERT_METAS import org.apache.spark.sql.hive.orc.OrcFileOperator import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} +import org.apache.spark.sql.internal.SQLConf.ORC_IMPLEMENTATION import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -2144,6 +2145,86 @@ class HiveDDLSuite } } + private def getReader(path: String): org.apache.orc.Reader = { + val conf = spark.sessionState.newHadoopConf() + val files = org.apache.spark.sql.execution.datasources.orc.OrcUtils.listOrcFiles(path, conf) + assert(files.length == 1) + val file = files.head + val fs = file.getFileSystem(conf) + val readerOptions = org.apache.orc.OrcFile.readerOptions(conf).filesystem(fs) + org.apache.orc.OrcFile.createReader(file, readerOptions) + } + + test("SPARK-23355 convertMetastoreOrc should not ignore table properties - STORED AS") { + Seq("native", "hive").foreach { orcImpl => + withSQLConf(ORC_IMPLEMENTATION.key -> orcImpl, CONVERT_METASTORE_ORC.key -> "true") { + withTable("t") { + withTempPath { path => + sql( + s""" + |CREATE TABLE t(id int) STORED AS ORC + |TBLPROPERTIES ( + | orc.compress 'ZLIB', + | orc.compress.size '1001', + | orc.row.index.stride '2002', + | hive.exec.orc.default.block.size '3003', + | hive.exec.orc.compression.strategy 'COMPRESSION') + |LOCATION '${path.toURI}' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(DDLUtils.isHiveTable(table)) + assert(table.storage.serde.get.contains("orc")) + val properties = table.properties + assert(properties.get("orc.compress") == Some("ZLIB")) + assert(properties.get("orc.compress.size") == Some("1001")) + assert(properties.get("orc.row.index.stride") == Some("2002")) + assert(properties.get("hive.exec.orc.default.block.size") == Some("3003")) + assert(properties.get("hive.exec.orc.compression.strategy") == Some("COMPRESSION")) + assert(spark.table("t").collect().isEmpty) + + sql("INSERT INTO t SELECT 1") + checkAnswer(spark.table("t"), Row(1)) + val maybeFile = path.listFiles().find(_.getName.startsWith("part")) + + val reader = getReader(maybeFile.head.getCanonicalPath) + assert(reader.getCompressionKind.name === "ZLIB") + assert(reader.getCompressionSize == 1001) + assert(reader.getRowIndexStride == 2002) + } + } + } + } + } + + test("SPARK-23355 convertMetastoreParquet should not ignore table properties - STORED AS") { + withSQLConf(CONVERT_METASTORE_PARQUET.key -> "true") { + withTable("t") { + withTempPath { path => + sql( + s""" + |CREATE TABLE t(id int) STORED AS PARQUET + |TBLPROPERTIES ( + | parquet.compression 'GZIP' + |) + |LOCATION '${path.toURI}' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(DDLUtils.isHiveTable(table)) + assert(table.storage.serde.get.contains("parquet")) + val properties = table.properties + assert(properties.get("parquet.compression") == Some("GZIP")) + assert(spark.table("t").collect().isEmpty) + + sql("INSERT INTO t SELECT 1") + checkAnswer(spark.table("t"), Row(1)) + val maybeFile = path.listFiles().find(_.getName.startsWith("part")) + + assertCompression(maybeFile, "parquet", "GZIP") + } + } + } + } + test("load command for non local invalid path validation") { withTable("tbl") { sql("CREATE TABLE tbl(i INT, j STRING)") From 109935fc5d8b3d381bb1b09a4a570040a0a1846f Mon Sep 17 00:00:00 2001 From: eric-maynard Date: Fri, 27 Apr 2018 15:25:07 +0800 Subject: [PATCH 702/774] [SPARK-23830][YARN] added check to ensure main method is found ## What changes were proposed in this pull request? When a user specifies the wrong class -- or, in fact, a class instead of an object -- Spark throws an NPE which is not useful for debugging. This was reported in [SPARK-23830](https://issues.apache.org/jira/browse/SPARK-23830). This PR adds a check to ensure the main method was found and logs a useful error in the event that it's null. ## How was this patch tested? * Unit tests + Manual testing * The scope of the changes is very limited Author: eric-maynard Author: Eric Maynard Closes #21168 from eric-maynard/feature/SPARK-23830. --- .../spark/deploy/yarn/ApplicationMaster.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 650840045361c..595077e7e809f 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} -import java.lang.reflect.InvocationTargetException +import java.lang.reflect.{InvocationTargetException, Modifier} import java.net.{Socket, URI, URL} import java.security.PrivilegedExceptionAction import java.util.concurrent.{TimeoutException, TimeUnit} @@ -675,9 +675,14 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends val userThread = new Thread { override def run() { try { - mainMethod.invoke(null, userArgs.toArray) - finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) - logDebug("Done running users class") + if (!Modifier.isStatic(mainMethod.getModifiers)) { + logError(s"Could not find static main method in object ${args.userClass}") + finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_EXCEPTION_USER_CLASS) + } else { + mainMethod.invoke(null, userArgs.toArray) + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + logDebug("Done running user class") + } } catch { case e: InvocationTargetException => e.getCause match { From 2824f12b8bac5d86a82339d4dfb4d2625e978a15 Mon Sep 17 00:00:00 2001 From: Patrick McGloin Date: Fri, 27 Apr 2018 23:04:14 +0800 Subject: [PATCH 703/774] [SPARK-23565][SS] New error message for structured streaming sources assertion ## What changes were proposed in this pull request? A more informative message to tell you why a structured streaming query cannot continue if you have added more sources, than there are in the existing checkpoint offsets. ## How was this patch tested? I added a Unit Test. Author: Patrick McGloin Closes #20946 from patrickmcgloin/master. --- .../org/apache/spark/sql/execution/streaming/OffsetSeq.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 73945b39b8967..787174481ff08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -39,7 +39,9 @@ case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[OffsetSeqMet * cannot be serialized). */ def toStreamProgress(sources: Seq[BaseStreamingSource]): StreamProgress = { - assert(sources.size == offsets.size) + assert(sources.size == offsets.size, s"There are [${offsets.size}] sources in the " + + s"checkpoint offsets and now there are [${sources.size}] sources requested by the query. " + + s"Cannot continue.") new StreamProgress ++ sources.zip(offsets).collect { case (s, Some(o)) => (s, o) } } From 3fd297af6dc568357c97abf86760c570309d6597 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 27 Apr 2018 11:43:29 -0700 Subject: [PATCH 704/774] [SPARK-24085][SQL] Query returns UnsupportedOperationException when scalar subquery is present in partitioning expression ## What changes were proposed in this pull request? In this case, the partition pruning happens before the planning phase of scalar subquery expressions. For scalar subquery expressions, the planning occurs late in the cycle (after the physical planning) in "PlanSubqueries" just before execution. Currently we try to execute the scalar subquery expression as part of partition pruning and fail as it implements Unevaluable. The fix attempts to ignore the Subquery expressions from partition pruning computation. Another option can be to somehow plan the subqueries before the partition pruning. Since this may not be a commonly occuring expression, i am opting for a simpler fix. Repro ``` SQL CREATE TABLE test_prc_bug ( id_value string ) partitioned by (id_type string) location '/tmp/test_prc_bug' stored as parquet; insert into test_prc_bug values ('1','a'); insert into test_prc_bug values ('2','a'); insert into test_prc_bug values ('3','b'); insert into test_prc_bug values ('4','b'); select * from test_prc_bug where id_type = (select 'b'); ``` ## How was this patch tested? Added test in SubquerySuite and hive/SQLQuerySuite Author: Dilip Biswal Closes #21174 from dilipbiswal/spark-24085. --- .../datasources/FileSourceStrategy.scala | 5 ++- .../PruneFileSourcePartitions.scala | 4 ++- .../org/apache/spark/sql/SubquerySuite.scala | 15 +++++++++ .../sql/hive/execution/SQLQuerySuite.scala | 31 +++++++++++++++++++ 4 files changed, 53 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 16b22717b8d92..0a568d6b8adce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -76,7 +76,10 @@ object FileSourceStrategy extends Strategy with Logging { fsRelation.partitionSchema, fsRelation.sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) val partitionKeyFilters = - ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) + ExpressionSet(normalizedFilters + .filterNot(SubqueryExpression.hasSubquery(_)) + .filter(_.references.subsetOf(partitionSet))) + logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}") val dataColumns = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 3b830accb83f0..16b2367bfdd5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -55,7 +55,9 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { partitionSchema, sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) val partitionKeyFilters = - ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) + ExpressionSet(normalizedFilters + .filterNot(SubqueryExpression.hasSubquery(_)) + .filter(_.references.subsetOf(partitionSet))) if (partitionKeyFilters.nonEmpty) { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 31e8b0e8dede0..acef62d81ee12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -955,4 +955,19 @@ class SubquerySuite extends QueryTest with SharedSQLContext { // before the fix this would throw AnalysisException spark.range(10).where("(id,id) in (select id, null from range(3))").count } + + test("SPARK-24085 scalar subquery in partitioning expression") { + withTable("parquet_part") { + Seq("1" -> "a", "2" -> "a", "3" -> "b", "4" -> "b") + .toDF("id_value", "id_type") + .write + .mode(SaveMode.Overwrite) + .partitionBy("id_type") + .format("parquet") + .saveAsTable("parquet_part") + checkAnswer( + sql("SELECT * FROM parquet_part WHERE id_type = (SELECT 'b')"), + Row("3", "b") :: Row("4", "b") :: Nil) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 73f83d593bbfb..704a410b6a37b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2156,4 +2156,35 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + test("SPARK-24085 scalar subquery in partitioning expression") { + Seq("orc", "parquet").foreach { format => + Seq(true, false).foreach { isConverted => + withSQLConf( + HiveUtils.CONVERT_METASTORE_ORC.key -> s"$isConverted", + HiveUtils.CONVERT_METASTORE_PARQUET.key -> s"$isConverted", + "hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTable(format) { + withTempPath { tempDir => + sql( + s""" + |CREATE TABLE ${format} (id_value string) + |PARTITIONED BY (id_type string) + |LOCATION '${tempDir.toURI}' + |STORED AS ${format} + """.stripMargin) + sql(s"insert into $format values ('1','a')") + sql(s"insert into $format values ('2','a')") + sql(s"insert into $format values ('3','b')") + sql(s"insert into $format values ('4','b')") + checkAnswer( + sql(s"SELECT * FROM $format WHERE id_type = (SELECT 'b')"), + Row("3", "b") :: Row("4", "b") :: Nil) + } + } + } + } + } + } + } From 8614edd445264007144caa6743a8c2ca2b5082e0 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 27 Apr 2018 14:14:28 -0700 Subject: [PATCH 705/774] [SPARK-24104] SQLAppStatusListener overwrites metrics onDriverAccumUpdates instead of updating them ## What changes were proposed in this pull request? Event `SparkListenerDriverAccumUpdates` may happen multiple times in a query - e.g. every `FileSourceScanExec` and `BroadcastExchangeExec` call `postDriverMetricUpdates`. In Spark 2.2 `SQLListener` updated the map with new values. `SQLAppStatusListener` overwrites it. Unless `update` preserved it in the KV store (dependant on `exec.lastWriteTime`), only the metrics from the last operator that does `postDriverMetricUpdates` are preserved. ## How was this patch tested? Unit test added. Author: Juliusz Sompolski Closes #21171 from juliuszsompolski/SPARK-24104. --- .../execution/ui/SQLAppStatusListener.scala | 2 +- .../ui/SQLAppStatusListenerSuite.scala | 24 +++++++++++++++---- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 2b6bb48467eb3..d254af400a7cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -289,7 +289,7 @@ class SQLAppStatusListener( private def onDriverAccumUpdates(event: SparkListenerDriverAccumUpdates): Unit = { val SparkListenerDriverAccumUpdates(executionId, accumUpdates) = event Option(liveExecutions.get(executionId)).foreach { exec => - exec.driverAccumUpdates = accumUpdates.toMap + exec.driverAccumUpdates = exec.driverAccumUpdates ++ accumUpdates update(exec) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index f3f08839c1d3a..02df45d1b7989 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -443,7 +443,8 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with val oldCount = statusStore.executionsList().size val expectedAccumValue = 12345 - val physicalPlan = MyPlan(sqlContext.sparkContext, expectedAccumValue) + val expectedAccumValue2 = 54321 + val physicalPlan = MyPlan(sqlContext.sparkContext, expectedAccumValue, expectedAccumValue2) val dummyQueryExecution = new QueryExecution(spark, LocalRelation()) { override lazy val sparkPlan = physicalPlan override lazy val executedPlan = physicalPlan @@ -466,10 +467,14 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with val execId = statusStore.executionsList().last.executionId val metrics = statusStore.executionMetrics(execId) val driverMetric = physicalPlan.metrics("dummy") + val driverMetric2 = physicalPlan.metrics("dummy2") val expectedValue = SQLMetrics.stringValue(driverMetric.metricType, Seq(expectedAccumValue)) + val expectedValue2 = SQLMetrics.stringValue(driverMetric2.metricType, Seq(expectedAccumValue2)) assert(metrics.contains(driverMetric.id)) assert(metrics(driverMetric.id) === expectedValue) + assert(metrics.contains(driverMetric2.id)) + assert(metrics(driverMetric2.id) === expectedValue2) } test("roundtripping SparkListenerDriverAccumUpdates through JsonProtocol (SPARK-18462)") { @@ -562,20 +567,31 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with * A dummy [[org.apache.spark.sql.execution.SparkPlan]] that updates a [[SQLMetrics]] * on the driver. */ -private case class MyPlan(sc: SparkContext, expectedValue: Long) extends LeafExecNode { +private case class MyPlan(sc: SparkContext, expectedValue: Long, expectedValue2: Long) + extends LeafExecNode { + override def sparkContext: SparkContext = sc override def output: Seq[Attribute] = Seq() override val metrics: Map[String, SQLMetric] = Map( - "dummy" -> SQLMetrics.createMetric(sc, "dummy")) + "dummy" -> SQLMetrics.createMetric(sc, "dummy"), + "dummy2" -> SQLMetrics.createMetric(sc, "dummy2")) override def doExecute(): RDD[InternalRow] = { longMetric("dummy") += expectedValue + longMetric("dummy2") += expectedValue2 + + // postDriverMetricUpdates may happen multiple time in a query. + // (normally from different operators, but for the sake of testing, from one operator) + SQLMetrics.postDriverMetricUpdates( + sc, + sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY), + Seq(metrics("dummy"))) SQLMetrics.postDriverMetricUpdates( sc, sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY), - metrics.values.toSeq) + Seq(metrics("dummy2"))) sc.emptyRDD } } From 1fb46f30f83e4751169ff288ad406f26b7c11f7e Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sat, 28 Apr 2018 09:55:56 +0800 Subject: [PATCH 706/774] [SPARK-23688][SS] Refactor tests away from rate source ## What changes were proposed in this pull request? Replace rate source with memory source in continuous mode test suite. Keep using "rate" source if the tests intend to put data periodically in background, or need to put short source name to load, since "memory" doesn't have provider for source. ## How was this patch tested? Ran relevant test suite from IDE. Author: Jungtaek Lim Closes #21152 from HeartSaVioR/SPARK-23688. --- .../continuous/ContinuousSuite.scala | 163 +++++++----------- 1 file changed, 61 insertions(+), 102 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index c318b951ff992..5f222e7885994 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -75,73 +75,50 @@ class ContinuousSuite extends ContinuousSuiteBase { } test("map") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .map(r => r.getLong(0) * 2) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().map(_.getInt(0) * 2) - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 40, 2).map(Row(_)))) + testStream(df)( + AddData(input, 0, 1), + CheckAnswer(0, 2), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer(0, 2, 4, 6, 8)) } test("flatMap") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .flatMap(r => Seq(0, r.getLong(0), r.getLong(0) * 2)) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().flatMap(r => Seq(0, r.getInt(0), r.getInt(0) * 2)) - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 20).flatMap(n => Seq(0, n, n * 2)).map(Row(_)))) + testStream(df)( + AddData(input, 0, 1), + CheckAnswer((0 to 1).flatMap(n => Seq(0, n, n * 2)): _*), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer((0 to 4).flatMap(n => Seq(0, n, n * 2)): _*)) } test("filter") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .where('value > 5) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().where('value > 2) - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(6, 20).map(Row(_)))) + testStream(df)( + AddData(input, 0, 1), + CheckAnswer(), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer(3, 4)) } test("deduplicate") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .dropDuplicates() + val input = ContinuousMemoryStream[Int] + val df = input.toDF().dropDuplicates() val except = intercept[AnalysisException] { - testStream(df, useV2Sink = true)(StartStream(longContinuousTrigger)) + testStream(df)(StartStream()) } assert(except.message.contains( @@ -149,15 +126,11 @@ class ContinuousSuite extends ContinuousSuiteBase { } test("timestamp") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select(current_timestamp()) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().select(current_timestamp()) val except = intercept[AnalysisException] { - testStream(df, useV2Sink = true)(StartStream(longContinuousTrigger)) + testStream(df)(StartStream()) } assert(except.message.contains( @@ -165,58 +138,43 @@ class ContinuousSuite extends ContinuousSuiteBase { } test("subquery alias") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .createOrReplaceTempView("rate") - val test = spark.sql("select value from rate where value > 5") + val input = ContinuousMemoryStream[Int] + input.toDF().createOrReplaceTempView("memory") + val test = spark.sql("select value from memory where value > 2") - testStream(test, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(6, 20).map(Row(_)))) + testStream(test)( + AddData(input, 0, 1), + CheckAnswer(), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer(3, 4)) } test("repeatedly restart") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) + val input = ContinuousMemoryStream[Int] + val df = input.toDF() - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 10).map(Row(_))), + testStream(df)( + StartStream(), + AddData(input, 0, 1), + CheckAnswer(0, 1), StopStream, - StartStream(longContinuousTrigger), + StartStream(), StopStream, - StartStream(longContinuousTrigger), + StartStream(), StopStream, - StartStream(longContinuousTrigger), - AwaitEpoch(2), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 20).map(Row(_))), + StartStream(), + StopStream, + AddData(input, 2, 3), + StartStream(), + CheckAnswer(0, 1, 2, 3), StopStream) } test("task failure kills the query") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) + val input = ContinuousMemoryStream[Int] + val df = input.toDF() // Get an arbitrary task from this query to kill. It doesn't matter which one. var taskId: Long = -1 @@ -227,9 +185,9 @@ class ContinuousSuite extends ContinuousSuiteBase { } spark.sparkContext.addSparkListener(listener) try { - testStream(df, useV2Sink = true)( + testStream(df)( StartStream(Trigger.Continuous(100)), - Execute(waitForRateSourceTriggers(_, 2)), + AddData(input, 0, 1, 2, 3), Execute { _ => // Wait until a task is started, then kill its first attempt. eventually(timeout(streamingTimeout)) { @@ -252,6 +210,7 @@ class ContinuousSuite extends ContinuousSuiteBase { .option("rowsPerSecond", "2") .load() .select('value) + val query = df.writeStream .format("memory") .queryName("noharness") From ad94e8592b2e8f4c1bdbd958e110797c6658af84 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 28 Apr 2018 10:47:43 +0800 Subject: [PATCH 707/774] [SPARK-23736][SQL][FOLLOWUP] Error message should contains SQL types ## What changes were proposed in this pull request? In the error messages we should return the SQL types (like `string` rather than the internal types like `StringType`). ## How was this patch tested? added UT Author: Marco Gaido Closes #21181 from mgaido91/SPARK-23736_followup. --- .../sql/catalyst/expressions/collectionOperations.scala | 5 +++-- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 5 +++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 90223b9126555..6d63a531e3b74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -863,8 +863,9 @@ case class Concat(children: Seq[Expression]) extends Expression { val childTypes = children.map(_.dataType) if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) { return TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + - s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + s"input to function $prettyName should have been ${StringType.simpleString}," + + s" ${BinaryType.simpleString} or ${ArrayType.simpleString}, but it's " + + childTypes.map(_.simpleString).mkString("[", ", ", "]")) } TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index c216d1322a06c..470a1c8e331ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -712,6 +712,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { intercept[AnalysisException] { df.selectExpr("concat(i1, array(i1, i2))") } + + val e = intercept[AnalysisException] { + df.selectExpr("concat(map(1, 2), map(3, 4))") + } + assert(e.getMessage.contains("string, binary or array")) } test("flatten function") { From 4df51361a5ff1fba20524f1b580f4049b328ed32 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 28 Apr 2018 16:57:41 +0800 Subject: [PATCH 708/774] [SPARK-22732][SS][FOLLOW-UP] Fix MemorySinkV2 toString error ## What changes were proposed in this pull request? Fix `MemorySinkV2` toString() error ## How was this patch tested? N/A Author: Yuming Wang Closes #21170 from wangyum/SPARK-22732. --- .../spark/sql/execution/streaming/sources/memoryV2.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 5f58246083bb2..d871d37ad37c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -96,7 +96,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { case _ => throw new IllegalArgumentException( - s"Output mode $outputMode is not supported by MemorySink") + s"Output mode $outputMode is not supported by MemorySinkV2") } } else { logDebug(s"Skipping already committed batch: $batchId") @@ -107,7 +107,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { batches.clear() } - override def toString(): String = "MemorySink" + override def toString(): String = "MemorySinkV2" } case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} @@ -175,7 +175,7 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode) /** - * Used to query the data that has been written into a [[MemorySink]]. + * Used to query the data that has been written into a [[MemorySinkV2]]. */ case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode { private val sizePerRow = output.map(_.dataType.defaultSize).sum From bd14da6fd5a77cc03efff193a84ffccbe892cc13 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 29 Apr 2018 11:25:31 +0800 Subject: [PATCH 709/774] [SPARK-23094][SPARK-23723][SPARK-23724][SQL] Support custom encoding for json files ## What changes were proposed in this pull request? I propose new option for JSON datasource which allows to specify encoding (charset) of input and output files. Here is an example of using of the option: ``` spark.read.schema(schema) .option("multiline", "true") .option("encoding", "UTF-16LE") .json(fileName) ``` If the option is not specified, charset auto-detection mechanism is used by default. The option can be used for saving datasets to jsons. Currently Spark is able to save datasets into json files in `UTF-8` charset only. The changes allow to save data in any supported charset. Here is the approximate list of supported charsets by Oracle Java SE: https://docs.oracle.com/javase/8/docs/technotes/guides/intl/encoding.doc.html . An user can specify the charset of output jsons via the charset option like `.option("charset", "UTF-16BE")`. By default the output charset is still `UTF-8` to keep backward compatibility. The solution has the following restrictions for per-line mode (`multiline = false`): - If charset is different from UTF-8, the lineSep option must be specified. The option required because Hadoop LineReader cannot detect the line separator correctly. Here is the ticket for solving the issue: https://issues.apache.org/jira/browse/SPARK-23725 - Encoding with [BOM](https://en.wikipedia.org/wiki/Byte_order_mark) are not supported. For example, the `UTF-16` and `UTF-32` encodings are blacklisted. The problem can be solved by https://github.com/MaxGekk/spark-1/pull/2 ## How was this patch tested? I added the following tests: - reads an json file in `UTF-16LE` encoding with BOM in `multiline` mode - read json file by using charset auto detection (`UTF-32BE` with BOM) - read json file using of user's charset (`UTF-16LE`) - saving in `UTF-32BE` and read the result by standard library (not by Spark) - checking that default charset is `UTF-8` - handling wrong (unsupported) charset Author: Maxim Gekk Author: Maxim Gekk Closes #20937 from MaxGekk/json-encoding-line-sep. --- python/pyspark/sql/readwriter.py | 15 +- python/pyspark/sql/tests.py | 7 + .../sql/people_array_utf16le.json | Bin 0 -> 182 bytes .../catalyst/json/CreateJacksonParser.scala | 49 +++- .../spark/sql/catalyst/json/JSONOptions.scala | 39 ++- .../sql/catalyst/json/JacksonParser.scala | 10 +- .../apache/spark/sql/DataFrameReader.scala | 3 + .../apache/spark/sql/DataFrameWriter.scala | 8 +- .../datasources/json/JsonDataSource.scala | 60 +++-- .../datasources/json/JsonFileFormat.scala | 10 +- .../datasources/text/TextOptions.scala | 18 +- .../src/test/resources/test-data/utf16LE.json | Bin 0 -> 98 bytes .../resources/test-data/utf16WithBOM.json | Bin 0 -> 200 bytes .../resources/test-data/utf32BEWithBOM.json | Bin 0 -> 172 bytes .../datasources/json/JsonBenchmarks.scala | 179 +++++++++++++ .../datasources/json/JsonSuite.scala | 245 +++++++++++++++++- 16 files changed, 599 insertions(+), 44 deletions(-) create mode 100644 python/test_support/sql/people_array_utf16le.json create mode 100644 sql/core/src/test/resources/test-data/utf16LE.json create mode 100644 sql/core/src/test/resources/test-data/utf16WithBOM.json create mode 100644 sql/core/src/test/resources/test-data/utf32BEWithBOM.json create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index df176c579fc8b..6811fa6b3b156 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -176,7 +176,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None, + encoding=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -237,6 +238,10 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param allowUnquotedControlChars: allows JSON Strings to contain unquoted control characters (ASCII characters with value less than 32, including tab and line feed characters) or not. + :param encoding: allows to forcibly set one of standard basic or extended encoding for + the JSON files. For example UTF-16BE, UTF-32LE. If None is set, + the encoding of input JSON will be detected automatically + when the multiLine option is set to ``true``. :param lineSep: defines the line separator that should be used for parsing. If None is set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. :param samplingRatio: defines fraction of input JSON objects used for schema inferring. @@ -259,7 +264,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, - samplingRatio=samplingRatio) + samplingRatio=samplingRatio, encoding=encoding) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -752,7 +757,7 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) @since(1.4) def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None, - lineSep=None): + lineSep=None, encoding=None): """Saves the content of the :class:`DataFrame` in JSON format (`JSON Lines text format or newline-delimited JSON `_) at the specified path. @@ -776,6 +781,8 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. + :param encoding: specifies encoding (charset) of saved json files. If None is set, + the default UTF-8 charset will be used. :param lineSep: defines the line separator that should be used for writing. If None is set, it uses the default value, ``\\n``. @@ -784,7 +791,7 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm self.mode(mode) self._set_opts( compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat, - lineSep=lineSep) + lineSep=lineSep, encoding=encoding) self._jwrite.json(path) @since(1.4) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6b28c557a803e..e0cd2aa41a2d0 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -685,6 +685,13 @@ def test_multiline_json(self): multiLine=True) self.assertEqual(people1.collect(), people_array.collect()) + def test_encoding_json(self): + people_array = self.spark.read\ + .json("python/test_support/sql/people_array_utf16le.json", + multiLine=True, encoding="UTF-16LE") + expected = [Row(age=30, name=u'Andy'), Row(age=19, name=u'Justin')] + self.assertEqual(people_array.collect(), expected) + def test_linesep_json(self): df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",") expected = [Row(_corrupt_record=None, name=u'Michael'), diff --git a/python/test_support/sql/people_array_utf16le.json b/python/test_support/sql/people_array_utf16le.json new file mode 100644 index 0000000000000000000000000000000000000000..9c657fa30ac9c651076ff8aa3676baa400b121fb GIT binary patch literal 182 zcma!M;9^h!!fGfDVk require(sep.nonEmpty, "'lineSep' cannot be an empty string.") sep } - // Note that the option 'lineSep' uses a different default value in read and write. - val lineSeparatorInRead: Option[Array[Byte]] = - lineSeparator.map(_.getBytes(StandardCharsets.UTF_8)) - // Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8. + + /** + * Standard encoding (charset) name. For example UTF-8, UTF-16LE and UTF-32BE. + * If the encoding is not specified (None), it will be detected automatically + * when the multiLine option is set to `true`. + */ + val encoding: Option[String] = parameters.get("encoding") + .orElse(parameters.get("charset")).map { enc => + // The following encodings are not supported in per-line mode (multiline is false) + // because they cause some problems in reading files with BOM which is supposed to + // present in the files with such encodings. After splitting input files by lines, + // only the first lines will have the BOM which leads to impossibility for reading + // the rest lines. Besides of that, the lineSep option must have the BOM in such + // encodings which can never present between lines. + val blacklist = Seq(Charset.forName("UTF-16"), Charset.forName("UTF-32")) + val isBlacklisted = blacklist.contains(Charset.forName(enc)) + require(multiLine || !isBlacklisted, + s"""The ${enc} encoding must not be included in the blacklist when multiLine is disabled: + | ${blacklist.mkString(", ")}""".stripMargin) + + val isLineSepRequired = !(multiLine == false && + Charset.forName(enc) != StandardCharsets.UTF_8 && lineSeparator.isEmpty) + require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding") + + enc + } + + val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => + lineSep.getBytes(encoding.getOrElse("UTF-8")) + } val lineSeparatorInWrite: String = lineSeparator.getOrElse("\n") /** Sets config options on a Jackson [[JsonFactory]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 7f6956994f31f..a5a4a13eb608b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.json -import java.io.ByteArrayOutputStream +import java.io.{ByteArrayOutputStream, CharConversionException} import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -361,6 +361,14 @@ class JacksonParser( // For such records, all fields other than the field configured by // `columnNameOfCorruptRecord` are set to `null`. throw BadRecordException(() => recordLiteral(record), () => None, e) + case e: CharConversionException if options.encoding.isEmpty => + val msg = + """JSON parser cannot handle a character in its input. + |Specifying encoding as an input option explicitly might help to resolve the issue. + |""".stripMargin + e.getMessage + val wrappedCharException = new CharConversionException(msg) + wrappedCharException.initCause(e) + throw BadRecordException(() => recordLiteral(record), () => None, wrappedCharException) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b44552f0eb17b..6b2ea6c06d3ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -372,6 +372,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * `java.text.SimpleDateFormat`. This applies to timestamp type. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • + *
  • `encoding` (by default it is not set): allows to forcibly set one of standard basic + * or extended encoding for the JSON files. For example UTF-16BE, UTF-32LE. If the encoding + * is not specified and `multiLine` is set to `true`, it will be detected automatically.
  • *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator * that should be used for parsing.
  • *
  • `samplingRatio` (default is 1.0): defines fraction of input JSON objects used diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index bbc063148a72c..e183fa6f9542b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -518,8 +518,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `lineSep` (default `\n`): defines the line separator that should - * be used for writing.
  • + *
  • `encoding` (by default it is not set): specifies encoding (charset) of saved json + * files. If it is not set, the UTF-8 charset will be used.
  • + *
  • `lineSep` (default `\n`): defines the line separator that should be used for writing.
  • * * * @since 1.4.0 @@ -589,8 +590,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
  • - *
  • `lineSep` (default `\n`): defines the line separator that should - * be used for writing.
  • + *
  • `lineSep` (default `\n`): defines the line separator that should be used for writing.
  • * * * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 5769c09c9a1d9..983a5f0dcade2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -31,11 +31,11 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.TaskContext import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} -import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.text.{TextFileFormat, TextOptions} +import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -92,26 +92,30 @@ object TextInputJsonDataSource extends JsonDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): StructType = { - val json: Dataset[String] = createBaseDataset( - sparkSession, inputPaths, parsedOptions.lineSeparator) + val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions) + inferFromDataset(json, parsedOptions) } def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = { val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions) - val rdd: RDD[UTF8String] = sampled.queryExecution.toRdd.map(_.getUTF8String(0)) - JsonInferSchema.infer(rdd, parsedOptions, CreateJacksonParser.utf8String) + val rdd: RDD[InternalRow] = sampled.queryExecution.toRdd + val rowParser = parsedOptions.encoding.map { enc => + CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow) + }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow)) + + JsonInferSchema.infer(rdd, parsedOptions, rowParser) } private def createBaseDataset( sparkSession: SparkSession, inputPaths: Seq[FileStatus], - lineSeparator: Option[String]): Dataset[String] = { - val textOptions = lineSeparator.map { lineSep => - Map(TextOptions.LINE_SEPARATOR -> lineSep) - }.getOrElse(Map.empty[String, String]) - + parsedOptions: JSONOptions): Dataset[String] = { val paths = inputPaths.map(_.getPath.toString) + val textOptions = Map.empty[String, String] ++ + parsedOptions.encoding.map("encoding" -> _) ++ + parsedOptions.lineSeparator.map("lineSep" -> _) + sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, @@ -129,8 +133,12 @@ object TextInputJsonDataSource extends JsonDataSource { schema: StructType): Iterator[InternalRow] = { val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + val textParser = parser.options.encoding + .map(enc => CreateJacksonParser.text(enc, _: JsonFactory, _: Text)) + .getOrElse(CreateJacksonParser.text(_: JsonFactory, _: Text)) + val safeParser = new FailureSafeParser[Text]( - input => parser.parse(input, CreateJacksonParser.text, textToUTF8String), + input => parser.parse(input, textParser, textToUTF8String), parser.options.parseMode, schema, parser.options.columnNameOfCorruptRecord) @@ -153,7 +161,11 @@ object MultiLineJsonDataSource extends JsonDataSource { parsedOptions: JSONOptions): StructType = { val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths) val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions) - JsonInferSchema.infer(sampled, parsedOptions, createParser) + val parser = parsedOptions.encoding + .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) + .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) + + JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) } private def createBaseRdd( @@ -175,11 +187,18 @@ object MultiLineJsonDataSource extends JsonDataSource { .values } - private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { - val path = new Path(record.getPath()) - CreateJacksonParser.inputStream( - jsonFactory, - CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, path)) + private def dataToInputStream(dataStream: PortableDataStream): InputStream = { + val path = new Path(dataStream.getPath()) + CodecStreams.createInputStreamWithCloseResource(dataStream.getConfiguration, path) + } + + private def createParser(jsonFactory: JsonFactory, stream: PortableDataStream): JsonParser = { + CreateJacksonParser.inputStream(jsonFactory, dataToInputStream(stream)) + } + + private def createParser(enc: String, jsonFactory: JsonFactory, + stream: PortableDataStream): JsonParser = { + CreateJacksonParser.inputStream(enc, jsonFactory, dataToInputStream(stream)) } override def readFile( @@ -194,9 +213,12 @@ object MultiLineJsonDataSource extends JsonDataSource { UTF8String.fromBytes(ByteStreams.toByteArray(inputStream)) } } + val streamParser = parser.options.encoding + .map(enc => CreateJacksonParser.inputStream(enc, _: JsonFactory, _: InputStream)) + .getOrElse(CreateJacksonParser.inputStream(_: JsonFactory, _: InputStream)) val safeParser = new FailureSafeParser[InputStream]( - input => parser.parse(input, CreateJacksonParser.inputStream, partitionedFileString), + input => parser.parse[InputStream](input, streamParser, partitionedFileString), parser.options.parseMode, schema, parser.options.columnNameOfCorruptRecord) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 0862c746fffad..3b04510d29695 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.json +import java.nio.charset.{Charset, StandardCharsets} + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} @@ -151,7 +153,13 @@ private[json] class JsonOutputWriter( context: TaskAttemptContext) extends OutputWriter with Logging { - private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) + private val encoding = options.encoding match { + case Some(charsetName) => Charset.forName(charsetName) + case None => StandardCharsets.UTF_8 + } + + private val writer = CodecStreams.createOutputStreamWriter( + context, new Path(path), encoding) // create the Generator without separator inserted between 2 records private[this] val gen = new JacksonGenerator(dataSchema, writer, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala index 5c1a35434f7b5..e4e201995faa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.text -import java.nio.charset.StandardCharsets +import java.nio.charset.{Charset, StandardCharsets} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs} @@ -41,13 +41,18 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti */ val wholeText = parameters.getOrElse(WHOLETEXT, "false").toBoolean - private val lineSeparator: Option[String] = parameters.get(LINE_SEPARATOR).map { sep => - require(sep.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.") - sep + val encoding: Option[String] = parameters.get(ENCODING) + + val lineSeparator: Option[String] = parameters.get(LINE_SEPARATOR).map { lineSep => + require(lineSep.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.") + + lineSep } + // Note that the option 'lineSep' uses a different default value in read and write. - val lineSeparatorInRead: Option[Array[Byte]] = - lineSeparator.map(_.getBytes(StandardCharsets.UTF_8)) + val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => + lineSep.getBytes(encoding.map(Charset.forName(_)).getOrElse(StandardCharsets.UTF_8)) + } val lineSeparatorInWrite: Array[Byte] = lineSeparatorInRead.getOrElse("\n".getBytes(StandardCharsets.UTF_8)) } @@ -55,5 +60,6 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti private[datasources] object TextOptions { val COMPRESSION = "compression" val WHOLETEXT = "wholetext" + val ENCODING = "encoding" val LINE_SEPARATOR = "lineSep" } diff --git a/sql/core/src/test/resources/test-data/utf16LE.json b/sql/core/src/test/resources/test-data/utf16LE.json new file mode 100644 index 0000000000000000000000000000000000000000..ce4117fd299dfcbc7089e7c0530098bfcaf5a27e GIT binary patch literal 98 zcmbi20w;GhFpeJpqLd9J2PYe#WR62N(?$cl`!==KvkHk Roq(bsb5ek+xfp7J7y!-s4k`cu literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/test-data/utf16WithBOM.json b/sql/core/src/test/resources/test-data/utf16WithBOM.json new file mode 100644 index 0000000000000000000000000000000000000000..cf4d29328b860ffe8288edea437222c6d432a100 GIT binary patch literal 200 zcmezWFPedufr~)_3ac5E7}6Lr8HyN+8A=%Z7!nzB8B&2_RzU2`kO36W1j;Be=m6C# tG2{T{G1WN%ML{N{09DiiRT68y3qw9bDMLB|(}RGj@}XvfOpXPc4*2#AY;xCDs(fH)C|bAdP&h(YSCptLiP&H!SNdXPSl f9+12a5Gz30IY1hupBVF;plV@mNP(JB3#7RK --jars + */ +object JSONBenchmarks { + val conf = new SparkConf() + + val spark = SparkSession.builder + .master("local[1]") + .appName("benchmark-json-datasource") + .config(conf) + .getOrCreate() + import spark.implicits._ + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + + def schemaInferring(rowsNum: Int): Unit = { + val benchmark = new Benchmark("JSON schema inferring", rowsNum) + + withTempPath { path => + // scalastyle:off println + benchmark.out.println("Preparing data for benchmarking ...") + // scalastyle:on println + + spark.sparkContext.range(0, rowsNum, 1) + .map(_ => "a") + .toDF("fieldA") + .write + .option("encoding", "UTF-8") + .json(path.getAbsolutePath) + + benchmark.addCase("No encoding", 3) { _ => + spark.read.json(path.getAbsolutePath) + } + + benchmark.addCase("UTF-8 is set", 3) { _ => + spark.read + .option("encoding", "UTF-8") + .json(path.getAbsolutePath) + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + JSON schema inferring: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + No encoding 38902 / 39282 2.6 389.0 1.0X + UTF-8 is set 56959 / 57261 1.8 569.6 0.7X + */ + benchmark.run() + } + } + + def perlineParsing(rowsNum: Int): Unit = { + val benchmark = new Benchmark("JSON per-line parsing", rowsNum) + + withTempPath { path => + // scalastyle:off println + benchmark.out.println("Preparing data for benchmarking ...") + // scalastyle:on println + + spark.sparkContext.range(0, rowsNum, 1) + .map(_ => "a") + .toDF("fieldA") + .write.json(path.getAbsolutePath) + val schema = new StructType().add("fieldA", StringType) + + benchmark.addCase("No encoding", 3) { _ => + spark.read + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + benchmark.addCase("UTF-8 is set", 3) { _ => + spark.read + .option("encoding", "UTF-8") + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + JSON per-line parsing: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + No encoding 25947 / 26188 3.9 259.5 1.0X + UTF-8 is set 46319 / 46417 2.2 463.2 0.6X + */ + benchmark.run() + } + } + + def perlineParsingOfWideColumn(rowsNum: Int): Unit = { + val benchmark = new Benchmark("JSON parsing of wide lines", rowsNum) + + withTempPath { path => + // scalastyle:off println + benchmark.out.println("Preparing data for benchmarking ...") + // scalastyle:on println + + spark.sparkContext.range(0, rowsNum, 1) + .map { i => + val s = "abcdef0123456789ABCDEF" * 20 + s"""{"a":"$s","b": $i,"c":"$s","d":$i,"e":"$s","f":$i,"x":"$s","y":$i,"z":"$s"}""" + } + .toDF().write.text(path.getAbsolutePath) + val schema = new StructType() + .add("a", StringType).add("b", LongType) + .add("c", StringType).add("d", LongType) + .add("e", StringType).add("f", LongType) + .add("x", StringType).add("y", LongType) + .add("z", StringType) + + benchmark.addCase("No encoding", 3) { _ => + spark.read + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + benchmark.addCase("UTF-8 is set", 3) { _ => + spark.read + .option("encoding", "UTF-8") + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + JSON parsing of wide lines: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + No encoding 45543 / 45660 0.2 4554.3 1.0X + UTF-8 is set 65737 / 65957 0.2 6573.7 0.7X + */ + benchmark.run() + } + } + + def main(args: Array[String]): Unit = { + schemaInferring(100 * 1000 * 1000) + perlineParsing(100 * 1000 * 1000) + perlineParsingOfWideColumn(10 * 1000 * 1000) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index a58dff827b92d..0db688fec9a67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution.datasources.json -import java.io.{File, StringWriter} -import java.nio.charset.StandardCharsets +import java.io.{File, FileOutputStream, StringWriter} +import java.nio.charset.{StandardCharsets, UnsupportedCharsetException} import java.nio.file.{Files, Paths, StandardOpenOption} import java.sql.{Date, Timestamp} import java.util.Locale @@ -48,6 +48,10 @@ class TestFileFilter extends PathFilter { class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ + def testFile(fileName: String): String = { + Thread.currentThread().getContextClassLoader.getResource(fileName).toString + } + test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { assert(expected.getClass == actual.getClass, @@ -2167,4 +2171,241 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val sampled = spark.read.option("samplingRatio", 1.0).json(ds) assert(sampled.count() == ds.count()) } + + test("SPARK-23723: json in UTF-16 with BOM") { + val fileName = "test-data/utf16WithBOM.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .option("encoding", "UTF-16") + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"), Row("Doug", "Rood"))) + } + + test("SPARK-23723: multi-line json in UTF-32BE with BOM") { + val fileName = "test-data/utf32BEWithBOM.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + test("SPARK-23723: Use user's encoding in reading of multi-line json in UTF-16LE") { + val fileName = "test-data/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .options(Map("encoding" -> "UTF-16LE")) + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + test("SPARK-23723: Unsupported encoding name") { + val invalidCharset = "UTF-128" + val exception = intercept[UnsupportedCharsetException] { + spark.read + .options(Map("encoding" -> invalidCharset, "lineSep" -> "\n")) + .json(testFile("test-data/utf16LE.json")) + .count() + } + + assert(exception.getMessage.contains(invalidCharset)) + } + + test("SPARK-23723: checking that the encoding option is case agnostic") { + val fileName = "test-data/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .options(Map("encoding" -> "uTf-16lE")) + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + + test("SPARK-23723: specified encoding is not matched to actual encoding") { + val fileName = "test-data/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val exception = intercept[SparkException] { + spark.read.schema(schema) + .option("mode", "FAILFAST") + .option("multiline", "true") + .options(Map("encoding" -> "UTF-16BE")) + .json(testFile(fileName)) + .count() + } + val errMsg = exception.getMessage + + assert(errMsg.contains("Malformed records are detected in record parsing")) + } + + def checkEncoding(expectedEncoding: String, pathToJsonFiles: String, + expectedContent: String): Unit = { + val jsonFiles = new File(pathToJsonFiles) + .listFiles() + .filter(_.isFile) + .filter(_.getName.endsWith("json")) + val actualContent = jsonFiles.map { file => + new String(Files.readAllBytes(file.toPath), expectedEncoding) + }.mkString.trim + + assert(actualContent == expectedContent) + } + + test("SPARK-23723: save json in UTF-32BE") { + val encoding = "UTF-32BE" + withTempPath { path => + val df = spark.createDataset(Seq(("Dog", 42))) + df.write + .options(Map("encoding" -> encoding, "lineSep" -> "\n")) + .json(path.getCanonicalPath) + + checkEncoding( + expectedEncoding = encoding, + pathToJsonFiles = path.getCanonicalPath, + expectedContent = """{"_1":"Dog","_2":42}""") + } + } + + test("SPARK-23723: save json in default encoding - UTF-8") { + withTempPath { path => + val df = spark.createDataset(Seq(("Dog", 42))) + df.write.json(path.getCanonicalPath) + + checkEncoding( + expectedEncoding = "UTF-8", + pathToJsonFiles = path.getCanonicalPath, + expectedContent = """{"_1":"Dog","_2":42}""") + } + } + + test("SPARK-23723: wrong output encoding") { + val encoding = "UTF-128" + val exception = intercept[UnsupportedCharsetException] { + withTempPath { path => + val df = spark.createDataset(Seq((0))) + df.write + .options(Map("encoding" -> encoding, "lineSep" -> "\n")) + .json(path.getCanonicalPath) + } + } + + assert(exception.getMessage == encoding) + } + + test("SPARK-23723: read back json in UTF-16LE") { + val options = Map("encoding" -> "UTF-16LE", "lineSep" -> "\n") + withTempPath { path => + val ds = spark.createDataset(Seq(("a", 1), ("b", 2), ("c", 3))).repartition(2) + ds.write.options(options).json(path.getCanonicalPath) + + val readBack = spark + .read + .options(options) + .json(path.getCanonicalPath) + + checkAnswer(readBack.toDF(), ds.toDF()) + } + } + + def checkReadJson(lineSep: String, encoding: String, inferSchema: Boolean, id: Int): Unit = { + test(s"SPARK-23724: checks reading json in ${encoding} #${id}") { + val schema = new StructType().add("f1", StringType).add("f2", IntegerType) + withTempPath { path => + val records = List(("a", 1), ("b", 2)) + val data = records + .map(rec => s"""{"f1":"${rec._1}", "f2":${rec._2}}""".getBytes(encoding)) + .reduce((a1, a2) => a1 ++ lineSep.getBytes(encoding) ++ a2) + val os = new FileOutputStream(path) + os.write(data) + os.close() + val reader = if (inferSchema) { + spark.read + } else { + spark.read.schema(schema) + } + val readBack = reader + .option("encoding", encoding) + .option("lineSep", lineSep) + .json(path.getCanonicalPath) + checkAnswer(readBack, records.map(rec => Row(rec._1, rec._2))) + } + } + } + + // scalastyle:off nonascii + List( + (0, "|", "UTF-8", false), + (1, "^", "UTF-16BE", true), + (2, "::", "ISO-8859-1", true), + (3, "!!!@3", "UTF-32LE", false), + (4, 0x1E.toChar.toString, "UTF-8", true), + (5, "아", "UTF-32BE", false), + (6, "куку", "CP1251", true), + (7, "sep", "utf-8", false), + (8, "\r\n", "UTF-16LE", false), + (9, "\r\n", "utf-16be", true), + (10, "\u000d\u000a", "UTF-32BE", false), + (11, "\u000a\u000d", "UTF-8", true), + (12, "===", "US-ASCII", false), + (13, "$^+", "utf-32le", true) + ).foreach { + case (testNum, sep, encoding, inferSchema) => checkReadJson(sep, encoding, inferSchema, testNum) + } + // scalastyle:on nonascii + + test("SPARK-23724: lineSep should be set if encoding if different from UTF-8") { + val encoding = "UTF-16LE" + val exception = intercept[IllegalArgumentException] { + spark.read + .options(Map("encoding" -> encoding)) + .json(testFile("test-data/utf16LE.json")) + .count() + } + + assert(exception.getMessage.contains( + s"""The lineSep option must be specified for the $encoding encoding""")) + } + + private val badJson = "\u0000\u0000\u0000A\u0001AAA" + + test("SPARK-23094: permissively read JSON file with leading nulls when multiLine is enabled") { + withTempPath { tempDir => + val path = tempDir.getAbsolutePath + Seq(badJson + """{"a":1}""").toDS().write.text(path) + val expected = s"""${badJson}{"a":1}\n""" + val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) + val df = spark.read.format("json") + .option("mode", "PERMISSIVE") + .option("multiLine", true) + .option("encoding", "UTF-8") + .schema(schema).load(path) + checkAnswer(df, Row(null, expected)) + } + } + + test("SPARK-23094: permissively read JSON file with leading nulls when multiLine is disabled") { + withTempPath { tempDir => + val path = tempDir.getAbsolutePath + Seq(badJson, """{"a":1}""").toDS().write.text(path) + val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) + val df = spark.read.format("json") + .option("mode", "PERMISSIVE") + .option("multiLine", false) + .option("encoding", "UTF-8") + .schema(schema).load(path) + checkAnswer(df, Seq(Row(1, null), Row(null, badJson))) + } + } + + test("SPARK-23094: permissively parse a dataset contains JSON with leading nulls") { + checkAnswer( + spark.read.option("mode", "PERMISSIVE").option("encoding", "UTF-8").json(Seq(badJson).toDS()), + Row(badJson)) + } } From 56f501e1c0cec3be7d13008bd2c0182ec83ed2a2 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 30 Apr 2018 09:40:46 +0800 Subject: [PATCH 710/774] [MINOR][DOCS] Fix a broken link for Arrow's supported types in the programming guide ## What changes were proposed in this pull request? This PR fixes a broken link for Arrow's supported types in the programming guide. ## How was this patch tested? Manually tested via `SKIP_API=1 jekyll watch`. "Supported SQL Types" here in https://spark.apache.org/docs/latest/sql-programming-guide.html#enabling-for-conversion-tofrom-pandas is broken. It should be https://spark.apache.org/docs/latest/sql-programming-guide.html#supported-sql-types Author: hyukjinkwon Closes #21191 from HyukjinKwon/minor-arrow-link. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index e8ff1470970f7..836ce990205a9 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1703,7 +1703,7 @@ Using the above optimizations with Arrow will produce the same results as when A enabled. Note that even with Arrow, `toPandas()` results in the collection of all records in the DataFrame to the driver program and should be done on a small subset of the data. Not all Spark data types are currently supported and an error can be raised if a column has an unsupported type, -see [Supported SQL Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`, +see [Supported SQL Types](#supported-sql-types). If an error occurs during `createDataFrame()`, Spark will fall back to create the DataFrame without Arrow. ## Pandas UDFs (a.k.a. Vectorized UDFs) From 3121b411f748859ed3ed1c97cbc21e6ae980a35c Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 30 Apr 2018 09:45:22 +0800 Subject: [PATCH 711/774] [SPARK-23846][SQL] The samplingRatio option for CSV datasource ## What changes were proposed in this pull request? I propose to support the `samplingRatio` option for schema inferring of CSV datasource similar to the same option of JSON datasource: https://github.com/apache/spark/blob/b14993e1fcb68e1c946a671c6048605ab4afdf58/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala#L49-L50 ## How was this patch tested? Added 2 tests for json and 2 tests for csv datasources. The tests checks that only subset of input dataset is used for schema inferring. Author: Maxim Gekk Author: Maxim Gekk Closes #20959 from MaxGekk/csv-sampling. --- python/pyspark/sql/readwriter.py | 7 ++- python/pyspark/sql/tests.py | 7 +++ .../apache/spark/sql/DataFrameReader.scala | 1 + .../datasources/csv/CSVDataSource.scala | 6 ++- .../datasources/csv/CSVOptions.scala | 3 ++ .../execution/datasources/csv/CSVUtils.scala | 28 +++++++++++ .../execution/datasources/csv/CSVSuite.scala | 47 ++++++++++++++++++- .../datasources/csv/TestCsvData.scala | 36 ++++++++++++++ 8 files changed, 129 insertions(+), 6 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 6811fa6b3b156..9899eb5058b82 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -345,7 +345,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, - columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None): + columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, + samplingRatio=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -428,6 +429,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non the quote character. If None is set, the default value is escape character when escape and quote characters are different, ``\0`` otherwise. + :param samplingRatio: defines fraction of rows used for schema inferring. + If None is set, it uses the default value, ``1.0``. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -446,7 +449,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, - charToEscapeQuoteEscaping=charToEscapeQuoteEscaping) + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio) if isinstance(path, basestring): path = [path] if type(path) == list: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e0cd2aa41a2d0..bc3eaf16b4de7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3033,6 +3033,13 @@ def test_json_sampling_ratio(self): .json(rdd).schema self.assertEquals(schema, StructType([StructField("a", LongType(), True)])) + def test_csv_sampling_ratio(self): + rdd = self.spark.sparkContext.range(0, 100, 1, 1) \ + .map(lambda x: '0.1' if x == 1 else str(x)) + schema = self.spark.read.option('inferSchema', True)\ + .csv(rdd, samplingRatio=0.5).schema + self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)])) + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 6b2ea6c06d3ae..53f44888ebaff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -539,6 +539,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `header` (default `false`): uses the first line as names of columns.
  • *
  • `inferSchema` (default `false`): infers the input schema automatically from data. It * requires one extra pass over the data.
  • + *
  • `samplingRatio` (default is 1.0): defines fraction of rows used for schema inferring.
  • *
  • `ignoreLeadingWhiteSpace` (default `false`): a flag indicating whether or not leading * whitespaces from values being read should be skipped.
  • *
  • `ignoreTrailingWhiteSpace` (default `false`): a flag indicating whether or not trailing diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 4870d75fc5f08..bc1f4ab3bb053 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -161,7 +161,8 @@ object TextInputCSVDataSource extends CSVDataSource { val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val tokenRDD = csv.rdd.mapPartitions { iter => + val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions) + val tokenRDD = sampled.rdd.mapPartitions { iter => val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) val linesWithoutHeader = CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) @@ -235,7 +236,8 @@ object MultiLineCSVDataSource extends CSVDataSource { parsedOptions.headerFlag, new CsvParser(parsedOptions.asParserSettings)) } - CSVInferSchema.infer(tokenRDD, header, parsedOptions) + val sampled = CSVUtils.sample(tokenRDD, parsedOptions) + CSVInferSchema.infer(sampled, header, parsedOptions) case None => // If the first row could not be read, just return the empty schema. StructType(Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index c16790630ce17..2ec0fc605a84b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -150,6 +150,9 @@ class CSVOptions( val isCommentSet = this.comment != '\u0000' + val samplingRatio = + parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() val format = writerSettings.getFormat diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 72b053d2092ca..31464f1bcc68e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.execution.datasources.csv +import org.apache.spark.input.PortableDataStream +import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset +import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -131,4 +134,29 @@ object CSVUtils { schema.foreach(field => verifyType(field.dataType)) } + /** + * Sample CSV dataset as configured by `samplingRatio`. + */ + def sample(csv: Dataset[String], options: CSVOptions): Dataset[String] = { + require(options.samplingRatio > 0, + s"samplingRatio (${options.samplingRatio}) should be greater than 0") + if (options.samplingRatio > 0.99) { + csv + } else { + csv.sample(withReplacement = false, options.samplingRatio, 1) + } + } + + /** + * Sample CSV RDD as configured by `samplingRatio`. + */ + def sample(csv: RDD[Array[String]], options: CSVOptions): RDD[Array[String]] = { + require(options.samplingRatio > 0, + s"samplingRatio (${options.samplingRatio}) should be greater than 0") + if (options.samplingRatio > 0.99) { + csv + } else { + csv.sample(withReplacement = false, options.samplingRatio, 1) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 4398e547d9217..461abdd96d3f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -30,12 +30,11 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.functions.{col, regexp_replace} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ -class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { +class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with TestCsvData { import testImplicits._ private val carsFile = "test-data/cars.csv" @@ -1279,4 +1278,48 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Row("0,2013-111-11 12:13:14") :: Row(null) :: Nil ) } + + test("SPARK-23846: schema inferring touches less data if samplingRatio < 1.0") { + // Set default values for the DataSource parameters to make sure + // that whole test file is mapped to only one partition. This will guarantee + // reliable sampling of the input file. + withSQLConf( + "spark.sql.files.maxPartitionBytes" -> (128 * 1024 * 1024).toString, + "spark.sql.files.openCostInBytes" -> (4 * 1024 * 1024).toString + )(withTempPath { path => + val ds = sampledTestData.coalesce(1) + ds.write.text(path.getAbsolutePath) + + val readback = spark.read + .option("inferSchema", true).option("samplingRatio", 0.1) + .csv(path.getCanonicalPath) + assert(readback.schema == new StructType().add("_c0", IntegerType)) + }) + } + + test("SPARK-23846: usage of samplingRatio while parsing a dataset of strings") { + val ds = sampledTestData.coalesce(1) + val readback = spark.read + .option("inferSchema", true).option("samplingRatio", 0.1) + .csv(ds) + + assert(readback.schema == new StructType().add("_c0", IntegerType)) + } + + test("SPARK-23846: samplingRatio is out of the range (0, 1.0]") { + val ds = spark.range(0, 100, 1, 1).map(_.toString) + + val errorMsg0 = intercept[IllegalArgumentException] { + spark.read.option("inferSchema", true).option("samplingRatio", -1).csv(ds) + }.getMessage + assert(errorMsg0.contains("samplingRatio (-1.0) should be greater than 0")) + + val errorMsg1 = intercept[IllegalArgumentException] { + spark.read.option("inferSchema", true).option("samplingRatio", 0).csv(ds) + }.getMessage + assert(errorMsg1.contains("samplingRatio (0.0) should be greater than 0")) + + val sampled = spark.read.option("inferSchema", true).option("samplingRatio", 1.0).csv(ds) + assert(sampled.count() == ds.count()) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala new file mode 100644 index 0000000000000..3e20cc47dca2c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} + +private[csv] trait TestCsvData { + protected def spark: SparkSession + + def sampledTestData: Dataset[String] = { + spark.range(0, 100, 1).map { index => + val predefinedSample = Set[Long](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, + 57, 62, 68, 72) + if (predefinedSample.contains(index)) { + index.toString + } else { + (index.toDouble + 0.1).toString + } + }(Encoders.STRING) + } +} From b42ad165bb93c96cc5be9ed05b5026f9baafdfa2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 30 Apr 2018 09:13:32 -0700 Subject: [PATCH 712/774] [SPARK-24072][SQL] clearly define pushed filters ## What changes were proposed in this pull request? filters like parquet row group filter, which is actually pushed to the data source but still to be evaluated by Spark, should also count as `pushedFilters`. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #21143 from cloud-fan/step1. --- .../SupportsPushDownCatalystFilters.java | 11 +++- .../v2/reader/SupportsPushDownFilters.java | 10 ++- .../datasources/v2/DataSourceV2Relation.scala | 63 +++++++++++-------- .../datasources/v2/DataSourceV2ScanExec.scala | 1 + .../datasources/v2/DataSourceV2Strategy.scala | 4 +- .../v2/DataSourceV2StringFormat.scala | 19 ++---- .../v2/PushDownOperatorsToDataSource.scala | 6 +- .../continuous/ContinuousSuite.scala | 2 +- 8 files changed, 68 insertions(+), 48 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index 290d614805ac7..4543c143a9aca 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -39,7 +39,16 @@ public interface SupportsPushDownCatalystFilters extends DataSourceReader { Expression[] pushCatalystFilters(Expression[] filters); /** - * Returns the catalyst filters that are pushed in {@link #pushCatalystFilters(Expression[])}. + * Returns the catalyst filters that are pushed to the data source via + * {@link #pushCatalystFilters(Expression[])}. + * + * There are 3 kinds of filters: + * 1. pushable filters which don't need to be evaluated again after scanning. + * 2. pushable filters which still need to be evaluated after scanning, e.g. parquet + * row group filter. + * 3. non-pushable filters. + * Both case 1 and 2 should be considered as pushed filters and should be returned by this method. + * * It's possible that there is no filters in the query and * {@link #pushCatalystFilters(Expression[])} is never called, empty array should be returned for * this case. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index 1cff024232a44..b6a90a3d0b681 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -37,7 +37,15 @@ public interface SupportsPushDownFilters extends DataSourceReader { Filter[] pushFilters(Filter[] filters); /** - * Returns the filters that are pushed in {@link #pushFilters(Filter[])}. + * Returns the filters that are pushed to the data source via {@link #pushFilters(Filter[])}. + * + * There are 3 kinds of filters: + * 1. pushable filters which don't need to be evaluated again after scanning. + * 2. pushable filters which still need to be evaluated after scanning, e.g. parquet + * row group filter. + * 3. non-pushable filters. + * Both case 1 and 2 should be considered as pushed filters and should be returned by this method. + * * It's possible that there is no filters in the query and {@link #pushFilters(Filter[])} * is never called, empty array should be returned for this case. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 2b282ffae2390..90fb5a14c9fc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -54,9 +54,12 @@ case class DataSourceV2Relation( private lazy val v2Options: DataSourceOptions = makeV2Options(options) + // postScanFilters: filters that need to be evaluated after the scan. + // pushedFilters: filters that will be pushed down and evaluated in the underlying data sources. + // Note: postScanFilters and pushedFilters can overlap, e.g. the parquet row group filter. lazy val ( reader: DataSourceReader, - unsupportedFilters: Seq[Expression], + postScanFilters: Seq[Expression], pushedFilters: Seq[Expression]) = { val newReader = userSpecifiedSchema match { case Some(s) => @@ -67,14 +70,16 @@ case class DataSourceV2Relation( DataSourceV2Relation.pushRequiredColumns(newReader, projection.toStructType) - val (remainingFilters, pushedFilters) = filters match { + val (postScanFilters, pushedFilters) = filters match { case Some(filterSeq) => DataSourceV2Relation.pushFilters(newReader, filterSeq) case _ => (Nil, Nil) } + logInfo(s"Post-Scan Filters: ${postScanFilters.mkString(",")}") + logInfo(s"Pushed Filters: ${pushedFilters.mkString(", ")}") - (newReader, remainingFilters, pushedFilters) + (newReader, postScanFilters, pushedFilters) } override def doCanonicalize(): LogicalPlan = { @@ -121,6 +126,8 @@ case class StreamingDataSourceV2Relation( override def simpleString: String = "Streaming RelationV2 " + metadataString + override def pushedFilters: Seq[Expression] = Nil + override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance())) // TODO: unify the equal/hashCode implementation for all data source v2 query plans. @@ -217,31 +224,35 @@ object DataSourceV2Relation { reader: DataSourceReader, filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { reader match { - case catalystFilterSupport: SupportsPushDownCatalystFilters => - ( - catalystFilterSupport.pushCatalystFilters(filters.toArray), - catalystFilterSupport.pushedCatalystFilters() - ) - - case filterSupport: SupportsPushDownFilters => - // A map from original Catalyst expressions to corresponding translated data source - // filters. If a predicate is not in this map, it means it cannot be pushed down. - val translatedMap: Map[Expression, Filter] = filters.flatMap { p => - DataSourceStrategy.translateFilter(p).map(f => p -> f) - }.toMap - - // Catalyst predicate expressions that cannot be converted to data source filters. - val nonConvertiblePredicates = filters.filterNot(translatedMap.contains) - - // Data source filters that cannot be pushed down. An unhandled filter means - // the data source cannot guarantee the rows returned can pass the filter. - // As a result we must return it so Spark can plan an extra filter operator. - val unhandledFilters = filterSupport.pushFilters(translatedMap.values.toArray).toSet - val (unhandledPredicates, pushedPredicates) = translatedMap.partition { case (_, f) => - unhandledFilters.contains(f) + case r: SupportsPushDownCatalystFilters => + val postScanFilters = r.pushCatalystFilters(filters.toArray) + val pushedFilters = r.pushedCatalystFilters() + (postScanFilters, pushedFilters) + + case r: SupportsPushDownFilters => + // A map from translated data source filters to original catalyst filter expressions. + val translatedFilterToExpr = scala.collection.mutable.HashMap.empty[Filter, Expression] + // Catalyst filter expression that can't be translated to data source filters. + val untranslatableExprs = scala.collection.mutable.ArrayBuffer.empty[Expression] + + for (filterExpr <- filters) { + val translated = DataSourceStrategy.translateFilter(filterExpr) + if (translated.isDefined) { + translatedFilterToExpr(translated.get) = filterExpr + } else { + untranslatableExprs += filterExpr + } } - (nonConvertiblePredicates ++ unhandledPredicates.keys, pushedPredicates.keys.toSeq) + // Data source filters that need to be evaluated again after scanning. which means + // the data source cannot guarantee the rows returned can pass these filters. + // As a result we must return it so Spark can plan an extra filter operator. + val postScanFilters = + r.pushFilters(translatedFilterToExpr.keys.toArray).map(translatedFilterToExpr) + // The filters which are marked as pushed to this data source + val pushedFilters = r.pushedFilters().map(translatedFilterToExpr) + + (untranslatableExprs ++ postScanFilters, pushedFilters) case _ => (filters, Nil) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 3a5e7bf89e142..41bdda47c8c3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -41,6 +41,7 @@ case class DataSourceV2ScanExec( output: Seq[AttributeReference], @transient source: DataSourceV2, @transient options: Map[String, String], + @transient pushedFilters: Seq[Expression], @transient reader: DataSourceReader) extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index c2a31442d2be5..1b7c639f10f98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -25,10 +25,10 @@ import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDat object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case r: DataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.options, r.reader) :: Nil + DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil case r: StreamingDataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.options, r.reader) :: Nil + DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala index aed55a429bfd7..693e67dcd108e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -19,11 +19,9 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.commons.lang3.StringUtils -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.DataSourceV2 -import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.util.Utils /** @@ -49,16 +47,9 @@ trait DataSourceV2StringFormat { def options: Map[String, String] /** - * The created data source reader. Here we use it to get the filters that has been pushed down - * so far, itself doesn't take part in the equals/hashCode. + * The filters which have been pushed to the data source. */ - def reader: DataSourceReader - - private lazy val filters = reader match { - case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet - case s: SupportsPushDownFilters => s.pushedFilters().toSet - case _ => Set.empty - } + def pushedFilters: Seq[Expression] private def sourceName: String = source match { case registered: DataSourceRegister => registered.shortName() @@ -68,8 +59,8 @@ trait DataSourceV2StringFormat { def metadataString: String = { val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] - if (filters.nonEmpty) { - entries += "Filters" -> filters.mkString("[", ", ", "]") + if (pushedFilters.nonEmpty) { + entries += "Filters" -> pushedFilters.mkString("[", ", ", "]") } // TODO: we should only display some standard options like path, table, etc. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index f23d228567241..9293d4f831bff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -57,9 +57,9 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { projection = projection.asInstanceOf[Seq[AttributeReference]], filters = Some(filters)) - // Add a Filter for any filters that could not be pushed - val unpushedFilter = newRelation.unsupportedFilters.reduceLeftOption(And) - val filtered = unpushedFilter.map(Filter(_, newRelation)).getOrElse(newRelation) + // Add a Filter for any filters that need to be evaluated after scan. + val postScanFilterCond = newRelation.postScanFilters.reduceLeftOption(And) + val filtered = postScanFilterCond.map(Filter(_, newRelation)).getOrElse(newRelation) // Add a Project to ensure the output matches the required projection if (newRelation.output != projectAttrs) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 5f222e7885994..cd1704ac2fdad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -41,7 +41,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, _, _, r: RateStreamContinuousReader) => r + case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 From 007ae6878f4b4defe1f08114212fa7289fc9ee4a Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Mon, 30 Apr 2018 13:40:03 -0700 Subject: [PATCH 713/774] [SPARK-24003][CORE] Add support to provide spark.executor.extraJavaOptions in terms of App Id and/or Executor Id's ## What changes were proposed in this pull request? Added support to specify the 'spark.executor.extraJavaOptions' value in terms of the `{{APP_ID}}` and/or `{{EXECUTOR_ID}}`, `{{APP_ID}}` will be replaced by Application Id and `{{EXECUTOR_ID}}` will be replaced by Executor Id while starting the executor. ## How was this patch tested? I have verified this by checking the executor process command and gc logs. I verified the same in different deployment modes(Standalone, YARN, Mesos) client and cluster modes. Author: Devaraj K Closes #21088 from devaraj-kavali/SPARK-24003. --- .../spark/deploy/worker/ExecutorRunner.scala | 8 ++++++-- .../main/scala/org/apache/spark/util/Utils.scala | 15 +++++++++++++++ docs/configuration.md | 5 +++++ .../k8s/features/BasicExecutorFeatureStep.scala | 4 +++- .../MesosCoarseGrainedSchedulerBackend.scala | 4 +++- .../mesos/MesosFineGrainedSchedulerBackend.scala | 4 +++- .../org/apache/spark/deploy/yarn/Client.scala | 8 ++++++-- .../spark/deploy/yarn/ExecutorRunnable.scala | 3 ++- 8 files changed, 43 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index d4d8521cc8204..dc6a3076a5113 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import com.google.common.io.Files import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} +import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointRef @@ -142,7 +142,11 @@ private[deploy] class ExecutorRunner( private def fetchAndRunExecutor() { try { // Launch the process - val builder = CommandUtils.buildProcessBuilder(appDesc.command, new SecurityManager(conf), + val subsOpts = appDesc.command.javaOpts.map { + Utils.substituteAppNExecIds(_, appId, execId.toString) + } + val subsCommand = appDesc.command.copy(javaOpts = subsOpts) + val builder = CommandUtils.buildProcessBuilder(subsCommand, new SecurityManager(conf), memory, sparkHome.getAbsolutePath, substituteVariables) val command = builder.command() val formattedCommand = command.asScala.mkString("\"", "\" \"", "\"") diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index d2be93226e2a2..dcad1b914038f 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2689,6 +2689,21 @@ private[spark] object Utils extends Logging { s"k8s://$resolvedURL" } + + /** + * Replaces all the {{EXECUTOR_ID}} occurrences with the Executor Id + * and {{APP_ID}} occurrences with the App Id. + */ + def substituteAppNExecIds(opt: String, appId: String, execId: String): String = { + opt.replace("{{APP_ID}}", appId).replace("{{EXECUTOR_ID}}", execId) + } + + /** + * Replaces all the {{APP_ID}} occurrences with the App Id. + */ + def substituteAppId(opt: String, appId: String): String = { + opt.replace("{{APP_ID}}", appId) + } } private[util] object CallerContext extends Logging { diff --git a/docs/configuration.md b/docs/configuration.md index fb02d7ea1d4ea..8a1aacef85760 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -328,6 +328,11 @@ Apart from these, the following properties are also available, and may be useful Note that it is illegal to set Spark properties or maximum heap size (-Xmx) settings with this option. Spark properties should be set using a SparkConf object or the spark-defaults.conf file used with the spark-submit script. Maximum heap size settings can be set with spark.executor.memory. + + The following symbols, if present will be interpolated: {{APP_ID}} will be replaced by + application ID and {{EXECUTOR_ID}} will be replaced by executor ID. For example, to enable + verbose gc logging to a file named for the executor ID of the app in /tmp, pass a 'value' of: + -verbose:gc -Xloggc:/tmp/{{APP_ID}}-{{EXECUTOR_ID}}.gc
  • diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index d22097587aafe..529069d3b8a0c 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -89,7 +89,9 @@ private[spark] class BasicExecutorFeatureStep( val executorExtraJavaOptionsEnv = kubernetesConf .get(EXECUTOR_JAVA_OPTIONS) .map { opts => - val delimitedOpts = Utils.splitCommandString(opts) + val subsOpts = Utils.substituteAppNExecIds(opts, kubernetesConf.appId, + kubernetesConf.roleSpecificConf.executorId) + val delimitedOpts = Utils.splitCommandString(subsOpts) delimitedOpts.zipWithIndex.map { case (opt, index) => new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 53f5f61cca486..9b75e4c98344a 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -227,7 +227,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( environment.addVariables( Environment.Variable.newBuilder().setName("SPARK_EXECUTOR_CLASSPATH").setValue(cp).build()) } - val extraJavaOpts = conf.get("spark.executor.extraJavaOptions", "") + val extraJavaOpts = conf.getOption("spark.executor.extraJavaOptions").map { + Utils.substituteAppNExecIds(_, appId, taskId) + }.getOrElse("") // Set the environment variable through a command prefix // to append to the existing value of the variable diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index d6d939d246109..71a70ff048ccc 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -111,7 +111,9 @@ private[spark] class MesosFineGrainedSchedulerBackend( environment.addVariables( Environment.Variable.newBuilder().setName("SPARK_EXECUTOR_CLASSPATH").setValue(cp).build()) } - val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").getOrElse("") + val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").map { + Utils.substituteAppNExecIds(_, appId, execId) + }.getOrElse("") val prefixEnv = sc.conf.getOption("spark.executor.extraLibraryPath").map { p => Utils.libraryPathEnvPrefix(Seq(p)) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 5763c3dbc5a8a..bafb129032b49 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -892,7 +892,9 @@ private[spark] class Client( // Include driver-specific java options if we are launching a driver if (isClusterMode) { sparkConf.get(DRIVER_JAVA_OPTIONS).foreach { opts => - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + javaOpts ++= Utils.splitCommandString(opts) + .map(Utils.substituteAppId(_, appId.toString)) + .map(YarnSparkHadoopUtil.escapeForShell) } val libraryPaths = Seq(sparkConf.get(DRIVER_LIBRARY_PATH), sys.props.get("spark.driver.libraryPath")).flatten @@ -914,7 +916,9 @@ private[spark] class Client( s"(was '$opts'). Use spark.yarn.am.memory instead." throw new SparkException(msg) } - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + javaOpts ++= Utils.splitCommandString(opts) + .map(Utils.substituteAppId(_, appId.toString)) + .map(YarnSparkHadoopUtil.escapeForShell) } sparkConf.get(AM_LIBRARY_PATH).foreach { paths => prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths)))) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index ab08698035c98..a2a18cdff65af 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -141,7 +141,8 @@ private[yarn] class ExecutorRunnable( // Set extra Java options for the executor, if defined sparkConf.get(EXECUTOR_JAVA_OPTIONS).foreach { opts => - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + val subsOpt = Utils.substituteAppNExecIds(opts, appId, executorId) + javaOpts ++= Utils.splitCommandString(subsOpt).map(YarnSparkHadoopUtil.escapeForShell) } sparkConf.get(EXECUTOR_LIBRARY_PATH).foreach { p => prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p)))) From b857fb549f3bf4e6f289ba11f3903db0a3696dec Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 1 May 2018 09:06:23 +0800 Subject: [PATCH 714/774] [SPARK-23853][PYSPARK][TEST] Run Hive-related PySpark tests only for `-Phive` ## What changes were proposed in this pull request? When `PyArrow` or `Pandas` are not available, the corresponding PySpark tests are skipped automatically. Currently, PySpark tests fail when we are not using `-Phive`. This PR aims to skip Hive related PySpark tests when `-Phive` is not given. **BEFORE** ```bash $ build/mvn -DskipTests clean package $ python/run-tests.py --python-executables python2.7 --modules pyspark-sql File "/Users/dongjoon/spark/python/pyspark/sql/readwriter.py", line 295, in pyspark.sql.readwriter.DataFrameReader.table ... IllegalArgumentException: u"Error while instantiating 'org.apache.spark.sql.hive.HiveExternalCatalog':" ********************************************************************** 1 of 3 in pyspark.sql.readwriter.DataFrameReader.table ***Test Failed*** 1 failures. ``` **AFTER** ```bash $ build/mvn -DskipTests clean package $ python/run-tests.py --python-executables python2.7 --modules pyspark-sql ... Tests passed in 138 seconds Skipped tests in pyspark.sql.tests with python2.7: ... test_hivecontext (pyspark.sql.tests.HiveSparkSubmitTests) ... skipped 'Hive is not available.' ``` ## How was this patch tested? This is a test-only change. First, this should pass the Jenkins. Then, manually do the following. ```bash build/mvn -DskipTests clean package python/run-tests.py --python-executables python2.7 --modules pyspark-sql ``` Author: Dongjoon Hyun Closes #21141 from dongjoon-hyun/SPARK-23853. --- python/pyspark/sql/readwriter.py | 2 +- python/pyspark/sql/tests.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 9899eb5058b82..448a4732001b5 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -979,7 +979,7 @@ def _test(): globs = pyspark.sql.readwriter.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') try: - spark = SparkSession.builder.enableHiveSupport().getOrCreate() + spark = SparkSession.builder.getOrCreate() except py4j.protocol.Py4JError: spark = SparkSession(sc) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index bc3eaf16b4de7..cc6acfdb07d99 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3043,6 +3043,26 @@ def test_csv_sampling_ratio(self): class HiveSparkSubmitTests(SparkSubmitTests): + @classmethod + def setUpClass(cls): + # get a SparkContext to check for availability of Hive + sc = SparkContext('local[4]', cls.__name__) + cls.hive_available = True + try: + sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + cls.hive_available = False + except TypeError: + cls.hive_available = False + finally: + # we don't need this SparkContext for the test + sc.stop() + + def setUp(self): + super(HiveSparkSubmitTests, self).setUp() + if not self.hive_available: + self.skipTest("Hive is not available.") + def test_hivecontext(self): # This test checks that HiveContext is using Hive metastore (SPARK-16224). # It sets a metastore url and checks if there is a derby dir created by From 7bbec0dced35aeed79c1a24b6f7a1e0a3508b0fb Mon Sep 17 00:00:00 2001 From: wangyanlin01 Date: Tue, 1 May 2018 16:22:52 +0800 Subject: [PATCH 715/774] [SPARK-24061][SS] Add TypedFilter support for continuous processing ## What changes were proposed in this pull request? Add TypedFilter support for continuous processing application. ## How was this patch tested? unit tests Author: wangyanlin01 Closes #21136 from yanlin-Lynn/SPARK-24061. --- .../UnsupportedOperationChecker.scala | 3 ++- .../analysis/UnsupportedOperationsSuite.scala | 23 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index ff9d6d7a7dded..d3d6c636c4ba8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -345,7 +345,8 @@ object UnsupportedOperationChecker { plan.foreachUp { implicit subPlan => subPlan match { case (_: Project | _: Filter | _: MapElements | _: MapPartitions | - _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias) => + _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias | + _: TypedFilter) => case node if node.nodeName == "StreamingRelationV2" => case node => throwError(s"Continuous processing does not support ${node.nodeName} operations.") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 60d1351fda264..cb487c8893541 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -621,6 +621,13 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Append, expectedMsgs = Seq("monotonically_increasing_id")) + assertSupportedForContinuousProcessing( + "TypedFilter", TypedFilter( + null, + null, + null, + null, + new TestStreamingRelationV2(attribute)), OutputMode.Append()) /* ======================================================================================= @@ -771,6 +778,16 @@ class UnsupportedOperationsSuite extends SparkFunSuite { } } + /** Assert that the logical plan is supported for continuous procsssing mode */ + def assertSupportedForContinuousProcessing( + name: String, + plan: LogicalPlan, + outputMode: OutputMode): Unit = { + test(s"continuous processing - $name: supported") { + UnsupportedOperationChecker.checkForContinuous(plan, outputMode) + } + } + /** * Assert that the logical plan is not supported inside a streaming plan. * @@ -840,4 +857,10 @@ class UnsupportedOperationsSuite extends SparkFunSuite { def this(attribute: Attribute) = this(Seq(attribute)) override def isStreaming: Boolean = true } + + case class TestStreamingRelationV2(output: Seq[Attribute]) extends LeafNode { + def this(attribute: Attribute) = this(Seq(attribute)) + override def isStreaming: Boolean = true + override def nodeName: String = "StreamingRelationV2" + } } From 6782359a04356e4cde32940861bf2410ef37f445 Mon Sep 17 00:00:00 2001 From: Bounkong Khamphousone Date: Tue, 1 May 2018 08:28:21 -0700 Subject: [PATCH 716/774] [SPARK-23941][MESOS] Mesos task failed on specific spark app name MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Shell escaped the name passed to spark-submit and change how conf attributes are shell escaped. ## How was this patch tested? This test has been tested manually with Hive-on-spark with mesos or with the use case described in the issue with the sparkPi application with a custom name which contains illegal shell characters. With this PR, hive-on-spark on mesos works like a charm with hive 3.0.0-SNAPSHOT. I state that this contribution is my original work and that I license the work to the project under the project’s open source license Author: Bounkong Khamphousone Closes #21014 from tiboun/fix/SPARK-23941. --- .../spark/scheduler/cluster/mesos/MesosClusterScheduler.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index d224a7325820a..b36f46456f9a5 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -530,9 +530,9 @@ private[spark] class MesosClusterScheduler( .filter { case (key, _) => !replicatedOptionsBlacklist.contains(key) } .toMap (defaultConf ++ driverConf).foreach { case (key, value) => - options ++= Seq("--conf", s""""$key=${shellEscape(value)}"""".stripMargin) } + options ++= Seq("--conf", s"${key}=${value}") } - options + options.map(shellEscape) } /** From e15850be6e0210614a734a307f5b83bdf44e2456 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 2 May 2018 10:55:01 +0800 Subject: [PATCH 717/774] [SPARK-24131][PYSPARK] Add majorMinorVersion API to PySpark for determining Spark versions ## What changes were proposed in this pull request? We need to determine Spark major and minor versions in PySpark. We can add a `majorMinorVersion` API to PySpark which is similar to the Scala API in `VersionUtils.majorMinorVersion`. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh Closes #21203 from viirya/SPARK-24131. --- python/pyspark/util.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 49afc13640332..04df835bf6717 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -16,6 +16,7 @@ # limitations under the License. # +import re import sys import inspect from py4j.protocol import Py4JJavaError @@ -61,6 +62,26 @@ def _get_argspec(f): return argspec +def majorMinorVersion(version): + """ + Get major and minor version numbers for given Spark version string. + + >>> version = "2.4.0" + >>> majorMinorVersion(version) + (2, 4) + + >>> version = "abc" + >>> majorMinorVersion(version) is None + True + + """ + m = re.search('^(\d+)\.(\d+)(\..*)?$', version) + if m is None: + return None + else: + return (int(m.group(1)), int(m.group(2))) + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() From 9215ee7a16b57c56ae927d65e024cf7afe542cbb Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 2 May 2018 10:41:34 +0200 Subject: [PATCH 718/774] [SPARK-23976][CORE] Detect length overflow in UTF8String.concat()/ByteArray.concat() ## What changes were proposed in this pull request? This PR detects length overflow if total elements in inputs are not acceptable. For example, when the three inputs has `0x7FFF_FF00`, `0x7FFF_FF00`, and `0xE00`, we should detect length overflow since we cannot allocate such a large structure on `byte[]`. On the other hand, the current algorithm can allocate the result structure with `0x1000`-byte length due to integer sum overflow. ## How was this patch tested? Existing UTs. If we would create UTs, we need large heap (6-8GB). It may make test environment unstable. If it is necessary to create UTs, I will create them. Author: Kazuaki Ishizaki Closes #21064 from kiszk/SPARK-23976. --- .../org/apache/spark/unsafe/types/ByteArray.java | 12 +++++++----- .../org/apache/spark/unsafe/types/UTF8String.java | 8 ++++---- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index c03caf0076f61..ecd7c19f2c634 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -17,10 +17,12 @@ package org.apache.spark.unsafe.types; -import org.apache.spark.unsafe.Platform; - import java.util.Arrays; +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.Platform; + public final class ByteArray { public static final byte[] EMPTY_BYTE = new byte[0]; @@ -77,17 +79,17 @@ public static byte[] subStringSQL(byte[] bytes, int pos, int len) { public static byte[] concat(byte[]... inputs) { // Compute the total length of the result - int totalLength = 0; + long totalLength = 0; for (int i = 0; i < inputs.length; i++) { if (inputs[i] != null) { - totalLength += inputs[i].length; + totalLength += (long)inputs[i].length; } else { return null; } } // Allocate a new byte array, and copy the inputs one by one into it - final byte[] result = new byte[totalLength]; + final byte[] result = new byte[Ints.checkedCast(totalLength)]; int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].length; diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index e9b3d9b045af5..e91fc4391425c 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -29,8 +29,8 @@ import com.esotericsoftware.kryo.KryoSerializable; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; - import com.google.common.primitives.Ints; + import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; @@ -877,17 +877,17 @@ public UTF8String lpad(int len, UTF8String pad) { */ public static UTF8String concat(UTF8String... inputs) { // Compute the total length of the result. - int totalLength = 0; + long totalLength = 0; for (int i = 0; i < inputs.length; i++) { if (inputs[i] != null) { - totalLength += inputs[i].numBytes; + totalLength += (long)inputs[i].numBytes; } else { return null; } } // Allocate a new byte array, and copy the inputs one by one into it. - final byte[] result = new byte[totalLength]; + final byte[] result = new byte[Ints.checkedCast(totalLength)]; int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].numBytes; From 152eaf6ae698cd0df7f5a5be3f17ee46e0be929d Mon Sep 17 00:00:00 2001 From: WangJinhai02 Date: Wed, 2 May 2018 22:40:14 +0800 Subject: [PATCH 719/774] [SPARK-24107][CORE] ChunkedByteBuffer.writeFully method has not reset the limit value MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit JIRA Issue: https://issues.apache.org/jira/browse/SPARK-24107?jql=text%20~%20%22ChunkedByteBuffer%22 ChunkedByteBuffer.writeFully method has not reset the limit value. When chunks larger than bufferWriteChunkSize, such as 80 * 1024 * 1024 larger than config.BUFFER_WRITE_CHUNK_SIZE(64 * 1024 * 1024),only while once, will lost 16 * 1024 * 1024 byte Author: WangJinhai02 Closes #21175 from manbuyun/bugfix-ChunkedByteBuffer. --- .../spark/util/io/ChunkedByteBuffer.scala | 13 +++++++++---- .../spark/io/ChunkedByteBufferSuite.scala | 17 +++++++++++++++-- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 7367af7888bd8..3ae8dfcc1cb66 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -63,10 +63,15 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { */ def writeFully(channel: WritableByteChannel): Unit = { for (bytes <- getChunks()) { - while (bytes.remaining() > 0) { - val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) - bytes.limit(bytes.position() + ioSize) - channel.write(bytes) + val curChunkLimit = bytes.limit() + while (bytes.hasRemaining) { + try { + val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) + bytes.limit(bytes.position() + ioSize) + channel.write(bytes) + } finally { + bytes.limit(curChunkLimit) + } } } } diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala index 3b798e36b0499..2107559572d78 100644 --- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala @@ -21,11 +21,12 @@ import java.nio.ByteBuffer import com.google.common.io.ByteStreams -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SharedSparkContext, SparkFunSuite} +import org.apache.spark.internal.config import org.apache.spark.network.util.ByteArrayWritableChannel import org.apache.spark.util.io.ChunkedByteBuffer -class ChunkedByteBufferSuite extends SparkFunSuite { +class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext { test("no chunks") { val emptyChunkedByteBuffer = new ChunkedByteBuffer(Array.empty[ByteBuffer]) @@ -56,6 +57,18 @@ class ChunkedByteBufferSuite extends SparkFunSuite { assert(chunkedByteBuffer.getChunks().head.position() === 0) } + test("SPARK-24107: writeFully() write buffer which is larger than bufferWriteChunkSize") { + try { + sc.conf.set(config.BUFFER_WRITE_CHUNK_SIZE, 32L * 1024L * 1024L) + val chunkedByteBuffer = new ChunkedByteBuffer(Array(ByteBuffer.allocate(40 * 1024 * 1024))) + val byteArrayWritableChannel = new ByteArrayWritableChannel(chunkedByteBuffer.size.toInt) + chunkedByteBuffer.writeFully(byteArrayWritableChannel) + assert(byteArrayWritableChannel.length() === chunkedByteBuffer.size) + } finally { + sc.conf.remove(config.BUFFER_WRITE_CHUNK_SIZE) + } + } + test("toArray()") { val empty = ByteBuffer.wrap(Array.empty[Byte]) val bytes = ByteBuffer.wrap(Array.tabulate(8)(_.toByte)) From 8dbf56c055218ff0f3fabae84b63c022f43afbfd Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 2 May 2018 11:58:55 -0700 Subject: [PATCH 720/774] [SPARK-24013][SQL] Remove unneeded compress in ApproximatePercentile ## What changes were proposed in this pull request? `ApproximatePercentile` contains a workaround logic to compress the samples since at the beginning `QuantileSummaries` was ignoring the compression threshold. This problem was fixed in SPARK-17439, but the workaround logic was not removed. So we are compressing the samples many more times than needed: this could lead to critical performance degradation. This can create serious performance issues in queries like: ``` select approx_percentile(id, array(0.1)) from range(10000000) ``` ## How was this patch tested? added UT Author: Marco Gaido Closes #21133 from mgaido91/SPARK-24013. --- .../aggregate/ApproximatePercentile.scala | 33 ++++--------------- .../sql/catalyst/util/QuantileSummaries.scala | 11 ++++--- .../sql/ApproximatePercentileQuerySuite.scala | 13 ++++++++ 3 files changed, 26 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index a45854a3b5146..f1bbbdabb41f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -206,27 +206,15 @@ object ApproximatePercentile { * with limited memory. PercentileDigest is backed by [[QuantileSummaries]]. * * @param summaries underlying probabilistic data structure [[QuantileSummaries]]. - * @param isCompressed An internal flag from class [[QuantileSummaries]] to indicate whether the - * underlying quantileSummaries is compressed. */ - class PercentileDigest( - private var summaries: QuantileSummaries, - private var isCompressed: Boolean) { - - // Trigger compression if the QuantileSummaries's buffer length exceeds - // compressThresHoldBufferLength. The buffer length can be get by - // quantileSummaries.sampled.length - private[this] final val compressThresHoldBufferLength: Int = { - // Max buffer length after compression. - val maxBufferLengthAfterCompression: Int = (1 / summaries.relativeError).toInt * 2 - // A safe upper bound for buffer length before compression - maxBufferLengthAfterCompression * 2 - } + class PercentileDigest(private var summaries: QuantileSummaries) { def this(relativeError: Double) = { - this(new QuantileSummaries(defaultCompressThreshold, relativeError), isCompressed = true) + this(new QuantileSummaries(defaultCompressThreshold, relativeError, compressed = true)) } + private[sql] def isCompressed: Boolean = summaries.compressed + /** Returns compressed object of [[QuantileSummaries]] */ def quantileSummaries: QuantileSummaries = { if (!isCompressed) compress() @@ -236,14 +224,6 @@ object ApproximatePercentile { /** Insert an observation value into the PercentileDigest data structure. */ def add(value: Double): Unit = { summaries = summaries.insert(value) - // The result of QuantileSummaries.insert is un-compressed - isCompressed = false - - // Currently, QuantileSummaries ignores the construction parameter compressThresHold, - // which may cause QuantileSummaries to occupy unbounded memory. We have to hack around here - // to make sure QuantileSummaries doesn't occupy infinite memory. - // TODO: Figure out why QuantileSummaries ignores construction parameter compressThresHold - if (summaries.sampled.length >= compressThresHoldBufferLength) compress() } /** In-place merges in another PercentileDigest. */ @@ -280,7 +260,6 @@ object ApproximatePercentile { private final def compress(): Unit = { summaries = summaries.compress() - isCompressed = true } } @@ -335,8 +314,8 @@ object ApproximatePercentile { sampled(i) = Stats(value, g, delta) i += 1 } - val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count) - new PercentileDigest(summary, isCompressed = true) + val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count, true) + new PercentileDigest(summary) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index b013add9c9778..3190e511e2cb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -40,12 +40,14 @@ import org.apache.spark.sql.catalyst.util.QuantileSummaries.Stats * See the G-K article for more details. * @param count the count of all the elements *inserted in the sampled buffer* * (excluding the head buffer) + * @param compressed whether the statistics have been compressed */ class QuantileSummaries( val compressThreshold: Int, val relativeError: Double, val sampled: Array[Stats] = Array.empty, - val count: Long = 0L) extends Serializable { + val count: Long = 0L, + var compressed: Boolean = false) extends Serializable { // a buffer of latest samples seen so far private val headSampled: ArrayBuffer[Double] = ArrayBuffer.empty @@ -60,6 +62,7 @@ class QuantileSummaries( */ def insert(x: Double): QuantileSummaries = { headSampled += x + compressed = false if (headSampled.size >= defaultHeadSize) { val result = this.withHeadBufferInserted if (result.sampled.length >= compressThreshold) { @@ -135,11 +138,11 @@ class QuantileSummaries( assert(inserted.count == count + headSampled.size) val compressed = compressImmut(inserted.sampled, mergeThreshold = 2 * relativeError * inserted.count) - new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count) + new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count, true) } private def shallowCopy: QuantileSummaries = { - new QuantileSummaries(compressThreshold, relativeError, sampled, count) + new QuantileSummaries(compressThreshold, relativeError, sampled, count, compressed) } /** @@ -163,7 +166,7 @@ class QuantileSummaries( val res = (sampled ++ other.sampled).sortBy(_.value) val comp = compressImmut(res, mergeThreshold = 2 * relativeError * count) new QuantileSummaries( - other.compressThreshold, other.relativeError, comp, other.count + count) + other.compressThreshold, other.relativeError, comp, other.count + count, true) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala index 137c5bea2abb9..d635912cf7205 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -279,4 +280,16 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(query, expected) } } + + test("SPARK-24013: unneeded compress can cause performance issues with sorted input") { + val buffer = new PercentileDigest(1.0D / ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY) + var compressCounts = 0 + (1 to 10000000).foreach { i => + buffer.add(i) + if (buffer.isCompressed) compressCounts += 1 + } + assert(compressCounts > 0) + buffer.quantileSummaries + assert(buffer.isCompressed) + } } From 8bd27025b7cf0b44726b6f4020d294ef14dbbb7e Mon Sep 17 00:00:00 2001 From: Ala Luszczak Date: Wed, 2 May 2018 12:43:19 -0700 Subject: [PATCH 721/774] [SPARK-24133][SQL] Check for integer overflows when resizing WritableColumnVectors ## What changes were proposed in this pull request? `ColumnVector`s store string data in one big byte array. Since the array size is capped at just under Integer.MAX_VALUE, a single `ColumnVector` cannot store more than 2GB of string data. But since the Parquet files commonly contain large blobs stored as strings, and `ColumnVector`s by default carry 4096 values, it's entirely possible to go past that limit. In such cases a negative capacity is requested from `WritableColumnVector.reserve()`. The call succeeds (requested capacity is smaller than already allocated capacity), and consequently `java.lang.ArrayIndexOutOfBoundsException` is thrown when the reader actually attempts to put the data into the array. This change introduces a simple check for integer overflow to `WritableColumnVector.reserve()` which should help catch the error earlier and provide more informative exception. Additionally, the error message in `WritableColumnVector.throwUnsupportedException()` was corrected, as it previously encouraged users to increase rather than reduce the batch size. ## How was this patch tested? New units tests were added. Author: Ala Luszczak Closes #21206 from ala/overflow-reserve. --- .../vectorized/WritableColumnVector.java | 21 ++++++++++++------- .../vectorized/ColumnarBatchSuite.scala | 7 +++++++ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 5275e4a91eac0..b0e119d658cb4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -81,7 +81,9 @@ public void close() { } public void reserve(int requiredCapacity) { - if (requiredCapacity > capacity) { + if (requiredCapacity < 0) { + throwUnsupportedException(requiredCapacity, null); + } else if (requiredCapacity > capacity) { int newCapacity = (int) Math.min(MAX_CAPACITY, requiredCapacity * 2L); if (requiredCapacity <= newCapacity) { try { @@ -96,13 +98,16 @@ public void reserve(int requiredCapacity) { } private void throwUnsupportedException(int requiredCapacity, Throwable cause) { - String message = "Cannot reserve additional contiguous bytes in the vectorized reader " + - "(requested = " + requiredCapacity + " bytes). As a workaround, you can disable the " + - "vectorized reader, or increase the vectorized reader batch size. For parquet file " + - "format, refer to " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + " and " + - SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() + "; for orc file format, refer to " + - SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + " and " + - SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() + "."; + String message = "Cannot reserve additional contiguous bytes in the vectorized reader (" + + (requiredCapacity >= 0 ? "requested " + requiredCapacity + " bytes" : "integer overflow") + + "). As a workaround, you can reduce the vectorized reader batch size, or disable the " + + "vectorized reader. For parquet file format, refer to " + + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() + + " (default " + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().defaultValueString() + + ") and " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + "; for orc file format, " + + "refer to " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() + + " (default " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().defaultValueString() + + ") and " + SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + "."; throw new RuntimeException(message, cause); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 772f687526008..f57f07b498261 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -1333,4 +1333,11 @@ class ColumnarBatchSuite extends SparkFunSuite { column.close() } + + testVector("WritableColumnVector.reserve(): requested capacity is negative", 1024, ByteType) { + column => + val ex = intercept[RuntimeException] { column.reserve(-1) } + assert(ex.getMessage.contains( + "Cannot reserve additional contiguous bytes in the vectorized reader (integer overflow)")) + } } From 504c9cfd21ef45a13d9428fef3b197dcbf6786cd Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 2 May 2018 13:49:15 -0700 Subject: [PATCH 722/774] [SPARK-24123][SQL] Fix precision issues in monthsBetween with more than 8 digits ## What changes were proposed in this pull request? SPARK-23902 introduced the ability to retrieve more than 8 digits in `monthsBetween`. Unfortunately, current implementation can cause precision loss in such a case. This was causing also a flaky UT. This PR mirrors Hive's implementation in order to avoid precision loss also when more than 8 digits are returned. ## How was this patch tested? running 10000000 times the flaky UT Author: Marco Gaido Closes #21196 from mgaido91/SPARK-24123. --- .../spark/sql/catalyst/util/DateTimeUtils.scala | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 4b00a61c6cf91..d2fe15c48c6dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -888,14 +888,19 @@ object DateTimeUtils { val months1 = year1 * 12 + monthInYear1 val months2 = year2 * 12 + monthInYear2 + val monthDiff = (months1 - months2).toDouble + if (dayInMonth1 == dayInMonth2 || ((daysToMonthEnd1 == 0) && (daysToMonthEnd2 == 0))) { - return (months1 - months2).toDouble + return monthDiff } - // milliseconds is enough for 8 digits precision on the right side - val timeInDay1 = millis1 - daysToMillis(date1, timeZone) - val timeInDay2 = millis2 - daysToMillis(date2, timeZone) - val timesBetween = (timeInDay1 - timeInDay2).toDouble / MILLIS_PER_DAY - val diff = (months1 - months2).toDouble + (dayInMonth1 - dayInMonth2 + timesBetween) / 31.0 + // using milliseconds can cause precision loss with more than 8 digits + // we follow Hive's implementation which uses seconds + val secondsInDay1 = (millis1 - daysToMillis(date1, timeZone)) / 1000L + val secondsInDay2 = (millis2 - daysToMillis(date2, timeZone)) / 1000L + val secondsDiff = (dayInMonth1 - dayInMonth2) * SECONDS_PER_DAY + secondsInDay1 - secondsInDay2 + // 2678400D is the number of seconds in 31 days + // every month is considered to be 31 days long in this function + val diff = monthDiff + secondsDiff / 2678400D if (roundOff) { // rounding to 8 digits math.round(diff * 1e8) / 1e8 From 5be8aab14468e55b1049a0c83f02dcec0651162f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 2 May 2018 13:53:10 -0700 Subject: [PATCH 723/774] [SPARK-23923][SQL] Add cardinality function ## What changes were proposed in this pull request? The PR adds the SQL function `cardinality`. The behavior of the function is based on Presto's one. The function returns the length of the array or map stored in the column as `int` while the Presto version returns the value as `BigInt` (`long` in Spark). The discussions regarding the difference of return type are [here](https://github.com/apache/spark/pull/21031#issuecomment-381284638) and [there](https://github.com/apache/spark/pull/21031#discussion_r181622107). ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki Closes #21031 from kiszk/SPARK-23923. --- .../spark/sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6bc7b4e4f7cb3..3ffbc9c8069fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -409,6 +409,7 @@ object FunctionRegistry { expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[Size]("size"), + expression[Size]("cardinality"), expression[SortArray]("sort_array"), expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 470a1c8e331ba..a5163accb1bb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -341,6 +341,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("size(a)"), Seq(Row(2), Row(0), Row(3), Row(-1)) ) + + checkAnswer( + df.selectExpr("cardinality(a)"), + Seq(Row(2L), Row(0L), Row(3L), Row(-1L)) + ) } test("map size function") { From e4c91c089a701117af82f585d14d8afc5245fc64 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 2 May 2018 16:12:21 -0700 Subject: [PATCH 724/774] [SPARK-24111][SQL] Add the TPCDS v2.7 (latest) queries in TPCDSQueryBenchmark ## What changes were proposed in this pull request? This pr added the TPCDS v2.7 (latest) queries in `TPCDSQueryBenchmark`. These query files have been added in `SPARK-23167`. ## How was this patch tested? Manually checked. Author: Takeshi Yamamuro Closes #21177 from maropu/AddTpcdsV2_7InBenchmark. --- .../benchmark/TPCDSQueryBenchmark.scala | 52 +++++++++++++------ 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index 69247d7f4e9aa..abe61a2c2b9c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -58,10 +58,13 @@ object TPCDSQueryBenchmark extends Logging { }.toMap } - def tpcdsAll(dataLocation: String, queries: Seq[String]): Unit = { - val tableSizes = setupTables(dataLocation) + def runTpcdsQueries( + queryLocation: String, + queries: Seq[String], + tableSizes: Map[String, Long], + nameSuffix: String = ""): Unit = { queries.foreach { name => - val queryString = resourceToString(s"tpcds/$name.sql", + val queryString = resourceToString(s"$queryLocation/$name.sql", classLoader = Thread.currentThread().getContextClassLoader) // This is an indirect hack to estimate the size of each query's input by traversing the @@ -78,7 +81,7 @@ object TPCDSQueryBenchmark extends Logging { } val numRows = queryRelations.map(tableSizes.getOrElse(_, 0L)).sum val benchmark = new Benchmark(s"TPCDS Snappy", numRows, 5) - benchmark.addCase(name) { i => + benchmark.addCase(s"$name$nameSuffix") { _ => spark.sql(queryString).collect() } logInfo(s"\n\n===== TPCDS QUERY BENCHMARK OUTPUT FOR $name =====\n") @@ -87,10 +90,20 @@ object TPCDSQueryBenchmark extends Logging { } } + def filterQueries( + origQueries: Seq[String], + args: TPCDSQueryBenchmarkArguments): Seq[String] = { + if (args.queryFilter.nonEmpty) { + origQueries.filter(args.queryFilter.contains) + } else { + origQueries + } + } + def main(args: Array[String]): Unit = { val benchmarkArgs = new TPCDSQueryBenchmarkArguments(args) - // List of all TPC-DS queries + // List of all TPC-DS v1.4 queries val tpcdsQueries = Seq( "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14a", "q14b", "q15", "q16", "q17", "q18", "q19", "q20", @@ -103,20 +116,25 @@ object TPCDSQueryBenchmark extends Logging { "q81", "q82", "q83", "q84", "q85", "q86", "q87", "q88", "q89", "q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99") + // This list only includes TPC-DS v2.7 queries that are different from v1.4 ones + val tpcdsQueriesV2_7 = Seq( + "q5a", "q6", "q10a", "q11", "q12", "q14", "q14a", "q18a", + "q20", "q22", "q22a", "q24", "q27a", "q34", "q35", "q35a", "q36a", "q47", "q49", + "q51a", "q57", "q64", "q67a", "q70a", "q72", "q74", "q75", "q77a", "q78", + "q80a", "q86a", "q98") + // If `--query-filter` defined, filters the queries that this option selects - val queriesToRun = if (benchmarkArgs.queryFilter.nonEmpty) { - val queries = tpcdsQueries.filter { case queryName => - benchmarkArgs.queryFilter.contains(queryName) - } - if (queries.isEmpty) { - throw new RuntimeException( - s"Empty queries to run. Bad query name filter: ${benchmarkArgs.queryFilter}") - } - queries - } else { - tpcdsQueries + val queriesV1_4ToRun = filterQueries(tpcdsQueries, benchmarkArgs) + val queriesV2_7ToRun = filterQueries(tpcdsQueriesV2_7, benchmarkArgs) + + if ((queriesV1_4ToRun ++ queriesV2_7ToRun).isEmpty) { + throw new RuntimeException( + s"Empty queries to run. Bad query name filter: ${benchmarkArgs.queryFilter}") } - tpcdsAll(benchmarkArgs.dataLocation, queries = queriesToRun) + val tableSizes = setupTables(benchmarkArgs.dataLocation) + runTpcdsQueries(queryLocation = "tpcds", queries = queriesV1_4ToRun, tableSizes) + runTpcdsQueries(queryLocation = "tpcds-v2.7.0", queries = queriesV2_7ToRun, tableSizes, + nameSuffix = "-v2.7") } } From bf4352ca6c96dfab16b286c54720685e32b216f1 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 3 May 2018 09:28:14 +0800 Subject: [PATCH 725/774] [SPARK-24110][THRIFT-SERVER] Avoid UGI.loginUserFromKeytab in STS ## What changes were proposed in this pull request? Spark ThriftServer will call UGI.loginUserFromKeytab twice in initialization. This is unnecessary and will cause various potential problems, like Hadoop IPC failure after 7 days, or RM failover issue and so on. So here we need to remove all the unnecessary login logics and make sure UGI in the context never be created again. Note this is actually a HS2 issue, If later on we upgrade supported Hive version, the issue may already be fixed in Hive side. ## How was this patch tested? Local verification in secure cluster. Author: jerryshao Closes #21178 from jerryshao/SPARK-24110. --- .../hive/service/auth/HiveAuthFactory.java | 62 +++++++++++++++++-- .../thriftserver/SparkSQLCLIService.scala | 20 +++++- 2 files changed, 75 insertions(+), 7 deletions(-) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java index c5ade65283045..10000f12ab329 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java @@ -18,6 +18,8 @@ package org.apache.hive.service.auth; import java.io.IOException; +import java.lang.reflect.Field; +import java.lang.reflect.Method; import java.net.InetSocketAddress; import java.net.UnknownHostException; import java.util.ArrayList; @@ -26,6 +28,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import javax.net.ssl.SSLServerSocket; import javax.security.auth.login.LoginException; @@ -92,7 +95,30 @@ public String getAuthName() { public static final String HS2_PROXY_USER = "hive.server2.proxy.user"; public static final String HS2_CLIENT_TOKEN = "hiveserver2ClientToken"; - public HiveAuthFactory(HiveConf conf) throws TTransportException { + private static Field keytabFile = null; + private static Method getKeytab = null; + static { + Class clz = UserGroupInformation.class; + try { + keytabFile = clz.getDeclaredField("keytabFile"); + keytabFile.setAccessible(true); + } catch (NoSuchFieldException nfe) { + LOG.debug("Cannot find private field \"keytabFile\" in class: " + + UserGroupInformation.class.getCanonicalName(), nfe); + keytabFile = null; + } + + try { + getKeytab = clz.getDeclaredMethod("getKeytab"); + getKeytab.setAccessible(true); + } catch(NoSuchMethodException nme) { + LOG.debug("Cannot find private method \"getKeytab\" in class:" + + UserGroupInformation.class.getCanonicalName(), nme); + getKeytab = null; + } + } + + public HiveAuthFactory(HiveConf conf) throws TTransportException, IOException { this.conf = conf; transportMode = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_TRANSPORT_MODE); authTypeStr = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_AUTHENTICATION); @@ -107,9 +133,16 @@ public HiveAuthFactory(HiveConf conf) throws TTransportException { authTypeStr = AuthTypes.NONE.getAuthName(); } if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName())) { - saslServer = ShimLoader.getHadoopThriftAuthBridge() - .createServer(conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB), - conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL)); + String principal = conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL); + String keytab = conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB); + if (needUgiLogin(UserGroupInformation.getCurrentUser(), + SecurityUtil.getServerPrincipal(principal, "0.0.0.0"), keytab)) { + saslServer = ShimLoader.getHadoopThriftAuthBridge().createServer(principal, keytab); + } else { + // Using the default constructor to avoid unnecessary UGI login. + saslServer = new HadoopThriftAuthBridge.Server(); + } + // start delegation token manager try { // rawStore is only necessary for DBTokenStore @@ -362,4 +395,25 @@ public static void verifyProxyAccess(String realUser, String proxyUser, String i } } + public static boolean needUgiLogin(UserGroupInformation ugi, String principal, String keytab) { + return null == ugi || !ugi.hasKerberosCredentials() || !ugi.getUserName().equals(principal) || + !Objects.equals(keytab, getKeytabFromUgi()); + } + + private static String getKeytabFromUgi() { + synchronized (UserGroupInformation.class) { + try { + if (keytabFile != null) { + return (String) keytabFile.get(null); + } else if (getKeytab != null) { + return (String) getKeytab.invoke(UserGroupInformation.getCurrentUser()); + } else { + return null; + } + } catch (Exception e) { + LOG.debug("Fail to get keytabFile path via reflection", e); + return null; + } + } + } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index ad1f5eb9ca3a7..1335e16e35882 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -27,7 +27,7 @@ import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.shims.Utils -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.security.{SecurityUtil, UserGroupInformation} import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.hive.service.Service.STATE import org.apache.hive.service.auth.HiveAuthFactory @@ -52,8 +52,22 @@ private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, sqlContext: SQLC if (UserGroupInformation.isSecurityEnabled) { try { - HiveAuthFactory.loginFromKeytab(hiveConf) - sparkServiceUGI = Utils.getUGI() + val principal = hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL) + val keyTabFile = hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB) + if (principal.isEmpty || keyTabFile.isEmpty) { + throw new IOException( + "HiveServer2 Kerberos principal or keytab is not correctly configured") + } + + val originalUgi = UserGroupInformation.getCurrentUser + sparkServiceUGI = if (HiveAuthFactory.needUgiLogin(originalUgi, + SecurityUtil.getServerPrincipal(principal, "0.0.0.0"), keyTabFile)) { + HiveAuthFactory.loginFromKeytab(hiveConf) + Utils.getUGI() + } else { + originalUgi + } + setSuperField(this, "serviceUGI", sparkServiceUGI) } catch { case e @ (_: IOException | _: LoginException) => From c9bfd1c6f8d16890ea1e5bc2bcb654a3afb32591 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 3 May 2018 15:15:05 +0800 Subject: [PATCH 726/774] [SPARK-23489][SQL][TEST] HiveExternalCatalogVersionsSuite should verify the downloaded file ## What changes were proposed in this pull request? Although [SPARK-22654](https://issues.apache.org/jira/browse/SPARK-22654) made `HiveExternalCatalogVersionsSuite` download from Apache mirrors three times, it has been flaky because it didn't verify the downloaded file. Some Apache mirrors terminate the downloading abnormally, the *corrupted* file shows the following errors. ``` gzip: stdin: not in gzip format tar: Child returned status 1 tar: Error is not recoverable: exiting now 22:46:32.700 WARN org.apache.spark.sql.hive.HiveExternalCatalogVersionsSuite: ===== POSSIBLE THREAD LEAK IN SUITE o.a.s.sql.hive.HiveExternalCatalogVersionsSuite, thread names: Keep-Alive-Timer ===== *** RUN ABORTED *** java.io.IOException: Cannot run program "./bin/spark-submit" (in directory "/tmp/test-spark/spark-2.2.0"): error=2, No such file or directory ``` This has been reported weirdly in two ways. For example, the above case is reported as Case 2 `no failures`. - Case 1. [Test Result (1 failure / +1)](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.7/4389/) - Case 2. [Test Result (no failures)](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-maven-hadoop-2.6/4811/) This PR aims to make `HiveExternalCatalogVersionsSuite` more robust by verifying the downloaded `tgz` file by extracting and checking the existence of `bin/spark-submit`. If it turns out that the file is empty or corrupted, `HiveExternalCatalogVersionsSuite` will do retry logic like the download failure. ## How was this patch tested? Pass the Jenkins. Author: Dongjoon Hyun Closes #21210 from dongjoon-hyun/SPARK-23489. --- .../HiveExternalCatalogVersionsSuite.scala | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 6ca58e68d31eb..ea86ab9772bc7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -67,7 +67,21 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { logInfo(s"Downloading Spark $version from $url") try { getFileFromUrl(url, path, filename) - return + val downloaded = new File(sparkTestingDir, filename).getCanonicalPath + val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath + + Seq("mkdir", targetDir).! + val exitCode = Seq("tar", "-xzf", downloaded, "-C", targetDir, "--strip-components=1").! + Seq("rm", downloaded).! + + // For a corrupted file, `tar` returns non-zero values. However, we also need to check + // the extracted file because `tar` returns 0 for empty file. + val sparkSubmit = new File(sparkTestingDir, s"spark-$version/bin/spark-submit") + if (exitCode == 0 && sparkSubmit.exists()) { + return + } else { + Seq("rm", "-rf", targetDir).! + } } catch { case ex: Exception => logWarning(s"Failed to download Spark $version from $url", ex) } @@ -75,20 +89,6 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { fail(s"Unable to download Spark $version") } - - private def downloadSpark(version: String): Unit = { - tryDownloadSpark(version, sparkTestingDir.getCanonicalPath) - - val downloaded = new File(sparkTestingDir, s"spark-$version-bin-hadoop2.7.tgz").getCanonicalPath - val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath - - Seq("mkdir", targetDir).! - - Seq("tar", "-xzf", downloaded, "-C", targetDir, "--strip-components=1").! - - Seq("rm", downloaded).! - } - private def genDataDir(name: String): String = { new File(tmpDataDir, name).getCanonicalPath } @@ -161,7 +161,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { PROCESS_TABLES.testingVersions.zipWithIndex.foreach { case (version, index) => val sparkHome = new File(sparkTestingDir, s"spark-$version") if (!sparkHome.exists()) { - downloadSpark(version) + tryDownloadSpark(version, sparkTestingDir.getCanonicalPath) } val args = Seq( From 417ad92502e714da71552f64d0e1257d2fd5d3d0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 3 May 2018 19:27:01 +0800 Subject: [PATCH 727/774] [SPARK-23715][SQL] the input of to/from_utc_timestamp can not have timezone ## What changes were proposed in this pull request? `from_utc_timestamp` assumes its input is in UTC timezone and shifts it to the specified timezone. When the timestamp contains timezone(e.g. `2018-03-13T06:18:23+00:00`), Spark breaks the semantic and respect the timezone in the string. This is not what user expects and the result is different from Hive/Impala. `to_utc_timestamp` has the same problem. More details please refer to the JIRA ticket. This PR fixes this by returning null if the input timestamp contains timezone. ## How was this patch tested? new tests Author: Wenchen Fan Closes #21169 from cloud-fan/from_utc_timezone. --- docs/sql-programming-guide.md | 13 +- .../sql/catalyst/analysis/TypeCoercion.scala | 30 +++- .../expressions/datetimeExpressions.scala | 42 ++++++ .../sql/catalyst/util/DateTimeUtils.scala | 22 ++- .../apache/spark/sql/internal/SQLConf.scala | 7 + .../catalyst/analysis/TypeCoercionSuite.scala | 12 +- .../resources/sql-tests/inputs/datetime.sql | 33 +++++ .../sql-tests/results/datetime.sql.out | 135 +++++++++++++++++- .../apache/spark/sql/DateFunctionsSuite.scala | 8 ++ 9 files changed, 283 insertions(+), 19 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 836ce990205a9..075b953a0898e 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1805,12 +1805,13 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unable to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. - - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. - - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. - - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. - - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. - - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. - - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. + - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. + - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. + - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. + - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. + - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 25bad28a2a209..b2817b0538a7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -59,7 +59,7 @@ object TypeCoercion { IfCoercion :: StackCoercion :: Division :: - ImplicitTypeCasts :: + new ImplicitTypeCasts(conf) :: DateTimeOperations :: WindowFrameCoercion :: Nil @@ -776,12 +776,33 @@ object TypeCoercion { /** * Casts types according to the expected input types for [[Expression]]s. */ - object ImplicitTypeCasts extends TypeCoercionRule { + class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { + + private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING) + override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e + // Special rules for `from/to_utc_timestamp`. These 2 functions assume the input timestamp + // string is in a specific timezone, so the string itself should not contain timezone. + // TODO: We should move the type coercion logic to expressions instead of a central + // place to put all the rules. + case e: FromUTCTimestamp if e.left.dataType == StringType => + if (rejectTzInString) { + e.copy(left = StringToTimestampWithoutTimezone(e.left)) + } else { + e.copy(left = Cast(e.left, TimestampType)) + } + + case e: ToUTCTimestamp if e.left.dataType == StringType => + if (rejectTzInString) { + e.copy(left = StringToTimestampWithoutTimezone(e.left)) + } else { + e.copy(left = Cast(e.left, TimestampType)) + } + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => findTightestCommonType(left.dataType, right.dataType).map { commonType => if (b.inputType.acceptsType(commonType)) { @@ -798,7 +819,7 @@ object TypeCoercion { case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => // If we cannot do the implicit cast, just use the original input. - implicitCast(in, expected).getOrElse(in) + ImplicitTypeCasts.implicitCast(in, expected).getOrElse(in) } e.withNewChildren(children) @@ -814,6 +835,9 @@ object TypeCoercion { } e.withNewChildren(children) } + } + + object ImplicitTypeCasts { /** * Given an expected data type, try to cast the expression and return the cast expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index d882d06cfd625..76aa61415a11f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1016,6 +1016,48 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S } } +/** + * A special expression used to convert the string input of `to/from_utc_timestamp` to timestamp, + * which requires the timestamp string to not have timezone information, otherwise null is returned. + */ +case class StringToTimestampWithoutTimezone(child: Expression, timeZoneId: Option[String] = None) + extends UnaryExpression with TimeZoneAwareExpression with ExpectsInputTypes { + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + override def dataType: DataType = TimestampType + override def nullable: Boolean = true + override def toString: String = child.toString + override def sql: String = child.sql + + override def nullSafeEval(input: Any): Any = { + DateTimeUtils.stringToTimestamp( + input.asInstanceOf[UTF8String], timeZone, rejectTzInString = true).orNull + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val tz = ctx.addReferenceObj("timeZone", timeZone) + val longOpt = ctx.freshName("longOpt") + val eval = child.genCode(ctx) + val code = s""" + |${eval.code} + |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = true; + |${CodeGenerator.JAVA_LONG} ${ev.value} = ${CodeGenerator.defaultValue(TimestampType)}; + |if (!${eval.isNull}) { + | scala.Option $longOpt = $dtu.stringToTimestamp(${eval.value}, $tz, true); + | if ($longOpt.isDefined()) { + | ${ev.value} = ((Long) $longOpt.get()).longValue(); + | ${ev.isNull} = false; + | } + |} + """.stripMargin + ev.copy(code = code) + } +} + /** * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index d2fe15c48c6dd..e646da0659e85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -296,10 +296,28 @@ object DateTimeUtils { * `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]+[h]h:[m]m` */ def stringToTimestamp(s: UTF8String): Option[SQLTimestamp] = { - stringToTimestamp(s, defaultTimeZone()) + stringToTimestamp(s, defaultTimeZone(), rejectTzInString = false) } def stringToTimestamp(s: UTF8String, timeZone: TimeZone): Option[SQLTimestamp] = { + stringToTimestamp(s, timeZone, rejectTzInString = false) + } + + /** + * Converts a timestamp string to microseconds from the unix epoch, w.r.t. the given timezone. + * Returns None if the input string is not a valid timestamp format. + * + * @param s the input timestamp string. + * @param timeZone the timezone of the timestamp string, will be ignored if the timestamp string + * already contains timezone information and `forceTimezone` is false. + * @param rejectTzInString if true, rejects timezone in the input string, i.e., if the + * timestamp string contains timezone, like `2000-10-10 00:00:00+00:00`, + * return None. + */ + def stringToTimestamp( + s: UTF8String, + timeZone: TimeZone, + rejectTzInString: Boolean): Option[SQLTimestamp] = { if (s == null) { return None } @@ -417,6 +435,8 @@ object DateTimeUtils { return None } + if (tz.isDefined && rejectTzInString) return None + val c = if (tz.isEmpty) { Calendar.getInstance(timeZone) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3729bd5293eca..3942240c442b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1208,6 +1208,13 @@ object SQLConf { .stringConf .createWithDefault("") + val REJECT_TIMEZONE_IN_STRING = buildConf("spark.sql.function.rejectTimezoneInString") + .internal() + .doc("If true, `to_utc_timestamp` and `from_utc_timestamp` return null if the input string " + + "contains a timezone part, e.g. `2000-10-10 00:00:00+00:00`.") + .booleanConf + .createWithDefault(true) + object PartitionOverwriteMode extends Enumeration { val STATIC, DYNAMIC = Value } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 1cc431aaf0a60..0acd3b490447d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -524,11 +524,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for expressions that implement ExpectsInputTypes") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeUnaryExpression(Literal.create(null, NullType)), AnyTypeUnaryExpression(Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeUnaryExpression(Literal.create(null, NullType)), NumericTypeUnaryExpression(Literal.create(null, DoubleType))) } @@ -536,11 +536,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for binary operators") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) } @@ -823,7 +823,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("type coercion for CaseKeyWhen") { - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) @@ -1275,7 +1275,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("SPARK-17117 null type coercion in divide") { - val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts) + val rules = Seq(FunctionArgumentConversion, Division, new ImplicitTypeCasts(conf)) val nullLit = Literal.create(null, NullType) ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index 547c2bef02b24..4950a4b7a4e5a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -27,3 +27,36 @@ select current_date = current_date(), current_timestamp = current_timestamp(), a select a, b from ttf2 order by a, current_date; select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15'); + +select from_utc_timestamp('2015-07-24 00:00:00', 'PST'); + +select from_utc_timestamp('2015-01-24 00:00:00', 'PST'); + +select from_utc_timestamp(null, 'PST'); + +select from_utc_timestamp('2015-07-24 00:00:00', null); + +select from_utc_timestamp(null, null); + +select from_utc_timestamp(cast(0 as timestamp), 'PST'); + +select from_utc_timestamp(cast('2015-01-24' as date), 'PST'); + +select to_utc_timestamp('2015-07-24 00:00:00', 'PST'); + +select to_utc_timestamp('2015-01-24 00:00:00', 'PST'); + +select to_utc_timestamp(null, 'PST'); + +select to_utc_timestamp('2015-07-24 00:00:00', null); + +select to_utc_timestamp(null, null); + +select to_utc_timestamp(cast(0 as timestamp), 'PST'); + +select to_utc_timestamp(cast('2015-01-24' as date), 'PST'); + +-- SPARK-23715: the input of to/from_utc_timestamp can not have timezone +select from_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST'); + +select to_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST'); diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index 4e1cfa6e48c1c..9eede305dbdcc 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 26 -- !query 0 @@ -82,9 +82,138 @@ struct 1 2 2 3 + -- !query 9 select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15') --- !query 3 schema +-- !query 9 schema struct --- !query 3 output +-- !query 9 output 5 3 5 NULL 4 + + +-- !query 10 +select from_utc_timestamp('2015-07-24 00:00:00', 'PST') +-- !query 10 schema +struct +-- !query 10 output +2015-07-23 17:00:00 + + +-- !query 11 +select from_utc_timestamp('2015-01-24 00:00:00', 'PST') +-- !query 11 schema +struct +-- !query 11 output +2015-01-23 16:00:00 + + +-- !query 12 +select from_utc_timestamp(null, 'PST') +-- !query 12 schema +struct +-- !query 12 output +NULL + + +-- !query 13 +select from_utc_timestamp('2015-07-24 00:00:00', null) +-- !query 13 schema +struct +-- !query 13 output +NULL + + +-- !query 14 +select from_utc_timestamp(null, null) +-- !query 14 schema +struct +-- !query 14 output +NULL + + +-- !query 15 +select from_utc_timestamp(cast(0 as timestamp), 'PST') +-- !query 15 schema +struct +-- !query 15 output +1969-12-31 08:00:00 + + +-- !query 16 +select from_utc_timestamp(cast('2015-01-24' as date), 'PST') +-- !query 16 schema +struct +-- !query 16 output +2015-01-23 16:00:00 + + +-- !query 17 +select to_utc_timestamp('2015-07-24 00:00:00', 'PST') +-- !query 17 schema +struct +-- !query 17 output +2015-07-24 07:00:00 + + +-- !query 18 +select to_utc_timestamp('2015-01-24 00:00:00', 'PST') +-- !query 18 schema +struct +-- !query 18 output +2015-01-24 08:00:00 + + +-- !query 19 +select to_utc_timestamp(null, 'PST') +-- !query 19 schema +struct +-- !query 19 output +NULL + + +-- !query 20 +select to_utc_timestamp('2015-07-24 00:00:00', null) +-- !query 20 schema +struct +-- !query 20 output +NULL + + +-- !query 21 +select to_utc_timestamp(null, null) +-- !query 21 schema +struct +-- !query 21 output +NULL + + +-- !query 22 +select to_utc_timestamp(cast(0 as timestamp), 'PST') +-- !query 22 schema +struct +-- !query 22 output +1970-01-01 00:00:00 + + +-- !query 23 +select to_utc_timestamp(cast('2015-01-24' as date), 'PST') +-- !query 23 schema +struct +-- !query 23 output +2015-01-24 08:00:00 + + +-- !query 24 +select from_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST') +-- !query 24 schema +struct +-- !query 24 output +NULL + + +-- !query 25 +select to_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST') +-- !query 25 schema +struct +-- !query 25 output +NULL diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index f712baa7a9134..237412aa692e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -23,6 +23,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.unsafe.types.CalendarInterval @@ -696,4 +697,11 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(Timestamp.valueOf("2015-07-25 07:00:00")))) } + test("SPARK-23715: to/from_utc_timestamp can retain the previous behavior") { + withSQLConf(SQLConf.REJECT_TIMEZONE_IN_STRING.key -> "false") { + checkAnswer( + sql("SELECT from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')"), + Row(Timestamp.valueOf("2000-10-09 18:00:00"))) + } + } } From 991b526992bcf1dc1268578b650916569b12f583 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 3 May 2018 19:56:30 +0800 Subject: [PATCH 728/774] [SPARK-24166][SQL] InMemoryTableScanExec should not access SQLConf at executor side ## What changes were proposed in this pull request? This PR is extracted from https://github.com/apache/spark/pull/21190 , to make it easier to backport. `InMemoryTableScanExec#createAndDecompressColumn` is executed inside `rdd.map`, we can't access `conf.offHeapColumnVectorEnabled` there. ## How was this patch tested? it's tested in #21190 Author: Wenchen Fan Closes #21223 from cloud-fan/minor1. --- .../execution/columnar/InMemoryTableScanExec.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index ea315fb71617c..0b4dd76c7d860 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -78,10 +78,12 @@ case class InMemoryTableScanExec( private lazy val columnarBatchSchema = new StructType(columnIndices.map(i => relationSchema(i))) - private def createAndDecompressColumn(cachedColumnarBatch: CachedBatch): ColumnarBatch = { + private def createAndDecompressColumn( + cachedColumnarBatch: CachedBatch, + offHeapColumnVectorEnabled: Boolean): ColumnarBatch = { val rowCount = cachedColumnarBatch.numRows val taskContext = Option(TaskContext.get()) - val columnVectors = if (!conf.offHeapColumnVectorEnabled || taskContext.isEmpty) { + val columnVectors = if (!offHeapColumnVectorEnabled || taskContext.isEmpty) { OnHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) } else { OffHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) @@ -101,10 +103,13 @@ case class InMemoryTableScanExec( private lazy val inputRDD: RDD[InternalRow] = { val buffers = filteredCachedBatches() + val offHeapColumnVectorEnabled = conf.offHeapColumnVectorEnabled if (supportsBatch) { // HACK ALERT: This is actually an RDD[ColumnarBatch]. // We're taking advantage of Scala's type erasure here to pass these batches along. - buffers.map(createAndDecompressColumn).asInstanceOf[RDD[InternalRow]] + buffers + .map(createAndDecompressColumn(_, offHeapColumnVectorEnabled)) + .asInstanceOf[RDD[InternalRow]] } else { val numOutputRows = longMetric("numOutputRows") From 96a50016bb0fb1cc57823a6706bff2467d671efd Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 3 May 2018 23:36:09 +0800 Subject: [PATCH 729/774] [SPARK-24169][SQL] JsonToStructs should not access SQLConf at executor side ## What changes were proposed in this pull request? This PR is extracted from #21190 , to make it easier to backport. `JsonToStructs` can be serialized to executors and evaluate, we should not call `SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)` in the body. ## How was this patch tested? tested in #21190 Author: Wenchen Fan Closes #21226 from cloud-fan/minor4. --- .../catalyst/analysis/FunctionRegistry.scala | 4 +- .../expressions/jsonExpressions.scala | 16 ++-- .../expressions/JsonExpressionsSuite.scala | 76 +++++++++---------- .../org/apache/spark/sql/functions.scala | 2 +- .../sql-tests/results/json-functions.sql.out | 4 +- 5 files changed, 54 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3ffbc9c8069fd..51bb6b0abe408 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -534,7 +534,9 @@ object FunctionRegistry { // Otherwise, find a constructor method that matches the number of arguments, and use that. val params = Seq.fill(expressions.size)(classOf[Expression]) val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { - val validParametersCount = constructors.map(_.getParameterCount).distinct.sorted + val validParametersCount = constructors + .filter(_.getParameterTypes.forall(_ == classOf[Expression])) + .map(_.getParameterCount).distinct.sorted val expectedNumberOfParameters = if (validParametersCount.length == 1) { validParametersCount.head.toString } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index fdd672c416a03..34161f0f03f4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -514,11 +514,10 @@ case class JsonToStructs( schema: DataType, options: Map[String, String], child: Expression, - timeZoneId: Option[String] = None) + timeZoneId: Option[String], + forceNullableSchema: Boolean) extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { - val forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA) - // The JSON input data might be missing certain fields. We force the nullability // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder // can generate incorrect files if values are missing in columns declared as non-nullable. @@ -532,14 +531,21 @@ case class JsonToStructs( schema = JsonExprUtils.validateSchemaLiteral(schema), options = Map.empty[String, String], child = child, - timeZoneId = None) + timeZoneId = None, + forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) def this(child: Expression, schema: Expression, options: Expression) = this( schema = JsonExprUtils.validateSchemaLiteral(schema), options = JsonExprUtils.convertToMapData(options), child = child, - timeZoneId = None) + timeZoneId = None, + forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) + + // Used in `org.apache.spark.sql.functions` + def this(schema: DataType, options: Map[String, String], child: Expression) = + this(schema, options, child, timeZoneId = None, + forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) override def checkInputDataTypes(): TypeCheckResult = nullableSchema match { case _: StructType | ArrayType(_: StructType, _) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 7812319756eae..00e97637eee7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -392,7 +392,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val jsonData = """{"a": 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true), InternalRow(1) ) } @@ -401,13 +401,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val jsonData = """{"a" 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true), null ) // Other modes should still return `null`. checkEvaluation( - JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId), + JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId, true), null ) } @@ -416,62 +416,62 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val input = """[{"a": 1}, {"a": 2}]""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: InternalRow(2) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=object, schema=array, output=array of single row") { val input = """{"a": 1}""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=empty array, schema=array, output=empty array") { val input = "[ ]" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=empty object, schema=array, output=array of single row with null") { val input = "{ }" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(null) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=array of single object, schema=struct, output=single row") { val input = """[{"a": 1}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(1) - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=array, schema=struct, output=null") { val input = """[{"a": 1}, {"a": 2}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=empty array, schema=struct, output=null") { val input = """[]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=empty object, schema=struct, output=single row with null") { val input = """{ }""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(null) - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json null input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId), + JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId, true), null ) } @@ -479,7 +479,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("SPARK-20549: from_json bad UTF-8") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(badJson), gmtId), + JsonToStructs(schema, Map.empty, Literal(badJson), gmtId, true), null) } @@ -491,14 +491,14 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with c.set(2016, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 123) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId, true), InternalRow(c.getTimeInMillis * 1000L) ) // The result doesn't change because the json string includes timezone string ("Z" here), // which means the string represents the timestamp string in the timezone regardless of // the timeZoneId parameter. checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST")), + JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST"), true), InternalRow(c.getTimeInMillis * 1000L) ) @@ -512,7 +512,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with schema, Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"), Literal(jsonData2), - Option(tz.getID)), + Option(tz.getID), + true), InternalRow(c.getTimeInMillis * 1000L) ) checkEvaluation( @@ -521,7 +522,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", DateTimeUtils.TIMEZONE_OPTION -> tz.getID), Literal(jsonData2), - gmtId), + gmtId, + true), InternalRow(c.getTimeInMillis * 1000L) ) } @@ -530,7 +532,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("SPARK-19543: from_json empty input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId), + JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId, true), null ) } @@ -685,27 +687,23 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("from_json missing fields") { for (forceJsonNullableSchema <- Seq(false, true)) { - withSQLConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA.key -> forceJsonNullableSchema.toString) { - val input = - """{ - | "a": 1, - | "c": "foo" - |} - |""".stripMargin - val jsonSchema = new StructType() - .add("a", LongType, nullable = false) - .add("b", StringType, nullable = false) - .add("c", StringType, nullable = false) - val output = InternalRow(1L, null, UTF8String.fromString("foo")) - checkEvaluation( - JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId), - output - ) - val schema = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId) - .dataType - val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema - assert(schemaToCompare == schema) - } + val input = + """{ + | "a": 1, + | "c": "foo" + |} + |""".stripMargin + val jsonSchema = new StructType() + .add("a", LongType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + val output = InternalRow(1L, null, UTF8String.fromString("foo")) + val expr = JsonToStructs( + jsonSchema, Map.empty, Literal.create(input, StringType), gmtId, forceJsonNullableSchema) + checkEvaluation(expr, output) + val schema = expr.dataType + val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema + assert(schemaToCompare == schema) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 25afaacc38d6f..d2e22fa355514 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3179,7 +3179,7 @@ object functions { * @since 2.2.0 */ def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr { - JsonToStructs(schema, options, e.expr) + new JsonToStructs(schema, options, e.expr) } /** diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 581dddc89d0bb..14a69128ffb41 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -129,7 +129,7 @@ select to_json() struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function to_json. Expected: one of 1, 2 and 3; Found: 0; line 1 pos 7 +Invalid number of arguments for function to_json. Expected: one of 1 and 2; Found: 0; line 1 pos 7 -- !query 13 @@ -225,7 +225,7 @@ select from_json() struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function from_json. Expected: one of 2, 3 and 4; Found: 0; line 1 pos 7 +Invalid number of arguments for function from_json. Expected: one of 2 and 3; Found: 0; line 1 pos 7 -- !query 22 From 94641fe6cc68e5977dd8663b8f232a287a783acb Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 3 May 2018 10:59:18 -0500 Subject: [PATCH 730/774] [SPARK-23433][CORE] Late zombie task completions update all tasksets Fetch failure lead to multiple tasksets which are active for a given stage. While there is only one "active" version of the taskset, the earlier attempts can still have running tasks, which can complete successfully. So a task completion needs to update every taskset so that it knows the partition is completed. That way the final active taskset does not try to submit another task for the same partition, and so that it knows when it is completed and when it should be marked as a "zombie". Added a regression test. Author: Imran Rashid Closes #21131 from squito/SPARK-23433. --- .../spark/scheduler/TaskSchedulerImpl.scala | 14 +++ .../spark/scheduler/TaskSetManager.scala | 20 +++- .../scheduler/TaskSchedulerImplSuite.scala | 104 ++++++++++++++++++ 3 files changed, 137 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 0c11806b3981b..8e97b3da33820 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -689,6 +689,20 @@ private[spark] class TaskSchedulerImpl( } } + /** + * Marks the task has completed in all TaskSetManagers for the given stage. + * + * After stage failure and retry, there may be multiple TaskSetManagers for the stage. + * If an earlier attempt of a stage completes a task, we should ensure that the later attempts + * do not also submit those same tasks. That also means that a task completion from an earlier + * attempt can lead to the entire stage getting marked as successful. + */ + private[scheduler] def markPartitionCompletedInAllTaskSets(stageId: Int, partitionId: Int) = { + taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm => + tsm.markPartitionCompleted(partitionId) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 8a96a7692f614..195fc8025e4b5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -73,6 +73,8 @@ private[spark] class TaskSetManager( val ser = env.closureSerializer.newInstance() val tasks = taskSet.tasks + private[scheduler] val partitionToIndex = tasks.zipWithIndex + .map { case (t, idx) => t.partitionId -> idx }.toMap val numTasks = tasks.length val copiesRunning = new Array[Int](numTasks) @@ -153,7 +155,7 @@ private[spark] class TaskSetManager( private[scheduler] val speculatableTasks = new HashSet[Int] // Task index, start and finish time for each task attempt (indexed by task ID) - private val taskInfos = new HashMap[Long, TaskInfo] + private[scheduler] val taskInfos = new HashMap[Long, TaskInfo] // Use a MedianHeap to record durations of successful tasks so we know when to launch // speculative tasks. This is only used when speculation is enabled, to avoid the overhead @@ -754,6 +756,9 @@ private[spark] class TaskSetManager( logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id + " because task " + index + " has already completed successfully") } + // There may be multiple tasksets for this stage -- we let all of them know that the partition + // was completed. This may result in some of the tasksets getting completed. + sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId) // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not // "deserialize" the value when holding a lock to avoid blocking other threads. So we call @@ -764,6 +769,19 @@ private[spark] class TaskSetManager( maybeFinishTaskSet() } + private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = { + partitionToIndex.get(partitionId).foreach { index => + if (!successful(index)) { + tasksSuccessful += 1 + successful(index) = true + if (tasksSuccessful == numTasks) { + isZombie = true + } + maybeFinishTaskSet() + } + } + } + /** * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the * DAG Scheduler. diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 6003899bb7bef..33f2ea1c94e75 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -917,4 +917,108 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B taskScheduler.initialize(new FakeSchedulerBackend) } } + + test("Completions in zombie tasksets update status of non-zombie taskset") { + val taskScheduler = setupSchedulerWithMockTaskSetBlacklist() + val valueSer = SparkEnv.get.serializer.newInstance() + + def completeTaskSuccessfully(tsm: TaskSetManager, partition: Int): Unit = { + val indexInTsm = tsm.partitionToIndex(partition) + val matchingTaskInfo = tsm.taskAttempts.flatten.filter(_.index == indexInTsm).head + val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq()) + tsm.handleSuccessfulTask(matchingTaskInfo.taskId, result) + } + + // Submit a task set, have it fail with a fetch failed, and then re-submit the task attempt, + // two times, so we have three active task sets for one stage. (For this to really happen, + // you'd need the previous stage to also get restarted, and then succeed, in between each + // attempt, but that happens outside what we're mocking here.) + val zombieAttempts = (0 until 2).map { stageAttempt => + val attempt = FakeTask.createTaskSet(10, stageAttemptId = stageAttempt) + taskScheduler.submitTasks(attempt) + val tsm = taskScheduler.taskSetManagerForAttempt(0, stageAttempt).get + val offers = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) } + taskScheduler.resourceOffers(offers) + assert(tsm.runningTasks === 10) + // fail attempt + tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED, + FetchFailed(null, 0, 0, 0, "fetch failed")) + // the attempt is a zombie, but the tasks are still running (this could be true even if + // we actively killed those tasks, as killing is best-effort) + assert(tsm.isZombie) + assert(tsm.runningTasks === 9) + tsm + } + + // we've now got 2 zombie attempts, each with 9 tasks still active. Submit the 3rd attempt for + // the stage, but this time with insufficient resources so not all tasks are active. + + val finalAttempt = FakeTask.createTaskSet(10, stageAttemptId = 2) + taskScheduler.submitTasks(finalAttempt) + val finalTsm = taskScheduler.taskSetManagerForAttempt(0, 2).get + val offers = (0 until 5).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) } + val finalAttemptLaunchedPartitions = taskScheduler.resourceOffers(offers).flatten.map { task => + finalAttempt.tasks(task.index).partitionId + }.toSet + assert(finalTsm.runningTasks === 5) + assert(!finalTsm.isZombie) + + // We simulate late completions from our zombie tasksets, corresponding to all the pending + // partitions in our final attempt. This means we're only waiting on the tasks we've already + // launched. + val finalAttemptPendingPartitions = (0 until 10).toSet.diff(finalAttemptLaunchedPartitions) + finalAttemptPendingPartitions.foreach { partition => + completeTaskSuccessfully(zombieAttempts(0), partition) + } + + // If there is another resource offer, we shouldn't run anything. Though our final attempt + // used to have pending tasks, now those tasks have been completed by zombie attempts. The + // remaining tasks to compute are already active in the non-zombie attempt. + assert( + taskScheduler.resourceOffers(IndexedSeq(WorkerOffer("exec-1", "host-1", 1))).flatten.isEmpty) + + val remainingTasks = finalAttemptLaunchedPartitions.toIndexedSeq.sorted + + // finally, if we finish the remaining partitions from a mix of tasksets, all attempts should be + // marked as zombie. + // for each of the remaining tasks, find the tasksets with an active copy of the task, and + // finish the task. + remainingTasks.foreach { partition => + val tsm = if (partition == 0) { + // we failed this task on both zombie attempts, this one is only present in the latest + // taskset + finalTsm + } else { + // should be active in every taskset. We choose a zombie taskset just to make sure that + // we transition the active taskset correctly even if the final completion comes + // from a zombie. + zombieAttempts(partition % 2) + } + completeTaskSuccessfully(tsm, partition) + } + + assert(finalTsm.isZombie) + + // no taskset has completed all of its tasks, so no updates to the blacklist tracker yet + verify(blacklist, never).updateBlacklistForSuccessfulTaskSet(anyInt(), anyInt(), anyObject()) + + // finally, lets complete all the tasks. We simulate failures in attempt 1, but everything + // else succeeds, to make sure we get the right updates to the blacklist in all cases. + (zombieAttempts ++ Seq(finalTsm)).foreach { tsm => + val stageAttempt = tsm.taskSet.stageAttemptId + tsm.runningTasksSet.foreach { index => + if (stageAttempt == 1) { + tsm.handleFailedTask(tsm.taskInfos(index).taskId, TaskState.FAILED, TaskResultLost) + } else { + val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq()) + tsm.handleSuccessfulTask(tsm.taskInfos(index).taskId, result) + } + } + + // we update the blacklist for the stage attempts with all successful tasks. Even though + // some tasksets had failures, we still consider them all successful from a blacklisting + // perspective, as the failures weren't from a problem w/ the tasks themselves. + verify(blacklist).updateBlacklistForSuccessfulTaskSet(meq(0), meq(stageAttempt), anyObject()) + } + } } From e3201e165e41f076ec72175af246d12c0da529cf Mon Sep 17 00:00:00 2001 From: maryannxue Date: Thu, 3 May 2018 17:05:02 -0700 Subject: [PATCH 731/774] [SPARK-24035][SQL] SQL syntax for Pivot ## What changes were proposed in this pull request? Add SQL support for Pivot according to Pivot grammar defined by Oracle (https://docs.oracle.com/database/121/SQLRF/img_text/pivot_clause.htm) with some simplifications, based on our existing functionality and limitations for Pivot at the backend: 1. For pivot_for_clause (https://docs.oracle.com/database/121/SQLRF/img_text/pivot_for_clause.htm), the column list form is not supported, which means the pivot column can only be one single column. 2. For pivot_in_clause (https://docs.oracle.com/database/121/SQLRF/img_text/pivot_in_clause.htm), the sub-query form and "ANY" is not supported (this is only supported by Oracle for XML anyway). 3. For pivot_in_clause, aliases for the constant values are not supported. The code changes are: 1. Add parser support for Pivot. Note that according to https://docs.oracle.com/database/121/SQLRF/statements_10002.htm#i2076542, Pivot cannot be used together with lateral views in the from clause. This restriction has been implemented in the Parser rule. 2. Infer group-by expressions: group-by expressions are not explicitly specified in SQL Pivot clause and need to be deduced based on this rule: https://docs.oracle.com/database/121/SQLRF/statements_10002.htm#CHDFAFIE, so we have to post-fix it at query analysis stage. 3. Override Pivot.resolved as "false": for the reason mentioned in [2] and the fact that output attributes change after Pivot being replaced by Project or Aggregate, we avoid resolving parent references until after Pivot has been resolved and replaced. 4. Verify aggregate expressions: only aggregate expressions with or without aliases can appear in the first part of the Pivot clause, and this check is performed as analysis stage. ## How was this patch tested? A new test suite PivotSuite is added. Author: maryannxue Closes #21187 from maryannxue/spark-24035. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 12 +- .../sql/catalyst/analysis/Analyzer.scala | 35 +++- .../sql/catalyst/parser/AstBuilder.scala | 20 +- .../plans/logical/basicLogicalOperators.scala | 27 ++- .../parser/TableIdentifierParserSuite.scala | 6 +- .../spark/sql/RelationalGroupedDataset.scala | 2 +- .../test/resources/sql-tests/inputs/pivot.sql | 113 ++++++++++ .../resources/sql-tests/results/pivot.sql.out | 194 ++++++++++++++++++ 8 files changed, 386 insertions(+), 23 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/pivot.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/pivot.sql.out diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 5fa75fe348e68..f7f921ec22c35 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -398,7 +398,7 @@ hintStatement ; fromClause - : FROM relation (',' relation)* lateralView* + : FROM relation (',' relation)* (pivotClause | lateralView*)? ; aggregation @@ -413,6 +413,10 @@ groupingSet | expression ; +pivotClause + : PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn=identifier IN '(' pivotValues+=constant (',' pivotValues+=constant)* ')' ')' + ; + lateralView : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)? ; @@ -725,7 +729,7 @@ nonReserved | ADD | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | LAST | FIRST | AFTER | MAP | ARRAY | STRUCT - | LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER + | PIVOT | LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | GLOBAL | TEMPORARY | OPTIONS | GROUPING | CUBE | ROLLUP @@ -745,7 +749,7 @@ nonReserved | REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | RECOVER | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION | LOCAL | INPATH | ASC | DESC | LIMIT | RENAME | SETS - | AT | NULLS | OVERWRITE | ALL | ALTER | AS | BETWEEN | BY | CREATE | DELETE + | AT | NULLS | OVERWRITE | ALL | ANY | ALTER | AS | BETWEEN | BY | CREATE | DELETE | DESCRIBE | DROP | EXISTS | FALSE | FOR | GROUP | IN | INSERT | INTO | IS |LIKE | NULL | ORDER | OUTER | TABLE | TRUE | WITH | RLIKE | AND | CASE | CAST | DISTINCT | DIV | ELSE | END | FUNCTION | INTERVAL | MACRO | OR | STRATIFY | THEN @@ -760,6 +764,7 @@ FROM: 'FROM'; ADD: 'ADD'; AS: 'AS'; ALL: 'ALL'; +ANY: 'ANY'; DISTINCT: 'DISTINCT'; WHERE: 'WHERE'; GROUP: 'GROUP'; @@ -805,6 +810,7 @@ RIGHT: 'RIGHT'; FULL: 'FULL'; NATURAL: 'NATURAL'; ON: 'ON'; +PIVOT: 'PIVOT'; LATERAL: 'LATERAL'; WINDOW: 'WINDOW'; OVER: 'OVER'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e821e96522f7c..dfdcdbc1eb2c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -275,9 +275,9 @@ class Analyzer( case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) => g.copy(aggregations = assignAliases(g.aggregations)) - case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) - if child.resolved && hasUnresolvedAlias(groupByExprs) => - Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child) + case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child) + if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) => + Pivot(Some(assignAliases(groupByOpt.get)), pivotColumn, pivotValues, aggregates, child) case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => Project(assignAliases(projectList), child) @@ -504,9 +504,20 @@ class Analyzer( object ResolvePivot extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) - | !p.groupByExprs.forall(_.resolved) | !p.pivotColumn.resolved => p - case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => + case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved) + || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) + || !p.pivotColumn.resolved => p + case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => + // Check all aggregate expressions. + aggregates.foreach { e => + if (!isAggregateExpression(e)) { + throw new AnalysisException( + s"Aggregate expression required for pivot, found '$e'") + } + } + // Group-by expressions coming from SQL are implicit and need to be deduced. + val groupByExprs = groupByExprsOpt.getOrElse( + (child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq) val singleAgg = aggregates.size == 1 def outputName(value: Literal, aggregate: Expression): String = { val utf8Value = Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) @@ -568,16 +579,20 @@ class Analyzer( // TODO: Don't construct the physical container until after analysis. case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId) } - if (filteredAggregate.fastEquals(aggregate)) { - throw new AnalysisException( - s"Aggregate expression required for pivot, found '$aggregate'") - } Alias(filteredAggregate, outputName(value, aggregate))() } } Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) } } + + private def isAggregateExpression(expr: Expression): Boolean = { + expr match { + case Alias(e, _) => isAggregateExpression(e) + case AggregateExpression(_, _, _, _) => true + case _ => false + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index bdc357d54a878..64eed23884584 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -503,7 +503,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val join = right.optionalMap(left)(Join(_, _, Inner, None)) withJoinRelations(join, relation) } - ctx.lateralView.asScala.foldLeft(from)(withGenerate) + if (ctx.pivotClause() != null) { + withPivot(ctx.pivotClause, from) + } else { + ctx.lateralView.asScala.foldLeft(from)(withGenerate) + } } /** @@ -614,6 +618,20 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging plan } + /** + * Add a [[Pivot]] to a logical plan. + */ + private def withPivot( + ctx: PivotClauseContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val aggregates = Option(ctx.aggregates).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + val pivotColumn = UnresolvedAttribute.quoted(ctx.pivotColumn.getText) + val pivotValues = ctx.pivotValues.asScala.map(typedVisit[Expression]).map(Literal.apply) + Pivot(None, pivotColumn, pivotValues, aggregates, query) + } + /** * Add a [[Generate]] (Lateral View) to a logical plan. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 10df504795430..3bf32ef7884e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -686,17 +686,34 @@ case class GroupingSets( override lazy val resolved: Boolean = false } +/** + * A constructor for creating a pivot, which will later be converted to a [[Project]] + * or an [[Aggregate]] during the query analysis. + * + * @param groupByExprsOpt A sequence of group by expressions. This field should be None if coming + * from SQL, in which group by expressions are not explicitly specified. + * @param pivotColumn The pivot column. + * @param pivotValues A sequence of values for the pivot column. + * @param aggregates The aggregation expressions, each with or without an alias. + * @param child Child operator + */ case class Pivot( - groupByExprs: Seq[NamedExpression], + groupByExprsOpt: Option[Seq[NamedExpression]], pivotColumn: Expression, pivotValues: Seq[Literal], aggregates: Seq[Expression], child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match { - case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) - case _ => pivotValues.flatMap{ value => - aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)()) + override lazy val resolved = false // Pivot will be replaced after being resolved. + override def output: Seq[Attribute] = { + val pivotAgg = aggregates match { + case agg :: Nil => + pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) + case _ => + pivotValues.flatMap { value => + aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)()) + } } + groupByExprsOpt.getOrElse(Seq.empty).map(_.toAttribute) ++ pivotAgg } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index cc80a41df998d..89903c2825125 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -41,12 +41,12 @@ class TableIdentifierParserSuite extends SparkFunSuite { "sort", "sorted", "ssl", "statistics", "stored", "streamtable", "string", "struct", "tables", "tblproperties", "temporary", "terminated", "tinyint", "touch", "transactions", "unarchive", "undo", "uniontype", "unlock", "unset", "unsigned", "uri", "use", "utc", "utctimestamp", - "view", "while", "year", "work", "transaction", "write", "isolation", "level", - "snapshot", "autocommit", "all", "alter", "array", "as", "authorization", "between", "bigint", + "view", "while", "year", "work", "transaction", "write", "isolation", "level", "snapshot", + "autocommit", "all", "any", "alter", "array", "as", "authorization", "between", "bigint", "binary", "boolean", "both", "by", "create", "cube", "current_date", "current_timestamp", "cursor", "date", "decimal", "delete", "describe", "double", "drop", "exists", "external", "false", "fetch", "float", "for", "grant", "group", "grouping", "import", "in", - "insert", "int", "into", "is", "lateral", "like", "local", "none", "null", + "insert", "int", "into", "is", "pivot", "lateral", "like", "local", "none", "null", "of", "order", "out", "outer", "partition", "percent", "procedure", "range", "reads", "revoke", "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger", "true", "truncate", "update", "user", "values", "with", "regexp", "rlike", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 7147798d99533..6c2be3610ae30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -73,7 +73,7 @@ class RelationalGroupedDataset protected[sql]( case RelationalGroupedDataset.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) Dataset.ofRows( - df.sparkSession, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) + df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.logicalPlan)) } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql new file mode 100644 index 0000000000000..01dea6c81c11b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql @@ -0,0 +1,113 @@ +create temporary view courseSales as select * from values + ("dotNET", 2012, 10000), + ("Java", 2012, 20000), + ("dotNET", 2012, 5000), + ("dotNET", 2013, 48000), + ("Java", 2013, 30000) + as courseSales(course, year, earnings); + +create temporary view years as select * from values + (2012, 1), + (2013, 2) + as years(y, s); + +-- pivot courses +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +); + +-- pivot years with no subquery +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +); + +-- pivot courses with multiple aggregations +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), avg(earnings) + FOR course IN ('dotNET', 'Java') +); + +-- pivot with no group by column +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +); + +-- pivot with no group by column and with multiple aggregations on different columns +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), min(year) + FOR course IN ('dotNET', 'Java') +); + +-- pivot on join query with multiple group by columns +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN (1, 2) +); + +-- pivot on join query with multiple aggregations on different columns +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings), min(s) + FOR course IN ('dotNET', 'Java') +); + +-- pivot on join query with multiple columns in one aggregation +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings * s) + FOR course IN ('dotNET', 'Java') +); + +-- pivot with aliases and projection +SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012, 2013) +); + +-- pivot years with non-aggregate function +SELECT * FROM courseSales +PIVOT ( + abs(earnings) + FOR year IN (2012, 2013) +); + +-- pivot with unresolvable columns +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +); diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out new file mode 100644 index 0000000000000..85e3488990e20 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -0,0 +1,194 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 13 + + +-- !query 0 +create temporary view courseSales as select * from values + ("dotNET", 2012, 10000), + ("Java", 2012, 20000), + ("dotNET", 2012, 5000), + ("dotNET", 2013, 48000), + ("Java", 2013, 30000) + as courseSales(course, year, earnings) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create temporary view years as select * from values + (2012, 1), + (2013, 2) + as years(y, s) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +) +-- !query 2 schema +struct +-- !query 2 output +2012 15000 20000 +2013 48000 30000 + + +-- !query 3 +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +) +-- !query 3 schema +struct +-- !query 3 output +Java 20000 30000 +dotNET 15000 48000 + + +-- !query 4 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), avg(earnings) + FOR course IN ('dotNET', 'Java') +) +-- !query 4 schema +struct +-- !query 4 output +2012 15000 7500.0 20000 20000.0 +2013 48000 48000.0 30000 30000.0 + + +-- !query 5 +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +) +-- !query 5 schema +struct +-- !query 5 output +63000 50000 + + +-- !query 6 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), min(year) + FOR course IN ('dotNET', 'Java') +) +-- !query 6 schema +struct +-- !query 6 output +63000 2012 50000 2012 + + +-- !query 7 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN (1, 2) +) +-- !query 7 schema +struct +-- !query 7 output +Java 2012 20000 NULL +Java 2013 NULL 30000 +dotNET 2012 15000 NULL +dotNET 2013 NULL 48000 + + +-- !query 8 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings), min(s) + FOR course IN ('dotNET', 'Java') +) +-- !query 8 schema +struct +-- !query 8 output +2012 15000 1 20000 1 +2013 48000 2 30000 2 + + +-- !query 9 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings * s) + FOR course IN ('dotNET', 'Java') +) +-- !query 9 schema +struct +-- !query 9 output +2012 15000 20000 +2013 96000 60000 + + +-- !query 10 +SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012, 2013) +) +-- !query 10 schema +struct<2012_s:bigint,2013_s:bigint,2012_a:double,2013_a:double,c:string> +-- !query 10 output +15000 48000 7500.0 48000.0 dotNET +20000 30000 20000.0 30000.0 Java + + +-- !query 11 +SELECT * FROM courseSales +PIVOT ( + abs(earnings) + FOR year IN (2012, 2013) +) +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +Aggregate expression required for pivot, found 'abs(earnings#x)'; + + +-- !query 12 +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +) +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve '`year`' given input columns: [__auto_generated_subquery_name.course, __auto_generated_subquery_name.earnings]; line 4 pos 0 From e646ae67f2e793204bc819ab2b90815214c2bbf3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 3 May 2018 17:27:13 -0700 Subject: [PATCH 732/774] [SPARK-24168][SQL] WindowExec should not access SQLConf at executor side ## What changes were proposed in this pull request? This PR is extracted from #21190 , to make it easier to backport. `WindowExec#createBoundOrdering` is called on executor side, so we can't use `conf.sessionLocalTimezone` there. ## How was this patch tested? tested in #21190 Author: Wenchen Fan Closes #21225 from cloud-fan/minor3. --- .../spark/sql/execution/window/WindowExec.scala | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 800a2ea3f3996..626f39d9e95cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -112,9 +112,11 @@ case class WindowExec( * * @param frame to evaluate. This can either be a Row or Range frame. * @param bound with respect to the row. + * @param timeZone the session local timezone for time related calculations. * @return a bound ordering object. */ - private[this] def createBoundOrdering(frame: FrameType, bound: Expression): BoundOrdering = { + private[this] def createBoundOrdering( + frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = { (frame, bound) match { case (RowFrame, CurrentRow) => RowBoundOrdering(0) @@ -144,7 +146,7 @@ case class WindowExec( val boundExpr = (expr.dataType, boundOffset.dataType) match { case (DateType, IntegerType) => DateAdd(expr, boundOffset) case (TimestampType, CalendarIntervalType) => - TimeAdd(expr, boundOffset, Some(conf.sessionLocalTimeZone)) + TimeAdd(expr, boundOffset, Some(timeZone)) case (a, b) if a== b => Add(expr, boundOffset) } val bound = newMutableProjection(boundExpr :: Nil, child.output) @@ -197,6 +199,7 @@ case class WindowExec( // Map the groups to a (unbound) expression and frame factory pair. var numExpressions = 0 + val timeZone = conf.sessionLocalTimeZone framedFunctions.toSeq.map { case (key, (expressions, functionSeq)) => val ordinal = numExpressions @@ -237,7 +240,7 @@ case class WindowExec( new UnboundedPrecedingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, upper)) + createBoundOrdering(frameType, upper, timeZone)) } // Shrinking Frame. @@ -246,7 +249,7 @@ case class WindowExec( new UnboundedFollowingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, lower)) + createBoundOrdering(frameType, lower, timeZone)) } // Moving Frame. @@ -255,8 +258,8 @@ case class WindowExec( new SlidingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, lower), - createBoundOrdering(frameType, upper)) + createBoundOrdering(frameType, lower, timeZone), + createBoundOrdering(frameType, upper, timeZone)) } } From 0c23e254c38d4a9210939e1e1b0074278568abed Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 4 May 2018 09:27:14 +0800 Subject: [PATCH 733/774] [SPARK-24167][SQL] ParquetFilters should not access SQLConf at executor side ## What changes were proposed in this pull request? This PR is extracted from #21190 , to make it easier to backport. `ParquetFilters` is used in the file scan function, which is executed in executor side, so we can't call `conf.parquetFilterPushDownDate` there. ## How was this patch tested? it's tested in #21190 Author: Wenchen Fan Closes #21224 from cloud-fan/minor2. --- .../datasources/parquet/ParquetFileFormat.scala | 3 ++- .../datasources/parquet/ParquetFilters.scala | 15 +++++++-------- .../datasources/parquet/ParquetFilterSuite.scala | 10 ++++++---- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index d8f47eec952de..d1f9e11ed4225 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -342,6 +342,7 @@ class ParquetFileFormat sparkSession.sessionState.conf.parquetFilterPushDown // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) + val pushDownDate = sqlConf.parquetFilterPushDownDate (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) @@ -352,7 +353,7 @@ class ParquetFileFormat // Collects all converted Parquet filter predicates. Notice that not all predicates can be // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` // is used here. - .flatMap(ParquetFilters.createFilter(requiredSchema, _)) + .flatMap(new ParquetFilters(pushDownDate).createFilter(requiredSchema, _)) .reduceOption(FilterApi.and) } else { None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index ccc8306866d68..310626197a763 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -25,14 +25,13 @@ import org.apache.parquet.io.api.Binary import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.types._ /** * Some utility function to convert Spark data source filters to Parquet filters. */ -private[parquet] object ParquetFilters { +private[parquet] class ParquetFilters(pushDownDate: Boolean) { private def dateToDays(date: Date): SQLDate = { DateTimeUtils.fromJavaDate(date) @@ -59,7 +58,7 @@ private[parquet] object ParquetFilters { (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.eq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) @@ -85,7 +84,7 @@ private[parquet] object ParquetFilters { (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.notEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) @@ -108,7 +107,7 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.lt( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) @@ -131,7 +130,7 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.ltEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) @@ -154,7 +153,7 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.gt( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) @@ -177,7 +176,7 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.gtEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 1d3476e747046..667e0b1760e3d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -55,6 +55,8 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} */ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { + private lazy val parquetFilters = new ParquetFilters(conf.parquetFilterPushDownDate) + override def beforeEach(): Unit = { super.beforeEach() // Note that there are many tests here that require record-level filtering set to be true. @@ -99,7 +101,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assert(selectedFilters.nonEmpty, "No filter is pushed down") selectedFilters.foreach { pred => - val maybeFilter = ParquetFilters.createFilter(df.schema, pred) + val maybeFilter = parquetFilters.createFilter(df.schema, pred) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) maybeFilter.exists(_.getClass === filterClass) @@ -517,7 +519,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex lt(intColumn("a"), 10: Integer), gt(doubleColumn("c"), 1.5: java.lang.Double))) ) { - ParquetFilters.createFilter( + parquetFilters.createFilter( schema, sources.And( sources.LessThan("a", 10), @@ -525,7 +527,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } assertResult(None) { - ParquetFilters.createFilter( + parquetFilters.createFilter( schema, sources.And( sources.LessThan("a", 10), @@ -533,7 +535,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } assertResult(None) { - ParquetFilters.createFilter( + parquetFilters.createFilter( schema, sources.Not( sources.And( From 7f1b6b182e3cf3cbf29399e7bfbe03fa869e0bc8 Mon Sep 17 00:00:00 2001 From: Arun Mahadevan Date: Fri, 4 May 2018 16:02:21 +0800 Subject: [PATCH 734/774] [SPARK-24136][SS] Fix MemoryStreamDataReader.next to skip sleeping if record is available ## What changes were proposed in this pull request? Avoid unnecessary sleep (10 ms) in each invocation of MemoryStreamDataReader.next. ## How was this patch tested? Ran ContinuousSuite from IDE. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Arun Mahadevan Closes #21207 from arunmahadevan/memorystream. --- .../streaming/sources/ContinuousMemoryStream.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index c28919b8b729b..a8fca3c19a2d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -183,11 +183,10 @@ class ContinuousMemoryStreamDataReader( private var current: Option[Row] = None override def next(): Boolean = { - current = None + current = getRecord while (current.isEmpty) { Thread.sleep(10) - current = endpoint.askSync[Option[Row]]( - GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset))) + current = getRecord } currentOffset += 1 true @@ -199,6 +198,10 @@ class ContinuousMemoryStreamDataReader( override def getOffset: ContinuousMemoryStreamPartitionOffset = ContinuousMemoryStreamPartitionOffset(partition, currentOffset) + + private def getRecord: Option[Row] = + endpoint.askSync[Option[Row]]( + GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset))) } case class ContinuousMemoryStreamOffset(partitionNums: Map[Int, Int]) From 4d5de4d303a773b1c18c350072344bd7efca9fc4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 4 May 2018 19:20:15 +0800 Subject: [PATCH 735/774] [SPARK-23697][CORE] LegacyAccumulatorWrapper should define isZero correctly ## What changes were proposed in this pull request? It's possible that Accumulators of Spark 1.x may no longer work with Spark 2.x. This is because `LegacyAccumulatorWrapper.isZero` may return wrong answer if `AccumulableParam` doesn't define equals/hashCode. This PR fixes this by using reference equality check in `LegacyAccumulatorWrapper.isZero`. ## How was this patch tested? a new test Author: Wenchen Fan Closes #21229 from cloud-fan/accumulator. --- .../org/apache/spark/util/AccumulatorV2.scala | 6 ++++-- .../spark/util/AccumulatorV2Suite.scala | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 0f84ea9752cf5..2bc84953a56eb 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -486,7 +486,9 @@ class LegacyAccumulatorWrapper[R, T]( param: org.apache.spark.AccumulableParam[R, T]) extends AccumulatorV2[T, R] { private[spark] var _value = initialValue // Current value on driver - override def isZero: Boolean = _value == param.zero(initialValue) + @transient private lazy val _zero = param.zero(initialValue) + + override def isZero: Boolean = _value.asInstanceOf[AnyRef].eq(_zero.asInstanceOf[AnyRef]) override def copy(): LegacyAccumulatorWrapper[R, T] = { val acc = new LegacyAccumulatorWrapper(initialValue, param) @@ -495,7 +497,7 @@ class LegacyAccumulatorWrapper[R, T]( } override def reset(): Unit = { - _value = param.zero(initialValue) + _value = _zero } override def add(v: T): Unit = _value = param.addAccumulator(_value, v) diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala index a04644d57ed88..fe0a9a471a651 100644 --- a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala +++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import org.apache.spark._ +import org.apache.spark.serializer.JavaSerializer class AccumulatorV2Suite extends SparkFunSuite { @@ -162,4 +163,22 @@ class AccumulatorV2Suite extends SparkFunSuite { assert(acc3.isZero) assert(acc3.value === "") } + + test("LegacyAccumulatorWrapper with AccumulatorParam that has no equals/hashCode") { + class MyData(val i: Int) extends Serializable + val param = new AccumulatorParam[MyData] { + override def zero(initialValue: MyData): MyData = new MyData(0) + override def addInPlace(r1: MyData, r2: MyData): MyData = new MyData(r1.i + r2.i) + } + + val acc = new LegacyAccumulatorWrapper(new MyData(0), param) + acc.metadata = AccumulatorMetadata( + AccumulatorContext.newId(), + Some("test"), + countFailedValues = false) + AccumulatorContext.register(acc) + + val ser = new JavaSerializer(new SparkConf).newInstance() + ser.serialize(acc) + } } From d04806a23c1843a7f0dcc4fa236ed1b40ae113a5 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Fri, 4 May 2018 13:29:47 -0700 Subject: [PATCH 736/774] =?UTF-8?q?[SPARK-24124]=20Spark=20history=20serve?= =?UTF-8?q?r=20should=20create=20spark.history.store.=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …path and set permissions properly ## What changes were proposed in this pull request? Spark history server should create spark.history.store.path and set permissions properly. Note createdDirectories doesn't do anything if the directories are already created. This does not stomp on the permissions if the user had manually created the directory before the history server could. ## How was this patch tested? Manually tested in a 100 node cluster. Ensured directories created with proper permissions. Ensured restarted worked apps/temp directories worked as apps were read. Author: Thomas Graves Closes #21234 from tgravescs/SPARK-24124. --- .../apache/spark/deploy/history/FsHistoryProvider.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 56db9359e033f..bf1eeb0c1bf59 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -18,6 +18,8 @@ package org.apache.spark.deploy.history import java.io.{File, FileNotFoundException, IOException} +import java.nio.file.Files +import java.nio.file.attribute.PosixFilePermissions import java.util.{Date, ServiceLoader} import java.util.concurrent.{ExecutorService, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} @@ -130,8 +132,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // Visible for testing. private[history] val listing: KVStore = storePath.map { path => - require(path.isDirectory(), s"Configured store directory ($path) does not exist.") - val dbPath = new File(path, "listing.ldb") + val perms = PosixFilePermissions.fromString("rwx------") + val dbPath = Files.createDirectories(new File(path, "listing.ldb").toPath(), + PosixFilePermissions.asFileAttribute(perms)).toFile() + val metadata = new FsHistoryProviderMetadata(CURRENT_LISTING_VERSION, AppStatusStore.CURRENT_VERSION, logDir.toString()) From af4dc50280ffcdeda208ef2dc5f8b843389732e5 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 4 May 2018 14:14:40 -0700 Subject: [PATCH 737/774] [SPARK-24039][SS] Do continuous processing writes with multiple compute() calls ## What changes were proposed in this pull request? Do continuous processing writes with multiple compute() calls. The current strategy (before this PR) is hacky; we just call next() on an iterator which has already returned hasNext = false, knowing that all the nodes we whitelist handle this properly. This will have to be changed before we can support more complex query plans. (In particular, I have a WIP https://github.com/jose-torres/spark/pull/13 which should be able to support aggregates in a single partition with minimal additional work.) Most of the changes here are just refactoring to accommodate the new model. The behavioral changes are: * The writer now calls prev.compute(split, context) once per epoch within the epoch loop. * ContinuousDataSourceRDD now spawns a ContinuousQueuedDataReader which is shared across multiple calls to compute() for the same partition. ## How was this patch tested? existing unit tests Author: Jose Torres Closes #21200 from jose-torres/noAggr. --- .../datasources/v2/DataSourceV2ScanExec.scala | 6 +- .../continuous/ContinuousDataSourceRDD.scala | 114 +++++++++ .../ContinuousDataSourceRDDIter.scala | 222 ------------------ .../ContinuousQueuedDataReader.scala | 211 +++++++++++++++++ .../continuous/ContinuousWriteRDD.scala | 90 +++++++ .../WriteToContinuousDataSourceExec.scala | 57 +---- .../ContinuousQueuedDataReaderSuite.scala | 167 +++++++++++++ 7 files changed, 592 insertions(+), 275 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 41bdda47c8c3e..77cb707340b0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -96,7 +96,11 @@ case class DataSourceV2ScanExec( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) .askSync[Unit](SetReaderPartitions(readerFactories.size)) - new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories) + new ContinuousDataSourceRDD( + sparkContext, + sqlContext.conf.continuousStreamingExecutorQueueSize, + sqlContext.conf.continuousStreamingExecutorPollIntervalMs, + readerFactories) .asInstanceOf[RDD[InternalRow]] case r: SupportsScanColumnarBatch if r.enableBatchRead() => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala new file mode 100644 index 0000000000000..0a3b9dcccb6c5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset} +import org.apache.spark.util.{NextIterator, ThreadUtils} + +class ContinuousDataSourceRDDPartition( + val index: Int, + val readerFactory: DataReaderFactory[UnsafeRow]) + extends Partition with Serializable { + + // This is semantically a lazy val - it's initialized once the first time a call to + // ContinuousDataSourceRDD.compute() needs to access it, so it can be shared across + // all compute() calls for a partition. This ensures that one compute() picks up where the + // previous one ended. + // We don't make it actually a lazy val because it needs input which isn't available here. + // This will only be initialized on the executors. + private[continuous] var queueReader: ContinuousQueuedDataReader = _ +} + +/** + * The bottom-most RDD of a continuous processing read task. Wraps a [[ContinuousQueuedDataReader]] + * to read from the remote source, and polls that queue for incoming rows. + * + * Note that continuous processing calls compute() multiple times, and the same + * [[ContinuousQueuedDataReader]] instance will/must be shared between each call for the same split. + */ +class ContinuousDataSourceRDD( + sc: SparkContext, + dataQueueSize: Int, + epochPollIntervalMs: Long, + @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]]) + extends RDD[UnsafeRow](sc, Nil) { + + override protected def getPartitions: Array[Partition] = { + readerFactories.zipWithIndex.map { + case (readerFactory, index) => new ContinuousDataSourceRDDPartition(index, readerFactory) + }.toArray + } + + /** + * Initialize the shared reader for this partition if needed, then read rows from it until + * it returns null to signal the end of the epoch. + */ + override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + // If attempt number isn't 0, this is a task retry, which we don't support. + if (context.attemptNumber() != 0) { + throw new ContinuousTaskRetryException() + } + + val readerForPartition = { + val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition] + if (partition.queueReader == null) { + partition.queueReader = + new ContinuousQueuedDataReader( + partition.readerFactory, context, dataQueueSize, epochPollIntervalMs) + } + + partition.queueReader + } + + new NextIterator[UnsafeRow] { + override def getNext(): UnsafeRow = { + readerForPartition.next() match { + case null => + finished = true + null + case row => row + } + } + + override def close(): Unit = {} + } + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + split.asInstanceOf[ContinuousDataSourceRDDPartition].readerFactory.preferredLocations() + } +} + +object ContinuousDataSourceRDD { + private[continuous] def getContinuousReader( + reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = { + reader match { + case r: ContinuousDataReader[UnsafeRow] => r + case wrapped: RowToUnsafeDataReader => + wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]] + case _ => + throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala deleted file mode 100644 index 06754f01657d3..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ /dev/null @@ -1,222 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.continuous - -import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit} -import java.util.concurrent.atomic.AtomicBoolean - -import scala.collection.JavaConverters._ - -import org.apache.spark._ -import org.apache.spark.internal.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset} -import org.apache.spark.util.ThreadUtils - -class ContinuousDataSourceRDD( - sc: SparkContext, - sqlContext: SQLContext, - @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]]) - extends RDD[UnsafeRow](sc, Nil) { - - private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize - private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs - - override protected def getPartitions: Array[Partition] = { - readerFactories.zipWithIndex.map { - case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) - }.toArray - } - - override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - // If attempt number isn't 0, this is a task retry, which we don't support. - if (context.attemptNumber() != 0) { - throw new ContinuousTaskRetryException() - } - - val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]] - .readerFactory.createDataReader() - - val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) - - // This queue contains two types of messages: - // * (null, null) representing an epoch boundary. - // * (row, off) containing a data row and its corresponding PartitionOffset. - val queue = new ArrayBlockingQueue[(UnsafeRow, PartitionOffset)](dataQueueSize) - - val epochPollFailed = new AtomicBoolean(false) - val epochPollExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( - s"epoch-poll--$coordinatorId--${context.partitionId()}") - val epochPollRunnable = new EpochPollRunnable(queue, context, epochPollFailed) - epochPollExecutor.scheduleWithFixedDelay( - epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) - - // Important sequencing - we must get start offset before the data reader thread begins - val startOffset = ContinuousDataSourceRDD.getBaseReader(reader).getOffset - - val dataReaderFailed = new AtomicBoolean(false) - val dataReaderThread = new DataReaderThread(reader, queue, context, dataReaderFailed) - dataReaderThread.setDaemon(true) - dataReaderThread.start() - - context.addTaskCompletionListener(_ => { - dataReaderThread.interrupt() - epochPollExecutor.shutdown() - }) - - val epochEndpoint = EpochCoordinatorRef.get(coordinatorId, SparkEnv.get) - new Iterator[UnsafeRow] { - private val POLL_TIMEOUT_MS = 1000 - - private var currentEntry: (UnsafeRow, PartitionOffset) = _ - private var currentOffset: PartitionOffset = startOffset - private var currentEpoch = - context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - override def hasNext(): Boolean = { - while (currentEntry == null) { - if (context.isInterrupted() || context.isCompleted()) { - currentEntry = (null, null) - } - if (dataReaderFailed.get()) { - throw new SparkException("data read failed", dataReaderThread.failureReason) - } - if (epochPollFailed.get()) { - throw new SparkException("epoch poll failed", epochPollRunnable.failureReason) - } - currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS) - } - - currentEntry match { - // epoch boundary marker - case (null, null) => - epochEndpoint.send(ReportPartitionOffset( - context.partitionId(), - currentEpoch, - currentOffset)) - currentEpoch += 1 - currentEntry = null - false - // real row - case (_, offset) => - currentOffset = offset - true - } - } - - override def next(): UnsafeRow = { - if (currentEntry == null) throw new NoSuchElementException("No current row was set") - val r = currentEntry._1 - currentEntry = null - r - } - } - } - - override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readerFactory.preferredLocations() - } -} - -case class EpochPackedPartitionOffset(epoch: Long) extends PartitionOffset - -class EpochPollRunnable( - queue: BlockingQueue[(UnsafeRow, PartitionOffset)], - context: TaskContext, - failedFlag: AtomicBoolean) - extends Thread with Logging { - private[continuous] var failureReason: Throwable = _ - - private val epochEndpoint = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) - private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - override def run(): Unit = { - try { - val newEpoch = epochEndpoint.askSync[Long](GetCurrentEpoch) - for (i <- currentEpoch to newEpoch - 1) { - queue.put((null, null)) - logDebug(s"Sent marker to start epoch ${i + 1}") - } - currentEpoch = newEpoch - } catch { - case t: Throwable => - failureReason = t - failedFlag.set(true) - throw t - } - } -} - -class DataReaderThread( - reader: DataReader[UnsafeRow], - queue: BlockingQueue[(UnsafeRow, PartitionOffset)], - context: TaskContext, - failedFlag: AtomicBoolean) - extends Thread( - s"continuous-reader--${context.partitionId()}--" + - s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") { - private[continuous] var failureReason: Throwable = _ - - override def run(): Unit = { - TaskContext.setTaskContext(context) - val baseReader = ContinuousDataSourceRDD.getBaseReader(reader) - try { - while (!context.isInterrupted && !context.isCompleted()) { - if (!reader.next()) { - // Check again, since reader.next() might have blocked through an incoming interrupt. - if (!context.isInterrupted && !context.isCompleted()) { - throw new IllegalStateException( - "Continuous reader reported no elements! Reader should have blocked waiting.") - } else { - return - } - } - - queue.put((reader.get().copy(), baseReader.getOffset)) - } - } catch { - case _: InterruptedException if context.isInterrupted() => - // Continuous shutdown always involves an interrupt; do nothing and shut down quietly. - - case t: Throwable => - failureReason = t - failedFlag.set(true) - // Don't rethrow the exception in this thread. It's not needed, and the default Spark - // exception handler will kill the executor. - } finally { - reader.close() - } - } -} - -object ContinuousDataSourceRDD { - private[continuous] def getBaseReader(reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = { - reader match { - case r: ContinuousDataReader[UnsafeRow] => r - case wrapped: RowToUnsafeDataReader => - wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]] - case _ => - throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala new file mode 100644 index 0000000000000..01a999f6505fc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.io.Closeable +import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean + +import scala.util.control.NonFatal + +import org.apache.spark.{Partition, SparkEnv, SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} +import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset +import org.apache.spark.util.ThreadUtils + +/** + * A wrapper for a continuous processing data reader, including a reading queue and epoch markers. + * + * This will be instantiated once per partition - successive calls to compute() in the + * [[ContinuousDataSourceRDD]] will reuse the same reader. This is required to get continuity of + * offsets across epochs. Each compute() should call the next() method here until null is returned. + */ +class ContinuousQueuedDataReader( + factory: DataReaderFactory[UnsafeRow], + context: TaskContext, + dataQueueSize: Int, + epochPollIntervalMs: Long) extends Closeable { + private val reader = factory.createDataReader() + + // Important sequencing - we must get our starting point before the provider threads start running + private var currentOffset: PartitionOffset = + ContinuousDataSourceRDD.getContinuousReader(reader).getOffset + private var currentEpoch: Long = + context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + /** + * The record types in the read buffer. + */ + sealed trait ContinuousRecord + case object EpochMarker extends ContinuousRecord + case class ContinuousRow(row: UnsafeRow, offset: PartitionOffset) extends ContinuousRecord + + private val queue = new ArrayBlockingQueue[ContinuousRecord](dataQueueSize) + + private val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) + private val epochCoordEndpoint = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) + + private val epochMarkerExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( + s"epoch-poll--$coordinatorId--${context.partitionId()}") + private val epochMarkerGenerator = new EpochMarkerGenerator + epochMarkerExecutor.scheduleWithFixedDelay( + epochMarkerGenerator, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) + + private val dataReaderThread = new DataReaderThread + dataReaderThread.setDaemon(true) + dataReaderThread.start() + + context.addTaskCompletionListener(_ => { + this.close() + }) + + private def shouldStop() = { + context.isInterrupted() || context.isCompleted() + } + + /** + * Return the next UnsafeRow to be read in the current epoch, or null if the epoch is done. + * + * After returning null, the [[ContinuousDataSourceRDD]] compute() for the following epoch + * will call next() again to start getting rows. + */ + def next(): UnsafeRow = { + val POLL_TIMEOUT_MS = 1000 + var currentEntry: ContinuousRecord = null + + while (currentEntry == null) { + if (shouldStop()) { + // Force the epoch to end here. The writer will notice the context is interrupted + // or completed and not start a new one. This makes it possible to achieve clean + // shutdown of the streaming query. + // TODO: The obvious generalization of this logic to multiple stages won't work. It's + // invalid to send an epoch marker from the bottom of a task if all its child tasks + // haven't sent one. + currentEntry = EpochMarker + } else { + if (dataReaderThread.failureReason != null) { + throw new SparkException("Data read failed", dataReaderThread.failureReason) + } + if (epochMarkerGenerator.failureReason != null) { + throw new SparkException( + "Epoch marker generation failed", + epochMarkerGenerator.failureReason) + } + currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS) + } + } + + currentEntry match { + case EpochMarker => + epochCoordEndpoint.send(ReportPartitionOffset( + context.partitionId(), currentEpoch, currentOffset)) + currentEpoch += 1 + null + case ContinuousRow(row, offset) => + currentOffset = offset + row + } + } + + override def close(): Unit = { + dataReaderThread.interrupt() + epochMarkerExecutor.shutdown() + } + + /** + * The data component of [[ContinuousQueuedDataReader]]. Pushes (row, offset) to the queue when + * a new row arrives to the [[DataReader]]. + */ + class DataReaderThread extends Thread( + s"continuous-reader--${context.partitionId()}--" + + s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") with Logging { + @volatile private[continuous] var failureReason: Throwable = _ + + override def run(): Unit = { + TaskContext.setTaskContext(context) + val baseReader = ContinuousDataSourceRDD.getContinuousReader(reader) + try { + while (!shouldStop()) { + if (!reader.next()) { + // Check again, since reader.next() might have blocked through an incoming interrupt. + if (!shouldStop()) { + throw new IllegalStateException( + "Continuous reader reported no elements! Reader should have blocked waiting.") + } else { + return + } + } + + queue.put(ContinuousRow(reader.get().copy(), baseReader.getOffset)) + } + } catch { + case _: InterruptedException => + // Continuous shutdown always involves an interrupt; do nothing and shut down quietly. + logInfo(s"shutting down interrupted data reader thread $getName") + + case NonFatal(t) => + failureReason = t + logWarning("data reader thread failed", t) + // If we throw from this thread, we may kill the executor. Let the parent thread handle + // it. + + case t: Throwable => + failureReason = t + throw t + } finally { + reader.close() + } + } + } + + /** + * The epoch marker component of [[ContinuousQueuedDataReader]]. Populates the queue with + * EpochMarker when a new epoch marker arrives. + */ + class EpochMarkerGenerator extends Runnable with Logging { + @volatile private[continuous] var failureReason: Throwable = _ + + private val epochCoordEndpoint = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) + // Note that this is *not* the same as the currentEpoch in [[ContinuousDataQueuedReader]]! That + // field represents the epoch wrt the data being processed. The currentEpoch here is just a + // counter to ensure we send the appropriate number of markers if we fall behind the driver. + private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + override def run(): Unit = { + try { + val newEpoch = epochCoordEndpoint.askSync[Long](GetCurrentEpoch) + // It's possible to fall more than 1 epoch behind if a GetCurrentEpoch RPC ends up taking + // a while. We catch up by injecting enough epoch markers immediately to catch up. This will + // result in some epochs being empty for this partition, but that's fine. + for (i <- currentEpoch to newEpoch - 1) { + queue.put(EpochMarker) + logDebug(s"Sent marker to start epoch ${i + 1}") + } + currentEpoch = newEpoch + } catch { + case t: Throwable => + failureReason = t + throw t + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala new file mode 100644 index 0000000000000..91f1576581511 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark.{Partition, SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.util.Utils + +/** + * The RDD writing to a sink in continuous processing. + * + * Within each task, we repeatedly call prev.compute(). Each resulting iterator contains the data + * to be written for one epoch, which we commit and forward to the driver. + * + * We keep repeating prev.compute() and writing new epochs until the query is shut down. + */ +class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactory[InternalRow]) + extends RDD[Unit](prev) { + + override val partitioner = prev.partitioner + + override def getPartitions: Array[Partition] = prev.partitions + + override def compute(split: Partition, context: TaskContext): Iterator[Unit] = { + val epochCoordinator = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + SparkEnv.get) + var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + while (!context.isInterrupted() && !context.isCompleted()) { + var dataWriter: DataWriter[InternalRow] = null + // write the data and commit this writer. + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + try { + val dataIterator = prev.compute(split, context) + dataWriter = writeTask.createDataWriter( + context.partitionId(), context.attemptNumber(), currentEpoch) + while (dataIterator.hasNext) { + dataWriter.write(dataIterator.next()) + } + logInfo(s"Writer for partition ${context.partitionId()} " + + s"in epoch $currentEpoch is committing.") + val msg = dataWriter.commit() + epochCoordinator.send( + CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) + ) + logInfo(s"Writer for partition ${context.partitionId()} " + + s"in epoch $currentEpoch committed.") + currentEpoch += 1 + } catch { + case _: InterruptedException => + // Continuous shutdown always involves an interrupt. Just finish the task. + } + })(catchBlock = { + // If there is an error, abort this writer. We enter this callback in the middle of + // rethrowing an exception, so compute() will stop executing at this point. + logError(s"Writer for partition ${context.partitionId()} is aborting.") + if (dataWriter != null) dataWriter.abort() + logError(s"Writer for partition ${context.partitionId()} aborted.") + }) + } + + Iterator() + } + + override def clearDependencies() { + super.clearDependencies() + prev = null + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index ba88ae1af469a..e0af3a2f1b85d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -46,24 +46,19 @@ case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPla case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) } - val rdd = query.execute() + val rdd = new ContinuousWriteRDD(query.execute(), writerFactory) logInfo(s"Start processing data source writer: $writer. " + - s"The input RDD has ${rdd.getNumPartitions} partitions.") - // Let the epoch coordinator know how many partitions the write RDD has. + s"The input RDD has ${rdd.partitions.length} partitions.") EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), - sparkContext.env) + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + sparkContext.env) .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) try { // Force the RDD to run so continuous processing starts; no data is actually being collected // to the driver, as ContinuousWriteRDD outputs nothing. - sparkContext.runJob( - rdd, - (context: TaskContext, iter: Iterator[InternalRow]) => - WriteToContinuousDataSourceExec.run(writerFactory, context, iter), - rdd.partitions.indices) + rdd.collect() } catch { case _: InterruptedException => // Interruption is how continuous queries are ended, so accept and ignore the exception. @@ -80,45 +75,3 @@ case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPla sparkContext.emptyRDD } } - -object WriteToContinuousDataSourceExec extends Logging { - def run( - writeTask: DataWriterFactory[InternalRow], - context: TaskContext, - iter: Iterator[InternalRow]): Unit = { - val epochCoordinator = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), - SparkEnv.get) - var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - do { - var dataWriter: DataWriter[InternalRow] = null - // write the data and commit this writer. - Utils.tryWithSafeFinallyAndFailureCallbacks(block = { - try { - dataWriter = writeTask.createDataWriter( - context.partitionId(), context.attemptNumber(), currentEpoch) - while (iter.hasNext) { - dataWriter.write(iter.next()) - } - logInfo(s"Writer for partition ${context.partitionId()} is committing.") - val msg = dataWriter.commit() - logInfo(s"Writer for partition ${context.partitionId()} committed.") - epochCoordinator.send( - CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) - ) - currentEpoch += 1 - } catch { - case _: InterruptedException => - // Continuous shutdown always involves an interrupt. Just finish the task. - } - })(catchBlock = { - // If there is an error, abort this writer. We enter this callback in the middle of - // rethrowing an exception, so runContinuous will stop executing at this point. - logError(s"Writer for partition ${context.partitionId()} is aborting.") - if (dataWriter != null) dataWriter.abort() - logError(s"Writer for partition ${context.partitionId()} aborted.") - }) - } while (!context.isInterrupted()) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala new file mode 100644 index 0000000000000..e755625d09e0f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous + +import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} + +import org.mockito.{ArgumentCaptor, Matchers} +import org.mockito.Mockito._ +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark.{SparkEnv, SparkFunSuite, TaskContext} +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{DataType, IntegerType} + +class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { + case class LongPartitionOffset(offset: Long) extends PartitionOffset + + val coordinatorId = s"${getClass.getSimpleName}-epochCoordinatorIdForUnitTest" + val startEpoch = 0 + + var epochEndpoint: RpcEndpointRef = _ + + override def beforeEach(): Unit = { + super.beforeEach() + epochEndpoint = EpochCoordinatorRef.create( + mock[StreamWriter], + mock[ContinuousReader], + mock[ContinuousExecution], + coordinatorId, + startEpoch, + spark, + SparkEnv.get) + } + + override def afterEach(): Unit = { + SparkEnv.get.rpcEnv.stop(epochEndpoint) + epochEndpoint = null + super.afterEach() + } + + + private val mockContext = mock[TaskContext] + when(mockContext.getLocalProperty(ContinuousExecution.START_EPOCH_KEY)) + .thenReturn(startEpoch.toString) + when(mockContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)) + .thenReturn(coordinatorId) + + /** + * Set up a ContinuousQueuedDataReader for testing. The blocking queue can be used to send + * rows to the wrapped data reader. + */ + private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = { + val queue = new ArrayBlockingQueue[UnsafeRow](1024) + val factory = new DataReaderFactory[UnsafeRow] { + override def createDataReader() = new ContinuousDataReader[UnsafeRow] { + var index = -1 + var curr: UnsafeRow = _ + + override def next() = { + curr = queue.take() + index += 1 + true + } + + override def get = curr + + override def getOffset = LongPartitionOffset(index) + + override def close() = {} + } + } + val reader = new ContinuousQueuedDataReader( + factory, + mockContext, + dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize, + epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs) + + (queue, reader) + } + + private def unsafeRow(value: Int) = { + UnsafeProjection.create(Array(IntegerType : DataType))( + new GenericInternalRow(Array(value: Any))) + } + + test("basic data read") { + val (input, reader) = setup() + + input.add(unsafeRow(12345)) + assert(reader.next().getInt(0) == 12345) + } + + test("basic epoch marker") { + val (input, reader) = setup() + + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + } + + test("new rows after markers") { + val (input, reader) = setup() + + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + assert(reader.next() == null) + assert(reader.next() == null) + input.add(unsafeRow(11111)) + input.add(unsafeRow(22222)) + assert(reader.next().getInt(0) == 11111) + assert(reader.next().getInt(0) == 22222) + } + + test("new markers after rows") { + val (input, reader) = setup() + + input.add(unsafeRow(11111)) + input.add(unsafeRow(22222)) + assert(reader.next().getInt(0) == 11111) + assert(reader.next().getInt(0) == 22222) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + assert(reader.next() == null) + assert(reader.next() == null) + } + + test("alternating markers and rows") { + val (input, reader) = setup() + + input.add(unsafeRow(11111)) + assert(reader.next().getInt(0) == 11111) + input.add(unsafeRow(22222)) + assert(reader.next().getInt(0) == 22222) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + input.add(unsafeRow(33333)) + assert(reader.next().getInt(0) == 33333) + input.add(unsafeRow(44444)) + assert(reader.next().getInt(0) == 44444) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + } +} From 47b5b68528c154d32b3f40f388918836d29462b8 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 4 May 2018 16:35:24 -0700 Subject: [PATCH 738/774] [SPARK-24157][SS] Enabled no-data batches in MicroBatchExecution for streaming aggregation and deduplication. ## What changes were proposed in this pull request? This PR enables the MicroBatchExecution to run no-data batches if some SparkPlan requires running another batch to output results based on updated watermark / processing time. In this PR, I have enabled streaming aggregations and streaming deduplicates to automatically run addition batch even if new data is available. See https://issues.apache.org/jira/browse/SPARK-24156 for more context. Major changes/refactoring done in this PR. - Refactoring MicroBatchExecution - A major point of confusion in MicroBatchExecution control flow was always (at least to me) was that `populateStartOffsets` internally called `constructNextBatch` which was not obvious from just the name "populateStartOffsets" and made the control flow from the main trigger execution loop very confusing (main loop in `runActivatedStream` called `constructNextBatch` but only if `populateStartOffsets` hadn't already called it). Instead, the refactoring makes it cleaner. - `populateStartOffsets` only the updates `availableOffsets` and `committedOffsets`. Does not call `constructNextBatch`. - Main loop in `runActivatedStream` calls `constructNextBatch` which returns true or false reflecting whether the next batch is ready for executing. This method is now idempotent; if a batch has already been constructed, then it will always return true until the batch has been executed. - If next batch is ready then we call `runBatch` or sleep. - That's it. - Refactoring watermark management logic - This has been refactored out from `MicroBatchExecution` in a separate class to simplify `MicroBatchExecution`. - New method `shouldRunAnotherBatch` in `IncrementalExecution` - This returns true if there is any stateful operation in the last execution plan that requires another batch for state cleanup, etc. This is used to decide whether to construct a batch or not in `constructNextBatch`. - Changes to stream testing framework - Many tests used CheckLastBatch to validate answers. This assumed that there will be no more batches after the last set of input has been processed, so the last batch is the one that has output corresponding to the last input. This is not true anymore. To account for that, I made two changes. - `CheckNewAnswer` is a new test action that verifies the new rows generated since the last time the answer was checked by `CheckAnswer`, `CheckNewAnswer` or `CheckLastBatch`. This is agnostic to how many batches occurred between the last check and now. To do make this easier, I added a common trait between MemorySink and MemorySinkV2 to abstract out some common methods. - `assertNumStateRows` has been updated in the same way to be agnostic to batches while checking what the total rows and how many state rows were updated (sums up updates since the last check). ## How was this patch tested? - Changes made to existing tests - Tests have been changed in one of the following patterns. - Tests where the last input was given again to force another batch to be executed and state cleaned up / output generated, they were simplified by removing the extra input. - Tests using aggregation+watermark where CheckLastBatch were replaced with CheckNewAnswer to make them batch agnostic. - New tests added to check whether the flag works for streaming aggregation and deduplication Author: Tathagata Das Closes #21220 from tdas/SPARK-24157. --- .../apache/spark/sql/internal/SQLConf.scala | 11 + .../streaming/IncrementalExecution.scala | 10 + .../streaming/MicroBatchExecution.scala | 231 ++++++++---------- .../streaming/WatermarkTracker.scala | 73 ++++++ .../sql/execution/streaming/memory.scala | 17 +- .../streaming/sources/memoryV2.scala | 8 +- .../streaming/statefulOperators.scala | 16 ++ .../sources/ForeachWriterSuite.scala | 8 +- .../streaming/EventTimeWatermarkSuite.scala | 112 ++++----- .../sql/streaming/FileStreamSinkSuite.scala | 7 +- .../sql/streaming/StateStoreMetricsTest.scala | 52 +++- .../spark/sql/streaming/StreamTest.scala | 56 ++++- ...cala => StreamingDeduplicationSuite.scala} | 94 +++---- .../sql/streaming/StreamingJoinSuite.scala | 18 +- .../sql/streaming/StreamingQuerySuite.scala | 2 +- 15 files changed, 450 insertions(+), 265 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala rename sql/core/src/test/scala/org/apache/spark/sql/streaming/{DeduplicateSuite.scala => StreamingDeduplicationSuite.scala} (80%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3942240c442b2..895e150756567 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -919,6 +919,14 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(10000L) + val STREAMING_NO_DATA_MICRO_BATCHES_ENABLED = + buildConf("spark.sql.streaming.noDataMicroBatchesEnabled") + .doc( + "Whether streaming micro-batch engine will execute batches without data " + + "for eager state management for stateful streaming queries.") + .booleanConf + .createWithDefault(true) + val STREAMING_METRICS_ENABLED = buildConf("spark.sql.streaming.metricsEnabled") .doc("Whether Dropwizard/Codahale metrics will be reported for active streaming queries.") @@ -1313,6 +1321,9 @@ class SQLConf extends Serializable with Logging { def streamingNoDataProgressEventInterval: Long = getConf(STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL) + def streamingNoDataMicroBatchesEnabled: Boolean = + getConf(STREAMING_NO_DATA_MICRO_BATCHES_ENABLED) + def streamingMetricsEnabled: Boolean = getConf(STREAMING_METRICS_ENABLED) def streamingProgressRetention: Int = getConf(STREAMING_PROGRESS_RETENTION) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 1a83c884d55bd..c480b96626f84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -143,4 +143,14 @@ class IncrementalExecution( /** No need assert supported, as this check has already been done */ override def assertSupported(): Unit = { } + + /** + * Should the MicroBatchExecution run another batch based on this execution and the current + * updated metadata. + */ + def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + executedPlan.collect { + case p: StateStoreWriter => p.shouldRunAnotherBatch(newMetadata) + }.exists(_ == true) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 6e231970f4a22..6709e7052f005 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -61,6 +61,8 @@ class MicroBatchExecution( case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger") } + private val watermarkTracker = new WatermarkTracker() + override lazy val logicalPlan: LogicalPlan = { assert(queryExecutionThread eq Thread.currentThread, "logicalPlan must be initialized in QueryExecutionThread " + @@ -128,40 +130,55 @@ class MicroBatchExecution( * Repeatedly attempts to run batches as data arrives. */ protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { - triggerExecutor.execute(() => { - startTrigger() + val noDataBatchesEnabled = + sparkSessionForStream.sessionState.conf.streamingNoDataMicroBatchesEnabled + + triggerExecutor.execute(() => { if (isActive) { + var currentBatchIsRunnable = false // Whether the current batch is runnable / has been run + var currentBatchHasNewData = false // Whether the current batch had new data + + startTrigger() + reportTimeTaken("triggerExecution") { + // We'll do this initialization only once every start / restart if (currentBatchId < 0) { - // We'll do this initialization only once populateStartOffsets(sparkSessionForStream) - sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) - logDebug(s"Stream running from $committedOffsets to $availableOffsets") - } else { - constructNextBatch() + logInfo(s"Stream started from $committedOffsets") } - if (dataAvailable) { - currentStatus = currentStatus.copy(isDataAvailable = true) - updateStatusMessage("Processing new data") + + // Set this before calling constructNextBatch() so any Spark jobs executed by sources + // while getting new data have the correct description + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) + + // Try to construct the next batch. This will return true only if the next batch is + // ready and runnable. Note that the current batch may be runnable even without + // new data to process as `constructNextBatch` may decide to run a batch for + // state cleanup, etc. `isNewDataAvailable` will be updated to reflect whether new data + // is available or not. + currentBatchIsRunnable = constructNextBatch(noDataBatchesEnabled) + + // Remember whether the current batch has data or not. This will be required later + // for bookkeeping after running the batch, when `isNewDataAvailable` will have changed + // to false as the batch would have already processed the available data. + currentBatchHasNewData = isNewDataAvailable + + currentStatus = currentStatus.copy(isDataAvailable = isNewDataAvailable) + if (currentBatchIsRunnable) { + if (currentBatchHasNewData) updateStatusMessage("Processing new data") + else updateStatusMessage("No new data but cleaning up state") runBatch(sparkSessionForStream) + } else { + updateStatusMessage("Waiting for data to arrive") } } - // Report trigger as finished and construct progress object. - finishTrigger(dataAvailable) - if (dataAvailable) { - // Update committed offsets. - commitLog.add(currentBatchId) - committedOffsets ++= availableOffsets - logDebug(s"batch ${currentBatchId} committed") - // We'll increase currentBatchId after we complete processing current batch's data - currentBatchId += 1 - sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) - } else { - currentStatus = currentStatus.copy(isDataAvailable = false) - updateStatusMessage("Waiting for data to arrive") - Thread.sleep(pollingDelayMs) - } + + finishTrigger(currentBatchHasNewData) // Must be outside reportTimeTaken so it is recorded + + // If the current batch has been executed, then increment the batch id, else there was + // no data to execute the batch + if (currentBatchIsRunnable) currentBatchId += 1 else Thread.sleep(pollingDelayMs) } updateStatusMessage("Waiting for next trigger") isActive @@ -211,6 +228,7 @@ class MicroBatchExecution( OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf) offsetSeqMetadata = OffsetSeqMetadata( metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) + watermarkTracker.setWatermark(metadata.batchWatermarkMs) } /* identify the current batch id: if commit log indicates we successfully processed the @@ -235,7 +253,6 @@ class MicroBatchExecution( currentBatchId = latestCommittedBatchId + 1 committedOffsets ++= availableOffsets // Construct a new batch be recomputing availableOffsets - constructNextBatch() } else if (latestCommittedBatchId < latestBatchId - 1) { logWarning(s"Batch completion log latest batch id is " + s"${latestCommittedBatchId}, which is not trailing " + @@ -243,19 +260,18 @@ class MicroBatchExecution( } case None => logInfo("no commit log present") } - logDebug(s"Resuming at batch $currentBatchId with committed offsets " + + logInfo(s"Resuming at batch $currentBatchId with committed offsets " + s"$committedOffsets and available offsets $availableOffsets") case None => // We are starting this stream for the first time. logInfo(s"Starting new streaming query.") currentBatchId = 0 - constructNextBatch() } } /** * Returns true if there is any new data available to be processed. */ - private def dataAvailable: Boolean = { + private def isNewDataAvailable: Boolean = { availableOffsets.exists { case (source, available) => committedOffsets @@ -266,93 +282,63 @@ class MicroBatchExecution( } /** - * Queries all of the sources to see if any new data is available. When there is new data the - * batchId counter is incremented and a new log entry is written with the newest offsets. + * Attempts to construct a batch according to: + * - Availability of new data + * - Need for timeouts and state cleanups in stateful operators + * + * Returns true only if the next batch should be executed. + * + * Here is the high-level logic on how this constructs the next batch. + * - Check each source whether new data is available + * - Updated the query's metadata and check using the last execution whether there is any need + * to run another batch (for state clean up, etc.) + * - If either of the above is true, then construct the next batch by committing to the offset + * log that range of offsets that the next batch will process. */ - private def constructNextBatch(): Unit = { - // Check to see what new data is available. - val hasNewData = { - awaitProgressLock.lock() - try { - // Generate a map from each unique source to the next available offset. - val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map { - case s: Source => - updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("getOffset") { - (s, s.getOffset) - } - case s: MicroBatchReader => - updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("setOffsetRange") { - // Once v1 streaming source execution is gone, we can refactor this away. - // For now, we set the range here to get the source to infer the available end offset, - // get that offset, and then set the range again when we later execute. - s.setOffsetRange( - toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), - Optional.empty()) - } - - val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() } - (s, Option(currentOffset)) - }.toMap - availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) - - if (dataAvailable) { - true - } else { - noNewData = true - false + private def constructNextBatch(noDataBatchesEnables: Boolean): Boolean = withProgressLocked { + // If new data is already available that means this method has already been called before + // and it must have already committed the offset range of next batch to the offset log. + // Hence do nothing, just return true. + if (isNewDataAvailable) return true + + // Generate a map from each unique source to the next available offset. + val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map { + case s: Source => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("getOffset") { + (s, s.getOffset) } - } finally { - awaitProgressLock.unlock() - } - } - if (hasNewData) { - var batchWatermarkMs = offsetSeqMetadata.batchWatermarkMs - // Update the eventTime watermarks if we find any in the plan. - if (lastExecution != null) { - lastExecution.executedPlan.collect { - case e: EventTimeWatermarkExec => e - }.zipWithIndex.foreach { - case (e, index) if e.eventTimeStats.value.count > 0 => - logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") - val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs - val prevWatermarkMs = watermarkMsMap.get(index) - if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { - watermarkMsMap.put(index, newWatermarkMs) - } - - // Populate 0 if we haven't seen any data yet for this watermark node. - case (_, index) => - if (!watermarkMsMap.isDefinedAt(index)) { - watermarkMsMap.put(index, 0) - } + case s: MicroBatchReader => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("setOffsetRange") { + // Once v1 streaming source execution is gone, we can refactor this away. + // For now, we set the range here to get the source to infer the available end offset, + // get that offset, and then set the range again when we later execute. + s.setOffsetRange( + toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), + Optional.empty()) } - // Update the global watermark to the minimum of all watermark nodes. - // This is the safest option, because only the global watermark is fault-tolerant. Making - // it the minimum of all individual watermarks guarantees it will never advance past where - // any individual watermark operator would be if it were in a plan by itself. - if(!watermarkMsMap.isEmpty) { - val newWatermarkMs = watermarkMsMap.minBy(_._2)._2 - if (newWatermarkMs > batchWatermarkMs) { - logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") - batchWatermarkMs = newWatermarkMs - } else { - logDebug( - s"Event time didn't move: $newWatermarkMs < " + - s"$batchWatermarkMs") - } - } - } - offsetSeqMetadata = offsetSeqMetadata.copy( - batchWatermarkMs = batchWatermarkMs, - batchTimestampMs = triggerClock.getTimeMillis()) // Current batch timestamp in milliseconds + val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() } + (s, Option(currentOffset)) + }.toMap + availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) + + // Update the query metadata + offsetSeqMetadata = offsetSeqMetadata.copy( + batchWatermarkMs = watermarkTracker.currentWatermark, + batchTimestampMs = triggerClock.getTimeMillis()) + + // Check whether next batch should be constructed + val lastExecutionRequiresAnotherBatch = noDataBatchesEnables && + Option(lastExecution).exists(_.shouldRunAnotherBatch(offsetSeqMetadata)) + val shouldConstructNextBatch = isNewDataAvailable || lastExecutionRequiresAnotherBatch + if (shouldConstructNextBatch) { + // Commit the next batch offset range to the offset log updateStatusMessage("Writing offsets to log") reportTimeTaken("walCommit") { - assert(offsetLog.add( - currentBatchId, + assert(offsetLog.add(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata)), s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId") logInfo(s"Committed offsets for batch $currentBatchId. " + @@ -373,7 +359,7 @@ class MicroBatchExecution( reader.commit(reader.deserializeOffset(off.json)) } } else { - throw new IllegalStateException(s"batch $currentBatchId doesn't exist") + throw new IllegalStateException(s"batch ${currentBatchId - 1} doesn't exist") } } @@ -384,15 +370,12 @@ class MicroBatchExecution( commitLog.purge(currentBatchId - minLogEntriesToMaintain) } } + noNewData = false } else { - awaitProgressLock.lock() - try { - // Wake up any threads that are waiting for the stream to progress. - awaitProgressLockCondition.signalAll() - } finally { - awaitProgressLock.unlock() - } + noNewData = true + awaitProgressLockCondition.signalAll() } + shouldConstructNextBatch } /** @@ -400,6 +383,8 @@ class MicroBatchExecution( * @param sparkSessionToRunBatch Isolated [[SparkSession]] to run this batch with. */ private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = { + logDebug(s"Running batch $currentBatchId") + // Request unprocessed data from all sources. newData = reportTimeTaken("getBatch") { availableOffsets.flatMap { @@ -513,17 +498,17 @@ class MicroBatchExecution( } } - awaitProgressLock.lock() - try { - // Wake up any threads that are waiting for the stream to progress. + withProgressLocked { + commitLog.add(currentBatchId) + committedOffsets ++= availableOffsets awaitProgressLockCondition.signalAll() - } finally { - awaitProgressLock.unlock() } + watermarkTracker.updateWatermark(lastExecution.executedPlan) + logDebug(s"Completed batch ${currentBatchId}") } /** Execute a function while locking the stream from making an progress */ - private[sql] def withProgressLocked(f: => Unit): Unit = { + private[sql] def withProgressLocked[T](f: => T): T = { awaitProgressLock.lock() try { f diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala new file mode 100644 index 0000000000000..80865669558dd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.SparkPlan + +class WatermarkTracker extends Logging { + private val operatorToWatermarkMap = mutable.HashMap[Int, Long]() + private var watermarkMs: Long = 0 + private var updated = false + + def setWatermark(newWatermarkMs: Long): Unit = synchronized { + watermarkMs = newWatermarkMs + } + + def updateWatermark(executedPlan: SparkPlan): Unit = synchronized { + val watermarkOperators = executedPlan.collect { + case e: EventTimeWatermarkExec => e + } + if (watermarkOperators.isEmpty) return + + + watermarkOperators.zipWithIndex.foreach { + case (e, index) if e.eventTimeStats.value.count > 0 => + logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") + val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs + val prevWatermarkMs = operatorToWatermarkMap.get(index) + if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { + operatorToWatermarkMap.put(index, newWatermarkMs) + } + + // Populate 0 if we haven't seen any data yet for this watermark node. + case (_, index) => + if (!operatorToWatermarkMap.isDefinedAt(index)) { + operatorToWatermarkMap.put(index, 0) + } + } + + // Update the global watermark to the minimum of all watermark nodes. + // This is the safest option, because only the global watermark is fault-tolerant. Making + // it the minimum of all individual watermarks guarantees it will never advance past where + // any individual watermark operator would be if it were in a plan by itself. + val newWatermarkMs = operatorToWatermarkMap.minBy(_._2)._2 + if (newWatermarkMs > watermarkMs) { + logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") + watermarkMs = newWatermarkMs + updated = true + } else { + logDebug(s"Event time didn't move: $newWatermarkMs < $watermarkMs") + updated = false + } + } + + def currentWatermark: Long = synchronized { watermarkMs } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 628923d367ce7..22258274c70c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -222,11 +222,20 @@ class MemoryStreamDataReaderFactory(records: Array[UnsafeRow]) } } +/** A common trait for MemorySinks with methods used for testing */ +trait MemorySinkBase extends BaseStreamingSink { + def allData: Seq[Row] + def latestBatchData: Seq[Row] + def dataSinceBatch(sinceBatchId: Long): Seq[Row] + def latestBatchId: Option[Long] +} + /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink with Logging { +class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink + with MemorySinkBase with Logging { private case class AddedData(batchId: Long, data: Array[Row]) @@ -236,7 +245,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { - batches.map(_.data).flatten + batches.flatMap(_.data) } def latestBatchId: Option[Long] = synchronized { @@ -245,6 +254,10 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi def latestBatchData: Seq[Row] = synchronized { batches.lastOption.toSeq.flatten(_.data) } + def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized { + batches.filter(_.batchId > sinceBatchId).flatMap(_.data) + } + def toDebugString: String = synchronized { batches.map { case AddedData(batchId, data) => val dataStr = try data.mkString(" ") catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index d871d37ad37c1..0d6c239274dd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} -import org.apache.spark.sql.execution.streaming.Sink +import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -39,7 +39,7 @@ import org.apache.spark.sql.types.StructType * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { +class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkBase with Logging { override def createStreamWriter( queryId: String, schema: StructType, @@ -67,6 +67,10 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { batches.lastOption.toSeq.flatten(_.data) } + def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized { + batches.filter(_.batchId > sinceBatchId).flatMap(_.data) + } + def toDebugString: String = synchronized { batches.map { case AddedData(batchId, data) => val dataStr = try data.mkString(" ") catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index c9354ac0ec78a..1691a6320a526 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -126,6 +126,12 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => name -> SQLMetrics.createTimingMetric(sparkContext, desc) }.toMap } + + /** + * Should the MicroBatchExecution run another batch based on this stateful operator and the + * current updated metadata. + */ + def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = false } /** An operator that supports watermark. */ @@ -388,6 +394,12 @@ case class StateStoreSaveExec( ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } + + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + (outputMode.contains(Append) || outputMode.contains(Update)) && + eventTimeWatermark.isDefined && + newMetadata.batchWatermarkMs > eventTimeWatermark.get + } } /** Physical operator for executing streaming Deduplicate. */ @@ -454,6 +466,10 @@ case class StreamingDeduplicateExec( override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning + + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + } } object StreamingDeduplicateExec { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala index 03bf71b3f4b78..e60c339bc9cc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala @@ -211,14 +211,12 @@ class ForeachWriterSuite extends StreamTest with SharedSQLContext with BeforeAnd try { inputData.addData(10, 11, 12) query.processAllAvailable() - inputData.addData(25) // Advance watermark to 15 seconds - query.processAllAvailable() inputData.addData(25) // Evict items less than previous watermark query.processAllAvailable() // There should be 3 batches and only does the last batch contain a value. val allEvents = ForeachWriterSuite.allEvents() - assert(allEvents.size === 3) + assert(allEvents.size === 4) val expectedEvents = Seq( Seq( ForeachWriterSuite.Open(partition = 0, version = 0), @@ -230,6 +228,10 @@ class ForeachWriterSuite extends StreamTest with SharedSQLContext with BeforeAnd ), Seq( ForeachWriterSuite.Open(partition = 0, version = 2), + ForeachWriterSuite.Close(None) + ), + Seq( + ForeachWriterSuite.Open(partition = 0, version = 3), ForeachWriterSuite.Process(value = 3), ForeachWriterSuite.Close(None) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index d6bef9ce07379..7e8fde1ff8e56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matchers with Logging { @@ -137,20 +138,12 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche assert(e.get("watermark") === formatTimestamp(5)) }, AddData(inputData2, 25), - CheckAnswer(), - assertEventStats { e => - assert(e.get("max") === formatTimestamp(25)) - assert(e.get("min") === formatTimestamp(25)) - assert(e.get("avg") === formatTimestamp(25)) - assert(e.get("watermark") === formatTimestamp(5)) - }, - AddData(inputData2, 25), CheckAnswer((10, 3)), assertEventStats { e => assert(e.get("max") === formatTimestamp(25)) assert(e.get("min") === formatTimestamp(25)) assert(e.get("avg") === formatTimestamp(25)) - assert(e.get("watermark") === formatTimestamp(15)) + assert(e.get("watermark") === formatTimestamp(5)) } ) } @@ -167,15 +160,12 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(windowedAggregation)( AddData(inputData, 10, 11, 12, 13, 14, 15), - CheckLastBatch(), + CheckNewAnswer(), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch(), - assertNumStateRows(3), - AddData(inputData, 25), // Emit items less than watermark and drop their state - CheckLastBatch((10, 5)), + CheckNewAnswer((10, 5)), assertNumStateRows(2), AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckLastBatch(), + CheckNewAnswer(), assertNumStateRows(2) ) } @@ -193,15 +183,15 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(windowedAggregation, OutputMode.Update)( AddData(inputData, 10, 11, 12, 13, 14, 15), - CheckLastBatch((10, 5), (15, 1)), + CheckNewAnswer((10, 5), (15, 1)), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch((25, 1)), - assertNumStateRows(3), + CheckNewAnswer((25, 1)), + assertNumStateRows(2), AddData(inputData, 10, 25), // Ignore 10 as its less than watermark - CheckLastBatch((25, 2)), + CheckNewAnswer((25, 2)), assertNumStateRows(2), AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckLastBatch(), + CheckNewAnswer(), assertNumStateRows(2) ) } @@ -251,56 +241,25 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(df)( AddData(inputData, 10, 11, 12, 13, 14, 15), - CheckLastBatch(), + CheckAnswer(), AddData(inputData, 25), // Advance watermark to 15 seconds - StopStream, - StartStream(), - CheckLastBatch(), - AddData(inputData, 25), // Evict items less than previous watermark. - CheckLastBatch((10, 5)), + CheckAnswer((10, 5)), StopStream, AssertOnQuery { q => // purge commit and clear the sink - val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + 1L + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) q.commitLog.purge(commit) q.sink.asInstanceOf[MemorySink].clear() true }, StartStream(), - CheckLastBatch((10, 5)), // Recompute last batch and re-evict timestamp 10 - AddData(inputData, 30), // Advance watermark to 20 seconds - CheckLastBatch(), + AddData(inputData, 10, 27, 30), // Advance watermark to 20 seconds, 10 should be ignored + CheckAnswer((15, 1)), StopStream, - StartStream(), // Watermark should still be 15 seconds - AddData(inputData, 17), - CheckLastBatch(), // We still do not see next batch - AddData(inputData, 30), // Advance watermark to 20 seconds - CheckLastBatch(), - AddData(inputData, 30), // Evict items less than previous watermark. - CheckLastBatch((15, 2)) // Ensure we see next window - ) - } - - test("dropping old data") { - val inputData = MemoryStream[Int] - - val windowedAggregation = inputData.toDF() - .withColumn("eventTime", $"value".cast("timestamp")) - .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) - - testStream(windowedAggregation)( - AddData(inputData, 10, 11, 12), - CheckAnswer(), - AddData(inputData, 25), // Advance watermark to 15 seconds - CheckAnswer(), - AddData(inputData, 25), // Evict items less than previous watermark. - CheckAnswer((10, 3)), - AddData(inputData, 10), // 10 is later than 15 second watermark - CheckAnswer((10, 3)), - AddData(inputData, 25), - CheckAnswer((10, 3)) // Should not emit an incorrect partial result. + StartStream(), + AddData(inputData, 17), // Watermark should still be 20 seconds, 17 should be ignored + CheckAnswer((15, 1)), + AddData(inputData, 40), // Advance watermark to 30 seconds, emit first data 25 + CheckNewAnswer((25, 2)) ) } @@ -421,8 +380,6 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche AddData(inputData, 10), CheckAnswer(), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckAnswer(), - AddData(inputData, 25), // Evict items less than previous watermark. CheckAnswer((10, 1)) ) } @@ -501,8 +458,35 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche } } + test("test no-data flag") { + val flagKey = SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key + + def testWithFlag(flag: Boolean): Unit = withClue(s"with $flagKey = $flag") { + val inputData = MemoryStream[Int] + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + StartStream(additionalConfs = Map(flagKey -> flag.toString)), + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckNewAnswer(), + AddData(inputData, 25), // Advance watermark to 15 seconds + // Check if there is new answer if flag is set, no new answer otherwise + if (flag) CheckNewAnswer((10, 5)) else CheckNewAnswer() + ) + } + + testWithFlag(true) + testWithFlag(false) + } + private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q => - val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get + q.processAllAvailable() + val progressWithData = q.recentProgress.lastOption.get assert(progressWithData.stateOperators(0).numRowsTotal === numTotalRows) true } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index cf41d7e0e4fe1..ed53def556cb8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -279,13 +279,10 @@ class FileStreamSinkSuite extends StreamTest { check() // nothing emitted yet addTimestamp(104, 123) // watermark = 90 before this, watermark = 123 - 10 = 113 after this - check() // nothing emitted yet + check((100L, 105L) -> 2L) // no-data-batch emits results on 100-105, addTimestamp(140) // wm = 113 before this, emit results on 100-105, wm = 130 after this - check((100L, 105L) -> 2L) - - addTimestamp(150) // wm = 130s before this, emit results on 120-125, wm = 150 after this - check((100L, 105L) -> 2L, (120L, 125L) -> 1L) + check((100L, 105L) -> 2L, (120L, 125L) -> 1L) // no-data-batch emits results on 120-125 } finally { if (query != null) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala index 368c4604dfca8..e45f9d3e2e97b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala @@ -17,20 +17,58 @@ package org.apache.spark.sql.streaming +import org.apache.spark.sql.execution.streaming.StreamExecution + trait StateStoreMetricsTest extends StreamTest { + private var lastCheckedRecentProgressIndex = -1 + private var lastQuery: StreamExecution = null + + override def beforeEach(): Unit = { + super.beforeEach() + lastCheckedRecentProgressIndex = -1 + } + def assertNumStateRows(total: Seq[Long], updated: Seq[Long]): AssertOnQuery = AssertOnQuery(s"Check total state rows = $total, updated state rows = $updated") { q => - val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get - assert( - progressWithData.stateOperators.map(_.numRowsTotal) === total, - "incorrect total rows") - assert( - progressWithData.stateOperators.map(_.numRowsUpdated) === updated, - "incorrect updates rows") + val recentProgress = q.recentProgress + require(recentProgress.nonEmpty, "No progress made, cannot check num state rows") + require(recentProgress.length < spark.sessionState.conf.streamingProgressRetention, + "This test assumes that all progresses are present in q.recentProgress but " + + "some may have been dropped due to retention limits") + + if (q.ne(lastQuery)) lastCheckedRecentProgressIndex = -1 + lastQuery = q + + val numStateOperators = recentProgress.last.stateOperators.length + val progressesSinceLastCheck = recentProgress + .slice(lastCheckedRecentProgressIndex + 1, recentProgress.length) + .filter(_.stateOperators.length == numStateOperators) + + val allNumUpdatedRowsSinceLastCheck = + progressesSinceLastCheck.map(_.stateOperators.map(_.numRowsUpdated)) + + lazy val debugString = "recent progresses:\n" + + progressesSinceLastCheck.map(_.prettyJson).mkString("\n\n") + + val numTotalRows = recentProgress.last.stateOperators.map(_.numRowsTotal) + assert(numTotalRows === total, s"incorrect total rows, $debugString") + + val numUpdatedRows = arraySum(allNumUpdatedRowsSinceLastCheck, numStateOperators) + assert(numUpdatedRows === updated, s"incorrect updates rows, $debugString") + + lastCheckedRecentProgressIndex = recentProgress.length - 1 true } def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = assertNumStateRows(Seq(total), Seq(updated)) + + def arraySum(arraySeq: Seq[Array[Long]], arrayLength: Int): Seq[Long] = { + if (arraySeq.isEmpty) return Seq.fill(arrayLength)(0L) + + assert(arraySeq.forall(_.length == arrayLength), + "Arrays are of different lengths:\n" + arraySeq.map(_.toSeq).mkString("\n")) + (0 until arrayLength).map { index => arraySeq.map(_.apply(index)).sum } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index af0268fa47871..9d139a927bea5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -192,7 +193,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be case class CheckAnswerRowsContains(expectedAnswer: Seq[Row], lastOnly: Boolean = false) extends StreamAction with StreamMustBeRunning { override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" - private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" + private def operatorName = if (lastOnly) "CheckLastBatchContains" else "CheckAnswerContains" } case class CheckAnswerRowsByFunc( @@ -202,6 +203,23 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" } + case class CheckNewAnswerRows(expectedAnswer: Seq[Row]) + extends StreamAction with StreamMustBeRunning { + override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" + + private def operatorName = "CheckNewAnswer" + } + + object CheckNewAnswer { + def apply(): CheckNewAnswerRows = CheckNewAnswerRows(Seq.empty) + + def apply[A: Encoder](data: A, moreData: A*): CheckNewAnswerRows = { + val encoder = encoderFor[A] + val toExternalRow = RowEncoder(encoder.schema).resolveAndBind() + CheckNewAnswerRows((data +: moreData).map(d => toExternalRow.fromRow(encoder.toRow(d)))) + } + } + /** Stops the stream. It must currently be running. */ case object StopStream extends StreamAction with StreamMustBeRunning @@ -435,13 +453,24 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be """.stripMargin) } - def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = { + var lastFetchedMemorySinkLastBatchId: Long = -1 + + def fetchStreamAnswer( + currentStream: StreamExecution, + lastOnly: Boolean = false, + sinceLastFetchOnly: Boolean = false) = { + verify( + !(lastOnly && sinceLastFetchOnly), "both lastOnly and sinceLastFetchOnly cannot be true") verify(currentStream != null, "stream not running") // Block until all data added has been processed for all the source awaiting.foreach { case (sourceIndex, offset) => failAfter(streamingTimeout) { currentStream.awaitOffset(sourceIndex, offset) + // Make sure all processing including no-data-batches have been executed + if (!currentStream.triggerClock.isInstanceOf[StreamManualClock]) { + currentStream.processAllAvailable() + } } } @@ -463,14 +492,21 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } } - val (latestBatchData, allData) = sink match { - case s: MemorySink => (s.latestBatchData, s.allData) - case s: MemorySinkV2 => (s.latestBatchData, s.allData) - } - try if (lastOnly) latestBatchData else allData catch { + val rows = try { + if (sinceLastFetchOnly) { + if (sink.latestBatchId.getOrElse(-1L) < lastFetchedMemorySinkLastBatchId) { + failTest("MemorySink was probably cleared since last fetch. Use CheckAnswer instead.") + } + sink.dataSinceBatch(lastFetchedMemorySinkLastBatchId) + } else { + if (lastOnly) sink.latestBatchData else sink.allData + } + } catch { case e: Exception => failTest("Exception while getting data from sink", e) } + lastFetchedMemorySinkLastBatchId = sink.latestBatchId.getOrElse(-1L) + rows } def executeAction(action: StreamAction): Unit = { @@ -704,6 +740,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } catch { case e: Throwable => failTest(e.toString) } + + case CheckNewAnswerRows(expectedAnswer) => + val sparkAnswer = fetchStreamAnswer(currentStream, sinceLastFetchOnly = true) + QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach { + error => failTest(error) + } } pos += 1 } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala similarity index 80% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala index 0088b64d6195e..42ffd472eb843 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala @@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplicateExec} import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf -class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class StreamingDeduplicationSuite extends StateStoreMetricsTest with BeforeAndAfterAll { import testImplicits._ @@ -97,28 +98,20 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { testStream(result, Append)( AddData(inputData, (1 to 5).flatMap(_ => (10 to 15)): _*), - CheckLastBatch(10 to 15: _*), + CheckAnswer(10 to 15: _*), assertNumStateRows(total = 6, updated = 6), - AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch(25), - assertNumStateRows(total = 7, updated = 1), - - AddData(inputData, 25), // Drop states less than watermark - CheckLastBatch(), - assertNumStateRows(total = 1, updated = 0), + AddData(inputData, 25), // Advance watermark to 15 secs, no-data-batch drops rows <= 15 + CheckNewAnswer(25), + assertNumStateRows(total = 1, updated = 1), AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckLastBatch(), + CheckNewAnswer(), assertNumStateRows(total = 1, updated = 0), - AddData(inputData, 45), // Advance watermark to 35 seconds - CheckLastBatch(45), - assertNumStateRows(total = 2, updated = 1), - - AddData(inputData, 45), // Drop states less than watermark - CheckLastBatch(), - assertNumStateRows(total = 1, updated = 0) + AddData(inputData, 45), // Advance watermark to 35 seconds, no-data-batch drops row 25 + CheckNewAnswer(45), + assertNumStateRows(total = 1, updated = 1) ) } @@ -141,33 +134,20 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { assertNumStateRows(total = Seq(2L, 6L), updated = Seq(2L, 6L)), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch(), - // states in aggregate in [10, 14), [15, 20) and [25, 30) (3 windows) - // states in deduplicate is 10 to 15 and 25 - assertNumStateRows(total = Seq(3L, 7L), updated = Seq(1L, 1L)), - - AddData(inputData, 25), // Emit items less than watermark and drop their state - CheckLastBatch((10 -> 5)), // 5 items (10 to 14) after deduplicate - // states in aggregate in [15, 20) and [25, 30) (2 windows, note aggregate uses the end of - // window to evict items, so [15, 20) is still in the state store) - // states in deduplicate is 25 - assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)), + CheckLastBatch((10 -> 5)), // 5 items (10 to 14) after deduplicate, emitted with no-data-batch + // states in aggregate in [15, 20) and [25, 30); no-data-batch removed [10, 14) + // states in deduplicate is 25, no-data-batch removed 10 to 14 + assertNumStateRows(total = Seq(2L, 1L), updated = Seq(1L, 1L)), AddData(inputData, 10), // Should not emit anything as data less than watermark CheckLastBatch(), assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)), AddData(inputData, 40), // Advance watermark to 30 seconds - CheckLastBatch(), - // states in aggregate in [15, 20), [25, 30) and [40, 45) - // states in deduplicate is 25 and 40, - assertNumStateRows(total = Seq(3L, 2L), updated = Seq(1L, 1L)), - - AddData(inputData, 40), // Emit items less than watermark and drop their state CheckLastBatch((15 -> 1), (25 -> 1)), - // states in aggregate in [40, 45) - // states in deduplicate is 40, - assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L)) + // states in aggregate is [40, 45); no-data-batch removed [15, 20) and [25, 30) + // states in deduplicate is 40; no-data-batch removed 25 + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)) ) } @@ -260,13 +240,13 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { .select($"id") testStream(df)( AddData(input, 1 -> 1, 1 -> 1, 1 -> 2), - CheckLastBatch(1, 2), + CheckAnswer(1, 2), AddData(input, 1 -> 1, 2 -> 3, 2 -> 4), - CheckLastBatch(3, 4), + CheckNewAnswer(3, 4), AddData(input, 1 -> 0, 1 -> 1, 3 -> 5, 3 -> 6), // Drop (1 -> 0, 1 -> 1) due to watermark - CheckLastBatch(5, 6), + CheckNewAnswer(5, 6), AddData(input, 1 -> 0, 4 -> 7), // Drop (1 -> 0) due to watermark - CheckLastBatch(7) + CheckNewAnswer(7) ) } @@ -279,7 +259,37 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { .select($"id", $"time".cast("long")) testStream(df)( AddData(input, 1 -> 1, 1 -> 2, 2 -> 2), - CheckLastBatch(1 -> 1, 2 -> 2) + CheckAnswer(1 -> 1, 2 -> 2) ) } + + test("test no-data flag") { + val flagKey = SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key + + def testWithFlag(flag: Boolean): Unit = withClue(s"with $flagKey = $flag") { + val inputData = MemoryStream[Int] + val result = inputData.toDS() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .dropDuplicates() + .select($"eventTime".cast("long").as[Long]) + + testStream(result, Append)( + StartStream(additionalConfs = Map(flagKey -> flag.toString)), + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckAnswer(10, 11, 12, 13, 14, 15), + assertNumStateRows(total = 6, updated = 6), + + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckNewAnswer(25), + { // State should have been cleaned if flag is set, otherwise should not have been cleaned + if (flag) assertNumStateRows(total = 1, updated = 1) + else assertNumStateRows(total = 7, updated = 1) + } + ) + } + + testWithFlag(true) + testWithFlag(false) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 11bdd13942dcb..da8f9608c1e9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -192,7 +192,7 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with CheckLastBatch((1, 5, 11)), AddData(rightInput, (1, 10)), CheckLastBatch(), // no match as neither 5, nor 10 from leftTime is less than rightTime 10 - 5 - assertNumStateRows(total = 3, updated = 1), + assertNumStateRows(total = 3, updated = 3), // Increase event time watermark to 20s by adding data with time = 30s on both inputs AddData(leftInput, (1, 3), (1, 30)), @@ -276,14 +276,14 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with CheckAnswer(), AddData(rightInput, (1, 14), (1, 15), (1, 25), (1, 26), (1, 30), (1, 31)), CheckLastBatch((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)), - assertNumStateRows(total = 7, updated = 6), + assertNumStateRows(total = 7, updated = 7), // If rightTime = 60, then it matches only leftTime = [50, 65] AddData(rightInput, (1, 60)), CheckLastBatch(), // matches with nothing on the left AddData(leftInput, (1, 49), (1, 50), (1, 65), (1, 66)), CheckLastBatch((1, 50, 60), (1, 65, 60)), - assertNumStateRows(total = 12, updated = 4), + assertNumStateRows(total = 12, updated = 5), // Event time watermark = min(left: 66 - delay 20 = 46, right: 60 - delay 30 = 30) = 30 // Left state value watermark = 30 - 10 = slightly less than 20 (since condition has <=) @@ -573,7 +573,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with // nulls won't show up until the next batch after the watermark advances. MultiAddData(leftInput, 21)(rightInput, 22), CheckLastBatch(), - assertNumStateRows(total = 12, updated = 2), + assertNumStateRows(total = 12, updated = 12), AddData(leftInput, 22), CheckLastBatch(Row(22, 30, 44, 66), Row(1, 10, 2, null), Row(2, 10, 4, null)), assertNumStateRows(total = 3, updated = 1) @@ -591,7 +591,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with // nulls won't show up until the next batch after the watermark advances. MultiAddData(leftInput, 21)(rightInput, 22), CheckLastBatch(), - assertNumStateRows(total = 12, updated = 2), + assertNumStateRows(total = 12, updated = 12), AddData(leftInput, 22), CheckLastBatch(Row(22, 30, 44, 66), Row(6, 10, null, 18), Row(7, 10, null, 21)), assertNumStateRows(total = 3, updated = 1) @@ -630,7 +630,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with CheckLastBatch((1, 1, 5, 10)), AddData(rightInput, (1, 11)), CheckLastBatch(), // no match as left time is too low - assertNumStateRows(total = 5, updated = 1), + assertNumStateRows(total = 5, updated = 5), // Increase event time watermark to 20s by adding data with time = 30s on both inputs AddData(leftInput, (1, 7), (1, 30)), @@ -668,7 +668,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), - assertNumStateRows(total = 5, updated = 2), + assertNumStateRows(total = 5, updated = 5), // 1...3 added, but 20 and 21 not added AddData(rightInput, 20), CheckLastBatch( Row(20, 30, 40, 60)), @@ -678,7 +678,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)), MultiAddData(leftInput, 70)(rightInput, 71), CheckLastBatch(), - assertNumStateRows(total = 6, updated = 2), + assertNumStateRows(total = 6, updated = 6), // all inputs added since last check AddData(rightInput, 70), CheckLastBatch((70, 80, 140, 210)), assertNumStateRows(total = 3, updated = 1), @@ -687,7 +687,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with CheckLastBatch(), MultiAddData(leftInput, 1000)(rightInput, 1001), CheckLastBatch(), - assertNumStateRows(total = 8, updated = 2), + assertNumStateRows(total = 8, updated = 5), // 101...103 added, but 1000 and 1001 not added AddData(rightInput, 1000), CheckLastBatch( Row(1000, 1010, 2000, 3000), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 390d67d1feb27..0cb2375e0a49a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -334,7 +334,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.sources.length === 1) assert(progress.sources(0).description contains "MemoryStream") - assert(progress.sources(0).startOffset === null) + assert(progress.sources(0).startOffset === "0") assert(progress.sources(0).endOffset !== null) assert(progress.sources(0).processedRowsPerSecond === 4.0) // 2 rows processed in 500 ms From dd4b1b9c7ccad3363a6a21524aed047fcd282f68 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 6 May 2018 10:25:01 +0800 Subject: [PATCH 739/774] [SPARK-24185][SPARKR][SQL] add flatten function to SparkR ## What changes were proposed in this pull request? add array flatten function to SparkR ## How was this patch tested? Unit tests were added in R/pkg/tests/fulltests/test_sparkSQL.R Author: Huaxin Gao Closes #21244 from huaxingao/spark-24185. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 14 ++++++++++++++ R/pkg/R/generics.R | 4 ++++ R/pkg/tests/fulltests/test_sparkSQL.R | 6 ++++++ 4 files changed, 25 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index f36d462a83cb0..8cd00352d1956 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -258,6 +258,7 @@ exportMethods("%<=>%", "expr", "factorial", "first", + "flatten", "floor", "format_number", "format_string", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index ec4bd4e73c7e5..0ec99d19e21e4 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -208,6 +208,7 @@ NULL #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) #' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) #' head(select(tmp, array_position(tmp$v1, 21))) +#' head(select(tmp, flatten(tmp$v1))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) @@ -3035,6 +3036,19 @@ setMethod("array_position", column(jc) }) +#' @details +#' \code{flatten}: Transforms an array of arrays into a single array. +#' +#' @rdname column_collection_functions +#' @aliases flatten flatten,Column-method +#' @note flatten since 2.4.0 +setMethod("flatten", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "flatten", x@jc) + column(jc) + }) + #' @details #' \code{map_keys}: Returns an unordered array containing the keys of the map. #' diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 562d3399ee9c8..4ef12d19b3575 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -918,6 +918,10 @@ setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) #' @name NULL setGeneric("expr", function(x) { standardGeneric("expr") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("flatten", function(x) { standardGeneric("flatten") }) + #' @rdname column_datetime_diff_functions #' @name NULL setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 8cc2db7a140f9..3a8866bf2a88a 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1502,6 +1502,12 @@ test_that("column functions", { result <- collect(select(df, sort_array(df[[1]])))[[1]] expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) + # Test flattern + df <- createDataFrame(list(list(list(list(1L, 2L), list(3L, 4L))), + list(list(list(5L, 6L), list(7L, 8L))))) + result <- collect(select(df, flatten(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L, 4L), list(5L, 6L, 7L, 8L))) + # Test map_keys(), map_values() and element_at() df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) result <- collect(select(df, map_keys(df$map)))[[1]] From f38ea00e83099a5ae8d3afdec2e896e43c2db612 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 6 May 2018 20:41:32 -0700 Subject: [PATCH 740/774] [SPARK-24017][SQL] Refactor ExternalCatalog to be an interface ## What changes were proposed in this pull request? This refactors the external catalog to be an interface. It can be easier for the future work in the catalog federation. After the refactoring, `ExternalCatalog` is much cleaner without mixing the listener event generation logic. ## How was this patch tested? The existing tests Author: gatorsmile Closes #21122 from gatorsmile/refactorExternalCatalog. --- .../catalyst/catalog/ExternalCatalog.scala | 134 ++------ .../catalog/ExternalCatalogWithListener.scala | 298 ++++++++++++++++++ .../catalyst/catalog/InMemoryCatalog.scala | 26 +- .../catalog/ExternalCatalogEventSuite.scala | 2 +- .../spark/sql/internal/SharedState.scala | 9 +- .../sql/hive/thriftserver/SparkSQLEnv.scala | 2 +- .../spark/sql/hive/HiveExternalCatalog.scala | 26 +- .../spark/sql/hive/HiveSessionCatalog.scala | 4 +- .../sql/hive/HiveSessionStateBuilder.scala | 7 +- .../sql/hive/execution/SaveAsHiveFile.scala | 2 +- .../apache/spark/sql/hive/test/TestHive.scala | 11 +- .../HiveExternalSessionCatalogSuite.scala | 2 +- .../sql/hive/HiveSchemaInferenceSuite.scala | 3 +- .../sql/hive/HiveSessionStateSuite.scala | 3 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 5 +- .../spark/sql/hive/ShowCreateTableSuite.scala | 3 +- .../spark/sql/hive/client/VersionsSuite.scala | 4 +- .../sql/hive/execution/HiveDDLSuite.scala | 6 +- .../sql/hive/execution/SQLQuerySuite.scala | 3 +- .../sql/hive/test/TestHiveSingleton.scala | 2 +- 20 files changed, 384 insertions(+), 168 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 45b4f013620c1..1a145c24d78cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.catalyst.catalog -import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException} +import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchPartitionException, NoSuchTableException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.StructType -import org.apache.spark.util.ListenerBus /** * Interface for the system catalog (of functions, partitions, tables, and databases). @@ -31,10 +30,13 @@ import org.apache.spark.util.ListenerBus * * Implementations should throw [[NoSuchDatabaseException]] when databases don't exist. */ -abstract class ExternalCatalog - extends ListenerBus[ExternalCatalogEventListener, ExternalCatalogEvent] { +trait ExternalCatalog { import CatalogTypes.TablePartitionSpec + // -------------------------------------------------------------------------- + // Utils + // -------------------------------------------------------------------------- + protected def requireDbExists(db: String): Unit = { if (!databaseExists(db)) { throw new NoSuchDatabaseException(db) @@ -63,22 +65,9 @@ abstract class ExternalCatalog // Databases // -------------------------------------------------------------------------- - final def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { - val db = dbDefinition.name - postToAll(CreateDatabasePreEvent(db)) - doCreateDatabase(dbDefinition, ignoreIfExists) - postToAll(CreateDatabaseEvent(db)) - } + def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit - protected def doCreateDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit - - final def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { - postToAll(DropDatabasePreEvent(db)) - doDropDatabase(db, ignoreIfNotExists, cascade) - postToAll(DropDatabaseEvent(db)) - } - - protected def doDropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit + def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit /** * Alter a database whose name matches the one specified in `dbDefinition`, @@ -87,14 +76,7 @@ abstract class ExternalCatalog * Note: If the underlying implementation does not support altering a certain field, * this becomes a no-op. */ - final def alterDatabase(dbDefinition: CatalogDatabase): Unit = { - val db = dbDefinition.name - postToAll(AlterDatabasePreEvent(db)) - doAlterDatabase(dbDefinition) - postToAll(AlterDatabaseEvent(db)) - } - - protected def doAlterDatabase(dbDefinition: CatalogDatabase): Unit + def alterDatabase(dbDefinition: CatalogDatabase): Unit def getDatabase(db: String): CatalogDatabase @@ -110,41 +92,15 @@ abstract class ExternalCatalog // Tables // -------------------------------------------------------------------------- - final def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { - val db = tableDefinition.database - val name = tableDefinition.identifier.table - val tableDefinitionWithVersion = - tableDefinition.copy(createVersion = org.apache.spark.SPARK_VERSION) - postToAll(CreateTablePreEvent(db, name)) - doCreateTable(tableDefinitionWithVersion, ignoreIfExists) - postToAll(CreateTableEvent(db, name)) - } - - protected def doCreateTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit - - final def dropTable( - db: String, - table: String, - ignoreIfNotExists: Boolean, - purge: Boolean): Unit = { - postToAll(DropTablePreEvent(db, table)) - doDropTable(db, table, ignoreIfNotExists, purge) - postToAll(DropTableEvent(db, table)) - } + def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit - protected def doDropTable( + def dropTable( db: String, table: String, ignoreIfNotExists: Boolean, purge: Boolean): Unit - final def renameTable(db: String, oldName: String, newName: String): Unit = { - postToAll(RenameTablePreEvent(db, oldName, newName)) - doRenameTable(db, oldName, newName) - postToAll(RenameTableEvent(db, oldName, newName)) - } - - protected def doRenameTable(db: String, oldName: String, newName: String): Unit + def renameTable(db: String, oldName: String, newName: String): Unit /** * Alter a table whose database and name match the ones specified in `tableDefinition`, assuming @@ -154,15 +110,7 @@ abstract class ExternalCatalog * Note: If the underlying implementation does not support altering a certain field, * this becomes a no-op. */ - final def alterTable(tableDefinition: CatalogTable): Unit = { - val db = tableDefinition.database - val name = tableDefinition.identifier.table - postToAll(AlterTablePreEvent(db, name, AlterTableKind.TABLE)) - doAlterTable(tableDefinition) - postToAll(AlterTableEvent(db, name, AlterTableKind.TABLE)) - } - - protected def doAlterTable(tableDefinition: CatalogTable): Unit + def alterTable(tableDefinition: CatalogTable): Unit /** * Alter the data schema of a table identified by the provided database and table name. The new @@ -173,22 +121,10 @@ abstract class ExternalCatalog * @param table Name of table to alter schema for * @param newDataSchema Updated data schema to be used for the table. */ - final def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit = { - postToAll(AlterTablePreEvent(db, table, AlterTableKind.DATASCHEMA)) - doAlterTableDataSchema(db, table, newDataSchema) - postToAll(AlterTableEvent(db, table, AlterTableKind.DATASCHEMA)) - } - - protected def doAlterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit + def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ - final def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit = { - postToAll(AlterTablePreEvent(db, table, AlterTableKind.STATS)) - doAlterTableStats(db, table, stats) - postToAll(AlterTableEvent(db, table, AlterTableKind.STATS)) - } - - protected def doAlterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit + def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit def getTable(db: String, table: String): CatalogTable @@ -340,49 +276,17 @@ abstract class ExternalCatalog // Functions // -------------------------------------------------------------------------- - final def createFunction(db: String, funcDefinition: CatalogFunction): Unit = { - val name = funcDefinition.identifier.funcName - postToAll(CreateFunctionPreEvent(db, name)) - doCreateFunction(db, funcDefinition) - postToAll(CreateFunctionEvent(db, name)) - } + def createFunction(db: String, funcDefinition: CatalogFunction): Unit - protected def doCreateFunction(db: String, funcDefinition: CatalogFunction): Unit + def dropFunction(db: String, funcName: String): Unit - final def dropFunction(db: String, funcName: String): Unit = { - postToAll(DropFunctionPreEvent(db, funcName)) - doDropFunction(db, funcName) - postToAll(DropFunctionEvent(db, funcName)) - } + def alterFunction(db: String, funcDefinition: CatalogFunction): Unit - protected def doDropFunction(db: String, funcName: String): Unit - - final def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = { - val name = funcDefinition.identifier.funcName - postToAll(AlterFunctionPreEvent(db, name)) - doAlterFunction(db, funcDefinition) - postToAll(AlterFunctionEvent(db, name)) - } - - protected def doAlterFunction(db: String, funcDefinition: CatalogFunction): Unit - - final def renameFunction(db: String, oldName: String, newName: String): Unit = { - postToAll(RenameFunctionPreEvent(db, oldName, newName)) - doRenameFunction(db, oldName, newName) - postToAll(RenameFunctionEvent(db, oldName, newName)) - } - - protected def doRenameFunction(db: String, oldName: String, newName: String): Unit + def renameFunction(db: String, oldName: String, newName: String): Unit def getFunction(db: String, funcName: String): CatalogFunction def functionExists(db: String, funcName: String): Boolean def listFunctions(db: String, pattern: String): Seq[String] - - override protected def doPostEvent( - listener: ExternalCatalogEventListener, - event: ExternalCatalogEvent): Unit = { - listener.onEvent(event) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala new file mode 100644 index 0000000000000..2f009be5816fa --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ListenerBus + +/** + * Wraps an ExternalCatalog to provide listener events. + */ +class ExternalCatalogWithListener(delegate: ExternalCatalog) + extends ExternalCatalog + with ListenerBus[ExternalCatalogEventListener, ExternalCatalogEvent] { + import CatalogTypes.TablePartitionSpec + + def unwrapped: ExternalCatalog = delegate + + override protected def doPostEvent( + listener: ExternalCatalogEventListener, + event: ExternalCatalogEvent): Unit = { + listener.onEvent(event) + } + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { + val db = dbDefinition.name + postToAll(CreateDatabasePreEvent(db)) + delegate.createDatabase(dbDefinition, ignoreIfExists) + postToAll(CreateDatabaseEvent(db)) + } + + override def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { + postToAll(DropDatabasePreEvent(db)) + delegate.dropDatabase(db, ignoreIfNotExists, cascade) + postToAll(DropDatabaseEvent(db)) + } + + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = { + val db = dbDefinition.name + postToAll(AlterDatabasePreEvent(db)) + delegate.alterDatabase(dbDefinition) + postToAll(AlterDatabaseEvent(db)) + } + + override def getDatabase(db: String): CatalogDatabase = { + delegate.getDatabase(db) + } + + override def databaseExists(db: String): Boolean = { + delegate.databaseExists(db) + } + + override def listDatabases(): Seq[String] = { + delegate.listDatabases() + } + + override def listDatabases(pattern: String): Seq[String] = { + delegate.listDatabases(pattern) + } + + override def setCurrentDatabase(db: String): Unit = { + delegate.setCurrentDatabase(db) + } + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + override def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { + val db = tableDefinition.database + val name = tableDefinition.identifier.table + val tableDefinitionWithVersion = + tableDefinition.copy(createVersion = org.apache.spark.SPARK_VERSION) + postToAll(CreateTablePreEvent(db, name)) + delegate.createTable(tableDefinitionWithVersion, ignoreIfExists) + postToAll(CreateTableEvent(db, name)) + } + + override def dropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean, + purge: Boolean): Unit = { + postToAll(DropTablePreEvent(db, table)) + delegate.dropTable(db, table, ignoreIfNotExists, purge) + postToAll(DropTableEvent(db, table)) + } + + override def renameTable(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameTablePreEvent(db, oldName, newName)) + delegate.renameTable(db, oldName, newName) + postToAll(RenameTableEvent(db, oldName, newName)) + } + + override def alterTable(tableDefinition: CatalogTable): Unit = { + val db = tableDefinition.database + val name = tableDefinition.identifier.table + postToAll(AlterTablePreEvent(db, name, AlterTableKind.TABLE)) + delegate.alterTable(tableDefinition) + postToAll(AlterTableEvent(db, name, AlterTableKind.TABLE)) + } + + override def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit = { + postToAll(AlterTablePreEvent(db, table, AlterTableKind.DATASCHEMA)) + delegate.alterTableDataSchema(db, table, newDataSchema) + postToAll(AlterTableEvent(db, table, AlterTableKind.DATASCHEMA)) + } + + override def alterTableStats( + db: String, + table: String, + stats: Option[CatalogStatistics]): Unit = { + postToAll(AlterTablePreEvent(db, table, AlterTableKind.STATS)) + delegate.alterTableStats(db, table, stats) + postToAll(AlterTableEvent(db, table, AlterTableKind.STATS)) + } + + override def getTable(db: String, table: String): CatalogTable = { + delegate.getTable(db, table) + } + + override def tableExists(db: String, table: String): Boolean = { + delegate.tableExists(db, table) + } + + override def listTables(db: String): Seq[String] = { + delegate.listTables(db) + } + + override def listTables(db: String, pattern: String): Seq[String] = { + delegate.listTables(db, pattern) + } + + override def loadTable( + db: String, + table: String, + loadPath: String, + isOverwrite: Boolean, + isSrcLocal: Boolean): Unit = { + delegate.loadTable(db, table, loadPath, isOverwrite, isSrcLocal) + } + + override def loadPartition( + db: String, + table: String, + loadPath: String, + partition: TablePartitionSpec, + isOverwrite: Boolean, + inheritTableSpecs: Boolean, + isSrcLocal: Boolean): Unit = { + delegate.loadPartition( + db, table, loadPath, partition, isOverwrite, inheritTableSpecs, isSrcLocal) + } + + override def loadDynamicPartitions( + db: String, + table: String, + loadPath: String, + partition: TablePartitionSpec, + replace: Boolean, + numDP: Int): Unit = { + delegate.loadDynamicPartitions(db, table, loadPath, partition, replace, numDP) + } + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + override def createPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition], + ignoreIfExists: Boolean): Unit = { + delegate.createPartitions(db, table, parts, ignoreIfExists) + } + + override def dropPartitions( + db: String, + table: String, + partSpecs: Seq[TablePartitionSpec], + ignoreIfNotExists: Boolean, + purge: Boolean, + retainData: Boolean): Unit = { + delegate.dropPartitions(db, table, partSpecs, ignoreIfNotExists, purge, retainData) + } + + override def renamePartitions( + db: String, + table: String, + specs: Seq[TablePartitionSpec], + newSpecs: Seq[TablePartitionSpec]): Unit = { + delegate.renamePartitions(db, table, specs, newSpecs) + } + + override def alterPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition]): Unit = { + delegate.alterPartitions(db, table, parts) + } + + override def getPartition( + db: String, + table: String, + spec: TablePartitionSpec): CatalogTablePartition = { + delegate.getPartition(db, table, spec) + } + + override def getPartitionOption( + db: String, + table: String, + spec: TablePartitionSpec): Option[CatalogTablePartition] = { + delegate.getPartitionOption(db, table, spec) + } + + override def listPartitionNames( + db: String, + table: String, + partialSpec: Option[TablePartitionSpec] = None): Seq[String] = { + delegate.listPartitionNames(db, table, partialSpec) + } + + override def listPartitions( + db: String, + table: String, + partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = { + delegate.listPartitions(db, table, partialSpec) + } + + override def listPartitionsByFilter( + db: String, + table: String, + predicates: Seq[Expression], + defaultTimeZoneId: String): Seq[CatalogTablePartition] = { + delegate.listPartitionsByFilter(db, table, predicates, defaultTimeZoneId) + } + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + override def createFunction(db: String, funcDefinition: CatalogFunction): Unit = { + val name = funcDefinition.identifier.funcName + postToAll(CreateFunctionPreEvent(db, name)) + delegate.createFunction(db, funcDefinition) + postToAll(CreateFunctionEvent(db, name)) + } + + override def dropFunction(db: String, funcName: String): Unit = { + postToAll(DropFunctionPreEvent(db, funcName)) + delegate.dropFunction(db, funcName) + postToAll(DropFunctionEvent(db, funcName)) + } + + override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = { + val name = funcDefinition.identifier.funcName + postToAll(AlterFunctionPreEvent(db, name)) + delegate.alterFunction(db, funcDefinition) + postToAll(AlterFunctionEvent(db, name)) + } + + override def renameFunction(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameFunctionPreEvent(db, oldName, newName)) + delegate.renameFunction(db, oldName, newName) + postToAll(RenameFunctionEvent(db, oldName, newName)) + } + + override def getFunction(db: String, funcName: String): CatalogFunction = { + delegate.getFunction(db, funcName) + } + + override def functionExists(db: String, funcName: String): Boolean = { + delegate.functionExists(db, funcName) + } + + override def listFunctions(db: String, pattern: String): Seq[String] = { + delegate.listFunctions(db, pattern) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 8eacfa058bd52..741dc46b07382 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -98,7 +98,7 @@ class InMemoryCatalog( // Databases // -------------------------------------------------------------------------- - override protected def doCreateDatabase( + override def createDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = synchronized { if (catalog.contains(dbDefinition.name)) { @@ -119,7 +119,7 @@ class InMemoryCatalog( } } - override protected def doDropDatabase( + override def dropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = synchronized { @@ -152,7 +152,7 @@ class InMemoryCatalog( } } - override def doAlterDatabase(dbDefinition: CatalogDatabase): Unit = synchronized { + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = synchronized { requireDbExists(dbDefinition.name) catalog(dbDefinition.name).db = dbDefinition } @@ -180,7 +180,7 @@ class InMemoryCatalog( // Tables // -------------------------------------------------------------------------- - override protected def doCreateTable( + override def createTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = synchronized { assert(tableDefinition.identifier.database.isDefined) @@ -221,7 +221,7 @@ class InMemoryCatalog( } } - override protected def doDropTable( + override def dropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -264,7 +264,7 @@ class InMemoryCatalog( } } - override protected def doRenameTable( + override def renameTable( db: String, oldName: String, newName: String): Unit = synchronized { @@ -294,7 +294,7 @@ class InMemoryCatalog( catalog(db).tables.remove(oldName) } - override def doAlterTable(tableDefinition: CatalogTable): Unit = synchronized { + override def alterTable(tableDefinition: CatalogTable): Unit = synchronized { assert(tableDefinition.identifier.database.isDefined) val db = tableDefinition.identifier.database.get requireTableExists(db, tableDefinition.identifier.table) @@ -303,7 +303,7 @@ class InMemoryCatalog( catalog(db).tables(tableDefinition.identifier.table).table = newTableDefinition } - override def doAlterTableDataSchema( + override def alterTableDataSchema( db: String, table: String, newDataSchema: StructType): Unit = synchronized { @@ -313,7 +313,7 @@ class InMemoryCatalog( catalog(db).tables(table).table = origTable.copy(schema = newSchema) } - override def doAlterTableStats( + override def alterTableStats( db: String, table: String, stats: Option[CatalogStatistics]): Unit = synchronized { @@ -564,24 +564,24 @@ class InMemoryCatalog( // Functions // -------------------------------------------------------------------------- - override protected def doCreateFunction(db: String, func: CatalogFunction): Unit = synchronized { + override def createFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) requireFunctionNotExists(db, func.identifier.funcName) catalog(db).functions.put(func.identifier.funcName, func) } - override protected def doDropFunction(db: String, funcName: String): Unit = synchronized { + override def dropFunction(db: String, funcName: String): Unit = synchronized { requireFunctionExists(db, funcName) catalog(db).functions.remove(funcName) } - override protected def doAlterFunction(db: String, func: CatalogFunction): Unit = synchronized { + override def alterFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) requireFunctionExists(db, func.identifier.funcName) catalog(db).functions.put(func.identifier.funcName, func) } - override protected def doRenameFunction( + override def renameFunction( db: String, oldName: String, newName: String): Unit = synchronized { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala index 1acbe34d9a075..2fcaeca34db3f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala @@ -36,7 +36,7 @@ class ExternalCatalogEventSuite extends SparkFunSuite { private def testWithCatalog( name: String)( f: (ExternalCatalog, Seq[ExternalCatalogEvent] => Unit) => Unit): Unit = test(name) { - val catalog = newCatalog + val catalog = new ExternalCatalogWithListener(newCatalog) val recorder = mutable.Buffer.empty[ExternalCatalogEvent] catalog.addListener(new ExternalCatalogEventListener { override def onEvent(event: ExternalCatalogEvent): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index baea4ceebf8e3..5b6160e2b408f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -99,7 +99,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { /** * A catalog that interacts with external systems. */ - lazy val externalCatalog: ExternalCatalog = { + lazy val externalCatalog: ExternalCatalogWithListener = { val externalCatalog = SharedState.reflect[ExternalCatalog, SparkConf, Configuration]( SharedState.externalCatalogClassName(sparkContext.conf), sparkContext.conf, @@ -117,14 +117,17 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { externalCatalog.createDatabase(defaultDbDefinition, ignoreIfExists = true) } + // Wrap to provide catalog events + val wrapped = new ExternalCatalogWithListener(externalCatalog) + // Make sure we propagate external catalog events to the spark listener bus - externalCatalog.addListener(new ExternalCatalogEventListener { + wrapped.addListener(new ExternalCatalogEventListener { override def onEvent(event: ExternalCatalogEvent): Unit = { sparkContext.listenerBus.post(event) } }) - externalCatalog + wrapped } /** diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index cbd75ad12d430..8980bcf885589 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -50,7 +50,7 @@ private[hive] object SparkSQLEnv extends Logging { sqlContext = sparkSession.sqlContext val metadataHive = sparkSession - .sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + .sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 28c340a176d91..011a3ba553cb2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -158,13 +158,13 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Databases // -------------------------------------------------------------------------- - override protected def doCreateDatabase( + override def createDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = withClient { client.createDatabase(dbDefinition, ignoreIfExists) } - override protected def doDropDatabase( + override def dropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = withClient { @@ -177,7 +177,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * * Note: As of now, this only supports altering database properties! */ - override def doAlterDatabase(dbDefinition: CatalogDatabase): Unit = withClient { + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = withClient { val existingDb = getDatabase(dbDefinition.name) if (existingDb.properties == dbDefinition.properties) { logWarning(s"Request to alter database ${dbDefinition.name} is a no-op because " + @@ -211,7 +211,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Tables // -------------------------------------------------------------------------- - override protected def doCreateTable( + override def createTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = withClient { assert(tableDefinition.identifier.database.isDefined) @@ -480,7 +480,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - override protected def doDropTable( + override def dropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -489,7 +489,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.dropTable(db, table, ignoreIfNotExists, purge) } - override protected def doRenameTable( + override def renameTable( db: String, oldName: String, newName: String): Unit = withClient { @@ -540,7 +540,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * Note: As of now, this doesn't support altering table schema, partition column names and bucket * specification. We will ignore them even if users do specify different values for these fields. */ - override def doAlterTable(tableDefinition: CatalogTable): Unit = withClient { + override def alterTable(tableDefinition: CatalogTable): Unit = withClient { assert(tableDefinition.identifier.database.isDefined) val db = tableDefinition.identifier.database.get requireTableExists(db, tableDefinition.identifier.table) @@ -624,7 +624,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * data schema should not have conflict column names with the existing partition columns, and * should still contain all the existing data columns. */ - override def doAlterTableDataSchema( + override def alterTableDataSchema( db: String, table: String, newDataSchema: StructType): Unit = withClient { @@ -656,7 +656,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ - override def doAlterTableStats( + override def alterTableStats( db: String, table: String, stats: Option[CatalogStatistics]): Unit = withClient { @@ -1208,7 +1208,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Functions // -------------------------------------------------------------------------- - override protected def doCreateFunction( + override def createFunction( db: String, funcDefinition: CatalogFunction): Unit = withClient { requireDbExists(db) @@ -1221,12 +1221,12 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.createFunction(db, funcDefinition.copy(identifier = functionIdentifier)) } - override protected def doDropFunction(db: String, name: String): Unit = withClient { + override def dropFunction(db: String, name: String): Unit = withClient { requireFunctionExists(db, name) client.dropFunction(db, name) } - override protected def doAlterFunction( + override def alterFunction( db: String, funcDefinition: CatalogFunction): Unit = withClient { requireDbExists(db) val functionName = funcDefinition.identifier.funcName.toLowerCase(Locale.ROOT) @@ -1235,7 +1235,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.alterFunction(db, funcDefinition.copy(identifier = functionIdentifier)) } - override protected def doRenameFunction( + override def renameFunction( db: String, oldName: String, newName: String): Unit = withClient { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index e5aff3b99d0b9..94ddeae1bf547 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -30,7 +30,7 @@ import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, Gener import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, ExternalCatalog, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper @@ -39,7 +39,7 @@ import org.apache.spark.sql.types.{DecimalType, DoubleType} private[sql] class HiveSessionCatalog( - externalCatalogBuilder: () => HiveExternalCatalog, + externalCatalogBuilder: () => ExternalCatalog, globalTempViewManagerBuilder: () => GlobalTempViewManager, val metastoreCatalog: HiveMetastoreCatalog, functionRegistry: FunctionRegistry, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 40b9bb51ca9a0..2882672f327c4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.Analyzer +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlanner @@ -35,14 +36,14 @@ import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLo class HiveSessionStateBuilder(session: SparkSession, parentState: Option[SessionState] = None) extends BaseSessionStateBuilder(session, parentState) { - private def externalCatalog: HiveExternalCatalog = - session.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] + private def externalCatalog: ExternalCatalogWithListener = session.sharedState.externalCatalog /** * Create a Hive aware resource loader. */ override protected lazy val resourceLoader: HiveSessionResourceLoader = { - new HiveSessionResourceLoader(session, () => externalCatalog.client) + new HiveSessionResourceLoader( + session, () => externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client) } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 6a7b25b36d9a5..e0f7375387d24 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -122,7 +122,7 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { allSupportedHiveVersions) val externalCatalog = sparkSession.sharedState.externalCatalog - val hiveVersion = externalCatalog.asInstanceOf[HiveExternalCatalog].client.version + val hiveVersion = externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.version val stagingDir = hadoopConf.get("hive.exec.stagingdir", ".hive-staging") val scratchDir = hadoopConf.get("hive.exec.scratchdir", "/tmp/hive") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 965aea2b61456..ee3f99ab7e9bb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -35,6 +35,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, ExternalCatalogWithListener} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.CacheTableCommand @@ -83,11 +84,11 @@ private[hive] class TestHiveSharedState( hiveClient: Option[HiveClient] = None) extends SharedState(sc) { - override lazy val externalCatalog: TestHiveExternalCatalog = { - new TestHiveExternalCatalog( + override lazy val externalCatalog: ExternalCatalogWithListener = { + new ExternalCatalogWithListener(new TestHiveExternalCatalog( sc.conf, sc.hadoopConfiguration, - hiveClient) + hiveClient)) } } @@ -208,7 +209,9 @@ private[hive] class TestHiveSparkSession( new TestHiveSessionStateBuilder(this, parentSessionState).build() } - lazy val metadataHive: HiveClient = sharedState.externalCatalog.client.newSession() + lazy val metadataHive: HiveClient = { + sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.newSession() + } override def newSession(): TestHiveSparkSession = { new TestHiveSparkSession(sc, Some(sharedState), None, loadTestTables) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala index 285f35b0b0eac..fd5f47e428239 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala @@ -26,7 +26,7 @@ class HiveExternalSessionCatalogSuite extends SessionCatalogSuite with TestHiveS private val externalCatalog = { val catalog = spark.sharedState.externalCatalog - catalog.asInstanceOf[HiveExternalCatalog].client.reset() + catalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.reset() catalog } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala index f2d27671094d7..51a48a20daaa2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala @@ -50,7 +50,8 @@ class HiveSchemaInferenceSuite FileStatusCache.resetForTesting() } - private val externalCatalog = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] + private val externalCatalog = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] private val client = externalCatalog.client // Return a copy of the given schema with all field names converted to lower case. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala index ecc09cdcdbeaf..a3579862c9e59 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala @@ -44,7 +44,8 @@ class HiveSessionStateSuite extends SessionStateSuite with TestHiveSingleton { val conf = sparkSession.sparkContext.hadoopConfiguration val oldValue = conf.get(ConfVars.METASTORECONNECTURLKEY.varname) sparkSession.cloneSession() - sparkSession.sharedState.externalCatalog.client.newSession() + sparkSession.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.newSession() val newValue = conf.get(ConfVars.METASTORECONNECTURLKEY.varname) assert(oldValue == newValue, "cloneSession and then newSession should not affect the Derby directory") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 079fe45860544..aa5b531992613 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -354,7 +354,7 @@ object SetMetastoreURLTest extends Logging { // HiveExternalCatalog is used when Hive support is enabled. val actualMetastoreURL = - spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client .getConf("javax.jdo.option.ConnectionURL", "this_is_a_wrong_URL") logInfo(s"javax.jdo.option.ConnectionURL is $actualMetastoreURL") @@ -780,7 +780,8 @@ object SPARK_18360 { val defaultDbLocation = spark.catalog.getDatabase("default").locationUri assert(new Path(defaultDbLocation) == new Path(spark.sharedState.warehousePath)) - val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val hiveClient = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client try { val tableMeta = CatalogTable( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala index fad81c7e9474e..473bbced41b31 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala @@ -289,7 +289,8 @@ class ShowCreateTableSuite extends QueryTest with SQLTestUtils with TestHiveSing } private def createRawHiveTable(ddl: String): Unit = { - hiveContext.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client.runSqlHive(ddl) + hiveContext.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.runSqlHive(ddl) } private def checkCreateTable(table: String): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 6176273c88db1..dc96ec416afd8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -134,8 +134,8 @@ class VersionsSuite extends SparkFunSuite with Logging { client = buildClient(version, hadoopConf, HiveUtils.formatTimeVarsForHiveClient(hadoopConf)) if (versionSpark != null) versionSpark.reset() versionSpark = TestHiveVersion(client) - assert(versionSpark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client - .version.fullVersion.startsWith(version)) + assert(versionSpark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.version.fullVersion.startsWith(version)) } def table(database: String, tableName: String): CatalogTable = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index daac6af9b557f..0341c3b378918 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1355,7 +1355,8 @@ class HiveDDLSuite val indexName = tabName + "_index" withTable(tabName) { // Spark SQL does not support creating index. Thus, we have to use Hive client. - val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val client = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client sql(s"CREATE TABLE $tabName(a int)") try { @@ -1393,7 +1394,8 @@ class HiveDDLSuite val tabName = "tab1" withTable(tabName) { // Spark SQL does not support creating skewed table. Thus, we have to use Hive client. - val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val client = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client client.runSqlHive( s""" |CREATE Table $tabName(col1 int, col2 int) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 704a410b6a37b..828c18a770c80 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2099,7 +2099,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { Seq("orc", "parquet").foreach { format => test(s"SPARK-18355 Read data from a hive table with a new column - $format") { - val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val client = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client Seq("true", "false").foreach { value => withSQLConf( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala index d3fff37c3424d..d50bf0b8fd603 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -30,7 +30,7 @@ trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll { protected val spark: SparkSession = TestHive.sparkSession protected val hiveContext: TestHiveContext = TestHive protected val hiveClient: HiveClient = - spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client protected override def afterAll(): Unit = { try { From a634d66ce767bd5e1d8553d1a2c32e2b1a80f642 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 7 May 2018 13:00:18 +0800 Subject: [PATCH 741/774] [SPARK-24126][PYSPARK] Use build-specific temp directory for pyspark tests. This avoids polluting and leaving garbage behind in /tmp, and allows the usual build tools to clean up any leftover files. Author: Marcelo Vanzin Closes #21198 from vanzin/SPARK-24126. --- python/pyspark/sql/tests.py | 4 ++-- python/pyspark/streaming/tests.py | 6 ++++-- python/pyspark/tests.py | 33 ++++++++++++++++++++----------- python/run-tests.py | 29 +++++++++++++++++++++++++-- 4 files changed, 54 insertions(+), 18 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index cc6acfdb07d99..16aa9378ad8ee 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3092,8 +3092,8 @@ def test_hivecontext(self): |print(hive_context.sql("show databases").collect()) """) proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local-cluster[1,1,1024]", - "--driver-class-path", hive_site_dir, script], + self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", + "--driver-class-path", hive_site_dir, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index d77f1baa1f344..e4a428a0b27e7 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -63,7 +63,7 @@ def setUpClass(cls): class_name = cls.__name__ conf = SparkConf().set("spark.default.parallelism", 1) cls.sc = SparkContext(appName=class_name, conf=conf) - cls.sc.setCheckpointDir("/tmp") + cls.sc.setCheckpointDir(tempfile.mkdtemp()) @classmethod def tearDownClass(cls): @@ -1549,7 +1549,9 @@ def search_kinesis_asl_assembly_jar(): kinesis_jar_present = True jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar) - os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars + existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") + jars_args = "--jars %s" % jars + os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args]) testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, StreamingListenerTests] diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 8392d7f29af53..7b8ce2c6b799f 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1951,7 +1951,12 @@ class SparkSubmitTests(unittest.TestCase): def setUp(self): self.programDir = tempfile.mkdtemp() - self.sparkSubmit = os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit") + tmp_dir = tempfile.gettempdir() + self.sparkSubmit = [ + os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit"), + "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + ] def tearDown(self): shutil.rmtree(self.programDir) @@ -2017,7 +2022,7 @@ def test_single_script(self): |sc = SparkContext() |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect()) """) - proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) + proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 4, 6]", out.decode('utf-8')) @@ -2033,7 +2038,7 @@ def test_script_with_local_functions(self): |sc = SparkContext() |print(sc.parallelize([1, 2, 3]).map(foo).collect()) """) - proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) + proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[3, 6, 9]", out.decode('utf-8')) @@ -2051,7 +2056,7 @@ def test_module_dependency(self): |def myfunc(x): | return x + 1 """) - proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, script], + proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) @@ -2070,7 +2075,7 @@ def test_module_dependency_on_cluster(self): |def myfunc(x): | return x + 1 """) - proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, "--master", + proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, "--master", "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) out, err = proc.communicate() @@ -2087,8 +2092,10 @@ def test_package_dependency(self): |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) self.create_spark_package("a:mylib:0.1") - proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", - "file:" + self.programDir, script], stdout=subprocess.PIPE) + proc = subprocess.Popen( + self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, script], + stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 3, 4]", out.decode('utf-8')) @@ -2103,9 +2110,11 @@ def test_package_dependency_on_cluster(self): |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) self.create_spark_package("a:mylib:0.1") - proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", - "file:" + self.programDir, "--master", - "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) + proc = subprocess.Popen( + self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, "--master", "local-cluster[1,1,1024]", + script], + stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 3, 4]", out.decode('utf-8')) @@ -2124,7 +2133,7 @@ def test_single_script_on_cluster(self): # this will fail if you have different spark.executor.memory # in conf/spark-defaults.conf proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local-cluster[1,1,1024]", script], + self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) @@ -2144,7 +2153,7 @@ def test_user_configuration(self): | sc.stop() """) proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local", script], + self.sparkSubmit + ["--master", "local", script], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) out, err = proc.communicate() diff --git a/python/run-tests.py b/python/run-tests.py index f408fc5082b3d..4c90926cfa350 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -22,11 +22,13 @@ from optparse import OptionParser import os import re +import shutil import subprocess import sys import tempfile from threading import Thread, Lock import time +import uuid if sys.version < '3': import Queue else: @@ -68,7 +70,7 @@ def print_red(text): raise Exception("Cannot find assembly build directory, please build Spark first.") -def run_individual_python_test(test_name, pyspark_python): +def run_individual_python_test(target_dir, test_name, pyspark_python): env = dict(os.environ) env.update({ 'SPARK_DIST_CLASSPATH': SPARK_DIST_CLASSPATH, @@ -77,6 +79,23 @@ def run_individual_python_test(test_name, pyspark_python): 'PYSPARK_PYTHON': which(pyspark_python), 'PYSPARK_DRIVER_PYTHON': which(pyspark_python) }) + + # Create a unique temp directory under 'target/' for each run. The TMPDIR variable is + # recognized by the tempfile module to override the default system temp directory. + tmp_dir = os.path.join(target_dir, str(uuid.uuid4())) + while os.path.isdir(tmp_dir): + tmp_dir = os.path.join(target_dir, str(uuid.uuid4())) + os.mkdir(tmp_dir) + env["TMPDIR"] = tmp_dir + + # Also override the JVM's temp directory by setting driver and executor options. + spark_args = [ + "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "pyspark-shell" + ] + env["PYSPARK_SUBMIT_ARGS"] = " ".join(spark_args) + LOGGER.info("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() try: @@ -84,6 +103,7 @@ def run_individual_python_test(test_name, pyspark_python): retcode = subprocess.Popen( [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], stderr=per_test_output, stdout=per_test_output, env=env).wait() + shutil.rmtree(tmp_dir, ignore_errors=True) except: LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python) # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if @@ -238,6 +258,11 @@ def main(): priority = 100 task_queue.put((priority, (python_exec, test_goal))) + # Create the target directory before starting tasks to avoid races. + target_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'target')) + if not os.path.isdir(target_dir): + os.mkdir(target_dir) + def process_queue(task_queue): while True: try: @@ -245,7 +270,7 @@ def process_queue(task_queue): except Queue.Empty: break try: - run_individual_python_test(test_goal, python_exec) + run_individual_python_test(target_dir, test_goal, python_exec) finally: task_queue.task_done() From 889f6cc10cbd7781df04f468674a61f0ac5a870b Mon Sep 17 00:00:00 2001 From: jinxing Date: Mon, 7 May 2018 14:16:27 +0800 Subject: [PATCH 742/774] [SPARK-24143] filter empty blocks when convert mapstatus to (blockId, size) pair MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In current code(`MapOutputTracker.convertMapStatuses`), mapstatus are converted to (blockId, size) pair for all blocks – no matter the block is empty or not, which result in OOM when there are lots of consecutive empty blocks, especially when adaptive execution is enabled. (blockId, size) pair is only used in `ShuffleBlockFetcherIterator` to control shuffle-read and only non-empty block request is sent. Can we just filter out the empty blocks in MapOutputTracker.convertMapStatuses and save memory? ## How was this patch tested? not added yet. Author: jinxing Closes #21212 from jinxing64/SPARK-24143. --- .../org/apache/spark/MapOutputTracker.scala | 31 +++++++++------- .../storage/ShuffleBlockFetcherIterator.scala | 35 +++++++++++-------- .../apache/spark/MapOutputTrackerSuite.scala | 31 +++++++++++++++- .../BlockStoreShuffleReaderSuite.scala | 2 +- .../ShuffleBlockFetcherIteratorSuite.scala | 19 +++++----- 5 files changed, 80 insertions(+), 38 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 195fd4f818b36..73646051f264c 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -22,7 +22,7 @@ import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolE import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.collection.mutable.{HashMap, HashSet, ListBuffer, Map} import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.reflect.ClassTag @@ -282,7 +282,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging // For testing def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) } @@ -296,7 +296,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] /** * Deletes map output status information for the specified shuffle stage. @@ -632,9 +632,10 @@ private[spark] class MapOutputTrackerMaster( } } + // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { case Some (shuffleStatus) => @@ -642,7 +643,7 @@ private[spark] class MapOutputTrackerMaster( MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) } case None => - Seq.empty + Iterator.empty } } @@ -669,8 +670,9 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr /** Remembers which map output locations are currently being fetched on an executor. */ private val fetching = new HashSet[Int] + // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") val statuses = getStatuses(shuffleId) try { @@ -841,6 +843,7 @@ private[spark] object MapOutputTracker extends Logging { * Given an array of map statuses and a range of map output partitions, returns a sequence that, * for each block manager ID, lists the shuffle block IDs and corresponding shuffle block sizes * stored at that block manager. + * Note that empty blocks are filtered in the result. * * If any of the statuses is null (indicating a missing location due to a failed mapper), * throws a FetchFailedException. @@ -857,22 +860,24 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { assert (statuses != null) - val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]] - for ((status, mapId) <- statuses.zipWithIndex) { + val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]] + for ((status, mapId) <- statuses.iterator.zipWithIndex) { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" logError(errorMessage) throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage) } else { for (part <- startPartition until endPartition) { - splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += - ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part))) + val size = status.getSizeForBlock(part) + if (size != 0) { + splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, part), size)) + } } } } - - splitsByAddress.toSeq + splitsByAddress.iterator } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index dd9df74689a13..6971efd2504c2 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -48,7 +48,9 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * @param blockManager [[BlockManager]] for reading local blocks * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in - * order to throttle the memory usage. + * order to throttle the memory usage. Note that zero-sized blocks are + * already excluded, which happened in + * [[MapOutputTracker.convertMapStatuses]]. * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. @@ -62,7 +64,7 @@ final class ShuffleBlockFetcherIterator( context: TaskContext, shuffleClient: ShuffleClient, blockManager: BlockManager, - blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])], streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqsInFlight: Int, @@ -74,8 +76,8 @@ final class ShuffleBlockFetcherIterator( import ShuffleBlockFetcherIterator._ /** - * Total number of blocks to fetch. This can be smaller than the total number of blocks - * in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]]. + * Total number of blocks to fetch. This should be equal to the total number of blocks + * in [[blocksByAddress]] because we already filter out zero-sized blocks in [[blocksByAddress]]. * * This should equal localBlocks.size + remoteBlocks.size. */ @@ -267,13 +269,16 @@ final class ShuffleBlockFetcherIterator( // at most maxBytesInFlight in order to limit the amount of data in flight. val remoteRequests = new ArrayBuffer[FetchRequest] - // Tracks total number of blocks (including zero sized blocks) - var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { - totalBlocks += blockInfos.size if (address.executorId == blockManager.blockManagerId.executorId) { - // Filter out zero-sized blocks - localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) + blockInfos.find(_._2 <= 0) match { + case Some((blockId, size)) if size < 0 => + throw new BlockException(blockId, "Negative block size " + size) + case Some((blockId, size)) if size == 0 => + throw new BlockException(blockId, "Zero-sized blocks should be excluded.") + case None => // do nothing. + } + localBlocks ++= blockInfos.map(_._1) numBlocksToFetch += localBlocks.size } else { val iterator = blockInfos.iterator @@ -281,14 +286,15 @@ final class ShuffleBlockFetcherIterator( var curBlocks = new ArrayBuffer[(BlockId, Long)] while (iterator.hasNext) { val (blockId, size) = iterator.next() - // Skip empty blocks - if (size > 0) { + if (size < 0) { + throw new BlockException(blockId, "Negative block size " + size) + } else if (size == 0) { + throw new BlockException(blockId, "Zero-sized blocks should be excluded.") + } else { curBlocks += ((blockId, size)) remoteBlocks += blockId numBlocksToFetch += 1 curRequestSize += size - } else if (size < 0) { - throw new BlockException(blockId, "Negative block size " + size) } if (curRequestSize >= targetRequestSize || curBlocks.size >= maxBlocksInFlightPerAddress) { @@ -306,7 +312,8 @@ final class ShuffleBlockFetcherIterator( } } } - logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks") + logInfo(s"Getting $numBlocksToFetch non-empty blocks including ${localBlocks.size}" + + s" local blocks and ${remoteBlocks.size} remote blocks") remoteRequests } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 50b8ea754d8d9..21f481d477242 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -147,7 +147,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("a", "hostA", 1000), Array(1000L))) slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) @@ -298,4 +298,33 @@ class MapOutputTrackerSuite extends SparkFunSuite { } } + test("zero-sized blocks should be excluded when getMapSizesByExecutorId") { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(10, 2) + + val size0 = MapStatus.decompressSize(MapStatus.compressSize(0L)) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) + tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(size0, size1000, size0, size10000))) + tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), + Array(size10000, size0, size1000, size0))) + assert(tracker.containsShuffle(10)) + assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === + Seq( + (BlockManagerId("a", "hostA", 1000), + Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))), + (BlockManagerId("b", "hostB", 1000), + Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000))) + ) + ) + + tracker.unregisterShuffle(10) + tracker.stop() + rpcEnv.shutdown() + } + } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index dba1172d5fdbd..2d8a83c6fabed 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -108,7 +108,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) (shuffleBlockId, byteOutputStream.size().toLong) } - Seq((localBlockManagerId, shuffleBlockIdsAndSizes)) + Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).toIterator } // Create a mocked shuffle handle to pass into HashShuffleReader. diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 692ae3bf597e0..cefebfa51b8b9 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -99,7 +99,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (localBmId, localBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq), (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) - ) + ).toIterator val iterator = new ShuffleBlockFetcherIterator( TaskContext.empty(), @@ -176,7 +176,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -244,7 +244,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -310,7 +310,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -378,7 +378,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (localBmId, localBlockLengths), (remoteBmId, remoteBlockLengths) - ) + ).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -437,7 +437,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -495,7 +495,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - def fetchShuffleBlock(blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { + def fetchShuffleBlock( + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. @@ -513,14 +514,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)) + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)).toIterator fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. assert(tempFileManager == null) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)).toIterator fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. From 7564a9a70695dac2f0b5f51493d37cbc93691663 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 7 May 2018 15:22:23 +0900 Subject: [PATCH 743/774] [SPARK-23921][SQL] Add array_sort function ## What changes were proposed in this pull request? The PR adds the SQL function `array_sort`. The behavior of the function is based on Presto's one. The function sorts the input array in ascending order. The elements of the input array must be orderable. Null elements will be placed at the end of the returned array. ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki Closes #21021 from kiszk/SPARK-23921. --- python/pyspark/sql/functions.py | 26 +- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 240 ++++++++++++++---- .../CollectionExpressionsSuite.scala | 34 ++- .../org/apache/spark/sql/functions.scala | 12 + .../spark/sql/DataFrameFunctionsSuite.scala | 34 ++- 6 files changed, 292 insertions(+), 55 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ad4bd6f5089e9..bd55b5f73b4d0 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2183,20 +2183,38 @@ def array_max(col): def sort_array(col, asc=True): """ Collection function: sorts the input array in ascending or descending order according - to the natural ordering of the array elements. + to the natural ordering of the array elements. Null elements will be placed at the beginning + of the returned array in ascending order or at the end of the returned array in descending + order. :param col: name of column or expression - >>> df = spark.createDataFrame([([2, 1, 3],),([1],),([],)], ['data']) + >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) >>> df.select(sort_array(df.data).alias('r')).collect() - [Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])] + [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])] >>> df.select(sort_array(df.data, asc=False).alias('r')).collect() - [Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])] + [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) +@since(2.4) +def array_sort(col): + """ + Collection function: sorts the input array in ascending order. The elements of the input array + must be orderable. Null elements will be placed at the end of the returned array. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) + >>> df.select(array_sort(df.data).alias('r')).collect() + [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_sort(_to_java_column(col))) + + @since(1.5) @ignore_unicode_prefix def reverse(col): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 51bb6b0abe408..01776b85e6f53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -403,6 +403,7 @@ object FunctionRegistry { expression[ArrayContains]("array_contains"), expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), + expression[ArraySort]("array_sort"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[ElementAt]("element_at"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6d63a531e3b74..23c09bc3b49d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -20,6 +20,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ @@ -119,47 +120,16 @@ case class MapValues(child: Expression) } /** - * Sorts the input array in ascending / descending order according to the natural ordering of - * the array elements and returns it. + * Common base class for [[SortArray]] and [[ArraySort]]. */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order according to the natural ordering of the array elements.", - examples = """ - Examples: - > SELECT _FUNC_(array('b', 'd', 'c', 'a'), true); - ["a","b","c","d"] - """) -// scalastyle:on line.size.limit -case class SortArray(base: Expression, ascendingOrder: Expression) - extends BinaryExpression with ExpectsInputTypes with CodegenFallback { +trait ArraySortLike extends ExpectsInputTypes { + protected def arrayExpression: Expression - def this(e: Expression) = this(e, Literal(true)) - - override def left: Expression = base - override def right: Expression = ascendingOrder - override def dataType: DataType = base.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) - - override def checkInputDataTypes(): TypeCheckResult = base.dataType match { - case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => - ascendingOrder match { - case Literal(_: Boolean, BooleanType) => - TypeCheckResult.TypeCheckSuccess - case _ => - TypeCheckResult.TypeCheckFailure( - "Sort order in second argument requires a boolean literal.") - } - case ArrayType(dt, _) => - TypeCheckResult.TypeCheckFailure( - s"$prettyName does not support sorting array of type ${dt.simpleString}") - case _ => - TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") - } + protected def nullOrder: NullOrder @transient private lazy val lt: Comparator[Any] = { - val ordering = base.dataType match { + val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -170,9 +140,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) if (o1 == null && o2 == null) { 0 } else if (o1 == null) { - -1 + nullOrder } else if (o2 == null) { - 1 + -nullOrder } else { ordering.compare(o1, o2) } @@ -182,7 +152,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) @transient private lazy val gt: Comparator[Any] = { - val ordering = base.dataType match { + val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -193,9 +163,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) if (o1 == null && o2 == null) { 0 } else if (o1 == null) { - 1 + -nullOrder } else if (o2 == null) { - -1 + nullOrder } else { -ordering.compare(o1, o2) } @@ -203,18 +173,200 @@ case class SortArray(base: Expression, ascendingOrder: Expression) } } - override def nullSafeEval(array: Any, ascending: Any): Any = { - val elementType = base.dataType.asInstanceOf[ArrayType].elementType + def elementType: DataType = arrayExpression.dataType.asInstanceOf[ArrayType].elementType + def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull + + def sortEval(array: Any, ascending: Boolean): Any = { val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) if (elementType != NullType) { - java.util.Arrays.sort(data, if (ascending.asInstanceOf[Boolean]) lt else gt) + java.util.Arrays.sort(data, if (ascending) lt else gt) } new GenericArrayData(data.asInstanceOf[Array[Any]]) } + def sortCodegen(ctx: CodegenContext, ev: ExprCode, base: String, order: String): String = { + val arrayData = classOf[ArrayData].getName + val genericArrayData = classOf[GenericArrayData].getName + val unsafeArrayData = classOf[UnsafeArrayData].getName + val array = ctx.freshName("array") + val c = ctx.freshName("c") + if (elementType == NullType) { + s"${ev.value} = $base.copy();" + } else { + val elementTypeTerm = ctx.addReferenceObj("elementTypeTerm", elementType) + val sortOrder = ctx.freshName("sortOrder") + val o1 = ctx.freshName("o1") + val o2 = ctx.freshName("o2") + val jt = CodeGenerator.javaType(elementType) + val comp = if (CodeGenerator.isPrimitiveType(elementType)) { + val bt = CodeGenerator.boxedType(elementType) + val v1 = ctx.freshName("v1") + val v2 = ctx.freshName("v2") + s""" + |$jt $v1 = (($bt) $o1).${jt}Value(); + |$jt $v2 = (($bt) $o2).${jt}Value(); + |int $c = ${ctx.genComp(elementType, v1, v2)}; + """.stripMargin + } else { + s"int $c = ${ctx.genComp(elementType, s"(($jt) $o1)", s"(($jt) $o2)")};" + } + val nonNullPrimitiveAscendingSort = + if (CodeGenerator.isPrimitiveType(elementType) && !containsNull) { + val javaType = CodeGenerator.javaType(elementType) + val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType) + s""" + |if ($order) { + | $javaType[] $array = $base.to${primitiveTypeName}Array(); + | java.util.Arrays.sort($array); + | ${ev.value} = $unsafeArrayData.fromPrimitiveArray($array); + |} else + """.stripMargin + } else { + "" + } + s""" + |$nonNullPrimitiveAscendingSort + |{ + | Object[] $array = $base.toObjectArray($elementTypeTerm); + | final int $sortOrder = $order ? 1 : -1; + | java.util.Arrays.sort($array, new java.util.Comparator() { + | @Override public int compare(Object $o1, Object $o2) { + | if ($o1 == null && $o2 == null) { + | return 0; + | } else if ($o1 == null) { + | return $sortOrder * $nullOrder; + | } else if ($o2 == null) { + | return -$sortOrder * $nullOrder; + | } + | $comp + | return $sortOrder * $c; + | } + | }); + | ${ev.value} = new $genericArrayData($array); + |} + """.stripMargin + } + } + +} + +object ArraySortLike { + type NullOrder = Int + // Least: place null element at the first of the array for ascending order + // Greatest: place null element at the end of the array for ascending order + object NullOrder { + val Least: NullOrder = -1 + val Greatest: NullOrder = 1 + } +} + +/** + * Sorts the input array in ascending / descending order according to the natural ordering of + * the array elements and returns it. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order + according to the natural ordering of the array elements. Null elements will be placed + at the beginning of the returned array in ascending order or at the end of the returned + array in descending order. + """, + examples = """ + Examples: + > SELECT _FUNC_(array('b', 'd', null, 'c', 'a'), true); + [null,"a","b","c","d"] + """) +// scalastyle:on line.size.limit +case class SortArray(base: Expression, ascendingOrder: Expression) + extends BinaryExpression with ArraySortLike { + + def this(e: Expression) = this(e, Literal(true)) + + override def left: Expression = base + override def right: Expression = ascendingOrder + override def dataType: DataType = base.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) + + override def arrayExpression: Expression = base + override def nullOrder: NullOrder = NullOrder.Least + + override def checkInputDataTypes(): TypeCheckResult = base.dataType match { + case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + ascendingOrder match { + case Literal(_: Boolean, BooleanType) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure( + "Sort order in second argument requires a boolean literal.") + } + case ArrayType(dt, _) => + val dtSimple = dt.simpleString + TypeCheckResult.TypeCheckFailure( + s"$prettyName does not support sorting array of type $dtSimple which is not orderable") + case _ => + TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") + } + + override def nullSafeEval(array: Any, ascending: Any): Any = { + sortEval(array, ascending.asInstanceOf[Boolean]) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (b, order) => sortCodegen(ctx, ev, b, order)) + } + override def prettyName: String = "sort_array" } + +/** + * Sorts the input array in ascending order according to the natural ordering of + * the array elements and returns it. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(array) - Sorts the input array in ascending order. The elements of the input array must + be orderable. Null elements will be placed at the end of the returned array. + """, + examples = """ + Examples: + > SELECT _FUNC_(array('b', 'd', null, 'c', 'a')); + ["a","b","c","d",null] + """, + since = "2.4.0") +// scalastyle:on line.size.limit +case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLike { + + override def dataType: DataType = child.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + override def arrayExpression: Expression = child + override def nullOrder: NullOrder = NullOrder.Greatest + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + TypeCheckResult.TypeCheckSuccess + case ArrayType(dt, _) => + val dtSimple = dt.simpleString + TypeCheckResult.TypeCheckFailure( + s"$prettyName does not support sorting array of type $dtSimple which is not orderable") + case _ => + TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") + } + + override def nullSafeEval(array: Any): Any = { + sortEval(array, true) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => sortCodegen(ctx, ev, c, "true")) + } + + override def prettyName: String = "array_sort" +} + /** * Returns a reversed string or an array with reverse order of elements. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 7048d93fd5649..749374f1a14a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -61,28 +61,58 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType)) val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType)) - val a4 = Literal.create(Seq(null, null), ArrayType(NullType)) + val d1 = new Decimal().set(10) + val d2 = new Decimal().set(100) + val a4 = Literal.create(Seq(d2, d1), ArrayType(DecimalType(10, 0))) + val a5 = Literal.create(Seq(null, null), ArrayType(NullType)) checkEvaluation(new SortArray(a0), Seq(1, 2, 3)) checkEvaluation(new SortArray(a1), Seq[Integer]()) checkEvaluation(new SortArray(a2), Seq("a", "b")) checkEvaluation(new SortArray(a3), Seq(null, "a", "b")) + checkEvaluation(new SortArray(a4), Seq(d1, d2)) checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3)) checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]()) checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b")) checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b")) + checkEvaluation(SortArray(a4, Literal(true)), Seq(d1, d2)) checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1)) checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]()) checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a")) checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null)) + checkEvaluation(SortArray(a4, Literal(false)), Seq(d2, d1)) checkEvaluation(Literal.create(null, ArrayType(StringType)), null) - checkEvaluation(new SortArray(a4), Seq(null, null)) + checkEvaluation(new SortArray(a5), Seq(null, null)) val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS) checkEvaluation(new SortArray(arrayStruct), Seq(create_row(1), create_row(2))) + + val typeAA = ArrayType(ArrayType(IntegerType)) + val aa1 = Array[java.lang.Integer](1, 2) + val aa2 = Array[java.lang.Integer](3, null, 4) + val arrayArray = Literal.create(Seq(aa2, aa1), typeAA) + + checkEvaluation(new SortArray(arrayArray), Seq(aa1, aa2)) + + val typeAAS = ArrayType(ArrayType(StructType(StructField("a", IntegerType) :: Nil))) + val aas1 = Array(create_row(1)) + val aas2 = Array(create_row(2)) + val arrayArrayStruct = Literal.create(Seq(aas2, aas1), typeAAS) + + checkEvaluation(new SortArray(arrayArrayStruct), Seq(aas1, aas2)) + + checkEvaluation(ArraySort(a0), Seq(1, 2, 3)) + checkEvaluation(ArraySort(a1), Seq[Integer]()) + checkEvaluation(ArraySort(a2), Seq("a", "b")) + checkEvaluation(ArraySort(a3), Seq("a", "b", null)) + checkEvaluation(ArraySort(a4), Seq(d1, d2)) + checkEvaluation(ArraySort(a5), Seq(null, null)) + checkEvaluation(ArraySort(arrayStruct), Seq(create_row(1), create_row(2))) + checkEvaluation(ArraySort(arrayArray), Seq(aa1, aa2)) + checkEvaluation(ArraySort(arrayArrayStruct), Seq(aas1, aas2)) } test("Array contains") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d2e22fa355514..10b6dcc0608c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3093,6 +3093,15 @@ object functions { ElementAt(column.expr, Literal(value)) } + /** + * Sorts the input array in ascending order. The elements of the input array must be orderable. + * Null elements will be placed at the end of the returned array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_sort(e: Column): Column = withExpr { ArraySort(e.expr) } + /** * Creates a new row for each element in the given array or map column. * @@ -3332,6 +3341,7 @@ object functions { /** * Sorts the input array for the given column in ascending order, * according to the natural ordering of the array elements. + * Null elements will be placed at the beginning of the returned array. * * @group collection_funcs * @since 1.5.0 @@ -3341,6 +3351,8 @@ object functions { /** * Sorts the input array for the given column in ascending or descending order, * according to the natural ordering of the array elements. + * Null elements will be placed at the beginning of the returned array in ascending order or + * at the end of the returned array in descending order. * * @group collection_funcs * @since 1.5.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index a5163accb1bb3..ae21cbc802d0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -276,7 +276,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } - test("sort_array function") { + test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), (Array.empty[Int], Array.empty[String]), @@ -286,28 +286,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.select(sort_array($"a"), sort_array($"b")), Seq( Row(Seq(1, 2, 3), Seq("a", "b", "c")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) checkAnswer( df.select(sort_array($"a", false), sort_array($"b", false)), Seq( Row(Seq(3, 2, 1), Seq("c", "b", "a")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) checkAnswer( df.selectExpr("sort_array(a)", "sort_array(b)"), Seq( Row(Seq(1, 2, 3), Seq("a", "b", "c")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) checkAnswer( df.selectExpr("sort_array(a, true)", "sort_array(b, false)"), Seq( Row(Seq(1, 2, 3), Seq("c", "b", "a")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) @@ -324,6 +324,30 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(intercept[AnalysisException] { df3.selectExpr("sort_array(a)").collect() }.getMessage().contains("only supports array input")) + + checkAnswer( + df.select(array_sort($"a"), array_sort($"b")), + Seq( + Row(Seq(1, 2, 3), Seq("a", "b", "c")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + checkAnswer( + df.selectExpr("array_sort(a)", "array_sort(b)"), + Seq( + Row(Seq(1, 2, 3), Seq("a", "b", "c")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + + checkAnswer( + df2.selectExpr("array_sort(a)"), + Seq(Row(Seq[Seq[Int]](Seq(1), Seq(2), Seq(2, 4), null))) + ) + + assert(intercept[AnalysisException] { + df3.selectExpr("array_sort(a)").collect() + }.getMessage().contains("only supports array input")) } test("array size function") { From d2aa859b4faeda03e32a7574dd0c5b4ed367fae4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 7 May 2018 14:34:03 +0800 Subject: [PATCH 744/774] [SPARK-24160] ShuffleBlockFetcherIterator should fail if it receives zero-size blocks ## What changes were proposed in this pull request? This patch modifies `ShuffleBlockFetcherIterator` so that the receipt of zero-size blocks is treated as an error. This is done as a preventative measure to guard against a potential source of data loss bugs. In the shuffle layer, we guarantee that zero-size blocks will never be requested (a block containing zero records is always 0 bytes in size and is marked as empty such that it will never be legitimately requested by executors). However, the existing code does not fully take advantage of this invariant in the shuffle-read path: the existing code did not explicitly check whether blocks are non-zero-size. Additionally, our decompression and deserialization streams treat zero-size inputs as empty streams rather than errors (EOF might actually be treated as "end-of-stream" in certain layers (longstanding behavior dating to earliest versions of Spark) and decompressors like Snappy may be tolerant to zero-size inputs). As a result, if some other bug causes legitimate buffers to be replaced with zero-sized buffers (due to corruption on either the send or receive sides) then this would translate into silent data loss rather than an explicit fail-fast error. This patch addresses this problem by adding a `buf.size != 0` check. See code comments for pointers to tests which guarantee the invariants relied on here. ## How was this patch tested? Existing tests (which required modifications, since some were creating empty buffers in mocks). I also added a test to make sure we fail on zero-size blocks. To test that the zero-size blocks are indeed a potential corruption source, I manually ran a workload in `spark-shell` with a modified build which replaces all buffers with zero-size buffers in the receive path. Author: Josh Rosen Closes #21219 from JoshRosen/SPARK-24160. --- .../storage/ShuffleBlockFetcherIterator.scala | 19 +++++ .../ShuffleBlockFetcherIteratorSuite.scala | 71 +++++++++++++------ 2 files changed, 70 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 6971efd2504c2..b31862323a895 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -414,6 +414,25 @@ final class ShuffleBlockFetcherIterator( logDebug("Number of requests in flight " + reqsInFlight) } + if (buf.size == 0) { + // We will never legitimately receive a zero-size block. All blocks with zero records + // have zero size and all zero-size blocks have no records (and hence should never + // have been requested in the first place). This statement relies on behaviors of the + // shuffle writers, which are guaranteed by the following test cases: + // + // - BypassMergeSortShuffleWriterSuite: "write with some empty partitions" + // - UnsafeShuffleWriterSuite: "writeEmptyIterator" + // - DiskBlockObjectWriterSuite: "commit() and close() without ever opening or writing" + // + // There is not an explicit test for SortShuffleWriter but the underlying APIs that + // uses are shared by the UnsafeShuffleWriter (both writers use DiskBlockObjectWriter + // which returns a zero-size from commitAndGet() in case no records were written + // since the last call. + val msg = s"Received a zero-size buffer for block $blockId from $address " + + s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)" + throwFetchFailedException(blockId, address, new IOException(msg)) + } + val in = try { buf.createInputStream() } catch { diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index cefebfa51b8b9..8e9374b768adc 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -65,12 +65,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } // Create a mock managed buffer for testing - def createMockManagedBuffer(): ManagedBuffer = { + def createMockManagedBuffer(size: Int = 1): ManagedBuffer = { val mockManagedBuffer = mock(classOf[ManagedBuffer]) val in = mock(classOf[InputStream]) when(in.read(any())).thenReturn(1) when(in.read(any(), any(), any())).thenReturn(1) when(mockManagedBuffer.createInputStream()).thenReturn(in) + when(mockManagedBuffer.size()).thenReturn(size) mockManagedBuffer } @@ -269,6 +270,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } } + private def mockCorruptBuffer(size: Long = 1L): ManagedBuffer = { + val corruptStream = mock(classOf[InputStream]) + when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) + val corruptBuffer = mock(classOf[ManagedBuffer]) + when(corruptBuffer.size()).thenReturn(size) + when(corruptBuffer.createInputStream()).thenReturn(corruptStream) + corruptBuffer + } + test("retry corrupt blocks") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) @@ -284,11 +294,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) - - val corruptStream = mock(classOf[InputStream]) - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) - val corruptBuffer = mock(classOf[ManagedBuffer]) - when(corruptBuffer.createInputStream()).thenReturn(corruptStream) val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) @@ -301,7 +306,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) listener.onBlockFetchSuccess( ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer) sem.release() @@ -339,7 +344,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Future { // Return the first block, and then fail. listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) sem.release() } } @@ -353,11 +358,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("big blocks are not checked for corruption") { - val corruptStream = mock(classOf[InputStream]) - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) - val corruptBuffer = mock(classOf[ManagedBuffer]) - when(corruptBuffer.createInputStream()).thenReturn(corruptStream) - doReturn(10000L).when(corruptBuffer).size() + val corruptBuffer = mockCorruptBuffer(10000L) val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) @@ -413,11 +414,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) - val corruptStream = mock(classOf[InputStream]) - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) - val corruptBuffer = mock(classOf[ManagedBuffer]) - when(corruptBuffer.createInputStream()).thenReturn(corruptStream) - val transfer = mock(classOf[BlockTransferService]) when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { @@ -428,9 +424,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 2, 0).toString, corruptBuffer) + ShuffleBlockId(0, 2, 0).toString, mockCorruptBuffer()) sem.release() } } @@ -527,4 +523,39 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // shuffle block to disk. assert(tempFileManager != null) } + + test("fail zero-size blocks") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer() + ) + + val transfer = createMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0))) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + + val taskContext = TaskContext.empty() + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress, + (_, in) => in, + 48 * 1024 * 1024, + Int.MaxValue, + Int.MaxValue, + Int.MaxValue, + true) + + // All blocks fetched return zero length and should trigger a receive-side error: + val e = intercept[FetchFailedException] { iterator.next() } + assert(e.getMessage.contains("Received a zero-size buffer")) + } } From c5981976f1d514a3ad8a684b9a21cebe38b786fa Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Mon, 7 May 2018 14:45:14 +0800 Subject: [PATCH 745/774] [SPARK-23775][TEST] Make DataFrameRangeSuite not flaky ## What changes were proposed in this pull request? DataFrameRangeSuite.test("Cancelling stage in a query with Range.") stays sometimes in an infinite loop and times out the build. There were multiple issues with the test: 1. The first valid stageId is zero when the test started alone and not in a suite and the following code waits until timeout: ``` eventually(timeout(10.seconds), interval(1.millis)) { assert(DataFrameRangeSuite.stageToKill > 0) } ``` 2. The `DataFrameRangeSuite.stageToKill` was overwritten by the task's thread after the reset which ended up in canceling the same stage 2 times. This caused the infinite wait. This PR solves this mentioned flakyness by removing the shared `DataFrameRangeSuite.stageToKill` and using `onTaskStart` where stage ID is provided. In order to make sure cancelStage called for all stages `waitUntilEmpty` is called on `ListenerBus`. In [PR20888](https://github.com/apache/spark/pull/20888) this tried to get solved by: * Stopping the executor thread with `wait` * Wait for all `cancelStage` called * Kill the executor thread by setting `SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL` but the thread killing left the shared `SparkContext` sometimes in a state where further jobs can't be submitted. As a result DataFrameRangeSuite.test("Cancelling stage in a query with Range.") test passed properly but the next test inside the suite was hanging. ## How was this patch tested? Existing unit test executed 10k times. Author: Gabor Somogyi Closes #21214 from gaborgsomogyi/SPARK-23775_1. --- .../spark/sql/DataFrameRangeSuite.scala | 24 +++++++------------ 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index 57a930dfaf320..b0b46640ff317 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -23,8 +23,8 @@ import scala.util.Random import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkException, TaskContext} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.SparkException +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -153,23 +153,17 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall test("Cancelling stage in a query with Range.") { val listener = new SparkListener { - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - eventually(timeout(10.seconds), interval(1.millis)) { - assert(DataFrameRangeSuite.stageToKill > 0) - } - sparkContext.cancelStage(DataFrameRangeSuite.stageToKill) + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + sparkContext.cancelStage(taskStart.stageId) } } sparkContext.addSparkListener(listener) for (codegen <- Seq(true, false)) { withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { - DataFrameRangeSuite.stageToKill = -1 val ex = intercept[SparkException] { - spark.range(0, 100000000000L, 1, 1).map { x => - DataFrameRangeSuite.stageToKill = TaskContext.get().stageId() - x - }.toDF("id").agg(sum("id")).collect() + spark.range(0, 100000000000L, 1, 1) + .toDF("id").agg(sum("id")).collect() } ex.getCause() match { case null => @@ -180,6 +174,8 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") } } + // Wait until all ListenerBus events consumed to make sure cancelStage called for all stages + sparkContext.listenerBus.waitUntilEmpty(20.seconds.toMillis) eventually(timeout(20.seconds)) { assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) } @@ -204,7 +200,3 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } } } - -object DataFrameRangeSuite { - @volatile var stageToKill = -1 -} From f06528015d5856d6dc5cce00309bc2ae985e080f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 7 May 2018 15:42:10 +0800 Subject: [PATCH 746/774] [SPARK-24160][FOLLOWUP] Fix compilation failure ## What changes were proposed in this pull request? SPARK-24160 is causing a compilation failure (after SPARK-24143 was merged). This fixes the issue. ## How was this patch tested? building successfully Author: Marco Gaido Closes #21256 from mgaido91/SPARK-24160_FOLLOWUP. --- .../apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 8e9374b768adc..a2997dbd1b1ac 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -546,7 +546,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT taskContext, transfer, blockManager, - blocksByAddress, + blocksByAddress.toIterator, (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, From e35ad3caddeaa4b0d4c8524dcfb9e9f56dc7fe3d Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 7 May 2018 16:57:37 +0900 Subject: [PATCH 747/774] [SPARK-23930][SQL] Add slice function ## What changes were proposed in this pull request? The PR add the `slice` function. The behavior of the function is based on Presto's one. The function slices an array according to the requested start index and length. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21040 from mgaido91/SPARK-23930. --- python/pyspark/sql/functions.py | 13 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/codegen/CodeGenerator.scala | 34 ++++ .../expressions/collectionOperations.scala | 163 ++++++++++++++---- .../CollectionExpressionsSuite.scala | 28 +++ .../expressions/ExpressionEvalHelper.scala | 6 + .../expressions/ObjectExpressionsSuite.scala | 1 - .../org/apache/spark/sql/functions.scala | 10 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 16 ++ 9 files changed, 233 insertions(+), 39 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index bd55b5f73b4d0..ac3c79766702c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1834,6 +1834,19 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(2.4) +def slice(x, start, length): + """ + Collection function: returns an array containing all the elements in `x` from index `start` + (or starting from the end if `start` is negative) with the specified `length`. + >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x']) + >>> df.select(slice(df.x, 2, 2).alias("sliced")).collect() + [Row(sliced=[2, 3]), Row(sliced=[5])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.slice(_to_java_column(x), start, length)) + + @ignore_unicode_prefix @since(2.4) def array_join(col, delimiter, null_replacement=None): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 01776b85e6f53..87b0911e150c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -410,6 +410,7 @@ object FunctionRegistry { expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[Size]("size"), + expression[Slice]("slice"), expression[Size]("cardinality"), expression[SortArray]("sort_array"), expression[ArrayMin]("array_min"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index cf0a91ff00626..4dda525294259 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types._ import org.apache.spark.util.{ParentClassLoader, Utils} @@ -730,6 +731,39 @@ class CodegenContext { """.stripMargin } + /** + * Generates code creating a [[UnsafeArrayData]]. + * + * @param arrayName name of the array to create + * @param numElements code representing the number of elements the array should contain + * @param elementType data type of the elements in the array + * @param additionalErrorMessage string to include in the error message + */ + def createUnsafeArray( + arrayName: String, + numElements: String, + elementType: DataType, + additionalErrorMessage: String): String = { + val arraySize = freshName("size") + val arrayBytes = freshName("arrayBytes") + + s""" + |long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElements, + | ${elementType.defaultSize}); + |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful try create array with " + $arraySize + + | " bytes of data due to exceeding the limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} bytes for UnsafeArrayData." + + | "$additionalErrorMessage"); + |} + |byte[] $arrayBytes = new byte[(int)$arraySize]; + |UnsafeArrayData $arrayName = new UnsafeArrayData(); + |Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); + |$arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); + """.stripMargin + } + /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 23c09bc3b49d7..12b9ab2b272ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -530,6 +529,129 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(x, start, length) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def children: Seq[Expression] = Seq(x, start, length) + + lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { + val startInt = startVal.asInstanceOf[Int] + val lengthInt = lengthVal.asInstanceOf[Int] + val arr = xVal.asInstanceOf[ArrayData] + val startIndex = if (startInt == 0) { + throw new RuntimeException( + s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") + } else if (startInt < 0) { + startInt + arr.numElements() + } else { + startInt - 1 + } + if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + + "length must be greater than or equal to 0.") + } + // startIndex can be negative if start is negative and its absolute value is greater than the + // number of elements in the array + if (startIndex < 0 || startIndex >= arr.numElements()) { + return new GenericArrayData(Array.empty[AnyRef]) + } + val data = arr.toSeq[AnyRef](elementType) + new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (x, start, length) => { + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + | + "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { + | throw new RuntimeException("Unexpected value for length in function $prettyName: " + | + "length must be greater than or equal to 0."); + |} else if ($length > $x.numElements() - $startIdx) { + | $resLength = $x.numElements() - $startIdx; + |} else { + | $resLength = $length; + |} + |${genCodeForResult(ctx, ev, x, startIdx, resLength)} + """.stripMargin + }) + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + startIdx: String, + resLength: String): String = { + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx") + if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + s""" + |Object[] $values; + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | $values = new Object[0]; + |} else { + | $values = new Object[$resLength]; + | for (int $i = 0; $i < $resLength; $i ++) { + | $values[$i] = $getValue; + | } + |} + |${ev.value} = new $arrayClass($values); + """.stripMargin + } else { + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + s""" + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | $resLength = 0; + |} + |${ctx.createUnsafeArray(values, resLength, elementType, s" $prettyName failed.")} + |for (int $i = 0; $i < $resLength; $i ++) { + | if ($inputArray.isNullAt($i + $startIdx)) { + | $values.setNullAt($i); + | } else { + | $values.set$primitiveValueTypeName($i, $getValue); + | } + |} + |${ev.value} = $values; + """.stripMargin + } + } +} + /** * Creates a String containing all the elements of the input array separated by the delimiter. */ @@ -1127,24 +1249,11 @@ case class Concat(children: Seq[Expression]) extends Expression { } private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { - val arrayName = ctx.freshName("array") - val arraySizeName = ctx.freshName("size") val counter = ctx.freshName("counter") val arrayData = ctx.freshName("arrayData") val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) - val unsafeArraySizeInBytes = s""" - |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( - | $numElemName, - | ${elementType.defaultSize}); - |if ($arraySizeName > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName + - | " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" + - | " for UnsafeArrayData."); - |} - """.stripMargin - val baseOffset = Platform.BYTE_ARRAY_OFFSET val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) s""" @@ -1152,11 +1261,7 @@ case class Concat(children: Seq[Expression]) extends Expression { | public ArrayData concat($javaType[] args) { | ${nullArgumentProtection()} | $numElemCode - | $unsafeArraySizeInBytes - | byte[] $arrayName = new byte[(int)$arraySizeName]; - | UnsafeArrayData $arrayData = new UnsafeArrayData(); - | Platform.putLong($arrayName, $baseOffset, $numElemName); - | $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName); + | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} | int $counter = 0; | for (int y = 0; y < ${children.length}; y++) { | for (int z = 0; z < args[y].numElements(); z++) { @@ -1308,34 +1413,16 @@ case class Flatten(child: Expression) extends UnaryExpression { ctx: CodegenContext, childVariableName: String, arrayDataName: String): String = { - val arrayName = ctx.freshName("array") - val arraySizeName = ctx.freshName("size") val counter = ctx.freshName("counter") val tempArrayDataName = ctx.freshName("tempArrayData") val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) - val unsafeArraySizeInBytes = s""" - |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( - | $numElemName, - | ${elementType.defaultSize}); - |if ($arraySizeName > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + - | $arraySizeName + " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH" + - | " bytes for UnsafeArrayData."); - |} - """.stripMargin - val baseOffset = Platform.BYTE_ARRAY_OFFSET - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) s""" |$numElemCode - |$unsafeArraySizeInBytes - |byte[] $arrayName = new byte[(int)$arraySizeName]; - |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData(); - |Platform.putLong($arrayName, $baseOffset, $numElemName); - |$tempArrayDataName.pointTo($arrayName, $baseOffset, (int)$arraySizeName); + |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, s" $prettyName failed.")} |int $counter = 0; |for (int k = 0; k < $childVariableName.numElements(); k++) { | ArrayData arr = $childVariableName.getArray(k); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 749374f1a14a1..a2851d071c7c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -136,6 +136,34 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + test("Slice") { + val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType)) + val a2 = Literal.create(Seq[String]("", null, "a", "b"), ArrayType(StringType)) + val a3 = Literal.create(Seq(1, 2, null, 4), ArrayType(IntegerType)) + + checkEvaluation(Slice(a0, Literal(1), Literal(2)), Seq(1, 2)) + checkEvaluation(Slice(a0, Literal(-3), Literal(2)), Seq(4, 5)) + checkEvaluation(Slice(a0, Literal(4), Literal(10)), Seq(4, 5, 6)) + checkEvaluation(Slice(a0, Literal(-1), Literal(2)), Seq(6)) + checkExceptionInExpression[RuntimeException](Slice(a0, Literal(1), Literal(-1)), + "Unexpected value for length") + checkExceptionInExpression[RuntimeException](Slice(a0, Literal(0), Literal(1)), + "Unexpected value for start") + checkEvaluation(Slice(a0, Literal(-20), Literal(1)), Seq.empty[Int]) + checkEvaluation(Slice(a1, Literal(-20), Literal(1)), Seq.empty[String]) + checkEvaluation(Slice(a0, Literal.create(null, IntegerType), Literal(2)), null) + checkEvaluation(Slice(a0, Literal(2), Literal.create(null, IntegerType)), null) + checkEvaluation(Slice(Literal.create(null, ArrayType(IntegerType)), Literal(1), Literal(2)), + null) + + checkEvaluation(Slice(a1, Literal(1), Literal(2)), Seq("a", "b")) + checkEvaluation(Slice(a2, Literal(1), Literal(2)), Seq("", null)) + checkEvaluation(Slice(a0, Literal(10), Literal(1)), Seq.empty[Int]) + checkEvaluation(Slice(a1, Literal(10), Literal(1)), Seq.empty[String]) + checkEvaluation(Slice(a3, Literal(2), Literal(3)), Seq(2, null, 4)) + } + test("ArrayJoin") { def testArrays( arrays: Seq[Expression], diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index b4bf6d7107d7e..a22e9d4655e8c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -104,6 +104,12 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } + protected def checkExceptionInExpression[T <: Throwable : ClassTag]( + expression: => Expression, + expectedErrMsg: String): Unit = { + checkExceptionInExpression[T](expression, InternalRow.empty, expectedErrMsg) + } + protected def checkExceptionInExpression[T <: Throwable : ClassTag]( expression: => Expression, inputRow: InternalRow, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 730b36c32333c..77ca640f2e0bd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -223,7 +223,6 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal.fromObject(new java.util.LinkedList[Int]), Map("nonexisting" -> Literal(1))) checkExceptionInExpression[Exception](initializeWithNonexistingMethod, - InternalRow.fromSeq(Seq()), """A method named "nonexisting" is not declared in any enclosing class """ + "nor any supertype") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 10b6dcc0608c2..8f9e4ae18b3f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3039,6 +3039,16 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Returns an array containing all the elements in `x` from index `start` (or starting from the + * end if `start` is negative) with the specified `length`. + * @group collection_funcs + * @since 2.4.0 + */ + def slice(x: Column, start: Int, length: Int): Column = withExpr { + Slice(x.expr, Literal(start), Literal(length)) + } + /** * Concatenates the elements of `column` using the `delimiter`. Null values are replaced with * `nullReplacement`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ae21cbc802d0a..ecce06f4c0755 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -442,6 +442,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("slice function") { + val df = Seq( + Seq(1, 2, 3), + Seq(4, 5) + ).toDF("x") + + val answer = Seq(Row(Seq(2, 3)), Row(Seq(5))) + + checkAnswer(df.select(slice(df("x"), 2, 2)), answer) + checkAnswer(df.selectExpr("slice(x, 2, 2)"), answer) + + val answerNegative = Seq(Row(Seq(3)), Row(Seq(5))) + checkAnswer(df.select(slice(df("x"), -1, 1)), answerNegative) + checkAnswer(df.selectExpr("slice(x, -1, 1)"), answerNegative) + } + test("array_join function") { val df = Seq( (Seq[String]("a", "b"), ","), From 4e861db5f149e10fd8dfe6b3c1484821a590b1e8 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 7 May 2018 11:21:22 +0200 Subject: [PATCH 748/774] [SPARK-16406][SQL] Improve performance of LogicalPlan.resolve ## What changes were proposed in this pull request? `LogicalPlan.resolve(...)` uses linear searches to find an attribute matching a name. This is fine in normal cases, but gets problematic when you try to resolve a large number of columns on a plan with a large number of attributes. This PR adds an indexing structure to `resolve(...)` in order to find potential matches quicker. This PR improves the reference resolution time for the following code by 4x (11.8s -> 2.4s): ``` scala val n = 4000 val values = (1 to n).map(_.toString).mkString(", ") val columns = (1 to n).map("column" + _).mkString(", ") val query = s""" |SELECT $columns |FROM VALUES ($values) T($columns) |WHERE 1=2 AND 1 IN ($columns) |GROUP BY $columns |ORDER BY $columns |""".stripMargin spark.time(sql(query)) ``` ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #14083 from hvanhovell/SPARK-16406. --- .../sql/catalyst/expressions/package.scala | 86 ++++++++++++++ .../catalyst/plans/logical/LogicalPlan.scala | 108 ++---------------- 2 files changed, 93 insertions(+), 101 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 1a48995358af7..8a06daa37132d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql.catalyst +import java.util.Locale + import com.google.common.collect.Maps +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -138,6 +142,88 @@ package object expressions { def indexOf(exprId: ExprId): Int = { Option(exprIdToOrdinal.get(exprId)).getOrElse(-1) } + + private def unique[T](m: Map[T, Seq[Attribute]]): Map[T, Seq[Attribute]] = { + m.mapValues(_.distinct).map(identity) + } + + /** Map to use for direct case insensitive attribute lookups. */ + @transient private lazy val direct: Map[String, Seq[Attribute]] = { + unique(attrs.groupBy(_.name.toLowerCase(Locale.ROOT))) + } + + /** Map to use for qualified case insensitive attribute lookups. */ + @transient private val qualified: Map[(String, String), Seq[Attribute]] = { + val grouped = attrs.filter(_.qualifier.isDefined).groupBy { a => + (a.qualifier.get.toLowerCase(Locale.ROOT), a.name.toLowerCase(Locale.ROOT)) + } + unique(grouped) + } + + /** Perform attribute resolution given a name and a resolver. */ + def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = { + // Collect matching attributes given a name and a lookup. + def collectMatches(name: String, candidates: Option[Seq[Attribute]]): Seq[Attribute] = { + candidates.toSeq.flatMap(_.collect { + case a if resolver(a.name, name) => a.withName(name) + }) + } + + // Find matches for the given name assuming that the 1st part is a qualifier (i.e. table name, + // alias, or subquery alias) and the 2nd part is the actual name. This returns a tuple of + // matched attributes and a list of parts that are to be resolved. + // + // For example, consider an example where "a" is the table name, "b" is the column name, + // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b", + // and the second element will be List("c"). + val matches = nameParts match { + case qualifier +: name +: nestedFields => + val key = (qualifier.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT)) + val attributes = collectMatches(name, qualified.get(key)).filter { a => + resolver(qualifier, a.qualifier.get) + } + (attributes, nestedFields) + case all => + (Nil, all) + } + + // If none of attributes match `table.column` pattern, we try to resolve it as a column. + val (candidates, nestedFields) = matches match { + case (Seq(), _) => + val name = nameParts.head + val attributes = collectMatches(name, direct.get(name.toLowerCase(Locale.ROOT))) + (attributes, nameParts.tail) + case _ => matches + } + + def name = UnresolvedAttribute(nameParts).name + candidates match { + case Seq(a) if nestedFields.nonEmpty => + // One match, but we also need to extract the requested nested field. + // The foldLeft adds ExtractValues for every remaining parts of the identifier, + // and aliased it with the last part of the name. + // For example, consider "a.b.c", where "a" is resolved to an existing attribute. + // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final + // expression as "c". + val fieldExprs = nestedFields.foldLeft(a: Expression) { (e, name) => + ExtractValue(e, Literal(name), resolver) + } + Some(Alias(fieldExprs, nestedFields.last)()) + + case Seq(a) => + // One match, no nested fields, use it. + Some(a) + + case Seq() => + // No matches. + None + + case ambiguousReferences => + // More than one match. + val referenceNames = ambiguousReferences.map(_.qualifiedName).mkString(", ") + throw new AnalysisException(s"Reference '$name' is ambiguous, could be: $referenceNames.") + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 42034403d6d03..e487693927ab6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -86,6 +86,10 @@ abstract class LogicalPlan } } + private[this] lazy val childAttributes = AttributeSeq(children.flatMap(_.output)) + + private[this] lazy val outputAttributes = AttributeSeq(output) + /** * Optionally resolves the given strings to a [[NamedExpression]] using the input from all child * nodes of this LogicalPlan. The attribute is expressed as @@ -94,7 +98,7 @@ abstract class LogicalPlan def resolveChildren( nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = - resolve(nameParts, children.flatMap(_.output), resolver) + childAttributes.resolve(nameParts, resolver) /** * Optionally resolves the given strings to a [[NamedExpression]] based on the output of this @@ -104,7 +108,7 @@ abstract class LogicalPlan def resolve( nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = - resolve(nameParts, output, resolver) + outputAttributes.resolve(nameParts, resolver) /** * Given an attribute name, split it to name parts by dot, but @@ -114,105 +118,7 @@ abstract class LogicalPlan def resolveQuoted( name: String, resolver: Resolver): Option[NamedExpression] = { - resolve(UnresolvedAttribute.parseAttributeName(name), output, resolver) - } - - /** - * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. - * - * This assumes `name` has multiple parts, where the 1st part is a qualifier - * (i.e. table name, alias, or subquery alias). - * See the comment above `candidates` variable in resolve() for semantics the returned data. - */ - private def resolveAsTableColumn( - nameParts: Seq[String], - resolver: Resolver, - attribute: Attribute): Option[(Attribute, List[String])] = { - assert(nameParts.length > 1) - if (attribute.qualifier.exists(resolver(_, nameParts.head))) { - // At least one qualifier matches. See if remaining parts match. - val remainingParts = nameParts.tail - resolveAsColumn(remainingParts, resolver, attribute) - } else { - None - } - } - - /** - * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. - * - * Different from resolveAsTableColumn, this assumes `name` does NOT start with a qualifier. - * See the comment above `candidates` variable in resolve() for semantics the returned data. - */ - private def resolveAsColumn( - nameParts: Seq[String], - resolver: Resolver, - attribute: Attribute): Option[(Attribute, List[String])] = { - if (resolver(attribute.name, nameParts.head)) { - Option((attribute.withName(nameParts.head), nameParts.tail.toList)) - } else { - None - } - } - - /** Performs attribute resolution given a name and a sequence of possible attributes. */ - protected def resolve( - nameParts: Seq[String], - input: Seq[Attribute], - resolver: Resolver): Option[NamedExpression] = { - - // A sequence of possible candidate matches. - // Each candidate is a tuple. The first element is a resolved attribute, followed by a list - // of parts that are to be resolved. - // For example, consider an example where "a" is the table name, "b" is the column name, - // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b", - // and the second element will be List("c"). - var candidates: Seq[(Attribute, List[String])] = { - // If the name has 2 or more parts, try to resolve it as `table.column` first. - if (nameParts.length > 1) { - input.flatMap { option => - resolveAsTableColumn(nameParts, resolver, option) - } - } else { - Seq.empty - } - } - - // If none of attributes match `table.column` pattern, we try to resolve it as a column. - if (candidates.isEmpty) { - candidates = input.flatMap { candidate => - resolveAsColumn(nameParts, resolver, candidate) - } - } - - def name = UnresolvedAttribute(nameParts).name - - candidates.distinct match { - // One match, no nested fields, use it. - case Seq((a, Nil)) => Some(a) - - // One match, but we also need to extract the requested nested field. - case Seq((a, nestedFields)) => - // The foldLeft adds ExtractValues for every remaining parts of the identifier, - // and aliased it with the last part of the name. - // For example, consider "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final - // expression as "c". - val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) => - ExtractValue(expr, Literal(fieldName), resolver)) - Some(Alias(fieldExprs, nestedFields.last)()) - - // No matches. - case Seq() => - logTrace(s"Could not find $name in ${input.mkString(", ")}") - None - - // More than one match. - case ambiguousReferences => - val referenceNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ") - throw new AnalysisException( - s"Reference '$name' is ambiguous, could be: $referenceNames.") - } + outputAttributes.resolve(UnresolvedAttribute.parseAttributeName(name), resolver) } /** From d83e9637246b05eea202add07a168688f6c0481b Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Mon, 7 May 2018 17:54:39 +0200 Subject: [PATCH 749/774] [SPARK-24043][SQL] Interpreted Predicate should initialize nondeterministic expressions ## What changes were proposed in this pull request? When creating an InterpretedPredicate instance, initialize any Nondeterministic expressions in the expression tree to avoid java.lang.IllegalArgumentException on later call to eval(). ## How was this patch tested? - sbt SQL tests - python SQL tests - new unit test Author: Bruce Robbins Closes #21144 from bersprockets/interpretedpredicate. --- .../spark/sql/catalyst/expressions/predicates.scala | 8 ++++++++ .../spark/sql/catalyst/expressions/PredicateSuite.scala | 6 ++++++ 2 files changed, 14 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index e195ec17f3bcf..f8c6dc4e6adc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -36,6 +36,14 @@ object InterpretedPredicate { case class InterpretedPredicate(expression: Expression) extends BasePredicate { override def eval(r: InternalRow): Boolean = expression.eval(r).asInstanceOf[Boolean] + + override def initialize(partitionIndex: Int): Unit = { + super.initialize(partitionIndex) + expression.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + } + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 1bfd180ae4393..ac76b17ef4761 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -449,4 +449,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(EqualNullSafe(Literal(null, DoubleType), Literal(-1.0d)), false) checkEvaluation(EqualNullSafe(Literal(-1.0d), Literal(null, DoubleType)), false) } + + test("Interpreted Predicate should initialize nondeterministic expressions") { + val interpreted = InterpretedPredicate.create(LessThan(Rand(7), Literal(1.0))) + interpreted.initialize(0) + assert(interpreted.eval(new UnsafeRow())) + } } From 56a52e0a58fc82ea69e47d0d8c4f905565be7c8b Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Mon, 7 May 2018 14:47:58 -0700 Subject: [PATCH 750/774] [SPARK-15750][MLLIB][PYSPARK] Constructing FPGrowth fails when no numPartitions specified in pyspark ## What changes were proposed in this pull request? Change FPGrowth from private to private[spark]. If no numPartitions is specified, then default value -1 is used. But -1 is only valid in the construction function of FPGrowth, but not in setNumPartitions. So I make this change and use the constructor directly rather than using set method. ## How was this patch tested? Unit test is added Author: Jeff Zhang Closes #13493 from zjffdu/SPARK-15750. --- .../spark/mllib/api/python/PythonMLLibAPI.scala | 5 +---- .../scala/org/apache/spark/mllib/fpm/FPGrowth.scala | 2 +- python/pyspark/mllib/tests.py | 12 ++++++++++++ 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index b32d3f252ae59..db3f074ecfbac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -572,10 +572,7 @@ private[python] class PythonMLLibAPI extends Serializable { data: JavaRDD[java.lang.Iterable[Any]], minSupport: Double, numPartitions: Int): FPGrowthModel[Any] = { - val fpg = new FPGrowth() - .setMinSupport(minSupport) - .setNumPartitions(numPartitions) - + val fpg = new FPGrowth(minSupport, numPartitions) val model = fpg.run(data.rdd.map(_.asScala.toArray)) new FPGrowthModelWrapper(model) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index f6b1143272d16..4f2b7e6f0764e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -162,7 +162,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { * */ @Since("1.3.0") -class FPGrowth private ( +class FPGrowth private[spark] ( private var minSupport: Double, private var numPartitions: Int) extends Logging with Serializable { diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 14d788b0bef60..4c2ce137e331c 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -57,6 +57,7 @@ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT from pyspark.mllib.linalg.distributed import RowMatrix from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD +from pyspark.mllib.fpm import FPGrowth from pyspark.mllib.recommendation import Rating from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD from pyspark.mllib.random import RandomRDDs @@ -1762,6 +1763,17 @@ def test_pca(self): self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1]) +class FPGrowthTest(MLlibTestCase): + + def test_fpgrowth(self): + data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] + rdd = self.sc.parallelize(data, 2) + model1 = FPGrowth.train(rdd, 0.6, 2) + # use default data partition number when numPartitions is not specified + model2 = FPGrowth.train(rdd, 0.6) + self.assertEqual(sorted(model1.freqItemsets().collect()), + sorted(model2.freqItemsets().collect())) + if __name__ == "__main__": from pyspark.mllib.tests import * if not _have_scipy: From 1c9c5de951ed86290bcd7d8edaab952b8cacd290 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 7 May 2018 14:52:14 -0700 Subject: [PATCH 751/774] [SPARK-23291][SPARK-23291][R][FOLLOWUP] Update SparkR migration note for ## What changes were proposed in this pull request? This PR fixes the migration note for SPARK-23291 since it's going to backport to 2.3.1. See the discussion in https://issues.apache.org/jira/browse/SPARK-23291 ## How was this patch tested? N/A Author: hyukjinkwon Closes #21249 from HyukjinKwon/SPARK-23291. --- docs/sparkr.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sparkr.md b/docs/sparkr.md index 7fabab5d38f16..4faad2c4c1824 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -664,6 +664,6 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma - For `summary`, option for statistics to compute has been added. Its output is changed from that from `describe`. - A warning can be raised if versions of SparkR package and the Spark JVM do not match. -## Upgrading to Spark 2.4.0 +## Upgrading to SparkR 2.3.1 and above - - The `start` parameter of `substr` method was wrongly subtracted by one, previously. In other words, the index specified by `start` parameter was considered as 0-base. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. It has been fixed so the `start` parameter of `substr` method is now 1-base, e.g., therefore to get the same result as `substr(df$a, 2, 5)`, it should be changed to `substr(df$a, 1, 4)`. + - In SparkR 2.3.0 and earlier, the `start` parameter of `substr` method was wrongly subtracted by one and considered as 0-based. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. In version 2.3.1 and later, it has been fixed so the `start` parameter of `substr` method is now 1-base. As an example, `substr(lit('abcdef'), 2, 4))` would result to `abc` in SparkR 2.3.0, and the result would be `bcd` in SparkR 2.3.1. From f48bd6bdc5aefd9ec43e2d0ee648d17add7ef554 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 7 May 2018 14:55:41 -0700 Subject: [PATCH 752/774] [SPARK-22885][ML][TEST] ML test for StructuredStreaming: spark.ml.tuning ## What changes were proposed in this pull request? ML test for StructuredStreaming: spark.ml.tuning ## How was this patch tested? N/A Author: WeichenXu Closes #20261 from WeichenXu123/ml_stream_tuning_test. --- .../spark/ml/tuning/CrossValidatorSuite.scala | 15 +++++++++++---- .../ml/tuning/TrainValidationSplitSuite.scala | 15 +++++++++++---- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 15dade2627090..e6ee7220d2279 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -25,17 +25,17 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, MulticlassClassificationEvaluator, RegressionEvaluator} import org.apache.spark.ml.feature.HashingTF -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} -import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType class CrossValidatorSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -66,6 +66,13 @@ class CrossValidatorSuite assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) assert(cvModel.avgMetrics.length === lrParamMaps.length) + + val result = cvModel.transform(dataset).select("prediction").as[Double].collect() + testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), cvModel, "prediction") { + rows => + val result2 = rows.map(_.getDouble(0)) + assert(result === result2) + } } test("cross validation with linear regression") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 9024342d9c831..cd76acf9c67bc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -24,17 +24,17 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest} import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} -import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType class TrainValidationSplitSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -64,6 +64,13 @@ class TrainValidationSplitSuite assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) assert(tvsModel.validationMetrics.length === lrParamMaps.length) + + val result = tvsModel.transform(dataset).select("prediction").as[Double].collect() + testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), tvsModel, "prediction") { + rows => + val result2 = rows.map(_.getDouble(0)) + assert(result === result2) + } } test("train validation with linear regression") { From 76ecd095024a658bf68e5db658e4416565b30c17 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 7 May 2018 14:57:14 -0700 Subject: [PATCH 753/774] [SPARK-20114][ML] spark.ml parity for sequential pattern mining - PrefixSpan ## What changes were proposed in this pull request? PrefixSpan API for spark.ml. New implementation instead of #20810 ## How was this patch tested? TestSuite added. Author: WeichenXu Closes #20973 from WeichenXu123/prefixSpan2. --- .../org/apache/spark/ml/fpm/PrefixSpan.scala | 96 +++++++++++++ .../apache/spark/mllib/fpm/PrefixSpan.scala | 3 +- .../apache/spark/ml/fpm/PrefixSpanSuite.scala | 136 ++++++++++++++++++ 3 files changed, 233 insertions(+), 2 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala new file mode 100644 index 0000000000000..02168fee16dbf --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.fpm + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan} +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType} + +/** + * :: Experimental :: + * A parallel PrefixSpan algorithm to mine frequent sequential patterns. + * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns + * Efficiently by Prefix-Projected Pattern Growth + * (see here). + * + * @see Sequential Pattern Mining + * (Wikipedia) + */ +@Since("2.4.0") +@Experimental +object PrefixSpan { + + /** + * :: Experimental :: + * Finds the complete set of frequent sequential patterns in the input sequences of itemsets. + * + * @param dataset A dataset or a dataframe containing a sequence column which is + * {{{Seq[Seq[_]]}}} type + * @param sequenceCol the name of the sequence column in dataset, rows with nulls in this column + * are ignored + * @param minSupport the minimal support level of the sequential pattern, any pattern that + * appears more than (minSupport * size-of-the-dataset) times will be output + * (recommended value: `0.1`). + * @param maxPatternLength the maximal length of the sequential pattern + * (recommended value: `10`). + * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the + * internal storage format) allowed in a projected database before + * local processing. If a projected database exceeds this size, another + * iteration of distributed prefix growth is run + * (recommended value: `32000000`). + * @return A `DataFrame` that contains columns of sequence and corresponding frequency. + * The schema of it will be: + * - `sequence: Seq[Seq[T]]` (T is the item type) + * - `freq: Long` + */ + @Since("2.4.0") + def findFrequentSequentialPatterns( + dataset: Dataset[_], + sequenceCol: String, + minSupport: Double, + maxPatternLength: Int, + maxLocalProjDBSize: Long): DataFrame = { + + val inputType = dataset.schema(sequenceCol).dataType + require(inputType.isInstanceOf[ArrayType] && + inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType], + s"The input column must be ArrayType and the array element type must also be ArrayType, " + + s"but got $inputType.") + + + val data = dataset.select(sequenceCol) + val sequences = data.where(col(sequenceCol).isNotNull).rdd + .map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray) + + val mllibPrefixSpan = new mllibPrefixSpan() + .setMinSupport(minSupport) + .setMaxPatternLength(maxPatternLength) + .setMaxLocalProjDBSize(maxLocalProjDBSize) + + val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq)) + val schema = StructType(Seq( + StructField("sequence", dataset.schema(sequenceCol).dataType, nullable = false), + StructField("freq", LongType, nullable = false))) + val freqSequences = dataset.sparkSession.createDataFrame(rows, schema) + + freqSequences + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 3f8d65a378e2c..7aed2f3bd8a61 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -49,8 +49,7 @@ import org.apache.spark.storage.StorageLevel * * @param minSupport the minimal support level of the sequential pattern, any pattern that appears * more than (minSupport * size-of-the-dataset) times will be output - * @param maxPatternLength the maximal length of the sequential pattern, any pattern that appears - * less than maxPatternLength will be output + * @param maxPatternLength the maximal length of the sequential pattern * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the internal * storage format) allowed in a projected database before local * processing. If a projected database exceeds this size, another diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala new file mode 100644 index 0000000000000..9e538696cbcf7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.fpm + +import org.apache.spark.ml.util.MLTest +import org.apache.spark.sql.DataFrame + +class PrefixSpanSuite extends MLTest { + + import testImplicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + } + + test("PrefixSpan projections with multiple partial starts") { + val smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence") + val result = PrefixSpan.findFrequentSequentialPatterns(smallDataset, "sequence", + minSupport = 1.0, maxPatternLength = 2, maxLocalProjDBSize = 32000000) + .as[(Seq[Seq[Int]], Long)].collect() + val expected = Array( + (Seq(Seq(1)), 1L), + (Seq(Seq(1, 2)), 1L), + (Seq(Seq(1), Seq(1)), 1L), + (Seq(Seq(1), Seq(2)), 1L), + (Seq(Seq(1), Seq(3)), 1L), + (Seq(Seq(1, 3)), 1L), + (Seq(Seq(2)), 1L), + (Seq(Seq(2, 3)), 1L), + (Seq(Seq(2), Seq(1)), 1L), + (Seq(Seq(2), Seq(2)), 1L), + (Seq(Seq(2), Seq(3)), 1L), + (Seq(Seq(3)), 1L)) + compareResults[Int](expected, result) + } + + /* + To verify expected results for `smallTestData`, create file "prefixSpanSeqs2" with content + (format = (transactionID, idxInTransaction, numItemsinItemset, itemset)): + 1 1 2 1 2 + 1 2 1 3 + 2 1 1 1 + 2 2 2 3 2 + 2 3 2 1 2 + 3 1 2 1 2 + 3 2 1 5 + 4 1 1 6 + In R, run: + library("arulesSequences") + prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE")) + freqItemSeq = cspade(prefixSpanSeqs, + parameter = 0.5, maxlen = 5 )) + resSeq = as(freqItemSeq, "data.frame") + resSeq + + sequence support + 1 <{1}> 0.75 + 2 <{2}> 0.75 + 3 <{3}> 0.50 + 4 <{1},{3}> 0.50 + 5 <{1,2}> 0.75 + */ + val smallTestData = Seq( + Seq(Seq(1, 2), Seq(3)), + Seq(Seq(1), Seq(3, 2), Seq(1, 2)), + Seq(Seq(1, 2), Seq(5)), + Seq(Seq(6))) + + val smallTestDataExpectedResult = Array( + (Seq(Seq(1)), 3L), + (Seq(Seq(2)), 3L), + (Seq(Seq(3)), 2L), + (Seq(Seq(1), Seq(3)), 2L), + (Seq(Seq(1, 2)), 3L) + ) + + test("PrefixSpan Integer type, variable-size itemsets") { + val df = smallTestData.toDF("sequence") + val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", + minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + .as[(Seq[Seq[Int]], Long)].collect() + + compareResults[Int](smallTestDataExpectedResult, result) + } + + test("PrefixSpan input row with nulls") { + val df = (smallTestData :+ null).toDF("sequence") + val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", + minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + .as[(Seq[Seq[Int]], Long)].collect() + + compareResults[Int](smallTestDataExpectedResult, result) + } + + test("PrefixSpan String type, variable-size itemsets") { + val intToString = (1 to 6).zip(Seq("a", "b", "c", "d", "e", "f")).toMap + val df = smallTestData + .map(seq => seq.map(itemSet => itemSet.map(intToString))) + .toDF("sequence") + val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", + minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + .as[(Seq[Seq[String]], Long)].collect() + + val expected = smallTestDataExpectedResult.map { case (seq, freq) => + (seq.map(itemSet => itemSet.map(intToString)), freq) + } + compareResults[String](expected, result) + } + + private def compareResults[Item]( + expectedValue: Array[(Seq[Seq[Item]], Long)], + actualValue: Array[(Seq[Seq[Item]], Long)]): Unit = { + val expectedSet = expectedValue.map { x => + (x._1.map(_.toSet), x._2) + }.toSet + val actualSet = actualValue.map { x => + (x._1.map(_.toSet), x._2) + }.toSet + assert(expectedSet === actualSet) + } +} + From 0d63eb8888d17df747fb41d7ba254718bb7af3ae Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Mon, 7 May 2018 20:08:41 -0700 Subject: [PATCH 754/774] [SPARK-23975][ML] Add support of array input for all clustering methods ## What changes were proposed in this pull request? Add support for all of the clustering methods ## How was this patch tested? unit tests added Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21195 from ludatabricks/SPARK-23975-1. --- .../spark/ml/clustering/BisectingKMeans.scala | 21 ++++---- .../spark/ml/clustering/GaussianMixture.scala | 12 +++-- .../apache/spark/ml/clustering/KMeans.scala | 31 +++--------- .../org/apache/spark/ml/clustering/LDA.scala | 9 ++-- .../apache/spark/ml/util/DatasetUtils.scala | 13 ++++- .../apache/spark/ml/util/SchemaUtils.scala | 16 ++++++- .../ml/clustering/BisectingKMeansSuite.scala | 21 +++++++- .../ml/clustering/GaussianMixtureSuite.scala | 21 +++++++- .../spark/ml/clustering/KMeansSuite.scala | 48 ++++++------------- .../apache/spark/ml/clustering/LDASuite.scala | 20 +++++++- .../apache/spark/ml/util/MLTestingUtils.scala | 23 ++++++++- 11 files changed, 147 insertions(+), 88 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index addc12ac52ec1..438e53ba6197c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -22,17 +22,15 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} -import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{IntegerType, StructType} @@ -75,7 +73,7 @@ private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } } @@ -113,7 +111,8 @@ class BisectingKMeansModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val predictUDF = udf((vector: Vector) => predict(vector)) - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + dataset.withColumn($(predictionCol), + predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } @Since("2.0.0") @@ -132,9 +131,9 @@ class BisectingKMeansModel private[ml] ( */ @Since("2.0.0") def computeCost(dataset: Dataset[_]): Double = { - SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) - val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } - parentModel.computeCost(data.map(OldVectors.fromML)) + SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol) + val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) + parentModel.computeCost(data) } @Since("2.0.0") @@ -260,9 +259,7 @@ class BisectingKMeans @Since("2.0.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): BisectingKMeansModel = { transformSchema(dataset.schema, logging = true) - val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { - case Row(point: Vector) => OldVectors.fromML(point) - } + val rdd = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) val instr = Instrumentation.create(this, rdd) instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index b5804900c0358..88d618c3a03a8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatr Vector => OldVector, Vectors => OldVectors} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{IntegerType, StructType} @@ -63,7 +63,7 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) val schemaWithPredictionCol = SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) SchemaUtils.appendColumn(schemaWithPredictionCol, $(probabilityCol), new VectorUDT) } @@ -109,8 +109,9 @@ class GaussianMixtureModel private[ml] ( transformSchema(dataset.schema, logging = true) val predUDF = udf((vector: Vector) => predict(vector)) val probUDF = udf((vector: Vector) => predictProbability(vector)) - dataset.withColumn($(predictionCol), predUDF(col($(featuresCol)))) - .withColumn($(probabilityCol), probUDF(col($(featuresCol)))) + dataset + .withColumn($(predictionCol), predUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) + .withColumn($(probabilityCol), probUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } @Since("2.0.0") @@ -340,7 +341,8 @@ class GaussianMixture @Since("2.0.0") ( val sc = dataset.sparkSession.sparkContext val numClusters = $(k) - val instances: RDD[Vector] = dataset.select(col($(featuresCol))).rdd.map { + val instances: RDD[Vector] = dataset + .select(DatasetUtils.columnToVector(dataset, getFeaturesCol)).rdd.map { case Row(features: Vector) => features }.cache() diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index de61c9c089a36..97f246fbfd859 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model, PipelineStage} -import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ @@ -34,7 +34,7 @@ import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.udf -import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType} +import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.VersionUtils.majorVersion @@ -86,24 +86,13 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") def getInitSteps: Int = $(initSteps) - /** - * Validates the input schema. - * @param schema input schema - */ - private[clustering] def validateSchema(schema: StructType): Unit = { - val typeCandidates = List( new VectorUDT, - new ArrayType(DoubleType, false), - new ArrayType(FloatType, false)) - - SchemaUtils.checkColumnTypes(schema, $(featuresCol), typeCandidates) - } /** * Validates and transforms the input schema. * @param schema input schema * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - validateSchema(schema) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } } @@ -160,12 +149,8 @@ class KMeansModel private[ml] ( // TODO: Replace the temp fix when we have proper evaluators defined for clustering. @Since("2.0.0") def computeCost(dataset: Dataset[_]): Double = { - validateSchema(dataset.schema) - - val data: RDD[OldVector] = dataset.select(DatasetUtils.columnToVector(dataset, getFeaturesCol)) - .rdd.map { - case Row(point: Vector) => OldVectors.fromML(point) - } + SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol) + val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) parentModel.computeCost(data) } @@ -351,11 +336,7 @@ class KMeans @Since("1.5.0") ( transformSchema(dataset.schema, logging = true) val handlePersistence = dataset.storageLevel == StorageLevel.NONE - val instances: RDD[OldVector] = dataset.select( - DatasetUtils.columnToVector(dataset, getFeaturesCol)) - .rdd.map { - case Row(point: Vector) => OldVectors.fromML(point) - } + val instances = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) if (handlePersistence) { instances.persist(StorageLevel.MEMORY_AND_DISK) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 47077230fac0a..afe599cd167cb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -43,7 +43,7 @@ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, StructType} import org.apache.spark.util.PeriodicCheckpointer import org.apache.spark.util.VersionUtils @@ -345,7 +345,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM s" must be >= 1. Found value: $getTopicConcentration") } } - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) } @@ -461,7 +461,8 @@ abstract class LDAModel private[ml] ( val transformer = oldLocalModel.getTopicDistributionMethod val t = udf { (v: Vector) => transformer(OldVectors.fromML(v)).asML } - dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF() + dataset.withColumn($(topicDistributionCol), + t(DatasetUtils.columnToVector(dataset, getFeaturesCol))).toDF() } else { logWarning("LDAModel.transform was called without any output columns. Set an output column" + " such as topicDistributionCol to produce results.") @@ -938,7 +939,7 @@ object LDA extends MLReadable[LDA] { featuresCol: String): RDD[(Long, OldVector)] = { dataset .withColumn("docId", monotonically_increasing_id()) - .select("docId", featuresCol) + .select(col("docId"), DatasetUtils.columnToVector(dataset, featuresCol)) .rdd .map { case Row(docId: Long, features: Vector) => (docId, OldVectors.fromML(features)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala index 52619cb65489a..6af4b3ebc2cc2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala @@ -17,8 +17,10 @@ package org.apache.spark.ml.util -import org.apache.spark.ml.linalg.{Vectors, VectorUDT} -import org.apache.spark.sql.{Column, Dataset} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Column, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType} @@ -60,4 +62,11 @@ private[spark] object DatasetUtils { throw new IllegalArgumentException(s"$other column cannot be cast to Vector") } } + + def columnToOldVector(dataset: Dataset[_], colName: String): RDD[OldVector] = { + dataset.select(columnToVector(dataset, colName)) + .rdd.map { + case Row(point: Vector) => OldVectors.fromML(point) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 334410c9620de..d9a3f85ef9a24 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -17,7 +17,8 @@ package org.apache.spark.ml.util -import org.apache.spark.sql.types.{DataType, NumericType, StructField, StructType} +import org.apache.spark.ml.linalg.VectorUDT +import org.apache.spark.sql.types._ /** @@ -101,4 +102,17 @@ private[spark] object SchemaUtils { require(!schema.fieldNames.contains(col.name), s"Column ${col.name} already exists.") StructType(schema.fields :+ col) } + + /** + * Check whether the given column in the schema is one of the supporting vector type: Vector, + * Array[Float]. Array[Double] + * @param schema input schema + * @param colName column name + */ + def validateVectorCompatibleColumn(schema: StructType, colName: String): Unit = { + val typeCandidates = List( new VectorUDT, + new ArrayType(DoubleType, false), + new ArrayType(FloatType, false)) + checkColumnTypes(schema, colName, typeCandidates) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 02880f96ae6d9..f3ff2afcad2cd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -17,13 +17,16 @@ package org.apache.spark.ml.clustering +import scala.language.existentials + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.clustering.DistanceMeasure import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Dataset +import org.apache.spark.sql.{DataFrame, Dataset} class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -182,6 +185,22 @@ class BisectingKMeansSuite model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0) } + + test("BisectingKMeans with Array input") { + def trainAndComputeCost(dataset: Dataset[_]): Double = { + val model = new BisectingKMeans().setK(k).setMaxIter(1).setSeed(1).fit(dataset) + model.computeCost(dataset) + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val trueCost = trainAndComputeCost(newDataset) + val doubleArrayCost = trainAndComputeCost(newDatasetD) + val floatArrayCost = trainAndComputeCost(newDatasetF) + + // checking the cost is fine enough as a sanity check + assert(trueCost ~== doubleArrayCost absTol 1e-6) + assert(trueCost ~== floatArrayCost absTol 1e-6) + } } object BisectingKMeansSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 08b800b7e4183..d0d461a42711a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import scala.language.existentials + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.ml.param.ParamMap @@ -24,8 +26,7 @@ import org.apache.spark.ml.stat.distribution.MultivariateGaussian import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Dataset, Row} - +import org.apache.spark.sql.{DataFrame, Dataset, Row} class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -256,6 +257,22 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext val expectedMatrix = GaussianMixture.unpackUpperTriangularMatrix(4, triangularValues) assert(symmetricMatrix === expectedMatrix) } + + test("GaussianMixture with Array input") { + def trainAndComputlogLikelihood(dataset: Dataset[_]): Double = { + val model = new GaussianMixture().setK(k).setMaxIter(1).setSeed(1).fit(dataset) + model.summary.logLikelihood + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val trueLikelihood = trainAndComputlogLikelihood(newDataset) + val doubleLikelihood = trainAndComputlogLikelihood(newDatasetD) + val floatLikelihood = trainAndComputlogLikelihood(newDatasetF) + + // checking the cost is fine enough as a sanity check + assert(trueLikelihood ~== doubleLikelihood absTol 1e-6) + assert(trueLikelihood ~== floatLikelihood absTol 1e-6) + } } object GaussianMixtureSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 5445ebe5c95eb..680a7c2034083 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.clustering +import scala.language.existentials import scala.util.Random import org.dmg.pmml.{ClusteringModel, PMML} @@ -25,13 +26,11 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util._ -import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, - KMeansModel => MLlibKMeansModel} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType} private[clustering] case class TestRow(features: Vector) @@ -202,38 +201,19 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR } test("KMean with Array input") { - val featuresColNameD = "array_double_features" - val featuresColNameF = "array_float_features" - - val doubleUDF = udf { (features: Vector) => - val featureArray = Array.fill[Double](features.size)(0.0) - features.foreachActive((idx, value) => featureArray(idx) = value.toFloat) - featureArray - } - val floatUDF = udf { (features: Vector) => - val featureArray = Array.fill[Float](features.size)(0.0f) - features.foreachActive((idx, value) => featureArray(idx) = value.toFloat) - featureArray + def trainAndComputeCost(dataset: Dataset[_]): Double = { + val model = new KMeans().setK(k).setMaxIter(1).setSeed(1).fit(dataset) + model.computeCost(dataset) } - val newdatasetD = dataset.withColumn(featuresColNameD, doubleUDF(col("features"))) - .drop("features") - val newdatasetF = dataset.withColumn(featuresColNameF, floatUDF(col("features"))) - .drop("features") - assert(newdatasetD.schema(featuresColNameD).dataType.equals(new ArrayType(DoubleType, false))) - assert(newdatasetF.schema(featuresColNameF).dataType.equals(new ArrayType(FloatType, false))) - - val kmeansD = new KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameD).setSeed(1) - val kmeansF = new KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameF).setSeed(1) - val modelD = kmeansD.fit(newdatasetD) - val modelF = kmeansF.fit(newdatasetF) - val transformedD = modelD.transform(newdatasetD) - val transformedF = modelF.transform(newdatasetF) - - val predictDifference = transformedD.select("prediction") - .except(transformedF.select("prediction")) - assert(predictDifference.count() == 0) - assert(modelD.computeCost(newdatasetD) == modelF.computeCost(newdatasetF) ) + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val trueCost = trainAndComputeCost(newDataset) + val doubleArrayCost = trainAndComputeCost(newDatasetD) + val floatArrayCost = trainAndComputeCost(newDatasetF) + + // checking the cost is fine enough as a sanity check + assert(trueCost ~== doubleArrayCost absTol 1e-6) + assert(trueCost ~== floatArrayCost absTol 1e-6) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index e73bbc18d76bd..8d728f063dd8c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import scala.language.existentials + import org.apache.hadoop.fs.Path import org.apache.spark.SparkFunSuite @@ -26,7 +28,6 @@ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql._ - object LDASuite { def generateLDAData( spark: SparkSession, @@ -323,4 +324,21 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead assert(model.getOptimizer === optimizer) } } + + test("LDA with Array input") { + def trainAndLogLikelihoodAndPerplexity(dataset: Dataset[_]): (Double, Double) = { + val model = new LDA().setK(k).setOptimizer("online").setMaxIter(1).setSeed(1).fit(dataset) + (model.logLikelihood(dataset), model.logPerplexity(dataset)) + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val (ll, lp) = trainAndLogLikelihoodAndPerplexity(newDataset) + val (llD, lpD) = trainAndLogLikelihoodAndPerplexity(newDatasetD) + val (llF, lpF) = trainAndLogLikelihoodAndPerplexity(newDatasetF) + // TODO: need to compare the results once we fix the seed issue for LDA (SPARK-22210) + assert(llD <= 0.0 && llD != Double.NegativeInfinity) + assert(llF <= 0.0 && llF != Double.NegativeInfinity) + assert(lpD >= 0.0 && lpD != Double.NegativeInfinity) + assert(lpF >= 0.0 && lpF != Double.NegativeInfinity) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index c328d81b4bc3a..5e72b4d864c1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml._ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.feature.{Instance, LabeledPoint} -import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol} import org.apache.spark.ml.recommendation.{ALS, ALSModel} @@ -247,4 +247,25 @@ object MLTestingUtils extends SparkFunSuite { } models.sliding(2).foreach { case Seq(m1, m2) => modelEquals(m1, m2)} } + + /** + * Helper function for testing different input types for "features" column. Given a DataFrame, + * generate three output DataFrames: one having vector "features" column with float precision, + * one having double array "features" column with float precision, and one having float array + * "features" column. + */ + def generateArrayFeatureDataset(dataset: Dataset[_], + featuresColName: String = "features"): (Dataset[_], Dataset[_], Dataset[_]) = { + val toFloatVectorUDF = udf { (features: Vector) => + Vectors.dense(features.toArray.map(_.toFloat.toDouble))} + val toDoubleArrayUDF = udf { (features: Vector) => features.toArray} + val toFloatArrayUDF = udf { (features: Vector) => features.toArray.map(_.toFloat)} + val newDataset = dataset.withColumn(featuresColName, toFloatVectorUDF(col(featuresColName))) + val newDatasetD = newDataset.withColumn(featuresColName, toDoubleArrayUDF(col(featuresColName))) + val newDatasetF = newDataset.withColumn(featuresColName, toFloatArrayUDF(col(featuresColName))) + assert(newDataset.schema(featuresColName).dataType.equals(new VectorUDT)) + assert(newDatasetD.schema(featuresColName).dataType.equals(new ArrayType(DoubleType, false))) + assert(newDatasetF.schema(featuresColName).dataType.equals(new ArrayType(FloatType, false))) + (newDataset, newDatasetD, newDatasetF) + } } From cd12c5c3ecf28f7b04f566c2057f9b65eb456b7d Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Tue, 8 May 2018 12:21:33 +0800 Subject: [PATCH 755/774] [SPARK-24128][SQL] Mention configuration option in implicit CROSS JOIN error ## What changes were proposed in this pull request? Mention `spark.sql.crossJoin.enabled` in error message when an implicit `CROSS JOIN` is detected. ## How was this patch tested? `CartesianProductSuite` and `JoinSuite`. Author: Henry Robinson Closes #21201 from henryr/spark-24128. --- R/pkg/tests/fulltests/test_sparkSQL.R | 4 ++-- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 6 ++++-- .../catalyst/optimizer/CheckCartesianProductsSuite.scala | 2 +- .../src/test/scala/org/apache/spark/sql/JoinSuite.scala | 4 ++-- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 3a8866bf2a88a..43725e0ebd3bf 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2210,8 +2210,8 @@ test_that("join(), crossJoin() and merge() on a DataFrame", { expect_equal(count(where(join(df, df2), df$name == df2$name)), 3) # cartesian join expect_error(tryCatch(count(join(df, df2)), error = function(e) { stop(e) }), - paste0(".*(org.apache.spark.sql.AnalysisException: Detected cartesian product for", - " INNER join between logical plans).*")) + paste0(".*(org.apache.spark.sql.AnalysisException: Detected implicit cartesian", + " product for INNER join between logical plans).*")) joined <- crossJoin(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 45f13956a0a85..bfa61116a6658 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1182,12 +1182,14 @@ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper { case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _) if isCartesianProduct(j) => throw new AnalysisException( - s"""Detected cartesian product for ${j.joinType.sql} join between logical plans + s"""Detected implicit cartesian product for ${j.joinType.sql} join between logical plans |${left.treeString(false).trim} |and |${right.treeString(false).trim} |Join condition is missing or trivial. - |Use the CROSS JOIN syntax to allow cartesian products between these relations.""" + |Either: use the CROSS JOIN syntax to allow cartesian products between these + |relations, or: enable implicit cartesian products by setting the configuration + |variable spark.sql.crossJoin.enabled=true""" .stripMargin) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala index 21220b38968e8..788fedb3c8e8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala @@ -56,7 +56,7 @@ class CheckCartesianProductsSuite extends PlanTest { val thrownException = the [AnalysisException] thrownBy { performCartesianProductCheck(joinType) } - assert(thrownException.message.contains("Detected cartesian product")) + assert(thrownException.message.contains("Detected implicit cartesian product")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 771e1186e63ab..8fa747465cb1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -239,7 +239,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(2, 2, 1, null) :: Row(2, 2, 2, 2) :: Nil) } - assert(e.getMessage.contains("Detected cartesian product for INNER join " + + assert(e.getMessage.contains("Detected implicit cartesian product for INNER join " + "between logical plans")) } } @@ -611,7 +611,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val e = intercept[Exception] { checkAnswer(sql(query), Nil); } - assert(e.getMessage.contains("Detected cartesian product")) + assert(e.getMessage.contains("Detected implicit cartesian product")) } cartesianQueries.foreach(checkCartesianDetection) From 05eb19b6e09065265358eec2db2ff3b42806dfc9 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 8 May 2018 14:32:04 +0800 Subject: [PATCH 756/774] [SPARK-24188][CORE] Restore "/version" API endpoint. It was missing the jax-rs annotation. Author: Marcelo Vanzin Closes #21245 from vanzin/SPARK-24188. Change-Id: Ib338e34b363d7c729cc92202df020dc51033b719 --- .../org/apache/spark/status/api/v1/ApiRootResource.scala | 1 + .../org/apache/spark/deploy/history/HistoryServerSuite.scala | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 7127397f6205c..d121068718b8a 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -49,6 +49,7 @@ private[v1] class ApiRootResource extends ApiRequestContext { @Path("applications/{appId}") def application(): Class[OneApplicationResource] = classOf[OneApplicationResource] + @GET @Path("version") def version(): VersionInfo = new VersionInfo(org.apache.spark.SPARK_VERSION) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 87f12f303cd5e..a871b1c717837 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -296,6 +296,11 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers all (siteRelativeLinks) should startWith (uiRoot) } + test("/version api endpoint") { + val response = getUrl("version") + assert(response.contains(SPARK_VERSION)) + } + test("ajax rendered relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val uiRoot = "/testwebproxybase" System.setProperty("spark.ui.proxyBase", uiRoot) From e17567ca78dbb416039c17da212957c8955bfa65 Mon Sep 17 00:00:00 2001 From: yucai Date: Tue, 8 May 2018 11:34:27 +0200 Subject: [PATCH 757/774] [SPARK-24076][SQL] Use different seed in HashAggregate to avoid hash conflict ## What changes were proposed in this pull request? HashAggregate uses the same hash algorithm and seed as ShuffleExchange, it may lead to bad hash conflict when shuffle.partitions=8192*n. Considering below example: ``` SET spark.sql.shuffle.partitions=8192; INSERT OVERWRITE TABLE target_xxx SELECT item_id, auct_end_dt FROM from source_xxx GROUP BY item_id, auct_end_dt; ``` In the shuffle stage, if user sets the shuffle.partition = 8192, all tuples in the same partition will meet the following relationship: ``` hash(tuple x) = hash(tuple y) + n * 8192 ``` Then in the next HashAggregate stage, all tuples from the same partition need be put into a 16K BytesToBytesMap (unsafeRowAggBuffer). Here, the HashAggregate uses the same hash algorithm on the same expression as shuffle, and uses the same seed, and 16K = 8192 * 2, so actually, all tuples in the same parititon will only be hashed to 2 different places in the BytesToBytesMap. It is bad hash conflict. With BytesToBytesMap growing, the conflict will always exist. Before change: hash_conflict After change: no_hash_conflict ## How was this patch tested? Unit tests and production cases. Author: yucai Closes #21149 from yucai/SPARK-24076. --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index a5dc6ebf2b0f2..6a8ec4f722aea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -755,7 +755,10 @@ case class HashAggregateExec( } // generate hash code for key - val hashExpr = Murmur3Hash(groupingExpressions, 42) + // SPARK-24076: HashAggregate uses the same hash algorithm on the same expressions + // as ShuffleExchange, it may lead to bad hash conflict when shuffle.partitions=8192*n, + // pick a different seed to avoid this conflict + val hashExpr = Murmur3Hash(groupingExpressions, 48) val hashEval = BindReferences.bindReference(hashExpr, child.output).genCode(ctx) val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, From b54bbe57b33b00063596cd9588fa2461745ed571 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 8 May 2018 21:22:54 +0800 Subject: [PATCH 758/774] [SPARK-24131][PYSPARK][FOLLOWUP] Add majorMinorVersion API to PySpark for determining Spark versions ## What changes were proposed in this pull request? More close to Scala API behavior when can't parse input by throwing exception. Add tests. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh Closes #21211 from viirya/SPARK-24131-followup. --- python/pyspark/tests.py | 4 ++++ python/pyspark/util.py | 37 ++++++++++++++++++++++--------------- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7b8ce2c6b799f..498d6b57e4353 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2312,6 +2312,10 @@ def test_py4j_exception_message(self): self.assertTrue('NullPointerException' in _exception_message(context.exception)) + def test_parsing_version_string(self): + from pyspark.util import VersionUtils + self.assertRaises(ValueError, lambda: VersionUtils.majorMinorVersion("abced")) + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 04df835bf6717..59cc2a6329350 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -62,24 +62,31 @@ def _get_argspec(f): return argspec -def majorMinorVersion(version): +class VersionUtils(object): """ - Get major and minor version numbers for given Spark version string. - - >>> version = "2.4.0" - >>> majorMinorVersion(version) - (2, 4) + Provides utility method to determine Spark versions with given input string. + """ + @staticmethod + def majorMinorVersion(sparkVersion): + """ + Given a Spark version string, return the (major version number, minor version number). + E.g., for 2.0.1-SNAPSHOT, return (2, 0). - >>> version = "abc" - >>> majorMinorVersion(version) is None - True + >>> sparkVersion = "2.4.0" + >>> VersionUtils.majorMinorVersion(sparkVersion) + (2, 4) + >>> sparkVersion = "2.3.0-SNAPSHOT" + >>> VersionUtils.majorMinorVersion(sparkVersion) + (2, 3) - """ - m = re.search('^(\d+)\.(\d+)(\..*)?$', version) - if m is None: - return None - else: - return (int(m.group(1)), int(m.group(2))) + """ + m = re.search('^(\d+)\.(\d+)(\..*)?$', sparkVersion) + if m is not None: + return (int(m.group(1)), int(m.group(2))) + else: + raise ValueError("Spark tried to parse '%s' as a Spark" % sparkVersion + + " version string, but it could not find the major and minor" + + " version numbers.") if __name__ == "__main__": From 2f6fe7d679a878ffd103cac6f06081c5b3888744 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 8 May 2018 21:24:35 +0800 Subject: [PATCH 759/774] [SPARK-23094][SPARK-23723][SPARK-23724][SQL][FOLLOW-UP] Support custom encoding for json files ## What changes were proposed in this pull request? This is to add a test case to check the behaviors when users write json in the specified UTF-16/UTF-32 encoding with multiline off. ## How was this patch tested? N/A Author: gatorsmile Closes #21254 from gatorsmile/followupSPARK-23094. --- .../spark/sql/catalyst/json/JSONOptions.scala | 9 +++++---- .../datasources/json/JsonSuite.scala | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 5f130af606e19..2579374e3f4e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -110,11 +110,12 @@ private[sql] class JSONOptions( val blacklist = Seq(Charset.forName("UTF-16"), Charset.forName("UTF-32")) val isBlacklisted = blacklist.contains(Charset.forName(enc)) require(multiLine || !isBlacklisted, - s"""The ${enc} encoding must not be included in the blacklist when multiLine is disabled: - | ${blacklist.mkString(", ")}""".stripMargin) + s"""The $enc encoding in the blacklist is not allowed when multiLine is disabled. + |Blacklist: ${blacklist.mkString(", ")}""".stripMargin) + + val isLineSepRequired = + multiLine || Charset.forName(enc) == StandardCharsets.UTF_8 || lineSeparator.nonEmpty - val isLineSepRequired = !(multiLine == false && - Charset.forName(enc) != StandardCharsets.UTF_8 && lineSeparator.isEmpty) require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding") enc diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 0db688fec9a67..4b3921c61a000 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2313,6 +2313,25 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } } + test("SPARK-23723: write json in UTF-16/32 with multiline off") { + Seq("UTF-16", "UTF-32").foreach { encoding => + withTempPath { path => + val ds = spark.createDataset(Seq( + ("a", 1), ("b", 2), ("c", 3)) + ).repartition(2) + val e = intercept[IllegalArgumentException] { + ds.write + .option("encoding", encoding) + .option("multiline", "false") + .format("json").mode("overwrite") + .save(path.getCanonicalPath) + }.getMessage + assert(e.contains( + s"$encoding encoding in the blacklist is not allowed when multiLine is disabled")) + } + } + } + def checkReadJson(lineSep: String, encoding: String, inferSchema: Boolean, id: Int): Unit = { test(s"SPARK-23724: checks reading json in ${encoding} #${id}") { val schema = new StructType().add("f1", StringType).add("f2", IntegerType) From 487faf17ab96c8edb729501dfb1ff82f7b2c6031 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 8 May 2018 23:43:02 +0800 Subject: [PATCH 760/774] [SPARK-24117][SQL] Unified the getSizePerRow ## What changes were proposed in this pull request? This pr unified the `getSizePerRow` because `getSizePerRow` is used in many places. For example: 1. [LocalRelation.scala#L80](https://github.com/wangyum/spark/blob/f70f46d1e5bc503e9071707d837df618b7696d32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala#L80) 2. [SizeInBytesOnlyStatsPlanVisitor.scala#L36](https://github.com/apache/spark/blob/76b8b840ddc951ee6203f9cccd2c2b9671c1b5e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala#L36) ## How was this patch tested? Exist tests Author: Yuming Wang Closes #21189 from wangyum/SPARK-24117. --- .../sql/catalyst/plans/logical/LocalRelation.scala | 3 ++- .../logical/statsEstimation/EstimationUtils.scala | 14 ++++++++------ .../SizeInBytesOnlyStatsPlanVisitor.scala | 4 ++-- .../spark/sql/execution/streaming/memory.scala | 10 ++++------ .../sql/execution/streaming/sources/memoryV2.scala | 3 ++- .../spark/sql/StatisticsCollectionSuite.scala | 2 +- .../sql/execution/streaming/MemorySinkSuite.scala | 4 ++-- 7 files changed, 21 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 720d42ab409a0..8c4828a4cef23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { @@ -77,7 +78,7 @@ case class LocalRelation( } override def computeStats(): Statistics = - Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) + Statistics(sizeInBytes = EstimationUtils.getSizePerRow(output) * data.length) def toSQL(inlineTableName: String): String = { require(data.nonEmpty) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 0f147f0ffb135..211a2a0717371 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.math.BigDecimal.RoundingMode @@ -25,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.{DecimalType, _} - object EstimationUtils { /** Check if each plan has rowCount in its statistics. */ @@ -73,13 +71,12 @@ object EstimationUtils { AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _))) } - def getOutputSize( + def getSizePerRow( attributes: Seq[Attribute], - outputRowCount: BigInt, attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = { // We assign a generic overhead for a Row object, the actual overhead is different for different // Row format. - val sizePerRow = 8 + attributes.map { attr => + 8 + attributes.map { attr => if (attrStats.get(attr).map(_.avgLen.isDefined).getOrElse(false)) { attr.dataType match { case StringType => @@ -92,10 +89,15 @@ object EstimationUtils { attr.dataType.defaultSize } }.sum + } + def getOutputSize( + attributes: Seq[Attribute], + outputRowCount: BigInt, + attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = { // Output size can't be zero, or sizeInBytes of BinaryNode will also be zero // (simple computation of statistics returns product of children). - if (outputRowCount > 0) outputRowCount * sizePerRow else 1 + if (outputRowCount > 0) outputRowCount * getSizePerRow(attributes, attrStats) else 1 } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala index 85f67c7d66075..ee43f9126386b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -33,8 +33,8 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { private def visitUnaryNode(p: UnaryNode): Statistics = { // There should be some overhead in Row object, the size should not be zero when there is // no columns, this help to prevent divide-by-zero error. - val childRowSize = p.child.output.map(_.dataType.defaultSize).sum + 8 - val outputRowSize = p.output.map(_.dataType.defaultSize).sum + 8 + val childRowSize = EstimationUtils.getSizePerRow(p.child.output) + val outputRowSize = EstimationUtils.getSizePerRow(p.output) // Assume there will be the same number of rows as child has. var sizeInBytes = (p.child.stats.sizeInBytes * outputRowSize) / childRowSize if (sizeInBytes == 0) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 22258274c70c1..6720cdd24b1b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -24,23 +24,21 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, ListBuffer} -import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.streaming.{OutputMode, Trigger} +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils - object MemoryStream { protected val currentBlockId = new AtomicInteger(0) protected val memoryStreamId = new AtomicInteger(0) @@ -307,7 +305,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode { def this(sink: MemorySink) = this(sink, sink.schema.toAttributes) - private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum + private val sizePerRow = EstimationUtils.getSizePerRow(sink.schema.toAttributes) override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 0d6c239274dd8..468313bfe8c3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} @@ -182,7 +183,7 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode) * Used to query the data that has been written into a [[MemorySinkV2]]. */ case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode { - private val sizePerRow = output.map(_.dataType.defaultSize).sum + private val sizePerRow = EstimationUtils.getSizePerRow(output) override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index b91712f4cc25d..60fa951e23178 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -50,7 +50,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}") - assert(sizes.head === BigInt(96), + assert(sizes.head === BigInt(128), s"expected exact size 96 for table 'test', got: ${sizes.head}") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index e8420eee7fe9d..3bc36ce55d902 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -220,11 +220,11 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { sink.addBatch(0, 1 to 3) plan.invalidateStatsCache() - assert(plan.stats.sizeInBytes === 12) + assert(plan.stats.sizeInBytes === 36) sink.addBatch(1, 4 to 6) plan.invalidateStatsCache() - assert(plan.stats.sizeInBytes === 24) + assert(plan.stats.sizeInBytes === 72) } ignore("stress test") { From ec7854a8504ec08485b3536ea71483cce46f9500 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Mon, 21 May 2018 19:30:10 +0200 Subject: [PATCH 761/774] re-raising StopIteration in client code --- python/pyspark/rdd.py | 35 ++++++++++++------ python/pyspark/shuffle.py | 19 ++++++++-- python/pyspark/sql/tests.py | 13 +++++++ python/pyspark/tests.py | 71 +++++++++++++++++++++++++++++++++++++ 4 files changed, 124 insertions(+), 14 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4b44f76747264..257b435ea7e77 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -49,7 +49,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, ExternalMerger, \ - get_used_memory, ExternalSorter, ExternalGroupBy + get_used_memory, ExternalSorter, ExternalGroupBy, safe_iter from pyspark.traceback_utils import SCCallSiteSync @@ -173,6 +173,7 @@ def ignore_unicode_prefix(f): return f + class Partitioner(object): def __init__(self, numPartitions, partitionFunc): self.numPartitions = numPartitions @@ -332,7 +333,7 @@ def map(self, f, preservesPartitioning=False): [('a', 1), ('b', 1), ('c', 1)] """ def func(_, iterator): - return map(f, iterator) + return map(safe_iter(f), iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -347,7 +348,7 @@ def flatMap(self, f, preservesPartitioning=False): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(s, iterator): - return chain.from_iterable(map(f, iterator)) + return chain.from_iterable(map(safe_iter(f), iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -410,7 +411,7 @@ def filter(self, f): [2, 4] """ def func(iterator): - return filter(f, iterator) + return filter(safe_iter(f), iterator) return self.mapPartitions(func, True) def distinct(self, numPartitions=None): @@ -791,9 +792,11 @@ def foreach(self, f): >>> def f(x): print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) """ + safe_f = safe_iter(f) + def processPartition(iterator): for x in iterator: - f(x) + safe_f(x) return iter([]) self.mapPartitions(processPartition).count() # Force evaluation @@ -840,13 +843,15 @@ def reduce(self, f): ... ValueError: Can not reduce() empty RDD """ + safe_f = safe_iter(f) + def func(iterator): iterator = iter(iterator) try: initial = next(iterator) except StopIteration: return - yield reduce(f, iterator, initial) + yield reduce(safe_f, iterator, initial) vals = self.mapPartitions(func).collect() if vals: @@ -911,10 +916,12 @@ def fold(self, zeroValue, op): >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) 15 """ + safe_op = safe_iter(op) + def func(iterator): acc = zeroValue for obj in iterator: - acc = op(acc, obj) + acc = safe_op(acc, obj) yield acc # collecting result of mapPartitions here ensures that the copy of # zeroValue provided to each partition is unique from the one provided @@ -943,16 +950,19 @@ def aggregate(self, zeroValue, seqOp, combOp): >>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp) (0, 0) """ + safe_seqOp = safe_iter(seqOp) + safe_combOp = safe_iter(combOp) + def func(iterator): acc = zeroValue for obj in iterator: - acc = seqOp(acc, obj) + acc = safe_seqOp(acc, obj) yield acc # collecting result of mapPartitions here ensures that the copy of # zeroValue provided to each partition is unique from the one provided # to the final reduce call vals = self.mapPartitions(func).collect() - return reduce(combOp, vals, zeroValue) + return reduce(safe_combOp, vals, zeroValue) def treeAggregate(self, zeroValue, seqOp, combOp, depth=2): """ @@ -1636,15 +1646,17 @@ def reduceByKeyLocally(self, func): >>> sorted(rdd.reduceByKeyLocally(add).items()) [('a', 2), ('b', 1)] """ + safe_func = safe_iter(func) + def reducePartition(iterator): m = {} for k, v in iterator: - m[k] = func(m[k], v) if k in m else v + m[k] = safe_func(m[k], v) if k in m else v yield m def mergeMaps(m1, m2): for k, v in m2.items(): - m1[k] = func(m1[k], v) if k in m1 else v + m1[k] = safe_func(m1[k], v) if k in m1 else v return m1 return self.mapPartitions(reducePartition).reduce(mergeMaps) @@ -1846,6 +1858,7 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, >>> sorted(x.combineByKey(to_list, append, extend).collect()) [('a', [1, 2]), ('b', [1])] """ + if numPartitions is None: numPartitions = self._defaultReducePartitions() diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 02c773302e9da..7445f66714f03 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -67,6 +67,19 @@ def get_used_memory(): return 0 +def safe_iter(f): + """ wraps f to make it safe (= does not lead to data loss) to use inside a for loop + make StopIteration's raised inside f explicit + """ + def wrapper(*args, **kwargs): + try: + return f(*args, **kwargs) + except StopIteration as exc: + raise RuntimeError('StopIteration in client code', exc) + + return wrapper + + def _get_local_dirs(sub): """ Get all the directories """ path = os.environ.get("SPARK_LOCAL_DIRS", "/tmp") @@ -94,9 +107,9 @@ class Aggregator(object): """ def __init__(self, createCombiner, mergeValue, mergeCombiners): - self.createCombiner = createCombiner - self.mergeValue = mergeValue - self.mergeCombiners = mergeCombiners + self.createCombiner = safe_iter(createCombiner) + self.mergeValue = safe_iter(mergeValue) + self.mergeCombiners = safe_iter(mergeCombiners) class SimpleAggregator(Aggregator): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 16aa9378ad8ee..f651f2b486ca4 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -900,6 +900,19 @@ def __call__(self, x): self.assertEqual(f, f_.func) self.assertEqual(return_type, f_.returnType) + def test_stopiteration_in_udf(self): + # test for SPARK-23754 + from pyspark.sql.functions import udf + from py4j.protocol import Py4JJavaError + + def foo(x): + raise StopIteration() + + with self.assertRaises(Py4JJavaError) as cm: + self.spark.range(0, 1000).withColumn('v', udf(foo)).show() + self.assertIn('StopIteration in client code', + cm.exception.java_exception.toString()) + def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json from pyspark.sql.column import _to_java_column diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 498d6b57e4353..14af8e1fef4bd 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -161,6 +161,46 @@ def gen_gs(N, step=1): self.assertEqual(k, len(vs)) self.assertEqual(list(range(k)), list(vs)) + def test_stopiteration_is_raised(self): + + def validate_exception(exc): + if isinstance(exc, RuntimeError): + self.assertEquals('StopIteration in client code', exc.args[0]) + else: + self.assertIn('StopIteration in client code', exc.java_exception.toString()) + + def stopit(*args, **kwargs): + raise StopIteration() + + def legit_create_combiner(x): + return [x] + + def legit_merge_value(x, y): + return x.append(y) or x + + def legit_merge_combiners(x, y): + return x.extend(y) or x + + data = [(x % 2, x) for x in range(100)] + + # wrong create combiner + m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + validate_exception(cm.exception) + + # wrong merge value + m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + validate_exception(cm.exception) + + # wrong merge combiners + m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) + validate_exception(cm.exception) + class SorterTests(unittest.TestCase): def test_in_memory_sort(self): @@ -1246,6 +1286,37 @@ def test_pipe_unicode(self): result = rdd.pipe('cat').collect() self.assertEqual(data, result) + def test_stopiteration_in_client_code(self): + + def a_rdd(keyed=False): + return self.sc.parallelize( + ((x % 2, x) if keyed else x) + for x in range(10) + ) + + def stopit(*x): + raise StopIteration() + + def do_test(action, *args, **kwargs): + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + action(*args, **kwargs) + if isinstance(cm.exception, RuntimeError): + self.assertEquals('StopIteration in client code', + cm.exception.args[0]) + else: + self.assertIn('StopIteration in client code', + cm.exception.java_exception.toString()) + + do_test(a_rdd().map(stopit).collect) + do_test(a_rdd().filter(stopit).collect) + do_test(a_rdd().cartesian(a_rdd()).flatMap(stopit).collect) + do_test(a_rdd().foreach, stopit) + do_test(a_rdd(keyed=True).reduceByKeyLocally, stopit) + do_test(a_rdd().reduce, stopit) + do_test(a_rdd().fold, 0, stopit) + do_test(a_rdd().aggregate, 0, stopit, lambda *x: 1) + do_test(a_rdd().aggregate, 0, lambda *x: 1, stopit) + class ProfilerTests(PySparkTestCase): From fddd031bbe4dda108739169f0a27eacae8f33099 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Tue, 22 May 2018 16:15:49 +0200 Subject: [PATCH 762/774] moved safe_iter to util module and more descriptive name --- python/pyspark/rdd.py | 22 +++++++++++----------- python/pyspark/shuffle.py | 20 ++++---------------- python/pyspark/util.py | 13 +++++++++++++ 3 files changed, 28 insertions(+), 27 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 257b435ea7e77..184fd1b60cd97 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -49,8 +49,9 @@ from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, ExternalMerger, \ - get_used_memory, ExternalSorter, ExternalGroupBy, safe_iter + get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync +from pyspark.util import fail_on_StopIteration __all__ = ["RDD"] @@ -173,7 +174,6 @@ def ignore_unicode_prefix(f): return f - class Partitioner(object): def __init__(self, numPartitions, partitionFunc): self.numPartitions = numPartitions @@ -333,7 +333,7 @@ def map(self, f, preservesPartitioning=False): [('a', 1), ('b', 1), ('c', 1)] """ def func(_, iterator): - return map(safe_iter(f), iterator) + return map(fail_on_StopIteration(f), iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -348,7 +348,7 @@ def flatMap(self, f, preservesPartitioning=False): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(s, iterator): - return chain.from_iterable(map(safe_iter(f), iterator)) + return chain.from_iterable(map(fail_on_StopIteration(f), iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -411,7 +411,7 @@ def filter(self, f): [2, 4] """ def func(iterator): - return filter(safe_iter(f), iterator) + return filter(fail_on_StopIteration(f), iterator) return self.mapPartitions(func, True) def distinct(self, numPartitions=None): @@ -792,7 +792,7 @@ def foreach(self, f): >>> def f(x): print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) """ - safe_f = safe_iter(f) + safe_f = fail_on_StopIteration(f) def processPartition(iterator): for x in iterator: @@ -843,7 +843,7 @@ def reduce(self, f): ... ValueError: Can not reduce() empty RDD """ - safe_f = safe_iter(f) + safe_f = fail_on_StopIteration(f) def func(iterator): iterator = iter(iterator) @@ -916,7 +916,7 @@ def fold(self, zeroValue, op): >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) 15 """ - safe_op = safe_iter(op) + safe_op = fail_on_StopIteration(op) def func(iterator): acc = zeroValue @@ -950,8 +950,8 @@ def aggregate(self, zeroValue, seqOp, combOp): >>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp) (0, 0) """ - safe_seqOp = safe_iter(seqOp) - safe_combOp = safe_iter(combOp) + safe_seqOp = fail_on_StopIteration(seqOp) + safe_combOp = fail_on_StopIteration(combOp) def func(iterator): acc = zeroValue @@ -1646,7 +1646,7 @@ def reduceByKeyLocally(self, func): >>> sorted(rdd.reduceByKeyLocally(add).items()) [('a', 2), ('b', 1)] """ - safe_func = safe_iter(func) + safe_func = fail_on_StopIteration(func) def reducePartition(iterator): m = {} diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 7445f66714f03..250c3233c976f 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -28,6 +28,7 @@ import pyspark.heapq3 as heapq from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ CompressedSerializer, AutoBatchedSerializer +from pyspark.util import fail_on_StopIteration try: @@ -67,19 +68,6 @@ def get_used_memory(): return 0 -def safe_iter(f): - """ wraps f to make it safe (= does not lead to data loss) to use inside a for loop - make StopIteration's raised inside f explicit - """ - def wrapper(*args, **kwargs): - try: - return f(*args, **kwargs) - except StopIteration as exc: - raise RuntimeError('StopIteration in client code', exc) - - return wrapper - - def _get_local_dirs(sub): """ Get all the directories """ path = os.environ.get("SPARK_LOCAL_DIRS", "/tmp") @@ -107,9 +95,9 @@ class Aggregator(object): """ def __init__(self, createCombiner, mergeValue, mergeCombiners): - self.createCombiner = safe_iter(createCombiner) - self.mergeValue = safe_iter(mergeValue) - self.mergeCombiners = safe_iter(mergeCombiners) + self.createCombiner = fail_on_StopIteration(createCombiner) + self.mergeValue = fail_on_StopIteration(mergeValue) + self.mergeCombiners = fail_on_StopIteration(mergeCombiners) class SimpleAggregator(Aggregator): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 59cc2a6329350..5807fde0812f8 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -89,6 +89,19 @@ def majorMinorVersion(sparkVersion): " version numbers.") +def fail_on_StopIteration(f): + """ wraps f to make it safe (= does not lead to data loss) to use inside a for loop + make StopIteration's raised inside f explicit + """ + def wrapper(*args, **kwargs): + try: + return f(*args, **kwargs) + except StopIteration as exc: + raise RuntimeError('StopIteration in client code', exc) + + return wrapper + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() From ee54924b9d23e616d432497c77e46671ad15ef88 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Tue, 22 May 2018 16:16:11 +0200 Subject: [PATCH 763/774] removed redundancy from tests --- python/pyspark/sql/tests.py | 2 -- python/pyspark/tests.py | 15 --------------- 2 files changed, 17 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f651f2b486ca4..f66423e05e9b6 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -910,8 +910,6 @@ def foo(x): with self.assertRaises(Py4JJavaError) as cm: self.spark.range(0, 1000).withColumn('v', udf(foo)).show() - self.assertIn('StopIteration in client code', - cm.exception.java_exception.toString()) def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 14af8e1fef4bd..383fdde59aad0 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -163,12 +163,6 @@ def gen_gs(N, step=1): def test_stopiteration_is_raised(self): - def validate_exception(exc): - if isinstance(exc, RuntimeError): - self.assertEquals('StopIteration in client code', exc.args[0]) - else: - self.assertIn('StopIteration in client code', exc.java_exception.toString()) - def stopit(*args, **kwargs): raise StopIteration() @@ -187,19 +181,16 @@ def legit_merge_combiners(x, y): m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: m.mergeValues(data) - validate_exception(cm.exception) # wrong merge value m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: m.mergeValues(data) - validate_exception(cm.exception) # wrong merge combiners m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) - validate_exception(cm.exception) class SorterTests(unittest.TestCase): @@ -1300,12 +1291,6 @@ def stopit(*x): def do_test(action, *args, **kwargs): with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: action(*args, **kwargs) - if isinstance(cm.exception, RuntimeError): - self.assertEquals('StopIteration in client code', - cm.exception.args[0]) - else: - self.assertIn('StopIteration in client code', - cm.exception.java_exception.toString()) do_test(a_rdd().map(stopit).collect) do_test(a_rdd().filter(stopit).collect) From d739eea9e8ed07dad9dd9b1a795ff21e8f915694 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Thu, 24 May 2018 15:16:58 +0200 Subject: [PATCH 764/774] improved doc, error message and code style --- python/pyspark/rdd.py | 35 +++++++++++++++++------------------ python/pyspark/shuffle.py | 8 ++++---- python/pyspark/util.py | 9 +++++---- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 184fd1b60cd97..ac127ac5d61c1 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -51,7 +51,7 @@ from pyspark.shuffle import Aggregator, ExternalMerger, \ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync -from pyspark.util import fail_on_StopIteration +from pyspark.util import fail_on_stopiteration __all__ = ["RDD"] @@ -333,7 +333,7 @@ def map(self, f, preservesPartitioning=False): [('a', 1), ('b', 1), ('c', 1)] """ def func(_, iterator): - return map(fail_on_StopIteration(f), iterator) + return map(fail_on_stopiteration(f), iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -348,7 +348,7 @@ def flatMap(self, f, preservesPartitioning=False): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(s, iterator): - return chain.from_iterable(map(fail_on_StopIteration(f), iterator)) + return chain.from_iterable(map(fail_on_stopiteration(f), iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -411,7 +411,7 @@ def filter(self, f): [2, 4] """ def func(iterator): - return filter(fail_on_StopIteration(f), iterator) + return filter(fail_on_stopiteration(f), iterator) return self.mapPartitions(func, True) def distinct(self, numPartitions=None): @@ -792,11 +792,11 @@ def foreach(self, f): >>> def f(x): print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) """ - safe_f = fail_on_StopIteration(f) + f = fail_on_stopiteration(f) def processPartition(iterator): for x in iterator: - safe_f(x) + f(x) return iter([]) self.mapPartitions(processPartition).count() # Force evaluation @@ -843,7 +843,7 @@ def reduce(self, f): ... ValueError: Can not reduce() empty RDD """ - safe_f = fail_on_StopIteration(f) + f = fail_on_stopiteration(f) def func(iterator): iterator = iter(iterator) @@ -851,7 +851,7 @@ def func(iterator): initial = next(iterator) except StopIteration: return - yield reduce(safe_f, iterator, initial) + yield reduce(f, iterator, initial) vals = self.mapPartitions(func).collect() if vals: @@ -916,12 +916,12 @@ def fold(self, zeroValue, op): >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) 15 """ - safe_op = fail_on_StopIteration(op) + op = fail_on_stopiteration(op) def func(iterator): acc = zeroValue for obj in iterator: - acc = safe_op(acc, obj) + acc = op(acc, obj) yield acc # collecting result of mapPartitions here ensures that the copy of # zeroValue provided to each partition is unique from the one provided @@ -950,19 +950,19 @@ def aggregate(self, zeroValue, seqOp, combOp): >>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp) (0, 0) """ - safe_seqOp = fail_on_StopIteration(seqOp) - safe_combOp = fail_on_StopIteration(combOp) + seqOp = fail_on_stopiteration(seqOp) + combOp = fail_on_stopiteration(combOp) def func(iterator): acc = zeroValue for obj in iterator: - acc = safe_seqOp(acc, obj) + acc = seqOp(acc, obj) yield acc # collecting result of mapPartitions here ensures that the copy of # zeroValue provided to each partition is unique from the one provided # to the final reduce call vals = self.mapPartitions(func).collect() - return reduce(safe_combOp, vals, zeroValue) + return reduce(combOp, vals, zeroValue) def treeAggregate(self, zeroValue, seqOp, combOp, depth=2): """ @@ -1646,17 +1646,17 @@ def reduceByKeyLocally(self, func): >>> sorted(rdd.reduceByKeyLocally(add).items()) [('a', 2), ('b', 1)] """ - safe_func = fail_on_StopIteration(func) + func = fail_on_stopiteration(func) def reducePartition(iterator): m = {} for k, v in iterator: - m[k] = safe_func(m[k], v) if k in m else v + m[k] = func(m[k], v) if k in m else v yield m def mergeMaps(m1, m2): for k, v in m2.items(): - m1[k] = safe_func(m1[k], v) if k in m1 else v + m1[k] = func(m1[k], v) if k in m1 else v return m1 return self.mapPartitions(reducePartition).reduce(mergeMaps) @@ -1858,7 +1858,6 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, >>> sorted(x.combineByKey(to_list, append, extend).collect()) [('a', [1, 2]), ('b', [1])] """ - if numPartitions is None: numPartitions = self._defaultReducePartitions() diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 250c3233c976f..bd0ac0039ffe1 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -28,7 +28,7 @@ import pyspark.heapq3 as heapq from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ CompressedSerializer, AutoBatchedSerializer -from pyspark.util import fail_on_StopIteration +from pyspark.util import fail_on_stopiteration try: @@ -95,9 +95,9 @@ class Aggregator(object): """ def __init__(self, createCombiner, mergeValue, mergeCombiners): - self.createCombiner = fail_on_StopIteration(createCombiner) - self.mergeValue = fail_on_StopIteration(mergeValue) - self.mergeCombiners = fail_on_StopIteration(mergeCombiners) + self.createCombiner = fail_on_stopiteration(createCombiner) + self.mergeValue = fail_on_stopiteration(mergeValue) + self.mergeCombiners = fail_on_stopiteration(mergeCombiners) class SimpleAggregator(Aggregator): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 5807fde0812f8..e77a40e8b808f 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -89,15 +89,16 @@ def majorMinorVersion(sparkVersion): " version numbers.") -def fail_on_StopIteration(f): - """ wraps f to make it safe (= does not lead to data loss) to use inside a for loop - make StopIteration's raised inside f explicit +def fail_on_stopiteration(f): + """ + Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' + prevents silent loss of data when 'f' is used in a for loop """ def wrapper(*args, **kwargs): try: return f(*args, **kwargs) except StopIteration as exc: - raise RuntimeError('StopIteration in client code', exc) + raise RuntimeError("Caught StopIteration thrown from user's code; failing the task", exc) return wrapper From f0f80ed1b8333bbab841a59f151deff18bc73447 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Thu, 24 May 2018 15:17:46 +0200 Subject: [PATCH 765/774] improved tests --- python/pyspark/sql/tests.py | 2 +- python/pyspark/tests.py | 32 +++++++++++++------------------- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f66423e05e9b6..fb592989ce4b0 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -908,7 +908,7 @@ def test_stopiteration_in_udf(self): def foo(x): raise StopIteration() - with self.assertRaises(Py4JJavaError) as cm: + with self.assertRaises(Py4JJavaError): self.spark.range(0, 1000).withColumn('v', udf(foo)).show() def test_validate_column_types(self): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 383fdde59aad0..18f88cee89e1f 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1279,28 +1279,22 @@ def test_pipe_unicode(self): def test_stopiteration_in_client_code(self): - def a_rdd(keyed=False): - return self.sc.parallelize( - ((x % 2, x) if keyed else x) - for x in range(10) - ) - def stopit(*x): raise StopIteration() - def do_test(action, *args, **kwargs): - with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: - action(*args, **kwargs) - - do_test(a_rdd().map(stopit).collect) - do_test(a_rdd().filter(stopit).collect) - do_test(a_rdd().cartesian(a_rdd()).flatMap(stopit).collect) - do_test(a_rdd().foreach, stopit) - do_test(a_rdd(keyed=True).reduceByKeyLocally, stopit) - do_test(a_rdd().reduce, stopit) - do_test(a_rdd().fold, 0, stopit) - do_test(a_rdd().aggregate, 0, stopit, lambda *x: 1) - do_test(a_rdd().aggregate, 0, lambda *x: 1, stopit) + seq_rdd = self.sc.parallelize(range(10)) + keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) + exc = Py4JJavaError, RuntimeError + + self.assertRaises(exc, seq_rdd.map(stopit).collect) + self.assertRaises(exc, seq_rdd.filter(stopit).collect) + self.assertRaises(exc, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) + self.assertRaises(exc, seq_rdd.foreach, stopit) + self.assertRaises(exc, keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaises(exc, seq_rdd.reduce, stopit) + self.assertRaises(exc, seq_rdd.fold, 0, stopit) + self.assertRaises(exc, seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaises(exc, seq_rdd.aggregate, 0, lambda *x: 1, stopit) class ProfilerTests(PySparkTestCase): From d59f0d5a2735713bb7e218cfcda2b494edfcf522 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Thu, 24 May 2018 15:37:53 +0200 Subject: [PATCH 766/774] fixed style --- python/pyspark/util.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index e77a40e8b808f..938e729260bba 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -98,7 +98,10 @@ def wrapper(*args, **kwargs): try: return f(*args, **kwargs) except StopIteration as exc: - raise RuntimeError("Caught StopIteration thrown from user's code; failing the task", exc) + raise RuntimeError( + "Caught StopIteration thrown from user's code; failing the task", + exc + ) return wrapper From b0af18e400c01095dd87589260ce80e9712a9f07 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Thu, 24 May 2018 18:44:38 +0200 Subject: [PATCH 767/774] fixed udf and its test --- python/pyspark/sql/tests.py | 2 +- python/pyspark/sql/udf.py | 4 ++-- python/pyspark/util.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fb592989ce4b0..53d6dff9eb1c4 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -909,7 +909,7 @@ def foo(x): raise StopIteration() with self.assertRaises(Py4JJavaError): - self.spark.range(0, 1000).withColumn('v', udf(foo)).show() + self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show() def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 9dbe49b831cef..f41e307d6992e 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -25,7 +25,7 @@ from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\ to_arrow_type, to_arrow_schema -from pyspark.util import _get_argspec +from pyspark.util import _get_argspec, fail_on_stopiteration __all__ = ["UDFRegistration"] @@ -92,7 +92,7 @@ def __init__(self, func, raise TypeError( "Invalid evalType: evalType should be an int but is {}".format(evalType)) - self.func = func + self.func = fail_on_stopiteration(func) self._returnType = returnType # Stores UserDefinedPythonFunctions jobj, once initialized self._returnType_placeholder = None diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 938e729260bba..fa1b1c2da0b21 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -91,8 +91,8 @@ def majorMinorVersion(sparkVersion): def fail_on_stopiteration(f): """ - Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' - prevents silent loss of data when 'f' is used in a for loop + Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' + prevents silent loss of data when 'f' is used in a for loop """ def wrapper(*args, **kwargs): try: From 167a75b81599e176f851daa2566b359f72264f61 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Thu, 24 May 2018 20:18:40 +0200 Subject: [PATCH 768/774] preserving metadata of wrapped function --- python/pyspark/sql/udf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index f41e307d6992e..b35b4bbdb5812 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -92,7 +92,7 @@ def __init__(self, func, raise TypeError( "Invalid evalType: evalType should be an int but is {}".format(evalType)) - self.func = fail_on_stopiteration(func) + self.func = func self._returnType = returnType # Stores UserDefinedPythonFunctions jobj, once initialized self._returnType_placeholder = None @@ -157,7 +157,7 @@ def _create_judf(self): spark = SparkSession.builder.getOrCreate() sc = spark.sparkContext - wrapped_func = _wrap_function(sc, self.func, self.returnType) + wrapped_func = _wrap_function(sc, fail_on_stopiteration(self.func), self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( self._name, wrapped_func, jdt, self.evalType, self.deterministic) From 90b064ddd2db562e90dd55846f1331779e795460 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Thu, 24 May 2018 20:19:16 +0200 Subject: [PATCH 769/774] catching relevant exceptions only --- python/pyspark/tests.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 18f88cee89e1f..3b37cc028c1b7 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1284,17 +1284,20 @@ def stopit(*x): seq_rdd = self.sc.parallelize(range(10)) keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) - exc = Py4JJavaError, RuntimeError - - self.assertRaises(exc, seq_rdd.map(stopit).collect) - self.assertRaises(exc, seq_rdd.filter(stopit).collect) - self.assertRaises(exc, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) - self.assertRaises(exc, seq_rdd.foreach, stopit) - self.assertRaises(exc, keyed_rdd.reduceByKeyLocally, stopit) - self.assertRaises(exc, seq_rdd.reduce, stopit) - self.assertRaises(exc, seq_rdd.fold, 0, stopit) - self.assertRaises(exc, seq_rdd.aggregate, 0, stopit, lambda *x: 1) - self.assertRaises(exc, seq_rdd.aggregate, 0, lambda *x: 1, stopit) + + self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit) + self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit) + self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit) + + # the exception raised is non-deterministic + self.assertRaises((Py4JJavaError, RuntimeError), + seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaises((Py4JJavaError, RuntimeError), + seq_rdd.aggregate, 0, lambda *x: 1, stopit) class ProfilerTests(PySparkTestCase): From 75316af5f366d9c0386a9396fc981a9294541cb0 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Sat, 26 May 2018 16:56:01 +0200 Subject: [PATCH 770/774] preserving argspecs of wrapped function --- python/pyspark/sql/tests.py | 8 +++++++- python/pyspark/util.py | 16 +++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 53d6dff9eb1c4..e5b8fde4cd94f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -908,9 +908,15 @@ def test_stopiteration_in_udf(self): def foo(x): raise StopIteration() - with self.assertRaises(Py4JJavaError): + with self.assertRaises(Py4JJavaError) as cm: self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show() + self.assertIn( + "Caught StopIteration thrown from user's code; failing the task", + cm.exception.java_exception.toString() + ) + + def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json from pyspark.sql.column import _to_java_column diff --git a/python/pyspark/util.py b/python/pyspark/util.py index fa1b1c2da0b21..178dd94034414 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -23,6 +23,8 @@ __all__ = [] +WRAPPED_ARGSPEC_ATTR = '_wrapped_argspec' + def _exception_message(excp): """Return the message from an exception as either a str or unicode object. Supports both @@ -55,7 +57,9 @@ def _get_argspec(f): """ # `getargspec` is deprecated since python3.0 (incompatible with function annotations). # See SPARK-23569. - if sys.version_info[0] < 3: + if hasattr(f, WRAPPED_ARGSPEC_ATTR): + argspec = getattr(f, WRAPPED_ARGSPEC_ATTR) + elif sys.version_info[0] < 3: argspec = inspect.getargspec(f) else: argspec = inspect.getfullargspec(f) @@ -103,6 +107,16 @@ def wrapper(*args, **kwargs): exc ) + # prevent inspect to fail + # e.g. inspect.getargspec(sum) raises + # TypeError: is not a Python function + try: + argspec = _get_argspec(f) + except TypeError: + pass + else: + setattr(wrapper, WRAPPED_ARGSPEC_ATTR, _get_argspec(f)) + return wrapper From 026ecddacb847d624cd53150e82c011b6befafc0 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Sat, 26 May 2018 17:23:32 +0200 Subject: [PATCH 771/774] style --- python/pyspark/sql/tests.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e5b8fde4cd94f..1f91d2c181685 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -916,7 +916,6 @@ def foo(x): cm.exception.java_exception.toString() ) - def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json from pyspark.sql.column import _to_java_column From f7b53c222e4341b59d3588017718d80ecb37a473 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Tue, 29 May 2018 09:43:08 +0200 Subject: [PATCH 772/774] saving argspec in udf --- python/pyspark/sql/udf.py | 12 +++++++++++- python/pyspark/util.py | 16 ++-------------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index b35b4bbdb5812..fb71fa1f60f14 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -157,7 +157,17 @@ def _create_judf(self): spark = SparkSession.builder.getOrCreate() sc = spark.sparkContext - wrapped_func = _wrap_function(sc, fail_on_stopiteration(self.func), self.returnType) + func = fail_on_stopiteration(self.func) + + # prevent inspect to fail + # e.g. inspect.getargspec(sum) raises + # TypeError: is not a Python function + try: + func._argspec = _get_argspec(self.func) + except TypeError: + pass + + wrapped_func = _wrap_function(sc, func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( self._name, wrapped_func, jdt, self.evalType, self.deterministic) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 178dd94034414..8b646d573f207 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -23,8 +23,6 @@ __all__ = [] -WRAPPED_ARGSPEC_ATTR = '_wrapped_argspec' - def _exception_message(excp): """Return the message from an exception as either a str or unicode object. Supports both @@ -57,8 +55,8 @@ def _get_argspec(f): """ # `getargspec` is deprecated since python3.0 (incompatible with function annotations). # See SPARK-23569. - if hasattr(f, WRAPPED_ARGSPEC_ATTR): - argspec = getattr(f, WRAPPED_ARGSPEC_ATTR) + if hasattr(f, '_argspec'): + argspec = f._argspec elif sys.version_info[0] < 3: argspec = inspect.getargspec(f) else: @@ -107,16 +105,6 @@ def wrapper(*args, **kwargs): exc ) - # prevent inspect to fail - # e.g. inspect.getargspec(sum) raises - # TypeError: is not a Python function - try: - argspec = _get_argspec(f) - except TypeError: - pass - else: - setattr(wrapper, WRAPPED_ARGSPEC_ATTR, _get_argspec(f)) - return wrapper From 8fac2a80deb79030dee161e0d86b7b090bc892a7 Mon Sep 17 00:00:00 2001 From: edorigatti Date: Tue, 29 May 2018 17:02:30 +0200 Subject: [PATCH 773/774] saving signature only for pandas udf, removed useless try/except --- python/pyspark/sql/udf.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index fb71fa1f60f14..c8fb49d7c2b65 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -159,13 +159,13 @@ def _create_judf(self): func = fail_on_stopiteration(self.func) - # prevent inspect to fail - # e.g. inspect.getargspec(sum) raises - # TypeError: is not a Python function - try: + # for pandas UDFs the worker needs to know if the function takes + # one or two arguments, but the signature is lost when wrapping with + # fail_on_stopiteration, so we store it here + if self.evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): func._argspec = _get_argspec(self.func) - except TypeError: - pass wrapped_func = _wrap_function(sc, func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) From 5b5570b7d4a4e71d470dbb9e763b50a948d4195c Mon Sep 17 00:00:00 2001 From: edorigatti Date: Wed, 30 May 2018 10:30:09 +0200 Subject: [PATCH 774/774] comment explaining hack --- python/pyspark/util.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 8b646d573f207..e95a9b523393f 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -53,13 +53,16 @@ def _get_argspec(f): """ Get argspec of a function. Supports both Python 2 and Python 3. """ - # `getargspec` is deprecated since python3.0 (incompatible with function annotations). - # See SPARK-23569. + if hasattr(f, '_argspec'): + # only used for pandas UDF: they wrap the user function, losing its signature + # workers need this signature, so UDF saves it here argspec = f._argspec elif sys.version_info[0] < 3: argspec = inspect.getargspec(f) else: + # `getargspec` is deprecated since python3.0 (incompatible with function annotations). + # See SPARK-23569. argspec = inspect.getfullargspec(f) return argspec
    `spark.mllib` modelPMML model
    spark.mllib modelPMML model
    This configuration limits the number of remote requests to fetch blocks at any given point. When the number of hosts in the cluster increase, it might lead to very large number - of in-bound connections to one or more nodes, causing the workers to fail under load. + of inbound connections to one or more nodes, causing the workers to fail under load. By allowing it to limit the number of fetch requests, this scenario can be mitigated.
    4194304 (4 MB) The estimated cost to open a file, measured by the number of bytes could be scanned at the same - time. This is used when putting multiple files into a partition. It is better to over estimate, + time. This is used when putting multiple files into a partition. It is better to overestimate, then the partitions with small files will be faster than partitions with bigger files.
    0.8 for KUBERNETES mode; 0.8 for YARN mode; 0.0 for standalone mode and Mesos coarse-grained mode The minimum ratio of registered resources (registered resources / total expected resources) - (resources are executors in yarn mode and Kubernetes mode, CPU cores in standalone mode and Mesos coarsed-grained + (resources are executors in yarn mode and Kubernetes mode, CPU cores in standalone mode and Mesos coarse-grained mode ['spark.cores.max' value is total expected resources for Mesos coarse-grained mode] ) to wait for before scheduling begins. Specified as a double between 0.0 and 1.0. Regardless of whether the minimum ratio of resources has been reached, @@ -1634,7 +1634,7 @@ Apart from these, the following properties are also available, and may be useful false (Experimental) If set to "true", Spark will blacklist the executor immediately when a fetch - failure happenes. If external shuffle service is enabled, then the whole node will be + failure happens. If external shuffle service is enabled, then the whole node will be blacklisted.
    spark.streaming.receiver.writeAheadLog.enable false - Enable write ahead logs for receivers. All the input data received through receivers - will be saved to write ahead logs that will allow it to be recovered after driver failures. + Enable write-ahead logs for receivers. All the input data received through receivers + will be saved to write-ahead logs that will allow it to be recovered after driver failures. See the deployment guide in the Spark Streaming programing guide for more details. spark.streaming.driver.writeAheadLog.closeFileAfterWrite false - Whether to close the file after writing a write ahead log record on the driver. Set this to 'true' + Whether to close the file after writing a write-ahead log record on the driver. Set this to 'true' when you want to use S3 (or any file system that does not support flushing) for the metadata WAL on the driver. spark.streaming.receiver.writeAheadLog.closeFileAfterWrite false - Whether to close the file after writing a write ahead log record on the receivers. Set this to 'true' + Whether to close the file after writing a write-ahead log record on the receivers. Set this to 'true' when you want to use S3 (or any file system that does not support flushing) for the data WAL on the receivers. spark.mesos.constraints (none) - Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. This setting + Attribute-based constraints on mesos resource offers. By default, all resource offers will be accepted. This setting applies only to executors. Refer to Mesos Attributes & Resources for more information on attributes.
      diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index e07759a4dba87..ceda8a3ae2403 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -418,7 +418,7 @@ To use a custom metrics.properties for the application master and executors, upd - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. - In `cluster` mode, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `client` mode, only the Spark executors do. -- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. +- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example, you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. - The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. # Kerberos diff --git a/docs/security.md b/docs/security.md index 3e5607a9a0d67..8c0c66fb5a285 100644 --- a/docs/security.md +++ b/docs/security.md @@ -374,7 +374,7 @@ replaced with one of the above namespaces.
    ${ns}.enabledAlgorithms None - A comma separated list of ciphers. The specified ciphers must be supported by JVM. + A comma-separated list of ciphers. The specified ciphers must be supported by JVM.
    The reference list of protocols can be found in the "JSSE Cipher Suite Names" section of the Java security guide. The list for Java 8 can be found at diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 8fa643abf1373..f06e72a387df1 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -338,7 +338,7 @@ worker during one single schedule iteration. # Monitoring and Logging -Spark's standalone mode offers a web-based user interface to monitor the cluster. The master and each worker has its own web UI that shows cluster and job statistics. By default you can access the web UI for the master at port 8080. The port can be changed either in the configuration file or via command-line options. +Spark's standalone mode offers a web-based user interface to monitor the cluster. The master and each worker has its own web UI that shows cluster and job statistics. By default, you can access the web UI for the master at port 8080. The port can be changed either in the configuration file or via command-line options. In addition, detailed log output for each job is also written to the work directory of each slave node (`SPARK_HOME/work` by default). You will see two files for each job, `stdout` and `stderr`, with all output it wrote to its console. diff --git a/docs/sparkr.md b/docs/sparkr.md index 2909247e79e95..7fabab5d38f16 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -107,7 +107,7 @@ The following Spark driver properties can be set in `sparkConfig` with `sparkR.s With a `SparkSession`, applications can create `SparkDataFrame`s from a local R data frame, from a [Hive table](sql-programming-guide.html#hive-tables), or from other [data sources](sql-programming-guide.html#data-sources). ### From local data frames -The simplest way to create a data frame is to convert a local R data frame into a SparkDataFrame. Specifically we can use `as.DataFrame` or `createDataFrame` and pass in the local R data frame to create a SparkDataFrame. As an example, the following creates a `SparkDataFrame` based using the `faithful` dataset from R. +The simplest way to create a data frame is to convert a local R data frame into a SparkDataFrame. Specifically, we can use `as.DataFrame` or `createDataFrame` and pass in the local R data frame to create a SparkDataFrame. As an example, the following creates a `SparkDataFrame` based using the `faithful` dataset from R.
    {% highlight r %} @@ -169,7 +169,7 @@ df <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.strings {% endhighlight %}
    -The data sources API can also be used to save out SparkDataFrames into multiple file formats. For example we can save the SparkDataFrame from the previous example +The data sources API can also be used to save out SparkDataFrames into multiple file formats. For example, we can save the SparkDataFrame from the previous example to a Parquet file using `write.df`.
    @@ -241,7 +241,7 @@ head(filter(df, df$waiting < 50)) ### Grouping, Aggregation -SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below +SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example, we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below
    {% highlight r %} diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 9822d669050d5..55d35b9dd31db 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -165,7 +165,7 @@ In addition to simple column references and expressions, Datasets also have a ri
    -In Python it's possible to access a DataFrame's columns either by attribute +In Python, it's possible to access a DataFrame's columns either by attribute (`df.age`) or by indexing (`df['age']`). While the former is convenient for interactive data exploration, users are highly encouraged to use the latter form, which is future proof and won't break with column names that @@ -278,7 +278,7 @@ the bytes back into an object. Spark SQL supports two different methods for converting existing RDDs into Datasets. The first method uses reflection to infer the schema of an RDD that contains specific types of objects. This -reflection based approach leads to more concise code and works well when you already know the schema +reflection-based approach leads to more concise code and works well when you already know the schema while writing your Spark application. The second method for creating Datasets is through a programmatic interface that allows you to @@ -1243,7 +1243,7 @@ The following options can be used to configure the version of Hive that is used
    com.mysql.jdbc,
    org.postgresql,
    com.microsoft.sqlserver,
    oracle.jdbc

    - A comma separated list of class prefixes that should be loaded using the classloader that is + A comma-separated list of class prefixes that should be loaded using the classloader that is shared between Spark SQL and a specific version of Hive. An example of classes that should be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need to be shared are those that interact with classes that are already shared. For example, @@ -1441,7 +1441,7 @@ SELECT * FROM resultTable # Performance Tuning -For some workloads it is possible to improve performance by either caching data in memory, or by +For some workloads, it is possible to improve performance by either caching data in memory, or by turning on some experimental options. ## Caching Data In Memory @@ -1804,7 +1804,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.3 to 2.4 - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. - - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. + - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unable to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. @@ -1966,11 +1966,11 @@ working with timestamps in `pandas_udf`s to get the best performance, see - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. All the arithmetic operations are affected by the change, ie. addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`) and positive module (`pmod`). - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. - - In PySpark, `df.replace` does not allow to omit `value` when `to_replace` is not a dictionary. Previously, `value` could be omitted in the other cases and had `None` by default, which is counterintuitive and error prone. + - In PySpark, `df.replace` does not allow to omit `value` when `to_replace` is not a dictionary. Previously, `value` could be omitted in the other cases and had `None` by default, which is counterintuitive and error-prone. ## Upgrading From Spark SQL 2.1 to 2.2 - - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. + - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time-consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. - Since Spark 2.2.1 and 2.3.0, the schema is always inferred at runtime when the data source tables have the columns that exist in both partition schema and data schema. The inferred schema does not have the partitioned columns. When reading the table, Spark respects the partition values of these overlapping columns instead of the values stored in the data source files. In 2.2.0 and 2.1.x release, the inferred schema is partitioned but the data of the table is invisible to users (i.e., the result set is empty). @@ -2013,7 +2013,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 1.5 to 1.6 - - From Spark 1.6, by default the Thrift server runs in multi-session mode. Which means each JDBC/ODBC + - From Spark 1.6, by default, the Thrift server runs in multi-session mode. Which means each JDBC/ODBC connection owns a copy of their own SQL configuration and temporary function registry. Cached tables are still shared though. If you prefer to run the Thrift server in the old single-session mode, please set option `spark.sql.hive.thriftServer.singleSession` to `true`. You may either add @@ -2161,7 +2161,7 @@ been renamed to `DataFrame`. This is primarily because DataFrames no longer inhe directly, but instead provide most of the functionality that RDDs provide though their own implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method. -In Scala there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for +In Scala, there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for some use cases. It is still recommended that users update their code to use `DataFrame` instead. Java and Python users will need to update their code. @@ -2170,11 +2170,11 @@ Java and Python users will need to update their code. Prior to Spark 1.3 there were separate Java compatible classes (`JavaSQLContext` and `JavaSchemaRDD`) that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users of either language should use `SQLContext` and `DataFrame`. In general these classes try to -use types that are usable from both languages (i.e. `Array` instead of language specific collections). +use types that are usable from both languages (i.e. `Array` instead of language-specific collections). In some cases where no common type exists (e.g., for passing in closures or Maps) function overloading is used instead. -Additionally the Java specific types API has been removed. Users of both Scala and Java should +Additionally, the Java specific types API has been removed. Users of both Scala and Java should use the classes present in `org.apache.spark.sql.types` to describe schema programmatically. @@ -2231,7 +2231,7 @@ referencing a singleton. ## Compatibility with Apache Hive Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. -Currently Hive SerDes and UDFs are based on Hive 1.2.1, +Currently, Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore (from 0.12.0 to 2.3.2. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). @@ -2323,10 +2323,10 @@ A handful of Hive optimizations are not yet included in Spark. Some of these (su less important due to Spark SQL's in-memory computational model. Others are slotted for future releases of Spark SQL. -* Block level bitmap indexes and virtual columns (used to build indexes) -* Automatically determine the number of reducers for joins and groupbys: Currently in Spark SQL, you +* Block-level bitmap indexes and virtual columns (used to build indexes) +* Automatically determine the number of reducers for joins and groupbys: Currently, in Spark SQL, you need to control the degree of parallelism post-shuffle using "`SET spark.sql.shuffle.partitions=[num_tasks];`". -* Meta-data only query: For queries that can be answered by using only meta data, Spark SQL still +* Meta-data only query: For queries that can be answered by using only metadata, Spark SQL still launches tasks to compute the result. * Skew data flag: Spark SQL does not follow the skew data flags in Hive. * `STREAMTABLE` hint in join: Spark SQL does not follow the `STREAMTABLE` hint. @@ -2983,6 +2983,6 @@ does not exactly match standard floating point semantics. Specifically: - NaN = NaN returns true. - - In aggregations all NaN values are grouped together. + - In aggregations, all NaN values are grouped together. - NaN is treated as a normal value in join keys. - NaN values go last when in ascending order, larger than any other numeric value. diff --git a/docs/storage-openstack-swift.md b/docs/storage-openstack-swift.md index 1dd54719b21aa..dacaa3438d489 100644 --- a/docs/storage-openstack-swift.md +++ b/docs/storage-openstack-swift.md @@ -39,7 +39,7 @@ For example, for Maven support, add the following to the pom.xml fi # Configuration Parameters Create core-site.xml and place it inside Spark's conf directory. -The main category of parameters that should be configured are the authentication parameters +The main category of parameters that should be configured is the authentication parameters required by Keystone. The following table contains a list of Keystone mandatory parameters. PROVIDER can be diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index 257a4f7d4f3ca..a1b6942ffe0a4 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -17,7 +17,7 @@ Choose a machine in your cluster such that - Flume can be configured to push data to a port on that machine. -Due to the push model, the streaming application needs to be up, with the receiver scheduled and listening on the chosen port, for Flume to be able push data. +Due to the push model, the streaming application needs to be up, with the receiver scheduled and listening on the chosen port, for Flume to be able to push data. #### Configuring Flume Configure Flume agent to send data to an Avro sink by having the following in the configuration file. @@ -100,7 +100,7 @@ Choose a machine that will run the custom sink in a Flume agent. The rest of the #### Configuring Flume Configuring Flume on the chosen machine requires the following two steps. -1. **Sink JARs**: Add the following JARs to Flume's classpath (see [Flume's documentation](https://flume.apache.org/documentation.html) to see how) in the machine designated to run the custom sink . +1. **Sink JARs**: Add the following JARs to Flume's classpath (see [Flume's documentation](https://flume.apache.org/documentation.html) to see how) in the machine designated to run the custom sink. (i) *Custom sink JAR*: Download the JAR corresponding to the following artifact (or [direct link](http://search.maven.org/remotecontent?filepath=org/apache/spark/spark-streaming-flume-sink_{{site.SCALA_BINARY_VERSION}}/{{site.SPARK_VERSION_SHORT}}/spark-streaming-flume-sink_{{site.SCALA_BINARY_VERSION}}-{{site.SPARK_VERSION_SHORT}}.jar)). @@ -128,7 +128,7 @@ Configuring Flume on the chosen machine requires the following two steps. agent.sinks.spark.port = agent.sinks.spark.channel = memoryChannel - Also make sure that the upstream Flume pipeline is configured to send the data to the Flume agent running this sink. + Also, make sure that the upstream Flume pipeline is configured to send the data to the Flume agent running this sink. See the [Flume's documentation](https://flume.apache.org/documentation.html) for more information about configuring Flume agents. diff --git a/docs/streaming-kafka-0-8-integration.md b/docs/streaming-kafka-0-8-integration.md index 9f0671da2ee31..becf217738d26 100644 --- a/docs/streaming-kafka-0-8-integration.md +++ b/docs/streaming-kafka-0-8-integration.md @@ -10,7 +10,7 @@ Here we explain how to configure Spark Streaming to receive data from Kafka. The ## Approach 1: Receiver-based Approach This approach uses a Receiver to receive the data. The Receiver is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data. -However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. +However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write-Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write-ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write-Ahead Logs. Next, we discuss how to use this approach in your streaming application. @@ -55,11 +55,11 @@ Next, we discuss how to use this approach in your streaming application. **Points to remember:** - - Topic partitions in Kafka does not correlate to partitions of RDDs generated in Spark Streaming. So increasing the number of topic-specific partitions in the `KafkaUtils.createStream()` only increases the number of threads using which topics that are consumed within a single receiver. It does not increase the parallelism of Spark in processing the data. Refer to the main document for more information on that. + - Topic partitions in Kafka do not correlate to partitions of RDDs generated in Spark Streaming. So increasing the number of topic-specific partitions in the `KafkaUtils.createStream()` only increases the number of threads using which topics that are consumed within a single receiver. It does not increase the parallelism of Spark in processing the data. Refer to the main document for more information on that. - Multiple Kafka input DStreams can be created with different groups and topics for parallel receiving of data using multiple receivers. - - If you have enabled Write Ahead Logs with a replicated file system like HDFS, the received data is already being replicated in the log. Hence, the storage level in storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER` (that is, use + - If you have enabled Write-Ahead Logs with a replicated file system like HDFS, the received data is already being replicated in the log. Hence, the storage level in storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER` (that is, use `KafkaUtils.createStream(..., StorageLevel.MEMORY_AND_DISK_SER)`). 3. **Deploying:** As with any Spark applications, `spark-submit` is used to launch your application. However, the details are slightly different for Scala/Java applications and Python applications. @@ -80,9 +80,9 @@ This approach has the following advantages over the receiver-based approach (i.e - *Simplified Parallelism:* No need to create multiple input Kafka streams and union them. With `directStream`, Spark Streaming will create as many RDD partitions as there are Kafka partitions to consume, which will all read data from Kafka in parallel. So there is a one-to-one mapping between Kafka and RDD partitions, which is easier to understand and tune. -- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. +- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write-Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write-Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write-Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. -- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semantics of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). +- *Exactly-once semantics:* The first approach uses Kafka's high-level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with-write-ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semantics of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). Note that one disadvantage of this approach is that it does not update offsets in Zookeeper, hence Zookeeper-based Kafka monitoring tools will not show progress. However, you can access the offsets processed by this approach in each batch and update Zookeeper yourself (see below). diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index ffda36d64a770..c30959263cdfa 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1461,7 +1461,7 @@ Note that the connections in the pool should be lazily created on demand and tim *** ## DataFrame and SQL Operations -You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SparkSession using the SparkContext that the StreamingContext is using. Furthermore this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SparkSession. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL. +You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SparkSession using the SparkContext that the StreamingContext is using. Furthermore, this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SparkSession. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL.

    @@ -2010,10 +2010,10 @@ To run a Spark Streaming applications, you need to have the following. + *Mesos* - [Marathon](https://github.com/mesosphere/marathon) has been used to achieve this with Mesos. -- *Configuring write ahead logs* - Since Spark 1.2, - we have introduced _write ahead logs_ for achieving strong +- *Configuring write-ahead logs* - Since Spark 1.2, + we have introduced _write-ahead logs_ for achieving strong fault-tolerance guarantees. If enabled, all the data received from a receiver gets written into - a write ahead log in the configuration checkpoint directory. This prevents data loss on driver + a write-ahead log in the configuration checkpoint directory. This prevents data loss on driver recovery, thus ensuring zero data loss (discussed in detail in the [Fault-tolerance Semantics](#fault-tolerance-semantics) section). This can be enabled by setting the [configuration parameter](configuration.html#spark-streaming) @@ -2021,15 +2021,15 @@ To run a Spark Streaming applications, you need to have the following. come at the cost of the receiving throughput of individual receivers. This can be corrected by running [more receivers in parallel](#level-of-parallelism-in-data-receiving) to increase aggregate throughput. Additionally, it is recommended that the replication of the - received data within Spark be disabled when the write ahead log is enabled as the log is already + received data within Spark be disabled when the write-ahead log is enabled as the log is already stored in a replicated storage system. This can be done by setting the storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER`. While using S3 (or any file system that - does not support flushing) for _write ahead logs_, please remember to enable + does not support flushing) for _write-ahead logs_, please remember to enable `spark.streaming.driver.writeAheadLog.closeFileAfterWrite` and `spark.streaming.receiver.writeAheadLog.closeFileAfterWrite`. See [Spark Streaming Configuration](configuration.html#spark-streaming) for more details. - Note that Spark will not encrypt data written to the write ahead log when I/O encryption is - enabled. If encryption of the write ahead log data is desired, it should be stored in a file + Note that Spark will not encrypt data written to the write-ahead log when I/O encryption is + enabled. If encryption of the write-ahead log data is desired, it should be stored in a file system that supports encryption natively. - *Setting the max receiving rate* - If the cluster resources is not large enough for the streaming @@ -2284,9 +2284,9 @@ Having bigger blockinterval means bigger blocks. A high value of `spark.locality - Instead of relying on batchInterval and blockInterval, you can define the number of partitions by calling `inputDstream.repartition(n)`. This reshuffles the data in RDD randomly to create n number of partitions. Yes, for greater parallelism. Though comes at the cost of a shuffle. An RDD's processing is scheduled by driver's jobscheduler as a job. At a given point of time only one job is active. So, if one job is executing the other jobs are queued. -- If you have two dstreams there will be two RDDs formed and there will be two jobs created which will be scheduled one after the another. To avoid this, you can union two dstreams. This will ensure that a single unionRDD is formed for the two RDDs of the dstreams. This unionRDD is then considered as a single job. However the partitioning of the RDDs is not impacted. +- If you have two dstreams there will be two RDDs formed and there will be two jobs created which will be scheduled one after the another. To avoid this, you can union two dstreams. This will ensure that a single unionRDD is formed for the two RDDs of the dstreams. This unionRDD is then considered as a single job. However, the partitioning of the RDDs is not impacted. -- If the batch processing time is more than batchinterval then obviously the receiver's memory will start filling up and will end up in throwing exceptions (most probably BlockNotFoundException). Currently there is no way to pause the receiver. Using SparkConf configuration `spark.streaming.receiver.maxRate`, rate of receiver can be limited. +- If the batch processing time is more than batchinterval then obviously the receiver's memory will start filling up and will end up in throwing exceptions (most probably BlockNotFoundException). Currently, there is no way to pause the receiver. Using SparkConf configuration `spark.streaming.receiver.maxRate`, rate of receiver can be limited. *************************************************************************************************** @@ -2388,7 +2388,7 @@ then besides these losses, all of the past data that was received and replicated lost. This will affect the results of the stateful transformations. To avoid this loss of past received data, Spark 1.2 introduced _write -ahead logs_ which save the received data to fault-tolerant storage. With the [write ahead logs +ahead logs_ which save the received data to fault-tolerant storage. With the [write-ahead logs enabled](#deploying-applications) and reliable receivers, there is zero data loss. In terms of semantics, it provides an at-least once guarantee. The following table summarizes the semantics under failures: @@ -2402,7 +2402,7 @@ The following table summarizes the semantics under failures:
    Spark 1.1 or earlier, OR
    - Spark 1.2 or later without write ahead logs + Spark 1.2 or later without write-ahead logs
    Buffered data lost with unreliable receivers
    @@ -2416,7 +2416,7 @@ The following table summarizes the semantics under failures:
    Spark 1.2 or later with write ahead logsSpark 1.2 or later with write-ahead logs Zero data loss with reliable receivers
    At-least once semantics diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index 5647ec6bc5797..71fd5b10cc407 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -15,7 +15,7 @@ For Scala/Java applications using SBT/Maven project definitions, link your appli For Python applications, you need to add this above library and its dependencies when deploying your application. See the [Deploying](#deploying) subsection below. -For experimenting on `spark-shell`, you need to add this above library and its dependencies too when invoking `spark-shell`. Also see the [Deploying](#deploying) subsection below. +For experimenting on `spark-shell`, you need to add this above library and its dependencies too when invoking `spark-shell`. Also, see the [Deploying](#deploying) subsection below. ## Reading Data from Kafka diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 9a83f157452ad..602a4c70848e7 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -8,7 +8,7 @@ title: Structured Streaming Programming Guide {:toc} # Overview -Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* +Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write-Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* Internally, by default, Structured Streaming queries are processed using a *micro-batch processing* engine, which processes data streams as a series of small batch jobs thereby achieving end-to-end latencies as low as 100 milliseconds and exactly-once fault-tolerance guarantees. However, since Spark 2.3, we have introduced a new low-latency processing mode called **Continuous Processing**, which can achieve end-to-end latencies as low as 1 millisecond with at-least-once guarantees. Without changing the Dataset/DataFrame operations in your queries, you will be able to choose the mode based on your application requirements. @@ -479,7 +479,7 @@ detail in the [Window Operations](#window-operations-on-event-time) section. ## Fault Tolerance Semantics Delivering end-to-end exactly-once semantics was one of key goals behind the design of Structured Streaming. To achieve that, we have designed the Structured Streaming sources, the sinks and the execution engine to reliably track the exact progress of the processing so that it can handle any kind of failure by restarting and/or reprocessing. Every streaming source is assumed to have offsets (similar to Kafka offsets, or Kinesis sequence numbers) -to track the read position in the stream. The engine uses checkpointing and write ahead logs to record the offset range of the data being processed in each trigger. The streaming sinks are designed to be idempotent for handling reprocessing. Together, using replayable sources and idempotent sinks, Structured Streaming can ensure **end-to-end exactly-once semantics** under any failure. +to track the read position in the stream. The engine uses checkpointing and write-ahead logs to record the offset range of the data being processed in each trigger. The streaming sinks are designed to be idempotent for handling reprocessing. Together, using replayable sources and idempotent sinks, Structured Streaming can ensure **end-to-end exactly-once semantics** under any failure. # API using Datasets and DataFrames Since Spark 2.0, DataFrames and Datasets can represent static, bounded data, as well as streaming, unbounded data. Similar to static Datasets/DataFrames, you can use the common entry point `SparkSession` @@ -690,7 +690,7 @@ These examples generate streaming DataFrames that are untyped, meaning that the By default, Structured Streaming from file based sources requires you to specify the schema, rather than rely on Spark to infer it automatically. This restriction ensures a consistent schema will be used for the streaming query, even in the case of failures. For ad-hoc use cases, you can reenable schema inference by setting `spark.sql.streaming.schemaInference` to `true`. -Partition discovery does occur when subdirectories that are named `/key=value/` are present and listing will automatically recurse into these directories. If these columns appear in the user provided schema, they will be filled in by Spark based on the path of the file being read. The directories that make up the partitioning scheme must be present when the query starts and must remain static. For example, it is okay to add `/data/year=2016/` when `/data/year=2015/` was present, but it is invalid to change the partitioning column (i.e. by creating the directory `/data/date=2016-04-17/`). +Partition discovery does occur when subdirectories that are named `/key=value/` are present and listing will automatically recurse into these directories. If these columns appear in the user-provided schema, they will be filled in by Spark based on the path of the file being read. The directories that make up the partitioning scheme must be present when the query starts and must remain static. For example, it is okay to add `/data/year=2016/` when `/data/year=2015/` was present, but it is invalid to change the partitioning column (i.e. by creating the directory `/data/date=2016-04-17/`). ## Operations on streaming DataFrames/Datasets You can apply all kinds of operations on streaming DataFrames/Datasets – ranging from untyped, SQL-like operations (e.g. `select`, `where`, `groupBy`), to typed RDD-like operations (e.g. `map`, `filter`, `flatMap`). See the [SQL programming guide](sql-programming-guide.html) for more details. Let’s take a look at a few example operations that you can use. @@ -2661,7 +2661,7 @@ sql("SET spark.sql.streaming.metricsEnabled=true") All queries started in the SparkSession after this configuration has been enabled will report metrics through Dropwizard to whatever [sinks](monitoring.html#metrics) have been configured (e.g. Ganglia, Graphite, JMX, etc.). ## Recovering from Failures with Checkpointing -In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. This checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries). +In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write-ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. This checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries).
    diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index a3643bf0838a1..77aa083c4a584 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -177,7 +177,7 @@ The master URL passed to Spark can be in one of the following formats: # Loading Configuration from a File The `spark-submit` script can load default [Spark configuration values](configuration.html) from a -properties file and pass them on to your application. By default it will read options +properties file and pass them on to your application. By default, it will read options from `conf/spark-defaults.conf` in the Spark directory. For more detail, see the section on [loading default configurations](configuration.html#loading-default-configurations). diff --git a/docs/tuning.md b/docs/tuning.md index fc27713f28d46..912c39879be8f 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -196,7 +196,7 @@ To further tune garbage collection, we first need to understand some basic infor * A simplified description of the garbage collection procedure: When Eden is full, a minor GC is run on Eden and objects that are alive from Eden and Survivor1 are copied to Survivor2. The Survivor regions are swapped. If an object is old - enough or Survivor2 is full, it is moved to Old. Finally when Old is close to full, a full GC is invoked. + enough or Survivor2 is full, it is moved to Old. Finally, when Old is close to full, a full GC is invoked. The goal of GC tuning in Spark is to ensure that only long-lived RDDs are stored in the Old generation and that the Young generation is sufficiently sized to store short-lived objects. This will help avoid full GCs to collect diff --git a/python/README.md b/python/README.md index 3f17fdb98a081..2e0112da58b94 100644 --- a/python/README.md +++ b/python/README.md @@ -22,7 +22,7 @@ This packaging is currently experimental and may change in future versions (alth Using PySpark requires the Spark JARs, and if you are building this from source please see the builder instructions at ["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). -The Python packaging for Spark is not intended to replace all of the other use cases. This Python packaged version of Spark is suitable for interacting with an existing cluster (be it Spark standalone, YARN, or Mesos) - but does not contain the tools required to setup your own standalone Spark cluster. You can download the full version of Spark from the [Apache Spark downloads page](http://spark.apache.org/downloads.html). +The Python packaging for Spark is not intended to replace all of the other use cases. This Python packaged version of Spark is suitable for interacting with an existing cluster (be it Spark standalone, YARN, or Mesos) - but does not contain the tools required to set up your own standalone Spark cluster. You can download the full version of Spark from the [Apache Spark downloads page](http://spark.apache.org/downloads.html). **NOTE:** If you are using this with a Spark standalone cluster you must ensure that the version (including minor version) matches or you may experience odd errors. diff --git a/sql/README.md b/sql/README.md index fe1d352050c09..70cc7c637b58d 100644 --- a/sql/README.md +++ b/sql/README.md @@ -6,7 +6,7 @@ This module provides support for executing relational queries expressed in eithe Spark SQL is broken up into four subprojects: - Catalyst (sql/catalyst) - An implementation-agnostic framework for manipulating trees of relational operators and expressions. - Execution (sql/core) - A query planner / execution engine for translating Catalyst's logical query plans into Spark RDDs. This component also includes a new public interface, SQLContext, that allows users to execute SQL or LINQ statements against existing RDDs and Parquet files. - - Hive Support (sql/hive) - Includes an extension of SQLContext called HiveContext that allows users to write queries using a subset of HiveQL and access data from a Hive Metastore using Hive SerDes. There are also wrappers that allows users to run queries that include Hive UDFs, UDAFs, and UDTFs. + - Hive Support (sql/hive) - Includes an extension of SQLContext called HiveContext that allows users to write queries using a subset of HiveQL and access data from a Hive Metastore using Hive SerDes. There are also wrappers that allow users to run queries that include Hive UDFs, UDAFs, and UDTFs. - HiveServer and CLI support (sql/hive-thriftserver) - Includes support for the SQL CLI (bin/spark-sql) and a HiveServer2 (for JDBC/ODBC) compatible server. Running `sql/create-docs.sh` generates SQL documentation for built-in functions under `sql/site`. From 94524019315ad463f9bc13c107131091d17c6af9 Mon Sep 17 00:00:00 2001 From: Yuchen Huo Date: Fri, 6 Apr 2018 08:35:20 -0700 Subject: [PATCH 573/774] [SPARK-23822][SQL] Improve error message for Parquet schema mismatches ## What changes were proposed in this pull request? This pull request tries to improve the error message for spark while reading parquet files with different schemas, e.g. One with a STRING column and the other with a INT column. A new ParquetSchemaColumnConvertNotSupportedException is added to replace the old UnsupportedOperationException. The Exception is again wrapped in FileScanRdd.scala to throw a more a general QueryExecutionException with the actual parquet file name which trigger the exception. ## How was this patch tested? Unit tests added to check the new exception and verify the error messages. Also manually tested with two parquet with different schema to check the error message. screen shot 2018-03-30 at 4 03 04 pm Author: Yuchen Huo Closes #20953 from yuchenhuo/SPARK-23822. --- ...emaColumnConvertNotSupportedException.java | 62 +++++++++++++++++++ .../parquet/VectorizedColumnReader.java | 38 ++++++++---- .../execution/QueryExecutionException.scala | 3 +- .../execution/datasources/FileScanRDD.scala | 21 ++++++- .../parquet/ParquetSchemaSuite.scala | 55 ++++++++++++++++ 5 files changed, 166 insertions(+), 13 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java new file mode 100644 index 0000000000000..82a1169cbe7ae --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * Exception thrown when the parquet reader find column type mismatches. + */ +@InterfaceStability.Unstable +public class SchemaColumnConvertNotSupportedException extends RuntimeException { + + /** + * Name of the column which cannot be converted. + */ + private String column; + /** + * Physical column type in the actual parquet file. + */ + private String physicalType; + /** + * Logical column type in the parquet schema the parquet reader use to parse all files. + */ + private String logicalType; + + public String getColumn() { + return column; + } + + public String getPhysicalType() { + return physicalType; + } + + public String getLogicalType() { + return logicalType; + } + + public SchemaColumnConvertNotSupportedException( + String column, + String physicalType, + String logicalType) { + super(); + this.column = column; + this.physicalType = physicalType; + this.logicalType = logicalType; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 47dd625f4b154..72f1d024b08ce 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet; import java.io.IOException; +import java.util.Arrays; import java.util.TimeZone; import org.apache.parquet.bytes.BytesUtils; @@ -31,6 +32,7 @@ import org.apache.parquet.schema.PrimitiveType; import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DecimalType; @@ -231,6 +233,18 @@ private boolean shouldConvertTimestamps() { return convertTz != null && !convertTz.equals(UTC); } + /** + * Helper function to construct exception for parquet schema mismatch. + */ + private SchemaColumnConvertNotSupportedException constructConvertNotSupportedException( + ColumnDescriptor descriptor, + WritableColumnVector column) { + return new SchemaColumnConvertNotSupportedException( + Arrays.toString(descriptor.getPath()), + descriptor.getType().toString(), + column.dataType().toString()); + } + /** * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`. */ @@ -261,7 +275,7 @@ private void decodeDictionaryIds( } } } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } break; @@ -282,7 +296,7 @@ private void decodeDictionaryIds( } } } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } break; @@ -321,7 +335,7 @@ private void decodeDictionaryIds( } } } else { - throw new UnsupportedOperationException(); + throw constructConvertNotSupportedException(descriptor, column); } break; case BINARY: @@ -360,7 +374,7 @@ private void decodeDictionaryIds( } } } else { - throw new UnsupportedOperationException(); + throw constructConvertNotSupportedException(descriptor, column); } break; @@ -375,7 +389,9 @@ private void decodeDictionaryIds( */ private void readBooleanBatch(int rowId, int num, WritableColumnVector column) { - assert(column.dataType() == DataTypes.BooleanType); + if (column.dataType() != DataTypes.BooleanType) { + throw constructConvertNotSupportedException(descriptor, column); + } defColumn.readBooleans( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } @@ -394,7 +410,7 @@ private void readIntBatch(int rowId, int num, WritableColumnVector column) { defColumn.readShorts( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -414,7 +430,7 @@ private void readLongBatch(int rowId, int num, WritableColumnVector column) { } } } else { - throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -425,7 +441,7 @@ private void readFloatBatch(int rowId, int num, WritableColumnVector column) { defColumn.readFloats( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else { - throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -436,7 +452,7 @@ private void readDoubleBatch(int rowId, int num, WritableColumnVector column) { defColumn.readDoubles( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -471,7 +487,7 @@ private void readBinaryBatch(int rowId, int num, WritableColumnVector column) { } } } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -510,7 +526,7 @@ private void readFixedLenByteArrayBatch( } } } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala index 16806c620635f..cffd97baea6a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala @@ -17,4 +17,5 @@ package org.apache.spark.sql.execution -class QueryExecutionException(message: String) extends Exception(message) +class QueryExecutionException(message: String, cause: Throwable = null) + extends Exception(message, cause) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 835ce98462477..28c36b6020d33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -21,11 +21,14 @@ import java.io.{FileNotFoundException, IOException} import scala.collection.mutable +import org.apache.parquet.io.ParquetDecodingException + import org.apache.spark.{Partition => RDDPartition, TaskContext, TaskKilledException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{InputFileBlockHolder, RDD} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.NextIterator @@ -179,7 +182,23 @@ class FileScanRDD( currentIterator = readCurrentFile() } - hasNext + try { + hasNext + } catch { + case e: SchemaColumnConvertNotSupportedException => + val message = "Parquet column cannot be converted in " + + s"file ${currentFile.filePath}. Column: ${e.getColumn}, " + + s"Expected: ${e.getLogicalType}, Found: ${e.getPhysicalType}" + throw new QueryExecutionException(message, e) + case e: ParquetDecodingException => + if (e.getMessage.contains("Can not read value at")) { + val message = "Encounter error while reading parquet files. " + + "One possible cause: Parquet column cannot be converted in the " + + "corresponding files. Details: " + throw new QueryExecutionException(message, e) + } + throw e + } } else { currentFile = null InputFileBlockHolder.unset() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 2cd2a600f2b97..9d3dfae348beb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -20,10 +20,13 @@ package org.apache.spark.sql.execution.datasources.parquet import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.apache.parquet.io.ParquetDecodingException import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.execution.QueryExecutionException +import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -382,6 +385,58 @@ class ParquetSchemaSuite extends ParquetSchemaTest { } } + // ======================================= + // Tests for parquet schema mismatch error + // ======================================= + def testSchemaMismatch(path: String, vectorizedReaderEnabled: Boolean): SparkException = { + import testImplicits._ + + var e: SparkException = null + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorizedReaderEnabled.toString) { + // Create two parquet files with different schemas in the same folder + Seq(("bcd", 2)).toDF("a", "b").coalesce(1).write.mode("overwrite").parquet(s"$path/parquet") + Seq((1, "abc")).toDF("a", "b").coalesce(1).write.mode("append").parquet(s"$path/parquet") + + e = intercept[SparkException] { + spark.read.parquet(s"$path/parquet").collect() + } + } + e + } + + test("schema mismatch failure error message for parquet reader") { + withTempPath { dir => + val e = testSchemaMismatch(dir.getCanonicalPath, vectorizedReaderEnabled = false) + val expectedMessage = "Encounter error while reading parquet files. " + + "One possible cause: Parquet column cannot be converted in the corresponding " + + "files. Details:" + assert(e.getCause.isInstanceOf[QueryExecutionException]) + assert(e.getCause.getCause.isInstanceOf[ParquetDecodingException]) + assert(e.getCause.getMessage.startsWith(expectedMessage)) + } + } + + test("schema mismatch failure error message for parquet vectorized reader") { + withTempPath { dir => + val e = testSchemaMismatch(dir.getCanonicalPath, vectorizedReaderEnabled = true) + assert(e.getCause.isInstanceOf[QueryExecutionException]) + assert(e.getCause.getCause.isInstanceOf[SchemaColumnConvertNotSupportedException]) + + // Check if the physical type is reporting correctly + val errMsg = e.getCause.getMessage + assert(errMsg.startsWith("Parquet column cannot be converted in file")) + val file = errMsg.substring("Parquet column cannot be converted in file ".length, + errMsg.indexOf(". ")) + val col = spark.read.parquet(file).schema.fields.filter(_.name.equals("a")) + assert(col.length == 1) + if (col(0).dataType == StringType) { + assert(errMsg.contains("Column: [a], Expected: IntegerType, Found: BINARY")) + } else { + assert(errMsg.endsWith("Column: [a], Expected: StringType, Found: INT32")) + } + } + } + // ======================================================= // Tests for converting Parquet LIST to Catalyst ArrayType // ======================================================= From d766ea2ff2bf59afbd631d3cc2e43bebfccdebed Mon Sep 17 00:00:00 2001 From: Li Jin Date: Sat, 7 Apr 2018 00:15:54 +0800 Subject: [PATCH 574/774] [SPARK-23861][SQL][DOC] Clarify default window frame with and without orderBy clause ## What changes were proposed in this pull request? Add docstring to clarify default window frame boundaries with and without orderBy clause ## How was this patch tested? Manually generate doc and check. Author: Li Jin Closes #20978 from icexelloss/SPARK-23861-window-doc. --- python/pyspark/sql/window.py | 4 ++++ .../main/scala/org/apache/spark/sql/expressions/Window.scala | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index e667fba099fb9..d19ced954f04e 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -44,6 +44,10 @@ class Window(object): >>> # PARTITION BY country ORDER BY date RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING >>> window = Window.orderBy("date").partitionBy("country").rangeBetween(-3, 3) + .. note:: When ordering is not defined, an unbounded window frame (rowFrame, + unboundedPreceding, unboundedFollowing) is used by default. When ordering is defined, + a growing window frame (rangeFrame, unboundedPreceding, currentRow) is used by default. + .. note:: Experimental .. versionadded:: 1.4 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index 1caa243f8d118..cd819bab1b14c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -33,6 +33,10 @@ import org.apache.spark.sql.catalyst.expressions._ * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) * }}} * + * @note When ordering is not defined, an unbounded window frame (rowFrame, unboundedPreceding, + * unboundedFollowing) is used by default. When ordering is defined, a growing window frame + * (rangeFrame, unboundedPreceding, currentRow) is used by default. + * * @since 1.4.0 */ @InterfaceStability.Stable From c926acf719a6deb9d884a0f19bde075c312bfe5a Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 6 Apr 2018 18:42:14 +0200 Subject: [PATCH 575/774] [SPARK-23882][CORE] UTF8StringSuite.writeToOutputStreamUnderflow() is not expected to be supported ## What changes were proposed in this pull request? This PR excludes an existing UT [`writeToOutputStreamUnderflow()`](https://github.com/apache/spark/blob/master/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java#L519-L532) in `UTF8StringSuite`. As discussed [here](https://github.com/apache/spark/pull/19222#discussion_r177692142), the behavior of this test looks surprising. This test seems to access metadata area of the JVM object where is reserved by `Platform.BYTE_ARRAY_OFFSET`. This test is introduced thru #16089 by NathanHowell. More specifically, [the commit](https://github.com/apache/spark/pull/16089/commits/27c102deb1701fe62f776fe4da61dac959270b73) `Improve test coverage of UTFString.write` introduced this UT. However, I cannot find any discussion about this UT. I think that it would be good to exclude this UT. ```java public void writeToOutputStreamUnderflow() throws IOException { // offset underflow is apparently supported? final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8); for (int i = 1; i <= Platform.BYTE_ARRAY_OFFSET; ++i) { new UTF8String( new ByteArrayMemoryBlock(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i)) .writeTo(outputStream); final ByteBuffer buffer = ByteBuffer.wrap(outputStream.toByteArray(), i, test.length); assertEquals("01234567", StandardCharsets.UTF_8.decode(buffer).toString()); outputStream.reset(); } } ``` ## How was this patch tested? Existing UTs Author: Kazuaki Ishizaki Closes #20995 from kiszk/SPARK-23882. --- .../spark/unsafe/types/UTF8StringSuite.java | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index bad908fcaf136..652c40a35527f 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -515,22 +515,6 @@ public void soundex() { assertEquals(fromString("世界千世").soundex(), fromString("世界千世")); } - @Test - public void writeToOutputStreamUnderflow() throws IOException { - // offset underflow is apparently supported? - final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8); - - for (int i = 1; i <= Platform.BYTE_ARRAY_OFFSET; ++i) { - new UTF8String( - new ByteArrayMemoryBlock(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i)) - .writeTo(outputStream); - final ByteBuffer buffer = ByteBuffer.wrap(outputStream.toByteArray(), i, test.length); - assertEquals("01234567", StandardCharsets.UTF_8.decode(buffer).toString()); - outputStream.reset(); - } - } - @Test public void writeToOutputStreamSlice() throws IOException { final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); From d23a805f975f209f273db2b52de3f336be17d873 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Fri, 6 Apr 2018 10:09:55 -0700 Subject: [PATCH 576/774] [SPARK-23859][ML] Initial PR for Instrumentation improvements: UUID and logging levels ## What changes were proposed in this pull request? Initial PR for Instrumentation improvements: UUID and logging levels. This PR takes over #20837 Closes #20837 ## How was this patch tested? Manual. Author: Bago Amirbekian Author: WeichenXu Closes #20982 from WeichenXu123/better-instrumentation. --- .../classification/LogisticRegression.scala | 15 ++++--- .../spark/ml/util/Instrumentation.scala | 40 +++++++++++++++---- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 3ae4db3f3f965..ee4b01058c75c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -517,6 +517,9 @@ class LogisticRegression @Since("1.2.0") ( (new MultivariateOnlineSummarizer, new MultiClassSummarizer) )(seqOp, combOp, $(aggregationDepth)) } + instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count) + instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString) + instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString) val histogram = labelSummarizer.histogram val numInvalid = labelSummarizer.countInvalid @@ -560,15 +563,15 @@ class LogisticRegression @Since("1.2.0") ( if (numInvalid != 0) { val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " + s"Found $numInvalid invalid labels." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } val isConstantLabel = histogram.count(_ != 0.0) == 1 if ($(fitIntercept) && isConstantLabel && !usingBoundConstrainedOptimization) { - logWarning(s"All labels are the same value and fitIntercept=true, so the coefficients " + - s"will be zeros. Training is not needed.") + instr.logWarning(s"All labels are the same value and fitIntercept=true, so the " + + s"coefficients will be zeros. Training is not needed.") val constantLabelIndex = Vectors.dense(histogram).argmax val coefMatrix = new SparseMatrix(numCoefficientSets, numFeatures, new Array[Int](numCoefficientSets + 1), Array.empty[Int], Array.empty[Double], @@ -581,7 +584,7 @@ class LogisticRegression @Since("1.2.0") ( (coefMatrix, interceptVec, Array.empty[Double]) } else { if (!$(fitIntercept) && isConstantLabel) { - logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " + + instr.logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " + s"dangerous ground, so the algorithm may not converge.") } @@ -590,7 +593,7 @@ class LogisticRegression @Since("1.2.0") ( if (!$(fitIntercept) && (0 until numFeatures).exists { i => featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) { - logWarning("Fitting LogisticRegressionModel without intercept on dataset with " + + instr.logWarning("Fitting LogisticRegressionModel without intercept on dataset with " + "constant nonzero column, Spark MLlib outputs zero coefficients for constant " + "nonzero columns. This behavior is the same as R glmnet but different from LIBSVM.") } @@ -708,7 +711,7 @@ class LogisticRegression @Since("1.2.0") ( (_initialModel.interceptVector.size == numCoefficientSets) && (_initialModel.getFitIntercept == $(fitIntercept)) if (!modelIsValid) { - logWarning(s"Initial coefficients will be ignored! Its dimensions " + + instr.logWarning(s"Initial coefficients will be ignored! Its dimensions " + s"(${providedCoefs.numRows}, ${providedCoefs.numCols}) did not match the " + s"expected size ($numCoefficientSets, $numFeatures)") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 7c46f45c59717..e694bc27b2f1e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.util -import java.util.concurrent.atomic.AtomicLong +import java.util.UUID import org.json4s._ import org.json4s.JsonDSL._ @@ -42,7 +42,7 @@ import org.apache.spark.sql.Dataset private[spark] class Instrumentation[E <: Estimator[_]] private ( estimator: E, dataset: RDD[_]) extends Logging { - private val id = Instrumentation.counter.incrementAndGet() + private val id = UUID.randomUUID() private val prefix = { val className = estimator.getClass.getSimpleName s"$className-${estimator.uid}-${dataset.hashCode()}-$id: " @@ -56,12 +56,31 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( } /** - * Logs a message with a prefix that uniquely identifies the training session. + * Logs a warning message with a prefix that uniquely identifies the training session. */ - def log(msg: String): Unit = { - logInfo(prefix + msg) + override def logWarning(msg: => String): Unit = { + super.logWarning(prefix + msg) } + /** + * Logs a error message with a prefix that uniquely identifies the training session. + */ + override def logError(msg: => String): Unit = { + super.logError(prefix + msg) + } + + /** + * Logs an info message with a prefix that uniquely identifies the training session. + */ + override def logInfo(msg: => String): Unit = { + super.logInfo(prefix + msg) + } + + /** + * Alias for logInfo, see above. + */ + def log(msg: String): Unit = logInfo(msg) + /** * Logs the value of the given parameters for the estimator being used in this session. */ @@ -77,11 +96,11 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( } def logNumFeatures(num: Long): Unit = { - log(compact(render("numFeatures" -> num))) + logNamedValue(Instrumentation.loggerTags.numFeatures, num) } def logNumClasses(num: Long): Unit = { - log(compact(render("numClasses" -> num))) + logNamedValue(Instrumentation.loggerTags.numClasses, num) } /** @@ -107,7 +126,12 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( * Some common methods for logging information about a training session. */ private[spark] object Instrumentation { - private val counter = new AtomicLong(0) + + object loggerTags { + val numFeatures = "numFeatures" + val numClasses = "numClasses" + val numExamples = "numExamples" + } /** * Creates an instrumentation object for a training session. From b6935ffb4dfb1d9fdf36ba402ac07bd02978c012 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 6 Apr 2018 10:23:26 -0700 Subject: [PATCH 577/774] [SPARK-10399][SPARK-23879][HOTFIX] Fix Java lint errors ## What changes were proposed in this pull request? This PR fixes the following errors in [Java lint](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Compile/job/spark-master-lint/7717/console) after #19222 has been merged. These errors were pointed by ueshin . ``` [ERROR] src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java:[57] (sizes) LineLength: Line is longer than 100 characters (found 106). [ERROR] src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java:[26,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java:[23,10] (modifier) ModifierOrder: 'public' modifier out of order with the JLS suggestions. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[64,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[69,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[74,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[79,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[84,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[89,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[94,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[99,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[104,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[109,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[114,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[119,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[124,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[129,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[60,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[65,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[70,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[75,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[80,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[85,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[90,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[95,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[100,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[105,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[110,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[115,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[120,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[125,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java:[114,16] (modifier) ModifierOrder: 'static' modifier out of order with the JLS suggestions. [ERROR] src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java:[20,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform. [ERROR] src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java:[30,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.memory.MemoryBlock. [ERROR] src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java:[126,15] (naming) MethodName: Method name 'ByteArrayMemoryBlockTest' must match pattern '^[a-z][a-z0-9][a-zA-Z0-9_]*$'. [ERROR] src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java:[143,15] (naming) MethodName: Method name 'OnHeapMemoryBlockTest' must match pattern '^[a-z][a-z0-9][a-zA-Z0-9_]*$'. [ERROR] src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java:[160,15] (naming) MethodName: Method name 'OffHeapArrayMemoryBlockTest' must match pattern '^[a-z][a-z0-9][a-zA-Z0-9_]*$'. [ERROR] src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java:[19,8] (imports) UnusedImports: Unused import - com.google.common.primitives.Ints. [ERROR] src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java:[21,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform. [ERROR] src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java:[20,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform. ``` ## How was this patch tested? Existing UTs Author: Kazuaki Ishizaki Closes #20991 from kiszk/SPARK-10399-jlint. --- .../sql/catalyst/expressions/HiveHasher.java | 1 - .../spark/unsafe/array/ByteArrayMethods.java | 4 +-- .../unsafe/memory/ByteArrayMemoryBlock.java | 28 +++++++++---------- .../unsafe/memory/HeapMemoryAllocator.java | 2 -- .../spark/unsafe/memory/MemoryBlock.java | 2 +- .../unsafe/memory/OffHeapMemoryBlock.java | 2 +- .../unsafe/memory/OnHeapMemoryBlock.java | 28 +++++++++---------- .../spark/unsafe/memory/MemoryBlockSuite.java | 6 ++-- .../spark/unsafe/types/UTF8StringSuite.java | 1 - .../spark/sql/catalyst/expressions/XXH64.java | 3 -- .../catalyst/expressions/HiveHasherSuite.java | 1 - 11 files changed, 35 insertions(+), 43 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java index 5d905943a3aa7..c34e36903a93e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; /** diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index c334c9651cf6b..4bc9955090fd7 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -54,7 +54,7 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { * @return true if the arrays are equal, false otherwise */ public static boolean arrayEqualsBlock( - MemoryBlock leftBase, long leftOffset, MemoryBlock rightBase, long rightOffset, final long length) { + MemoryBlock leftBase, long leftOffset, MemoryBlock rightBase, long rightOffset, long length) { return arrayEquals(leftBase.getBaseObject(), leftBase.getBaseOffset() + leftOffset, rightBase.getBaseObject(), rightBase.getBaseOffset() + rightOffset, length); } @@ -64,7 +64,7 @@ public static boolean arrayEqualsBlock( * @return true if the arrays are equal, false otherwise */ public static boolean arrayEquals( - Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) { + Object leftBase, long leftOffset, Object rightBase, long rightOffset, long length) { int i = 0; // check if starts align and we can get both offsets to be aligned diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java index 99a9868a49a79..9f238632bc87a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java @@ -57,72 +57,72 @@ public static ByteArrayMemoryBlock fromArray(final byte[] array) { } @Override - public final int getInt(long offset) { + public int getInt(long offset) { return Platform.getInt(array, this.offset + offset); } @Override - public final void putInt(long offset, int value) { + public void putInt(long offset, int value) { Platform.putInt(array, this.offset + offset, value); } @Override - public final boolean getBoolean(long offset) { + public boolean getBoolean(long offset) { return Platform.getBoolean(array, this.offset + offset); } @Override - public final void putBoolean(long offset, boolean value) { + public void putBoolean(long offset, boolean value) { Platform.putBoolean(array, this.offset + offset, value); } @Override - public final byte getByte(long offset) { + public byte getByte(long offset) { return array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)]; } @Override - public final void putByte(long offset, byte value) { + public void putByte(long offset, byte value) { array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)] = value; } @Override - public final short getShort(long offset) { + public short getShort(long offset) { return Platform.getShort(array, this.offset + offset); } @Override - public final void putShort(long offset, short value) { + public void putShort(long offset, short value) { Platform.putShort(array, this.offset + offset, value); } @Override - public final long getLong(long offset) { + public long getLong(long offset) { return Platform.getLong(array, this.offset + offset); } @Override - public final void putLong(long offset, long value) { + public void putLong(long offset, long value) { Platform.putLong(array, this.offset + offset, value); } @Override - public final float getFloat(long offset) { + public float getFloat(long offset) { return Platform.getFloat(array, this.offset + offset); } @Override - public final void putFloat(long offset, float value) { + public void putFloat(long offset, float value) { Platform.putFloat(array, this.offset + offset, value); } @Override - public final double getDouble(long offset) { + public double getDouble(long offset) { return Platform.getDouble(array, this.offset + offset); } @Override - public final void putDouble(long offset, double value) { + public void putDouble(long offset, double value) { Platform.putDouble(array, this.offset + offset, value); } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index acf28fd7ee59b..36caf80888cda 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -23,8 +23,6 @@ import java.util.LinkedList; import java.util.Map; -import org.apache.spark.unsafe.Platform; - /** * A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array. */ diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index b086941108522..ca7213bbf92da 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -111,7 +111,7 @@ public final void fill(byte value) { /** * Instantiate MemoryBlock for given object type with new offset */ - public final static MemoryBlock allocateFromObject(Object obj, long offset, long length) { + public static final MemoryBlock allocateFromObject(Object obj, long offset, long length) { MemoryBlock mb = null; if (obj instanceof byte[]) { byte[] array = (byte[])obj; diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java index f90f62bf21dcb..3431b08980eb8 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java @@ -20,7 +20,7 @@ import org.apache.spark.unsafe.Platform; public class OffHeapMemoryBlock extends MemoryBlock { - static public final OffHeapMemoryBlock NULL = new OffHeapMemoryBlock(0, 0); + public static final OffHeapMemoryBlock NULL = new OffHeapMemoryBlock(0, 0); public OffHeapMemoryBlock(long address, long size) { super(null, address, size); diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java index 12f67c7bd593e..ee42bc27c9c5f 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java @@ -61,72 +61,72 @@ public static OnHeapMemoryBlock fromArray(final long[] array, long size) { } @Override - public final int getInt(long offset) { + public int getInt(long offset) { return Platform.getInt(array, this.offset + offset); } @Override - public final void putInt(long offset, int value) { + public void putInt(long offset, int value) { Platform.putInt(array, this.offset + offset, value); } @Override - public final boolean getBoolean(long offset) { + public boolean getBoolean(long offset) { return Platform.getBoolean(array, this.offset + offset); } @Override - public final void putBoolean(long offset, boolean value) { + public void putBoolean(long offset, boolean value) { Platform.putBoolean(array, this.offset + offset, value); } @Override - public final byte getByte(long offset) { + public byte getByte(long offset) { return Platform.getByte(array, this.offset + offset); } @Override - public final void putByte(long offset, byte value) { + public void putByte(long offset, byte value) { Platform.putByte(array, this.offset + offset, value); } @Override - public final short getShort(long offset) { + public short getShort(long offset) { return Platform.getShort(array, this.offset + offset); } @Override - public final void putShort(long offset, short value) { + public void putShort(long offset, short value) { Platform.putShort(array, this.offset + offset, value); } @Override - public final long getLong(long offset) { + public long getLong(long offset) { return Platform.getLong(array, this.offset + offset); } @Override - public final void putLong(long offset, long value) { + public void putLong(long offset, long value) { Platform.putLong(array, this.offset + offset, value); } @Override - public final float getFloat(long offset) { + public float getFloat(long offset) { return Platform.getFloat(array, this.offset + offset); } @Override - public final void putFloat(long offset, float value) { + public void putFloat(long offset, float value) { Platform.putFloat(array, this.offset + offset, value); } @Override - public final double getDouble(long offset) { + public double getDouble(long offset) { return Platform.getDouble(array, this.offset + offset); } @Override - public final void putDouble(long offset, double value) { + public void putDouble(long offset, double value) { Platform.putDouble(array, this.offset + offset, value); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java index 47f05c928f2e5..5d5fdc1c55a75 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java @@ -123,7 +123,7 @@ private void check(MemoryBlock memory, Object obj, long offset, int length) { } @Test - public void ByteArrayMemoryBlockTest() { + public void testByteArrayMemoryBlock() { byte[] obj = new byte[56]; long offset = Platform.BYTE_ARRAY_OFFSET; int length = obj.length; @@ -140,7 +140,7 @@ public void ByteArrayMemoryBlockTest() { } @Test - public void OnHeapMemoryBlockTest() { + public void testOnHeapMemoryBlock() { long[] obj = new long[7]; long offset = Platform.LONG_ARRAY_OFFSET; int length = obj.length * 8; @@ -157,7 +157,7 @@ public void OnHeapMemoryBlockTest() { } @Test - public void OffHeapArrayMemoryBlockTest() { + public void testOffHeapArrayMemoryBlock() { MemoryAllocator memoryAllocator = new UnsafeMemoryAllocator(); MemoryBlock memory = memoryAllocator.allocate(56); Object obj = memory.getBaseObject(); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 652c40a35527f..2c08535a16465 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -27,7 +27,6 @@ import com.google.common.collect.ImmutableMap; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; -import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; import org.junit.Test; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java index 883748932ad33..fe727f6011cbf 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java @@ -16,9 +16,6 @@ */ package org.apache.spark.sql.catalyst.expressions; -import com.google.common.primitives.Ints; - -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; // scalastyle: off diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java index 8ffc1d7c24d61..76930f9368514 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.UTF8String; From e998250588de0df250e2800278da4d3e3705c259 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 6 Apr 2018 11:51:36 -0700 Subject: [PATCH 578/774] [SPARK-23828][ML][PYTHON] PySpark StringIndexerModel should have constructor from labels ## What changes were proposed in this pull request? The Scala StringIndexerModel has an alternate constructor that will create the model from an array of label strings. Add the corresponding Python API: model = StringIndexerModel.from_labels(["a", "b", "c"]) ## How was this patch tested? Add doctest and unit test. Author: Huaxin Gao Closes #20968 from huaxingao/spark-23828. --- python/pyspark/ml/feature.py | 88 ++++++++++++++++++++++++++---------- python/pyspark/ml/tests.py | 41 ++++++++++++++++- 2 files changed, 104 insertions(+), 25 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index fcb0dfc563720..5a3e0dd655150 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2342,9 +2342,38 @@ def mean(self): return self._call_java("mean") +class _StringIndexerParams(JavaParams, HasHandleInvalid, HasInputCol, HasOutputCol): + """ + Params for :py:attr:`StringIndexer` and :py:attr:`StringIndexerModel`. + """ + + stringOrderType = Param(Params._dummy(), "stringOrderType", + "How to order labels of string column. The first label after " + + "ordering is assigned an index of 0. Supported options: " + + "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.", + typeConverter=TypeConverters.toString) + + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " + + "or NULL values) in features and label column of string type. " + + "Options are 'skip' (filter out rows with invalid data), " + + "error (throw an error), or 'keep' (put invalid data " + + "in a special additional bucket, at index numLabels).", + typeConverter=TypeConverters.toString) + + def __init__(self, *args): + super(_StringIndexerParams, self).__init__(*args) + self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc") + + @since("2.3.0") + def getStringOrderType(self): + """ + Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'. + """ + return self.getOrDefault(self.stringOrderType) + + @inherit_doc -class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable, - JavaMLWritable): +class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLWritable): """ A label indexer that maps a string column of labels to an ML column of label indices. If the input column is numeric, we cast it to string and index the string values. @@ -2388,23 +2417,16 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), ... key=lambda x: x[0]) [(0, 2.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 0.0)] + >>> fromlabelsModel = StringIndexerModel.from_labels(["a", "b", "c"], + ... inputCol="label", outputCol="indexed", handleInvalid="error") + >>> result = fromlabelsModel.transform(stringIndDf) + >>> sorted(set([(i[0], i[1]) for i in result.select(result.id, result.indexed).collect()]), + ... key=lambda x: x[0]) + [(0, 0.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 2.0)] .. versionadded:: 1.4.0 """ - stringOrderType = Param(Params._dummy(), "stringOrderType", - "How to order labels of string column. The first label after " + - "ordering is assigned an index of 0. Supported options: " + - "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.", - typeConverter=TypeConverters.toString) - - handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " + - "or NULL values) in features and label column of string type. " + - "Options are 'skip' (filter out rows with invalid data), " + - "error (throw an error), or 'keep' (put invalid data " + - "in a special additional bucket, at index numLabels).", - typeConverter=TypeConverters.toString) - @keyword_only def __init__(self, inputCol=None, outputCol=None, handleInvalid="error", stringOrderType="frequencyDesc"): @@ -2414,7 +2436,6 @@ def __init__(self, inputCol=None, outputCol=None, handleInvalid="error", """ super(StringIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid) - self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -2440,21 +2461,33 @@ def setStringOrderType(self, value): """ return self._set(stringOrderType=value) - @since("2.3.0") - def getStringOrderType(self): - """ - Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'. - """ - return self.getOrDefault(self.stringOrderType) - -class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): +class StringIndexerModel(JavaModel, _StringIndexerParams, JavaMLReadable, JavaMLWritable): """ Model fitted by :py:class:`StringIndexer`. .. versionadded:: 1.4.0 """ + @classmethod + @since("2.4.0") + def from_labels(cls, labels, inputCol, outputCol=None, handleInvalid=None): + """ + Construct the model directly from an array of label strings, + requires an active SparkContext. + """ + sc = SparkContext._active_spark_context + java_class = sc._gateway.jvm.java.lang.String + jlabels = StringIndexerModel._new_java_array(labels, java_class) + model = StringIndexerModel._create_from_java_class( + "org.apache.spark.ml.feature.StringIndexerModel", jlabels) + model.setInputCol(inputCol) + if outputCol is not None: + model.setOutputCol(outputCol) + if handleInvalid is not None: + model.setHandleInvalid(handleInvalid) + return model + @property @since("1.5.0") def labels(self): @@ -2463,6 +2496,13 @@ def labels(self): """ return self._call_java("labels") + @since("2.4.0") + def setHandleInvalid(self, value): + """ + Sets the value of :py:attr:`handleInvalid`. + """ + return self._set(handleInvalid=value) + @inherit_doc class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index c2c4861e2aff4..4ce54547eab09 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -800,6 +800,43 @@ def test_string_indexer_handle_invalid(self): expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)] self.assertEqual(actual2, expected2) + def test_string_indexer_from_labels(self): + model = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label", + outputCol="indexed", handleInvalid="keep") + self.assertEqual(model.labels, ["a", "b", "c"]) + + df1 = self.spark.createDataFrame([ + (0, "a"), + (1, "c"), + (2, None), + (3, "b"), + (4, "b")], ["id", "label"]) + + result1 = model.transform(df1) + actual1 = result1.select("id", "indexed").collect() + expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=2.0), Row(id=2, indexed=3.0), + Row(id=3, indexed=1.0), Row(id=4, indexed=1.0)] + self.assertEqual(actual1, expected1) + + model_empty_labels = StringIndexerModel.from_labels( + [], inputCol="label", outputCol="indexed", handleInvalid="keep") + actual2 = model_empty_labels.transform(df1).select("id", "indexed").collect() + expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=0.0), Row(id=2, indexed=0.0), + Row(id=3, indexed=0.0), Row(id=4, indexed=0.0)] + self.assertEqual(actual2, expected2) + + # Test model with default settings can transform + model_default = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label") + df2 = self.spark.createDataFrame([ + (0, "a"), + (1, "c"), + (2, "b"), + (3, "b"), + (4, "b")], ["id", "label"]) + transformed_list = model_default.transform(df2)\ + .select(model_default.getOrDefault(model_default.outputCol)).collect() + self.assertEqual(len(transformed_list), 5) + class HasInducedError(Params): @@ -2097,9 +2134,11 @@ def test_java_params(self): ParamTests.check_params(self, cls(), check_params_exist=False) # Additional classes that need explicit construction - from pyspark.ml.feature import CountVectorizerModel + from pyspark.ml.feature import CountVectorizerModel, StringIndexerModel ParamTests.check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'), check_params_exist=False) + ParamTests.check_params(self, StringIndexerModel.from_labels(['a', 'b'], 'input'), + check_params_exist=False) def _squared_distance(a, b): From 6ab134ca7d8f7802a6d196929513cc02b9b4d35d Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 6 Apr 2018 15:00:13 -0700 Subject: [PATCH 579/774] [SPARK-21898][ML][FOLLOWUP] Fix Scala 2.12 build. ## What changes were proposed in this pull request? This is a follow-up pr of #19108 which broke Scala 2.12 build. ``` [error] /Users/ueshin/workspace/apache-spark/spark/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala:86: overloaded method value test with alternatives: [error] (dataset: org.apache.spark.sql.DataFrame,sampleCol: String,cdf: org.apache.spark.api.java.function.Function[java.lang.Double,java.lang.Double])org.apache.spark.sql.DataFrame [error] (dataset: org.apache.spark.sql.DataFrame,sampleCol: String,cdf: scala.Double => scala.Double)org.apache.spark.sql.DataFrame [error] cannot be applied to (org.apache.spark.sql.DataFrame, String, scala.Double => java.lang.Double) [error] test(dataset, sampleCol, (x: Double) => cdf.call(x)) [error] ^ [error] one error found ``` ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #20994 from ueshin/issues/SPARK-21898/fix_scala-2.12. --- .../scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala index 8d80e7768cb6e..c62d7463288f7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala @@ -83,7 +83,8 @@ object KolmogorovSmirnovTest { @Since("2.4.0") def test(dataset: DataFrame, sampleCol: String, cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = { - test(dataset, sampleCol, (x: Double) => cdf.call(x)) + val f: Double => Double = x => cdf.call(x) + test(dataset, sampleCol, f) } /** From 2c1fe647575e97e28b2232478ca86847d113e185 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 8 Apr 2018 12:09:06 +0800 Subject: [PATCH 580/774] [SPARK-23847][PYTHON][SQL] Add asc_nulls_first, asc_nulls_last to PySpark ## What changes were proposed in this pull request? Column.scala and Functions.scala have asc_nulls_first, asc_nulls_last, desc_nulls_first and desc_nulls_last. Add the corresponding python APIs in column.py and functions.py ## How was this patch tested? Add doctest Author: Huaxin Gao Closes #20962 from huaxingao/spark-23847. --- python/pyspark/sql/column.py | 56 +++++++++++++++++-- python/pyspark/sql/functions.py | 13 +++++ python/pyspark/sql/tests.py | 14 +++++ .../scala/org/apache/spark/sql/Column.scala | 4 +- .../org/apache/spark/sql/functions.scala | 2 +- 5 files changed, 82 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 922c7cf288f8f..e7dec11c69b57 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -447,24 +447,72 @@ def isin(self, *cols): # order _asc_doc = """ - Returns a sort expression based on the ascending order of the given column name + Returns a sort expression based on ascending order of the column. >>> from pyspark.sql import Row - >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"]) >>> df.select(df.name).orderBy(df.name.asc()).collect() [Row(name=u'Alice'), Row(name=u'Tom')] """ + _asc_nulls_first_doc = """ + Returns a sort expression based on ascending order of the column, and null values + return before non-null values. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) + >>> df.select(df.name).orderBy(df.name.asc_nulls_first()).collect() + [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')] + + .. versionadded:: 2.4 + """ + _asc_nulls_last_doc = """ + Returns a sort expression based on ascending order of the column, and null values + appear after non-null values. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) + >>> df.select(df.name).orderBy(df.name.asc_nulls_last()).collect() + [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)] + + .. versionadded:: 2.4 + """ _desc_doc = """ - Returns a sort expression based on the descending order of the given column name. + Returns a sort expression based on the descending order of the column. >>> from pyspark.sql import Row - >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"]) >>> df.select(df.name).orderBy(df.name.desc()).collect() [Row(name=u'Tom'), Row(name=u'Alice')] """ + _desc_nulls_first_doc = """ + Returns a sort expression based on the descending order of the column, and null values + appear before non-null values. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) + >>> df.select(df.name).orderBy(df.name.desc_nulls_first()).collect() + [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')] + + .. versionadded:: 2.4 + """ + _desc_nulls_last_doc = """ + Returns a sort expression based on the descending order of the column, and null values + appear after non-null values. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) + >>> df.select(df.name).orderBy(df.name.desc_nulls_last()).collect() + [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)] + + .. versionadded:: 2.4 + """ asc = ignore_unicode_prefix(_unary_op("asc", _asc_doc)) + asc_nulls_first = ignore_unicode_prefix(_unary_op("asc_nulls_first", _asc_nulls_first_doc)) + asc_nulls_last = ignore_unicode_prefix(_unary_op("asc_nulls_last", _asc_nulls_last_doc)) desc = ignore_unicode_prefix(_unary_op("desc", _desc_doc)) + desc_nulls_first = ignore_unicode_prefix(_unary_op("desc_nulls_first", _desc_nulls_first_doc)) + desc_nulls_last = ignore_unicode_prefix(_unary_op("desc_nulls_last", _desc_nulls_last_doc)) _isNull_doc = """ True if the current expression is null. diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ad3e37c872628..1b192680f0795 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -138,6 +138,17 @@ def _(): 'bitwiseNOT': 'Computes bitwise not.', } +_functions_2_4 = { + 'asc_nulls_first': 'Returns a sort expression based on the ascending order of the given' + + ' column name, and null values return before non-null values.', + 'asc_nulls_last': 'Returns a sort expression based on the ascending order of the given' + + ' column name, and null values appear after non-null values.', + 'desc_nulls_first': 'Returns a sort expression based on the descending order of the given' + + ' column name, and null values appear before non-null values.', + 'desc_nulls_last': 'Returns a sort expression based on the descending order of the given' + + ' column name, and null values appear after non-null values', +} + _collect_list_doc = """ Aggregate function: returns a list of objects with duplicates. @@ -250,6 +261,8 @@ def _(): globals()[_name] = since(2.1)(_create_function(_name, _doc)) for _name, _message in _functions_deprecated.items(): globals()[_name] = _wrap_deprecated_function(globals()[_name], _message) +for _name, _doc in _functions_2_4.items(): + globals()[_name] = since(2.4)(_create_function(_name, _doc)) del _name, _doc diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5181053a0d318..dd04ffb4ed393 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2991,6 +2991,20 @@ def test_create_dateframe_from_pandas_with_dst(self): os.environ['TZ'] = orig_env_tz time.tzset() + def test_2_4_functions(self): + from pyspark.sql import functions + + df = self.spark.createDataFrame( + [('Tom', 80), (None, 60), ('Alice', 50)], ["name", "height"]) + df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect() + [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')] + df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect() + [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)] + df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect() + [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')] + df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect() + [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)] + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 92988680871a4..ad0efbae89830 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -1083,10 +1083,10 @@ class Column(val expr: Expression) extends Logging { * and null values return before non-null values. * {{{ * // Scala: sort a DataFrame by age column in ascending order and null values appearing first. - * df.sort(df("age").asc_nulls_last) + * df.sort(df("age").asc_nulls_first) * * // Java - * df.sort(df.col("age").asc_nulls_last()); + * df.sort(df.col("age").asc_nulls_first()); * }}} * * @group expr_ops diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c9ca9a8996344..c658f25ced053 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -132,7 +132,7 @@ object functions { * Returns a sort expression based on ascending order of the column, * and null values return before non-null values. * {{{ - * df.sort(asc_nulls_last("dept"), desc("age")) + * df.sort(asc_nulls_first("dept"), desc("age")) * }}} * * @group sort_funcs From 6a734575a80e6b4ec4963206254451f05d64b742 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 7 Apr 2018 21:44:32 -0700 Subject: [PATCH 581/774] [SPARK-23849][SQL] Tests for the samplingRatio option of JSON datasource ## What changes were proposed in this pull request? Proposed tests checks that only subset of input dataset is touched during schema inferring. Author: Maxim Gekk Closes #20963 from MaxGekk/json-sampling-tests. --- .../datasources/json/JsonSuite.scala | 37 ++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 10bac0554484a..70aee561ff0f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.nio.charset.StandardCharsets -import java.nio.file.Files +import java.nio.file.{Files, Paths, StandardOpenOption} import java.sql.{Date, Timestamp} import java.util.Locale @@ -2127,4 +2127,39 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(df.schema === expectedSchema) } } + + test("SPARK-23849: schema inferring touches less data if samplingRation < 1.0") { + val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, + 57, 62, 68, 72) + withTempPath { path => + val writer = Files.newBufferedWriter(Paths.get(path.getAbsolutePath), + StandardCharsets.UTF_8, StandardOpenOption.CREATE_NEW) + for (i <- 0 until 100) { + if (predefinedSample.contains(i)) { + writer.write(s"""{"f1":${i.toString}}""" + "\n") + } else { + writer.write(s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n") + } + } + writer.close() + + val ds = spark.read.option("samplingRatio", 0.1).json(path.getCanonicalPath) + assert(ds.schema == new StructType().add("f1", LongType)) + } + } + + test("SPARK-23849: usage of samplingRation while parsing of dataset of strings") { + val dstr = spark.sparkContext.parallelize(0 until 100, 1).map { i => + val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, + 57, 62, 68, 72) + if (predefinedSample.contains(i)) { + s"""{"f1":${i.toString}}""" + "\n" + } else { + s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n" + } + }.toDS() + val ds = spark.read.option("samplingRatio", 0.1).json(dstr) + + assert(ds.schema == new StructType().add("f1", LongType)) + } } From 710a68cec27a94c2df10d8b4022a755a94a5443b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 8 Apr 2018 20:26:31 +0200 Subject: [PATCH 582/774] [SPARK-23892][TEST] Improve converge and fix lint error in UTF8String-related tests ## What changes were proposed in this pull request? This PR improves test coverage in `UTF8StringSuite` and code efficiency in `UTF8StringPropertyCheckSuite`. This PR also fixes lint-java issue in `UTF8StringSuite` reported at [here](https://github.com/apache/spark/pull/20995#issuecomment-379325527) ```[ERROR] src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java:[28,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform.``` ## How was this patch tested? Existing UT Author: Kazuaki Ishizaki Closes #21000 from kiszk/SPARK-23892. --- .../java/org/apache/spark/unsafe/types/UTF8StringSuite.java | 5 ++--- .../spark/unsafe/types/UTF8StringPropertyCheckSuite.scala | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 2c08535a16465..42dda30480702 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -25,7 +25,6 @@ import java.util.*; import com.google.common.collect.ImmutableMap; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; import org.junit.Test; @@ -53,8 +52,8 @@ private static void checkBasic(String str, int len) { assertTrue(s1.contains(s2)); assertTrue(s2.contains(s1)); - assertTrue(s1.startsWith(s1)); - assertTrue(s1.endsWith(s1)); + assertTrue(s1.startsWith(s2)); + assertTrue(s1.endsWith(s2)); } @Test diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala index 62d4176d00f94..48004e812a8bf 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala @@ -164,7 +164,7 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty def padding(origin: String, pad: String, length: Int, isLPad: Boolean): String = { if (length <= 0) return "" if (length <= origin.length) { - if (length <= 0) "" else origin.substring(0, length) + origin.substring(0, length) } else { if (pad.length == 0) return origin val toPad = length - origin.length From 8d40a79a077a30024a8ef921781b68f6f7e542d1 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 8 Apr 2018 20:40:27 +0200 Subject: [PATCH 583/774] [SPARK-23893][CORE][SQL] Avoid possible integer overflow in multiplication ## What changes were proposed in this pull request? This PR avoids possible overflow at an operation `long = (long)(int * int)`. The multiplication of large positive integer values may set one to MSB. This leads to a negative value in long while we expected a positive value (e.g. `0111_0000_0000_0000 * 0000_0000_0000_0010`). This PR performs long cast before the multiplication to avoid this situation. ## How was this patch tested? Existing UTs Author: Kazuaki Ishizaki Closes #21002 from kiszk/SPARK-23893. --- .../util/collection/unsafe/sort/UnsafeInMemorySorter.java | 2 +- .../util/collection/unsafe/sort/UnsafeSortDataFormat.java | 2 +- .../src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala | 2 +- .../scala/org/apache/spark/InternalAccumulatorSuite.scala | 2 +- .../apache/spark/deploy/history/FsHistoryProviderSuite.scala | 4 ++-- .../test/scala/org/apache/spark/util/JsonProtocolSuite.scala | 4 ++-- .../src/test/scala/org/apache/spark/sql/HashBenchmark.scala | 2 +- .../scala/org/apache/spark/sql/HashByteArrayBenchmark.scala | 3 ++- .../org/apache/spark/sql/UnsafeProjectionBenchmark.scala | 2 +- .../columnar/compression/CompressionSchemeBenchmark.scala | 4 ++-- .../sql/execution/vectorized/ColumnarBatchBenchmark.scala | 2 +- 11 files changed, 15 insertions(+), 14 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 20a7a8b267438..717823ebbd320 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -124,7 +124,7 @@ public UnsafeInMemorySorter( int initialSize, boolean canUseRadixSort) { this(consumer, memoryManager, recordComparator, prefixComparator, - consumer.allocateArray(initialSize * 2), canUseRadixSort); + consumer.allocateArray(initialSize * 2L), canUseRadixSort); } public UnsafeInMemorySorter( diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java index d9f84d10e9051..37772f41caa87 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java @@ -84,7 +84,7 @@ public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int @Override public LongArray allocate(int length) { - assert (length * 2 <= buffer.size()) : + assert (length * 2L <= buffer.size()) : "the buffer is smaller than required: " + buffer.size() + " < " + (length * 2); return buffer; } diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index c9ed12f4e1bd4..13db4985b0b80 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -90,7 +90,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi // Otherwise, interpolate the number of partitions we need to try, but overestimate it // by 50%. We also cap the estimation in the end. if (results.size == 0) { - numPartsToTry = partsScanned * 4 + numPartsToTry = partsScanned * 4L } else { // the left side of max is >=1 whenever partsScanned >= 2 numPartsToTry = Math.max(1, diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala index 8d7be77f51fe9..62824a5bec9d1 100644 --- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -135,7 +135,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { // This job runs 2 stages, and we're in the second stage. Therefore, any task attempt // ID that's < 2 * numPartitions belongs to the first attempt of this stage. val taskContext = TaskContext.get() - val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2 + val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2L if (isFirstStageAttempt) { throw new FetchFailedException( SparkEnv.get.blockManager.blockManagerId, diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index fde5f25bce456..0ba57bf4563c1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -382,8 +382,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val log = newLogFile("downloadApp1", Some(s"attempt$i"), inProgress = false) writeFile(log, true, None, SparkListenerApplicationStart( - "downloadApp1", Some("downloadApp1"), 5000 * i, "test", Some(s"attempt$i")), - SparkListenerApplicationEnd(5001 * i) + "downloadApp1", Some("downloadApp1"), 5000L * i, "test", Some(s"attempt$i")), + SparkListenerApplicationEnd(5001L * i) ) log } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 4abbb8e7894f5..74b72d940eeef 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -317,7 +317,7 @@ class JsonProtocolSuite extends SparkFunSuite { test("SparkListenerJobStart backward compatibility") { // Prior to Spark 1.2.0, SparkListenerJobStart did not have a "Stage Infos" property. val stageIds = Seq[Int](1, 2, 3, 4) - val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400, x * 500)) + val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400L, x * 500L)) val dummyStageInfos = stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown")) val jobStart = SparkListenerJobStart(10, jobSubmissionTime, stageInfos, properties) @@ -331,7 +331,7 @@ class JsonProtocolSuite extends SparkFunSuite { // Prior to Spark 1.3.0, SparkListenerJobStart did not have a "Submission Time" property. // Also, SparkListenerJobEnd did not have a "Completion Time" property. val stageIds = Seq[Int](1, 2, 3, 4) - val stageInfos = stageIds.map(x => makeStageInfo(x * 10, x * 20, x * 30, x * 40, x * 50)) + val stageInfos = stageIds.map(x => makeStageInfo(x * 10, x * 20, x * 30, x * 40L, x * 50L)) val jobStart = SparkListenerJobStart(11, jobSubmissionTime, stageInfos, properties) val oldStartEvent = JsonProtocol.jobStartToJson(jobStart) .removeField({ _._1 == "Submission Time"}) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala index 2d94b66a1e122..9a89e6290e695 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala @@ -40,7 +40,7 @@ object HashBenchmark { safeProjection(encoder.toRow(generator().asInstanceOf[Row])).copy() ).toArray - val benchmark = new Benchmark("Hash For " + name, iters * numRows) + val benchmark = new Benchmark("Hash For " + name, iters * numRows.toLong) benchmark.addCase("interpreted version") { _: Int => var sum = 0 for (_ <- 0L until iters) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala index 2a753a0c84ed5..f6c8111f5bc57 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala @@ -36,7 +36,8 @@ object HashByteArrayBenchmark { bytes } - val benchmark = new Benchmark("Hash byte arrays with length " + length, iters * numArrays) + val benchmark = + new Benchmark("Hash byte arrays with length " + length, iters * numArrays.toLong) benchmark.addCase("Murmur3_x86_32") { _: Int => var sum = 0L for (_ <- 0L until iters) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala index 769addf3b29e6..6c63769945312 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala @@ -38,7 +38,7 @@ object UnsafeProjectionBenchmark { val iters = 1024 * 16 val numRows = 1024 * 16 - val benchmark = new Benchmark("unsafe projection", iters * numRows) + val benchmark = new Benchmark("unsafe projection", iters * numRows.toLong) val schema1 = new StructType().add("l", LongType, false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala index 9005ec93e786e..619b76fabdd5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala @@ -77,7 +77,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { count: Int, tpe: NativeColumnType[T], input: ByteBuffer): Unit = { - val benchmark = new Benchmark(name, iters * count) + val benchmark = new Benchmark(name, iters * count.toLong) schemes.filter(_.supports(tpe)).foreach { scheme => val (compressFunc, compressionRatio, buf) = prepareEncodeInternal(count, tpe, scheme, input) @@ -101,7 +101,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { count: Int, tpe: NativeColumnType[T], input: ByteBuffer): Unit = { - val benchmark = new Benchmark(name, iters * count) + val benchmark = new Benchmark(name, iters * count.toLong) schemes.filter(_.supports(tpe)).foreach { scheme => val (compressFunc, _, buf) = prepareEncodeInternal(count, tpe, scheme, input) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 1f31aa45a1220..8aeb06d428951 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -295,7 +295,7 @@ object ColumnarBatchBenchmark { def booleanAccess(iters: Int): Unit = { val count = 8 * 1024 - val benchmark = new Benchmark("Boolean Read/Write", iters * count) + val benchmark = new Benchmark("Boolean Read/Write", iters * count.toLong) benchmark.addCase("Bitset") { i: Int => { val b = new BitSet(count) var sum = 0L From 32471ba0af52b59141b44a8375025b6a7eafae70 Mon Sep 17 00:00:00 2001 From: Nolan Emirot Date: Mon, 9 Apr 2018 08:04:02 -0500 Subject: [PATCH 584/774] Fix typo in Python docstring kinesis example ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Nolan Emirot Closes #20990 from emirot/kinesis_stream_example_typo. --- .../src/main/python/examples/streaming/kinesis_wordcount_asl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py b/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py index 4d7fc9a549bfb..49794faab88c4 100644 --- a/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py +++ b/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py @@ -34,7 +34,7 @@ $ export AWS_SECRET_KEY= # run the example - $ bin/spark-submit -jar external/kinesis-asl/target/scala-*/\ + $ bin/spark-submit -jars external/kinesis-asl/target/scala-*/\ spark-streaming-kinesis-asl-assembly_*.jar \ external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \ myAppName mySparkStream https://kinesis.us-east-1.amazonaws.com From d81f29ecafe8fc9816e36087e3b8acdc93d6cc1b Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Mon, 9 Apr 2018 10:19:22 -0700 Subject: [PATCH 585/774] [SPARK-23881][CORE][TEST] Fix flaky test JobCancellationSuite."interruptible iterator of shuffle reader" ## What changes were proposed in this pull request? The test case JobCancellationSuite."interruptible iterator of shuffle reader" has been flaky because `KillTask` event is handled asynchronously, so it can happen that the semaphore is released but the task is still running. Actually we only have to check if the total number of processed elements is less than the input elements number, so we know the task get cancelled. ## How was this patch tested? The new test case still fails without the purposed patch, and succeeded in current master. Author: Xingbo Jiang Closes #20993 from jiangxb1987/JobCancellationSuite. --- .../apache/spark/JobCancellationSuite.scala | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 3b793bb231cf3..61da4138896cd 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -332,13 +332,15 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft import JobCancellationSuite._ sc = new SparkContext("local[2]", "test interruptible iterator") + // Increase the number of elements to be proceeded to avoid this test being flaky. + val numElements = 10000 val taskCompletedSem = new Semaphore(0) sc.addSparkListener(new SparkListener { override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { // release taskCancelledSemaphore when cancelTasks event has been posted if (stageCompleted.stageInfo.stageId == 1) { - taskCancelledSemaphore.release(1000) + taskCancelledSemaphore.release(numElements) } } @@ -349,28 +351,31 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft } }) - val f = sc.parallelize(1 to 1000).map { i => (i, i) } + // Explicitly disable interrupt task thread on cancelling tasks, so the task thread can only be + // interrupted by `InterruptibleIterator`. + sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false") + + val f = sc.parallelize(1 to numElements).map { i => (i, i) } .repartitionAndSortWithinPartitions(new HashPartitioner(1)) .mapPartitions { iter => taskStartedSemaphore.release() iter }.foreachAsync { x => - if (x._1 >= 10) { - // This block of code is partially executed. It will be blocked when x._1 >= 10 and the - // next iteration will be cancelled if the source iterator is interruptible. Then in this - // case, the maximum num of increment would be 10(|1...10|) - taskCancelledSemaphore.acquire() - } + // Block this code from being executed, until the job get cancelled. In this case, if the + // source iterator is interruptible, the max number of increment should be under + // `numElements`. + taskCancelledSemaphore.acquire() executionOfInterruptibleCounter.getAndIncrement() } taskStartedSemaphore.acquire() // Job is cancelled when: // 1. task in reduce stage has been started, guaranteed by previous line. - // 2. task in reduce stage is blocked after processing at most 10 records as - // taskCancelledSemaphore is not released until cancelTasks event is posted - // After job being cancelled, task in reduce stage will be cancelled and no more iteration are - // executed. + // 2. task in reduce stage is blocked as taskCancelledSemaphore is not released until + // JobCancelled event is posted. + // After job being cancelled, task in reduce stage will be cancelled asynchronously, thus + // partial of the inputs should not get processed (It's very unlikely that Spark can process + // 10000 elements between JobCancelled is posted and task is really killed). f.cancel() val e = intercept[SparkException](f.get()).getCause @@ -378,7 +383,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft // Make sure tasks are indeed completed. taskCompletedSem.acquire() - assert(executionOfInterruptibleCounter.get() <= 10) + assert(executionOfInterruptibleCounter.get() < numElements) } def testCount() { From 10f45bb8233e6ac838dd4f053052c8556f5b54bd Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 9 Apr 2018 11:31:21 -0700 Subject: [PATCH 586/774] [SPARK-23816][CORE] Killed tasks should ignore FetchFailures. SPARK-19276 ensured that FetchFailures do not get swallowed by other layers of exception handling, but it also meant that a killed task could look like a fetch failure. This is particularly a problem with speculative execution, where we expect to kill tasks as they are reading shuffle data. The fix is to ensure that we always check for killed tasks first. Added a new unit test which fails before the fix, ran it 1k times to check for flakiness. Full suite of tests on jenkins. Author: Imran Rashid Closes #20987 from squito/SPARK-23816. --- .../org/apache/spark/executor/Executor.scala | 26 +++--- .../apache/spark/executor/ExecutorSuite.scala | 92 +++++++++++++++---- 2 files changed, 88 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index dcec3ec21b546..c325222b764b8 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -480,6 +480,19 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) } catch { + case t: TaskKilledException => + logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") + setTaskFinishedAndClearInterruptStatus() + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) + + case _: InterruptedException | NonFatal(_) if + task != null && task.reasonIfKilled.isDefined => + val killReason = task.reasonIfKilled.getOrElse("unknown reason") + logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") + setTaskFinishedAndClearInterruptStatus() + execBackend.statusUpdate( + taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) + case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => val reason = task.context.fetchFailed.get.toTaskFailedReason if (!t.isInstanceOf[FetchFailedException]) { @@ -494,19 +507,6 @@ private[spark] class Executor( setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - case t: TaskKilledException => - logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) - - case _: InterruptedException | NonFatal(_) if - task != null && task.reasonIfKilled.isDefined => - val killReason = task.reasonIfKilled.getOrElse("unknown reason") - logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate( - taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) - case CausedBy(cDE: CommitDeniedException) => val reason = cDE.toTaskCommitDeniedReason setTaskFinishedAndClearInterruptStatus() diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 105a178f2d94e..1a7bebe2c53cd 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -22,6 +22,7 @@ import java.lang.Thread.UncaughtExceptionHandler import java.nio.ByteBuffer import java.util.Properties import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.Map import scala.concurrent.duration._ @@ -139,7 +140,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug // the fetch failure. The executor should still tell the driver that the task failed due to a // fetch failure, not a generic exception from user code. val inputRDD = new FetchFailureThrowingRDD(sc) - val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false) + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false, interrupt = false) val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() val task = new ResultTask( @@ -173,17 +174,48 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } test("SPARK-19276: OOMs correctly handled with a FetchFailure") { + val (failReason, uncaughtExceptionHandler) = testFetchFailureHandling(true) + assert(failReason.isInstanceOf[ExceptionFailure]) + val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]) + verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture()) + assert(exceptionCaptor.getAllValues.size === 1) + assert(exceptionCaptor.getAllValues().get(0).isInstanceOf[OutOfMemoryError]) + } + + test("SPARK-23816: interrupts are not masked by a FetchFailure") { + // If killing the task causes a fetch failure, we still treat it as a task that was killed, + // as the fetch failure could easily be caused by interrupting the thread. + val (failReason, _) = testFetchFailureHandling(false) + assert(failReason.isInstanceOf[TaskKilled]) + } + + /** + * Helper for testing some cases where a FetchFailure should *not* get sent back, because its + * superceded by another error, either an OOM or intentionally killing a task. + * @param oom if true, throw an OOM after the FetchFailure; else, interrupt the task after the + * FetchFailure + */ + private def testFetchFailureHandling( + oom: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = { // when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it // may be a false positive. And we should call the uncaught exception handler. + // SPARK-23816 also handle interrupts the same way, as killing an obsolete speculative task + // does not represent a real fetch failure. val conf = new SparkConf().setMaster("local").setAppName("executor suite test") sc = new SparkContext(conf) val serializer = SparkEnv.get.closureSerializer.newInstance() val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size - // Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat - // the fetch failure as a false positive, and just do normal OOM handling. + // Submit a job where a fetch failure is thrown, but then there is an OOM or interrupt. We + // should treat the fetch failure as a false positive, and do normal OOM or interrupt handling. val inputRDD = new FetchFailureThrowingRDD(sc) - val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true) + if (!oom) { + // we are trying to setup a case where a task is killed after a fetch failure -- this + // is just a helper to coordinate between the task thread and this thread that will + // kill the task + ExecutorSuiteHelper.latches = new ExecutorSuiteHelper() + } + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = oom, interrupt = !oom) val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() val task = new ResultTask( @@ -200,15 +232,8 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val serTask = serializer.serialize(task) val taskDescription = createFakeTaskDescription(serTask) - val (failReason, uncaughtExceptionHandler) = - runTaskGetFailReasonAndExceptionHandler(taskDescription) - // make sure the task failure just looks like a OOM, not a fetch failure - assert(failReason.isInstanceOf[ExceptionFailure]) - val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]) - verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture()) - assert(exceptionCaptor.getAllValues.size === 1) - assert(exceptionCaptor.getAllValues.get(0).isInstanceOf[OutOfMemoryError]) - } + runTaskGetFailReasonAndExceptionHandler(taskDescription, killTask = !oom) + } test("Gracefully handle error in task deserialization") { val conf = new SparkConf @@ -257,22 +282,39 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = { - runTaskGetFailReasonAndExceptionHandler(taskDescription)._1 + runTaskGetFailReasonAndExceptionHandler(taskDescription, false)._1 } private def runTaskGetFailReasonAndExceptionHandler( - taskDescription: TaskDescription): (TaskFailedReason, UncaughtExceptionHandler) = { + taskDescription: TaskDescription, + killTask: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = { val mockBackend = mock[ExecutorBackend] val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler] var executor: Executor = null + val timedOut = new AtomicBoolean(false) try { executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true, uncaughtExceptionHandler = mockUncaughtExceptionHandler) // the task will be launched in a dedicated worker thread executor.launchTask(mockBackend, taskDescription) + if (killTask) { + val killingThread = new Thread("kill-task") { + override def run(): Unit = { + // wait to kill the task until it has thrown a fetch failure + if (ExecutorSuiteHelper.latches.latch1.await(10, TimeUnit.SECONDS)) { + // now we can kill the task + executor.killAllTasks(true, "Killed task, eg. because of speculative execution") + } else { + timedOut.set(true) + } + } + } + killingThread.start() + } eventually(timeout(5.seconds), interval(10.milliseconds)) { assert(executor.numRunningTasks === 0) } + assert(!timedOut.get(), "timed out waiting to be ready to kill tasks") } finally { if (executor != null) { executor.stop() @@ -282,8 +324,9 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) orderedMock.verify(mockBackend) .statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture()) + val finalState = if (killTask) TaskState.KILLED else TaskState.FAILED orderedMock.verify(mockBackend) - .statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture()) + .statusUpdate(meq(0L), meq(finalState), statusCaptor.capture()) // first statusUpdate for RUNNING has empty data assert(statusCaptor.getAllValues().get(0).remaining() === 0) // second update is more interesting @@ -321,7 +364,8 @@ class SimplePartition extends Partition { class FetchFailureHidingRDD( sc: SparkContext, val input: FetchFailureThrowingRDD, - throwOOM: Boolean) extends RDD[Int](input) { + throwOOM: Boolean, + interrupt: Boolean) extends RDD[Int](input) { override def compute(split: Partition, context: TaskContext): Iterator[Int] = { val inItr = input.compute(split, context) try { @@ -330,6 +374,15 @@ class FetchFailureHidingRDD( case t: Throwable => if (throwOOM) { throw new OutOfMemoryError("OOM while handling another exception") + } else if (interrupt) { + // make sure our test is setup correctly + assert(TaskContext.get().asInstanceOf[TaskContextImpl].fetchFailed.isDefined) + // signal our test is ready for the task to get killed + ExecutorSuiteHelper.latches.latch1.countDown() + // then wait for another thread in the test to kill the task -- this latch + // is never actually decremented, we just wait to get killed. + ExecutorSuiteHelper.latches.latch2.await(10, TimeUnit.SECONDS) + throw new IllegalStateException("timed out waiting to be interrupted") } else { throw new RuntimeException("User Exception that hides the original exception", t) } @@ -352,6 +405,11 @@ private class ExecutorSuiteHelper { @volatile var testFailedReason: TaskFailedReason = _ } +// helper for coordinating killing tasks +private object ExecutorSuiteHelper { + var latches: ExecutorSuiteHelper = null +} + private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable { def writeExternal(out: ObjectOutput): Unit = {} def readExternal(in: ObjectInput): Unit = { From 7c1654e2159662e7e663ba141719d755002f770a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 9 Apr 2018 11:54:35 -0700 Subject: [PATCH 587/774] [SPARK-22856][SQL] Add wrappers for codegen output and nullability ## What changes were proposed in this pull request? The codegen output of `Expression`, aka `ExprCode`, now encapsulates only strings of output value (`value`) and nullability (`isNull`). It makes difficulty for us to know what the output really is. I think it is better if we can add wrappers for the value and nullability that let us to easily know that. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #20043 from viirya/SPARK-22856. --- .../catalyst/expressions/BoundAttribute.scala | 4 +- .../sql/catalyst/expressions/Expression.scala | 16 ++-- .../MonotonicallyIncreasingID.scala | 4 +- .../expressions/SparkPartitionID.scala | 4 +- .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 16 ++-- .../expressions/codegen/CodegenFallback.scala | 2 +- .../expressions/codegen/ExprValue.scala | 76 +++++++++++++++++++ .../codegen/GenerateMutableProjection.scala | 6 +- .../codegen/GenerateSafeProjection.scala | 19 +++-- .../codegen/GenerateUnsafeProjection.scala | 8 +- .../expressions/collectionOperations.scala | 4 +- .../expressions/complexTypeCreator.scala | 8 +- .../expressions/conditionalExpressions.scala | 3 +- .../expressions/datetimeExpressions.scala | 3 +- .../sql/catalyst/expressions/generators.scala | 4 +- .../spark/sql/catalyst/expressions/hash.scala | 4 +- .../catalyst/expressions/inputFileBlock.scala | 8 +- .../sql/catalyst/expressions/literals.scala | 24 +++--- .../spark/sql/catalyst/expressions/misc.scala | 5 +- .../expressions/nullExpressions.scala | 25 ++++-- .../expressions/objects/objects.scala | 28 ++++--- .../sql/catalyst/expressions/predicates.scala | 10 +-- .../expressions/randomExpressions.scala | 6 +- .../expressions/CodeGenerationSuite.scala | 6 +- .../expressions/codegen/ExprValueSuite.scala | 46 +++++++++++ .../sql/execution/ColumnarBatchScan.scala | 10 ++- .../spark/sql/execution/ExpandExec.scala | 5 +- .../spark/sql/execution/GenerateExec.scala | 13 ++-- .../sql/execution/WholeStageCodegenExec.scala | 13 ++-- .../aggregate/HashAggregateExec.scala | 3 +- .../aggregate/HashMapGenerator.scala | 5 +- .../execution/basicPhysicalOperators.scala | 6 +- .../joins/BroadcastHashJoinExec.scala | 6 +- .../execution/joins/SortMergeJoinExec.scala | 8 +- 35 files changed, 294 insertions(+), 120 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 89ffbb0016916..5021a567592e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types._ /** @@ -76,7 +76,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) | ${CodeGenerator.defaultValue(dataType)} : ($value); """.stripMargin) } else { - ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false") + ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 38caf67d465d8..7a5e49cb5206b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -104,7 +104,9 @@ abstract class Expression extends TreeNode[Expression] { }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val eval = doGenCode(ctx, ExprCode("", isNull, value)) + val eval = doGenCode(ctx, ExprCode("", + VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, CodeGenerator.javaType(dataType)))) reduceCodeSize(ctx, eval) if (eval.code.nonEmpty) { // Add `this` in the comment. @@ -118,10 +120,10 @@ abstract class Expression extends TreeNode[Expression] { private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { // TODO: support whole stage codegen too if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { - val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { + val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull - eval.isNull = globalIsNull + eval.isNull = GlobalValue(globalIsNull, CodeGenerator.JAVA_BOOLEAN) s"$globalIsNull = $localIsNull;" } else { "" @@ -140,7 +142,7 @@ abstract class Expression extends TreeNode[Expression] { |} """.stripMargin) - eval.value = newValue + eval.value = VariableValue(newValue, javaType) eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" } } @@ -446,7 +448,7 @@ abstract class UnaryExpression extends Expression { boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $resultCode""", isNull = "false") + $resultCode""", isNull = FalseLiteral) } } } @@ -546,7 +548,7 @@ abstract class BinaryExpression extends Expression { ${leftGen.code} ${rightGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $resultCode""", isNull = "false") + $resultCode""", isNull = FalseLiteral) } } } @@ -690,7 +692,7 @@ abstract class TernaryExpression extends Expression { ${midGen.code} ${rightGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $resultCode""", isNull = "false") + $resultCode""", isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index dd523d312e3b4..ad1e7bdb31987 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types.{DataType, LongType} /** @@ -73,7 +73,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { ev.copy(code = s""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; - $countTerm++;""", isNull = "false") + $countTerm++;""", isNull = FalseLiteral) } override def prettyName: String = "monotonically_increasing_id" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index cc6a769d032d3..787bcaf5e81de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -47,6 +47,6 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", - isNull = "false") + isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 508bdd5050b54..478ff3a7c1011 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -601,7 +601,8 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) + ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), + CodeGenerator.JAVA_BOOLEAN) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -680,7 +681,8 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) + ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), + CodeGenerator.JAVA_BOOLEAN) val evals = evalChildren.map(eval => s""" |${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 84b1e3fbda876..c9c60ef1be640 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -56,16 +56,17 @@ import org.apache.spark.util.{ParentClassLoader, Utils} * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. */ -case class ExprCode(var code: String, var isNull: String, var value: String) +case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue) object ExprCode { def forNullValue(dataType: DataType): ExprCode = { val defaultValueLiteral = CodeGenerator.defaultValue(dataType, typedNull = true) - ExprCode(code = "", isNull = "true", value = defaultValueLiteral) + ExprCode(code = "", isNull = TrueLiteral, + value = LiteralValue(defaultValueLiteral, CodeGenerator.javaType(dataType))) } - def forNonNullValue(value: String): ExprCode = { - ExprCode(code = "", isNull = "false", value = value) + def forNonNullValue(value: ExprValue): ExprCode = { + ExprCode(code = "", isNull = FalseLiteral, value = value) } } @@ -77,7 +78,7 @@ object ExprCode { * @param value A term for a value of a common sub-expression. Not valid if `isNull` * is set to `true`. */ -case class SubExprEliminationState(isNull: String, value: String) +case class SubExprEliminationState(isNull: ExprValue, value: ExprValue) /** * Codes and common subexpressions mapping used for subexpression elimination. @@ -330,7 +331,7 @@ class CodegenContext { case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" case _ => s"$value = $initCode;" } - ExprCode(code, "false", value) + ExprCode(code, FalseLiteral, GlobalValue(value, javaType(dataType))) } def declareMutableStates(): String = { @@ -1003,7 +1004,8 @@ class CodegenContext { // at least two nodes) as the cost of doing it is expected to be low. subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - val state = SubExprEliminationState(isNull, value) + val state = SubExprEliminationState(GlobalValue(isNull, JAVA_BOOLEAN), + GlobalValue(value, javaType(expr.dataType))) subExprEliminationExprs ++= e.map(_ -> state).toMap } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index e12420bb5dfdd..a91989e129664 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -59,7 +59,7 @@ trait CodegenFallback extends Expression { $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); $javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; - """, isNull = "false") + """, isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala new file mode 100644 index 0000000000000..df5f1c58b1b2d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import scala.language.implicitConversions + +import org.apache.spark.sql.types.DataType + +// An abstraction that represents the evaluation result of [[ExprCode]]. +abstract class ExprValue { + + val javaType: String + + // Whether we can directly access the evaluation value anywhere. + // For example, a variable created outside a method can not be accessed inside the method. + // For such cases, we may need to pass the evaluation as parameter. + val canDirectAccess: Boolean + + def isPrimitive: Boolean = CodeGenerator.isPrimitiveType(javaType) +} + +object ExprValue { + implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString +} + +// A literal evaluation of [[ExprCode]]. +class LiteralValue(val value: String, val javaType: String) extends ExprValue { + override def toString: String = value + override val canDirectAccess: Boolean = true +} + +object LiteralValue { + def apply(value: String, javaType: String): LiteralValue = new LiteralValue(value, javaType) + def unapply(literal: LiteralValue): Option[(String, String)] = + Some((literal.value, literal.javaType)) +} + +// A variable evaluation of [[ExprCode]]. +case class VariableValue( + val variableName: String, + val javaType: String) extends ExprValue { + override def toString: String = variableName + override val canDirectAccess: Boolean = false +} + +// A statement evaluation of [[ExprCode]]. +case class StatementValue( + val statement: String, + val javaType: String, + val canDirectAccess: Boolean = false) extends ExprValue { + override def toString: String = statement +} + +// A global variable evaluation of [[ExprCode]]. +case class GlobalValue(val value: String, val javaType: String) extends ExprValue { + override def toString: String = value + override val canDirectAccess: Boolean = true +} + +case object TrueLiteral extends LiteralValue("true", "boolean") +case object FalseLiteral extends LiteralValue("false", "boolean") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index d35fd8ecb4d63..3ae0b54c754cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -59,7 +59,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination) // 4-tuples: (code for projection, isNull variable name, value variable name, column index) - val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map { + val projectionCodes: Seq[(String, ExprValue, String, Int)] = exprVals.zip(index).map { case (ev, i) => val e = expressions(i) val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value") @@ -69,7 +69,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP |${ev.code} |$isNull = ${ev.isNull}; |$value = ${ev.value}; - """.stripMargin, isNull, value, i) + """.stripMargin, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), value, i) } else { (s""" |${ev.code} @@ -83,7 +83,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val updates = validExpr.zip(projectionCodes).map { case (e, (_, isNull, value, i)) => - val ev = ExprCode("", isNull, value) + val ev = ExprCode("", isNull, GlobalValue(value, CodeGenerator.javaType(e.dataType))) CodeGenerator.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index f92f70ee71fef..a30a0b22cd305 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -53,7 +53,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => - val converter = convertToSafe(ctx, CodeGenerator.getValue(tmpInput, dt, i.toString), dt) + val converter = convertToSafe(ctx, + StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString), + CodeGenerator.javaType(dt)), dt) s""" if (!$tmpInput.isNullAt($i)) { ${converter.code} @@ -74,7 +76,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] |final InternalRow $output = new $rowClass($values); """.stripMargin - ExprCode(code, "false", output) + ExprCode(code, FalseLiteral, VariableValue(output, "InternalRow")) } private def createCodeForArray( @@ -89,8 +91,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val index = ctx.freshName("index") val arrayClass = classOf[GenericArrayData].getName - val elementConverter = convertToSafe( - ctx, CodeGenerator.getValue(tmpInput, elementType, index), elementType) + val elementConverter = convertToSafe(ctx, + StatementValue(CodeGenerator.getValue(tmpInput, elementType, index), + CodeGenerator.javaType(elementType)), elementType) val code = s""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); @@ -104,7 +107,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final ArrayData $output = new $arrayClass($values); """ - ExprCode(code, "false", output) + ExprCode(code, FalseLiteral, VariableValue(output, "ArrayData")) } private def createCodeForMap( @@ -125,19 +128,19 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value}); """ - ExprCode(code, "false", output) + ExprCode(code, FalseLiteral, VariableValue(output, "MapData")) } @tailrec private def convertToSafe( ctx: CodegenContext, - input: String, + input: ExprValue, dataType: DataType): ExprCode = dataType match { case s: StructType => createCodeForStruct(ctx, input, s) case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) - case _ => ExprCode("", "false", input) + case _ => ExprCode("", FalseLiteral, input) } protected def create(expressions: Seq[Expression]): Projection = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index ab2254cd9f70a..4a4d76313a543 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -52,7 +52,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode("", s"$tmpInput.isNullAt($i)", CodeGenerator.getValue(tmpInput, dt, i.toString)) + ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)", CodeGenerator.JAVA_BOOLEAN), + StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString), + CodeGenerator.javaType(dt))) } val rowWriterClass = classOf[UnsafeRowWriter].getName @@ -334,7 +336,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $evalSubexpr $writeExpressions """ - ExprCode(code, "false", s"$rowWriter.getRow()") + // `rowWriter` is declared as a class field, so we can access it directly in methods. + ExprCode(code, FalseLiteral, StatementValue(s"$rowWriter.getRow()", "UnsafeRow", + canDirectAccess = true)) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index beb84694c44e8..91188da8b0bd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -20,7 +20,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -55,7 +55,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : - (${childGen.value}).numElements();""", isNull = "false") + (${childGen.value}).numElements();""", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 85facdad43db7..49a8d12057188 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -64,8 +64,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression { GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( code = preprocess + assigns + postprocess, - value = arrayData, - isNull = "false") + value = VariableValue(arrayData, CodeGenerator.javaType(dataType)), + isNull = FalseLiteral) } override def prettyName: String = "array" @@ -378,7 +378,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc |$valuesCode |final InternalRow ${ev.value} = new $rowClass($values); |$values = null; - """.stripMargin, isNull = "false") + """.stripMargin, isNull = FalseLiteral) } override def prettyName: String = "named_struct" @@ -394,7 +394,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) - ExprCode(code = eval.code, isNull = "false", value = eval.value) + ExprCode(code = eval.code, isNull = FalseLiteral, value = eval.value) } override def prettyName: String = "named_struct_unsafe" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index f4e9619bac59d..409c0b6b79b81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -191,7 +191,8 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - ev.value = ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value) + ev.value = GlobalValue(ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value), + CodeGenerator.javaType(dataType)) // these blocks are meant to be inside a // do { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 1ae4e5a2f716b..49dd988b4b53c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -813,7 +813,8 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ val df = classOf[DateFormat].getName if (format.foldable) { if (formatter == null) { - ExprCode("", "true", "(UTF8String) null") + ExprCode("", TrueLiteral, LiteralValue("(UTF8String) null", + CodeGenerator.javaType(dataType))) } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val t = left.genCode(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 4f4d49166e88c..3af4bfebad45e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ @@ -218,7 +218,7 @@ case class Stack(children: Seq[Expression]) extends Generator { s""" |$code |$wrapperClass ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData); - """.stripMargin, isNull = "false") + """.stripMargin, isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index b76b64ab5096f..df29c38d64d3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -270,7 +270,7 @@ abstract class HashExpression[E] extends Expression { protected def computeHash(value: Any, dataType: DataType, seed: E): E override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = "false" + ev.isNull = FalseLiteral val childrenHash = children.map { child => val childGen = child.genCode(ctx) @@ -644,7 +644,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = "false" + ev.isNull = FalseLiteral val childHash = ctx.freshName("childHash") val childrenHash = children.map { child => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 07785e7448586..2a3cc580273ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -43,7 +43,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getInputFilePath();", isNull = "false") + s"$className.getInputFilePath();", isNull = FalseLiteral) } } @@ -66,7 +66,7 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getStartOffset();", isNull = "false") + s"$className.getStartOffset();", isNull = FalseLiteral) } } @@ -89,6 +89,6 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getLength();", isNull = "false") + s"$className.getLength();", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 7395609a04ba5..742a650eb445d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -283,36 +283,36 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { } else { dataType match { case BooleanType | IntegerType | DateType => - ExprCode.forNonNullValue(value.toString) + ExprCode.forNonNullValue(LiteralValue(value.toString, javaType)) case FloatType => value.asInstanceOf[Float] match { case v if v.isNaN => - ExprCode.forNonNullValue("Float.NaN") + ExprCode.forNonNullValue(LiteralValue("Float.NaN", javaType)) case Float.PositiveInfinity => - ExprCode.forNonNullValue("Float.POSITIVE_INFINITY") + ExprCode.forNonNullValue(LiteralValue("Float.POSITIVE_INFINITY", javaType)) case Float.NegativeInfinity => - ExprCode.forNonNullValue("Float.NEGATIVE_INFINITY") + ExprCode.forNonNullValue(LiteralValue("Float.NEGATIVE_INFINITY", javaType)) case _ => - ExprCode.forNonNullValue(s"${value}F") + ExprCode.forNonNullValue(LiteralValue(s"${value}F", javaType)) } case DoubleType => value.asInstanceOf[Double] match { case v if v.isNaN => - ExprCode.forNonNullValue("Double.NaN") + ExprCode.forNonNullValue(LiteralValue("Double.NaN", javaType)) case Double.PositiveInfinity => - ExprCode.forNonNullValue("Double.POSITIVE_INFINITY") + ExprCode.forNonNullValue(LiteralValue("Double.POSITIVE_INFINITY", javaType)) case Double.NegativeInfinity => - ExprCode.forNonNullValue("Double.NEGATIVE_INFINITY") + ExprCode.forNonNullValue(LiteralValue("Double.NEGATIVE_INFINITY", javaType)) case _ => - ExprCode.forNonNullValue(s"${value}D") + ExprCode.forNonNullValue(LiteralValue(s"${value}D", javaType)) } case ByteType | ShortType => - ExprCode.forNonNullValue(s"($javaType)$value") + ExprCode.forNonNullValue(LiteralValue(s"($javaType)$value", javaType)) case TimestampType | LongType => - ExprCode.forNonNullValue(s"${value}L") + ExprCode.forNonNullValue(LiteralValue(s"${value}L", javaType)) case _ => val constRef = ctx.addReferenceObj("literal", value, javaType) - ExprCode.forNonNullValue(constRef) + ExprCode.forNonNullValue(GlobalValue(constRef, javaType)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index a390f8ef7fd9a..7081a5e096d56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -91,7 +91,8 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa ExprCode(code = s"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); - |}""".stripMargin, isNull = "true", value = "null") + |}""".stripMargin, isNull = TrueLiteral, + value = LiteralValue("null", CodeGenerator.javaType(dataType))) } override def sql: String = s"assert_true(${child.sql})" @@ -150,7 +151,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta "new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" + s"${randomSeed.get}L + partitionIndex);") ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", - isNull = "false") + isNull = FalseLiteral) } override def freshCopy(): Uuid = Uuid(randomSeed) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index b35fa72e95d1e..55b6e346be82a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -72,7 +72,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) + ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), + CodeGenerator.JAVA_BOOLEAN) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -235,7 +236,7 @@ case class IsNaN(child: Expression) extends UnaryExpression ev.copy(code = s""" ${eval.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = "false") + ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral) } } } @@ -320,7 +321,12 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - ExprCode(code = eval.code, isNull = "false", value = eval.isNull) + val value = if (eval.isNull.isInstanceOf[LiteralValue]) { + LiteralValue(eval.isNull, CodeGenerator.JAVA_BOOLEAN) + } else { + VariableValue(eval.isNull, CodeGenerator.JAVA_BOOLEAN) + } + ExprCode(code = eval.code, isNull = FalseLiteral, value = value) } override def sql: String = s"(${child.sql} IS NULL)" @@ -346,7 +352,14 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - ExprCode(code = eval.code, isNull = "false", value = s"(!(${eval.isNull}))") + val value = if (eval.isNull == TrueLiteral) { + FalseLiteral + } else if (eval.isNull == FalseLiteral) { + TrueLiteral + } else { + StatementValue(s"(!(${eval.isNull}))", CodeGenerator.javaType(dataType)) + } + ExprCode(code = eval.code, isNull = FalseLiteral, value = value) } override def sql: String = s"(${child.sql} IS NOT NULL)" @@ -441,6 +454,6 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate | $codes |} while (false); |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; - """.stripMargin, isNull = "false") + """.stripMargin, isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9252425f86473..b2cca3178cd2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -61,13 +61,13 @@ trait InvokeLike extends Expression with NonSQLExpression { * @param ctx a [[CodegenContext]] * @return (code to prepare arguments, argument string, result of argument null check) */ - def prepareArguments(ctx: CodegenContext): (String, String, String) = { + def prepareArguments(ctx: CodegenContext): (String, String, ExprValue) = { val resultIsNull = if (needNullCheck) { val resultIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "resultIsNull") - resultIsNull + GlobalValue(resultIsNull, CodeGenerator.JAVA_BOOLEAN) } else { - "false" + FalseLiteral } val argValues = arguments.map { e => val argValue = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "argValue") @@ -244,7 +244,7 @@ case class StaticInvoke( val prepareIsNull = if (nullable) { s"boolean ${ev.isNull} = $resultIsNull;" } else { - ev.isNull = "false" + ev.isNull = FalseLiteral "" } @@ -546,7 +546,7 @@ case class WrapOption(child: Expression, optType: DataType) ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); """ - ev.copy(code = code, isNull = "false") + ev.copy(code = code, isNull = FalseLiteral) } } @@ -568,7 +568,13 @@ case class LambdaVariable( } override def genCode(ctx: CodegenContext): ExprCode = { - ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false") + val isNullValue = if (nullable) { + VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN) + } else { + FalseLiteral + } + ExprCode(code = "", value = VariableValue(value, CodeGenerator.javaType(dataType)), + isNull = isNullValue) } // This won't be called as `genCode` is overrided, just overriding it to make @@ -840,7 +846,7 @@ case class MapObjects private( // Make a copy of the data if it's unsafe-backed def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value" - val genFunctionValue = lambdaFunction.dataType match { + val genFunctionValue: String = lambdaFunction.dataType match { case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) @@ -1343,7 +1349,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) |$childrenCode |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); """.stripMargin - ev.copy(code = code, isNull = "false") + ev.copy(code = code, isNull = FalseLiteral) } } @@ -1538,7 +1544,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) throw new NullPointerException($errMsgField); } """ - ev.copy(code = code, isNull = "false", value = childGen.value) + ev.copy(code = code, isNull = FalseLiteral, value = childGen.value) } } @@ -1589,7 +1595,7 @@ case class GetExternalRowField( final Object ${ev.value} = ${row.value}.get($index); """ - ev.copy(code = code, isNull = "false") + ev.copy(code = code, isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 4b85d9adbe311..e195ec17f3bcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -405,7 +405,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with if (${eval1.value}) { ${eval2.code} ${ev.value} = ${eval2.value}; - }""", isNull = "false") + }""", isNull = FalseLiteral) } else { ev.copy(code = s""" ${eval1.code} @@ -461,7 +461,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P // The result should be `true`, if any of them is `true` whenever the other is null or not. if (!left.nullable && !right.nullable) { - ev.isNull = "false" + ev.isNull = FalseLiteral ev.copy(code = s""" ${eval1.code} boolean ${ev.value} = true; @@ -469,7 +469,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P if (!${eval1.value}) { ${eval2.code} ${ev.value} = ${eval2.value}; - }""", isNull = "false") + }""", isNull = FalseLiteral) } else { ev.copy(code = s""" ${eval1.code} @@ -615,7 +615,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) ev.copy(code = eval1.code + eval2.code + s""" boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) || - (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = "false") + (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index f36633867316e..70186053617f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -83,7 +83,7 @@ case class Rand(child: Expression) extends RDG { s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", - isNull = "false") + isNull = FalseLiteral) } override def freshCopy(): Rand = Rand(child) @@ -120,7 +120,7 @@ case class Randn(child: Expression) extends RDG { s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", - isNull = "false") + isNull = FalseLiteral) } override def freshCopy(): Randn = Randn(child) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 398b6767654fa..8e83b35c3809c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -448,6 +448,8 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val ref = BoundReference(0, IntegerType, true) val add1 = Add(ref, ref) val add2 = Add(add1, add1) + val dummy = SubExprEliminationState(VariableValue("dummy", "boolean"), + VariableValue("dummy", "boolean")) // raw testing of basic functionality { @@ -457,7 +459,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { ctx.subExprEliminationExprs += ref -> SubExprEliminationState(e.isNull, e.value) assert(ctx.subExprEliminationExprs.contains(ref)) // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(add1 -> SubExprEliminationState("dummy", "dummy"))) { + ctx.withSubExprEliminationExprs(Map(add1 -> dummy)) { assert(ctx.subExprEliminationExprs.contains(add1)) assert(!ctx.subExprEliminationExprs.contains(ref)) Seq.empty @@ -475,7 +477,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE assert(ctx.subExprEliminationExprs.contains(add1)) // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(ref -> SubExprEliminationState("dummy", "dummy"))) { + ctx.withSubExprEliminationExprs(Map(ref -> dummy)) { assert(ctx.subExprEliminationExprs.contains(ref)) assert(!ctx.subExprEliminationExprs.contains(add1)) Seq.empty diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala new file mode 100644 index 0000000000000..c8f4cff7db48d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite + +class ExprValueSuite extends SparkFunSuite { + + test("TrueLiteral and FalseLiteral should be LiteralValue") { + val trueLit = TrueLiteral + val falseLit = FalseLiteral + + assert(trueLit.value == "true") + assert(falseLit.value == "false") + + assert(trueLit.isPrimitive) + assert(falseLit.isPrimitive) + + trueLit match { + case LiteralValue(value, javaType) => + assert(value == "true" && javaType == "boolean") + case _ => fail() + } + + falseLit match { + case LiteralValue(value, javaType) => + assert(value == "false" && javaType == "boolean") + case _ => fail() + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 392906a022903..434214a10e1e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, VariableValue} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -51,7 +51,11 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { nullable: Boolean): ExprCode = { val javaType = CodeGenerator.javaType(dataType) val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal) - val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" } + val isNullVar = if (nullable) { + VariableValue(ctx.freshName("isNull"), CodeGenerator.JAVA_BOOLEAN) + } else { + FalseLiteral + } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" val code = s"${ctx.registerComment(str)}\n" + (if (nullable) { @@ -62,7 +66,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } else { s"$javaType $valueVar = $value;" }).trim - ExprCode(code, isNullVar, valueVar) + ExprCode(code, isNullVar, VariableValue(valueVar, javaType)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 12ae1ea4a7c13..0d9a62cace62a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, VariableValue} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -157,7 +157,8 @@ case class ExpandExec( |${CodeGenerator.javaType(firstExpr.dataType)} $value = | ${CodeGenerator.defaultValue(firstExpr.dataType)}; """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, CodeGenerator.javaType(firstExpr.dataType))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 384f0398a1ec0..85c5ebfdaa689 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} @@ -170,9 +170,10 @@ case class GenerateExec( // Add position val position = if (e.position) { if (outer) { - Seq(ExprCode("", s"$index == -1", index)) + Seq(ExprCode("", StatementValue(s"$index == -1", CodeGenerator.JAVA_BOOLEAN), + VariableValue(index, CodeGenerator.JAVA_INT))) } else { - Seq(ExprCode("", "false", index)) + Seq(ExprCode("", FalseLiteral, VariableValue(index, CodeGenerator.JAVA_INT))) } } else { Seq.empty @@ -315,9 +316,11 @@ case class GenerateExec( |boolean $isNull = ${checks.mkString(" || ")}; |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, javaType)) } else { - ExprCode(s"$javaType $value = $getter;", "false", value) + ExprCode(s"$javaType $value = $getter;", FalseLiteral, + VariableValue(value, javaType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 6ddaacfee1a40..805ff3cf001ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -111,7 +111,7 @@ trait CodegenSupport extends SparkPlan { private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = { if (row != null) { - ExprCode("", "false", row) + ExprCode("", FalseLiteral, VariableValue(row, "UnsafeRow")) } else { if (colVars.nonEmpty) { val colExprs = output.zipWithIndex.map { case (attr, i) => @@ -126,10 +126,10 @@ trait CodegenSupport extends SparkPlan { |$evaluateInputs |${ev.code.trim} """.stripMargin.trim - ExprCode(code, "false", ev.value) + ExprCode(code, FalseLiteral, ev.value) } else { // There is no columns - ExprCode("", "false", "unsafeRow") + ExprCode("", FalseLiteral, VariableValue("unsafeRow", "UnsafeRow")) } } } @@ -241,15 +241,16 @@ trait CodegenSupport extends SparkPlan { parameters += s"$paramType $paramName" val paramIsNull = if (!attributes(i).nullable) { // Use constant `false` without passing `isNull` for non-nullable variable. - "false" + FalseLiteral } else { val isNull = ctx.freshName(s"exprIsNull_$i") arguments += ev.isNull parameters += s"boolean $isNull" - isNull + VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN) } - paramVars += ExprCode("", paramIsNull, paramName) + paramVars += ExprCode("", paramIsNull, + VariableValue(paramName, CodeGenerator.javaType(attributes(i).dataType))) } (arguments, parameters, paramVars) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 1926e9373bc55..8f7f10243d4cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -194,7 +194,8 @@ case class HashAggregateExec( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, isNull, value) + ExprCode(ev.code + initVars, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), + GlobalValue(value, CodeGenerator.javaType(e.dataType))) } val initBufVar = evaluateVariables(bufVars) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 6b60b414ffe5f..4978954271311 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, GlobalValue} import org.apache.spark.sql.types._ /** @@ -54,7 +54,8 @@ abstract class HashMapGenerator( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, isNull, value) + ExprCode(ev.code + initVars, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), + GlobalValue(value, CodeGenerator.javaType(e.dataType))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 4707022f74547..cab7081400ce9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -24,7 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, ExpressionCanonicalizer} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType @@ -192,7 +192,7 @@ case class FilterExec(condition: Expression, child: SparkPlan) // generate better code (remove dead branches). val resultVars = input.zipWithIndex.map { case (ev, i) => if (notNullAttributes.contains(child.output(i).exprId)) { - ev.isNull = "false" + ev.isNull = FalseLiteral } ev } @@ -368,7 +368,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number") val value = ctx.freshName("value") - val ev = ExprCode("", "false", value) + val ev = ExprCode("", FalseLiteral, VariableValue(value, CodeGenerator.JAVA_LONG)) val BigInt = classOf[java.math.BigInteger].getName // Inline mutable state since not many Range operations in a task diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 487d6a2383318..fa62a32d51f3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -192,7 +192,8 @@ case class BroadcastHashJoinExec( | $value = ${ev.value}; |} """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, CodeGenerator.javaType(a.dataType))) } } } @@ -487,7 +488,8 @@ case class BroadcastHashJoinExec( s"$existsVar = true;" } - val resultVar = input ++ Seq(ExprCode("", "false", existsVar)) + val resultVar = input ++ Seq(ExprCode("", FalseLiteral, + VariableValue(existsVar, CodeGenerator.JAVA_BOOLEAN))) if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 5a511b30e4fd9..b61acb8d4fda9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, VariableValue} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, @@ -531,11 +531,13 @@ case class SortMergeJoinExec( |boolean $isNull = false; |$javaType $value = $defaultValue; """.stripMargin - (ExprCode(code, isNull, value), leftVarsDecl) + (ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, CodeGenerator.javaType(a.dataType))), leftVarsDecl) } else { val code = s"$value = $valueCode;" val leftVarsDecl = s"""$javaType $value = $defaultValue;""" - (ExprCode(code, "false", value), leftVarsDecl) + (ExprCode(code, FalseLiteral, + VariableValue(value, CodeGenerator.javaType(a.dataType))), leftVarsDecl) } }.unzip } From 252468a744b95082400ba9e8b2e3b3d9d50ab7fa Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 9 Apr 2018 12:18:07 -0700 Subject: [PATCH 588/774] [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes ## What changes were proposed in this pull request? API: ``` trait ClassificationNode extends Node def getLabelCount(label: Int): Double trait RegressionNode extends Node def getCount(): Double def getSum(): Double def getSquareSum(): Double // turn LeafNode to be trait trait LeafNode extends Node { def prediction: Double def impurity: Double ... } class ClassificationLeafNode extends ClassificationNode with LeafNode class RegressionLeafNode extends RegressionNode with LeafNode // turn InternalNode to be trait trait InternalNode extends Node{ def gain: Double def leftChild: Node def rightChild: Node def split: Split ... } class ClassificationInternalNode extends ClassificationNode with InternalNode override def leftChild: ClassificationNode override def rightChild: ClassificationNode class RegressionInternalNode extends RegressionNode with InternalNode override val leftChild: RegressionNode override val rightChild: RegressionNode class DecisionTreeClassificationModel override val rootNode: ClassificationNode class DecisionTreeRegressionModel override val rootNode: RegressionNode ``` Closes #17466 ## How was this patch tested? UT will be added soon. Author: WeichenXu Author: jkbradley Closes #20786 from WeichenXu123/tree_stat_api_2. --- .../DecisionTreeClassifier.scala | 14 +- .../ml/classification/GBTClassifier.scala | 6 +- .../RandomForestClassifier.scala | 6 +- .../ml/regression/DecisionTreeRegressor.scala | 13 +- .../spark/ml/regression/GBTRegressor.scala | 6 +- .../ml/regression/RandomForestRegressor.scala | 6 +- .../scala/org/apache/spark/ml/tree/Node.scala | 247 ++++++++++++++---- .../spark/ml/tree/impl/RandomForest.scala | 10 +- .../org/apache/spark/ml/tree/treeModels.scala | 36 ++- .../DecisionTreeClassifierSuite.scala | 31 ++- .../classification/GBTClassifierSuite.scala | 4 +- .../RandomForestClassifierSuite.scala | 5 +- .../DecisionTreeRegressorSuite.scala | 14 + .../ml/tree/impl/RandomForestSuite.scala | 22 +- .../apache/spark/ml/tree/impl/TreeTests.scala | 12 +- project/MimaExcludes.scala | 9 +- 16 files changed, 333 insertions(+), 108 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 65cce697d8202..771cd4fe91dcf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -165,7 +165,7 @@ object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifi @Since("1.4.0") class DecisionTreeClassificationModel private[ml] ( @Since("1.4.0")override val uid: String, - @Since("1.4.0")override val rootNode: Node, + @Since("1.4.0")override val rootNode: ClassificationNode, @Since("1.6.0")override val numFeatures: Int, @Since("1.5.0")override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] @@ -178,7 +178,7 @@ class DecisionTreeClassificationModel private[ml] ( * Construct a decision tree classification model. * @param rootNode Root node of tree, with other nodes attached. */ - private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = + private[ml] def this(rootNode: ClassificationNode, numFeatures: Int, numClasses: Int) = this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses) override def predict(features: Vector): Double = { @@ -276,8 +276,9 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] - val root = loadTreeNodes(path, metadata, sparkSession) - val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses) + val root = loadTreeNodes(path, metadata, sparkSession, isClassification = true) + val model = new DecisionTreeClassificationModel(metadata.uid, + root.asInstanceOf[ClassificationNode], numFeatures, numClasses) DefaultParamsReader.getAndSetParams(model, metadata) model } @@ -292,9 +293,10 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica require(oldModel.algo == OldAlgo.Classification, s"Cannot convert non-classification DecisionTreeModel (old API) to" + s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") - val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = true) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") // Can't infer number of features from old model, so default to -1 - new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1) + new DecisionTreeClassificationModel(uid, + rootNode.asInstanceOf[ClassificationNode], numFeatures, -1) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index cd44489f618b2..c0255103bc313 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -371,14 +371,14 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { override def load(path: String): GBTClassificationModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int] val numTrees = (metadata.metadata \ numTreesKey).extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + val tree = new DecisionTreeRegressionModel(treeMetadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) DefaultParamsReader.getAndSetParams(tree, treeMetadata) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 78a4972adbdbb..bb972e9706fc1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -310,15 +310,15 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica override def load(path: String): RandomForestClassificationModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, true) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeClassificationModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses) + val tree = new DecisionTreeClassificationModel(treeMetadata.uid, + root.asInstanceOf[ClassificationNode], numFeatures, numClasses) DefaultParamsReader.getAndSetParams(tree, treeMetadata) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index ad154fcd010cc..5cef5c9f21f1e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -160,7 +160,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor @Since("1.4.0") class DecisionTreeRegressionModel private[ml] ( override val uid: String, - override val rootNode: Node, + override val rootNode: RegressionNode, override val numFeatures: Int) extends PredictionModel[Vector, DecisionTreeRegressionModel] with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with Serializable { @@ -175,7 +175,7 @@ class DecisionTreeRegressionModel private[ml] ( * Construct a decision tree regression model. * @param rootNode Root node of tree, with other nodes attached. */ - private[ml] def this(rootNode: Node, numFeatures: Int) = + private[ml] def this(rootNode: RegressionNode, numFeatures: Int) = this(Identifiable.randomUID("dtr"), rootNode, numFeatures) override def predict(features: Vector): Double = { @@ -279,8 +279,9 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] - val root = loadTreeNodes(path, metadata, sparkSession) - val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures) + val root = loadTreeNodes(path, metadata, sparkSession, isClassification = false) + val model = new DecisionTreeRegressionModel(metadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) DefaultParamsReader.getAndSetParams(model, metadata) model } @@ -295,8 +296,8 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode require(oldModel.algo == OldAlgo.Regression, s"Cannot convert non-regression DecisionTreeModel (old API) to" + s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}") - val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = false) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr") - new DecisionTreeRegressionModel(uid, rootNode, numFeatures) + new DecisionTreeRegressionModel(uid, rootNode.asInstanceOf[RegressionNode], numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 6569ff2a5bfc1..834aaa0e362d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -302,15 +302,15 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { override def load(path: String): GBTRegressionModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + val tree = new DecisionTreeRegressionModel(treeMetadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) DefaultParamsReader.getAndSetParams(tree, treeMetadata) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 2d594460c2475..7f77398ba2a22 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -269,13 +269,13 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode override def load(path: String): RandomForestRegressionModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + val tree = new DecisionTreeRegressionModel(treeMetadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) DefaultParamsReader.getAndSetParams(tree, treeMetadata) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index d30be452a436e..0242bc76698d0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -17,14 +17,16 @@ package org.apache.spark.ml.tree +import org.apache.spark.annotation.Since import org.apache.spark.ml.linalg.Vector import org.apache.spark.mllib.tree.impurity.ImpurityCalculator -import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} +import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, + Node => OldNode, Predict => OldPredict} /** * Decision tree node interface. */ -sealed abstract class Node extends Serializable { +sealed trait Node extends Serializable { // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree // code into the new API and deprecate the old API. SPARK-3727 @@ -84,35 +86,86 @@ private[ml] object Node { /** * Create a new Node from the old Node format, recursively creating child nodes as needed. */ - def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = { + def fromOld( + oldNode: OldNode, + categoricalFeatures: Map[Int, Int], + isClassification: Boolean): Node = { if (oldNode.isLeaf) { // TODO: Once the implementation has been moved to this API, then include sufficient // statistics here. - new LeafNode(prediction = oldNode.predict.predict, - impurity = oldNode.impurity, impurityStats = null) + if (isClassification) { + new ClassificationLeafNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, impurityStats = null) + } else { + new RegressionLeafNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, impurityStats = null) + } } else { val gain = if (oldNode.stats.nonEmpty) { oldNode.stats.get.gain } else { 0.0 } - new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity, - gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures), - rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures), - split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) + if (isClassification) { + new ClassificationInternalNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, gain = gain, + leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, true) + .asInstanceOf[ClassificationNode], + rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, true) + .asInstanceOf[ClassificationNode], + split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) + } else { + new RegressionInternalNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, gain = gain, + leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, false) + .asInstanceOf[RegressionNode], + rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, false) + .asInstanceOf[RegressionNode], + split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) + } } } } -/** - * Decision tree leaf node. - * @param prediction Prediction this node makes - * @param impurity Impurity measure at this node (for training data) - */ -class LeafNode private[ml] ( - override val prediction: Double, - override val impurity: Double, - override private[ml] val impurityStats: ImpurityCalculator) extends Node { +@Since("2.4.0") +sealed trait ClassificationNode extends Node { + + /** + * Get count of training examples for specified label in this node + * @param label label number in the range [0, numClasses) + */ + @Since("2.4.0") + def getLabelCount(label: Int): Double = { + require(label >= 0 && label < impurityStats.stats.length, + "label should be in the range between 0 (inclusive) " + + s"and ${impurityStats.stats.length} (exclusive).") + impurityStats.stats(label) + } +} + +@Since("2.4.0") +sealed trait RegressionNode extends Node { + + /** Number of training data points in this node */ + @Since("2.4.0") + def getCount: Double = impurityStats.stats(0) + + /** Sum over training data points of the labels in this node */ + @Since("2.4.0") + def getSum: Double = impurityStats.stats(1) + + /** Sum over training data points of the square of the labels in this node */ + @Since("2.4.0") + def getSumOfSquares: Double = impurityStats.stats(2) +} + +@Since("2.4.0") +sealed trait LeafNode extends Node { + + /** Prediction this node makes. */ + def prediction: Double + + def impurity: Double override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)" @@ -135,32 +188,58 @@ class LeafNode private[ml] ( override private[ml] def maxSplitFeatureIndex(): Int = -1 +} + +/** + * Decision tree leaf node for classification. + */ +@Since("2.4.0") +class ClassificationLeafNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + override private[ml] val impurityStats: ImpurityCalculator) + extends ClassificationNode with LeafNode { + override private[tree] def deepCopy(): Node = { - new LeafNode(prediction, impurity, impurityStats) + new ClassificationLeafNode(prediction, impurity, impurityStats) } } /** - * Internal Decision Tree node. - * @param prediction Prediction this node would make if it were a leaf node - * @param impurity Impurity measure at this node (for training data) - * @param gain Information gain value. Values less than 0 indicate missing values; - * this quirk will be removed with future updates. - * @param leftChild Left-hand child node - * @param rightChild Right-hand child node - * @param split Information about the test used to split to the left or right child. + * Decision tree leaf node for regression. */ -class InternalNode private[ml] ( +@Since("2.4.0") +class RegressionLeafNode private[ml] ( override val prediction: Double, override val impurity: Double, - val gain: Double, - val leftChild: Node, - val rightChild: Node, - val split: Split, - override private[ml] val impurityStats: ImpurityCalculator) extends Node { + override private[ml] val impurityStats: ImpurityCalculator) + extends RegressionNode with LeafNode { - // Note to developers: The constructor argument impurityStats should be reconsidered before we - // make the constructor public. We may be able to improve the representation. + override private[tree] def deepCopy(): Node = { + new RegressionLeafNode(prediction, impurity, impurityStats) + } +} + +/** + * Internal Decision Tree node. + */ +@Since("2.4.0") +sealed trait InternalNode extends Node { + + /** + * Information gain value. Values less than 0 indicate missing values; + * this quirk will be removed with future updates. + */ + def gain: Double + + /** Left-hand child node */ + def leftChild: Node + + /** Right-hand child node */ + def rightChild: Node + + /** Information about the test used to split to the left or right child. */ + def split: Split override def toString: String = { s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" @@ -205,11 +284,6 @@ class InternalNode private[ml] ( math.max(split.featureIndex, math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex())) } - - override private[tree] def deepCopy(): Node = { - new InternalNode(prediction, impurity, gain, leftChild.deepCopy(), rightChild.deepCopy(), - split, impurityStats) - } } private object InternalNode { @@ -240,6 +314,57 @@ private object InternalNode { } } +/** + * Internal Decision Tree node for regression. + */ +@Since("2.4.0") +class ClassificationInternalNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + override val gain: Double, + override val leftChild: ClassificationNode, + override val rightChild: ClassificationNode, + override val split: Split, + override private[ml] val impurityStats: ImpurityCalculator) + extends ClassificationNode with InternalNode { + + // Note to developers: The constructor argument impurityStats should be reconsidered before we + // make the constructor public. We may be able to improve the representation. + + override private[tree] def deepCopy(): Node = { + new ClassificationInternalNode(prediction, impurity, gain, + leftChild.deepCopy().asInstanceOf[ClassificationNode], + rightChild.deepCopy().asInstanceOf[ClassificationNode], + split, impurityStats) + } +} + +/** + * Internal Decision Tree node for regression. + */ +@Since("2.4.0") +class RegressionInternalNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + override val gain: Double, + override val leftChild: RegressionNode, + override val rightChild: RegressionNode, + override val split: Split, + override private[ml] val impurityStats: ImpurityCalculator) + extends RegressionNode with InternalNode { + + // Note to developers: The constructor argument impurityStats should be reconsidered before we + // make the constructor public. We may be able to improve the representation. + + override private[tree] def deepCopy(): Node = { + new RegressionInternalNode(prediction, impurity, gain, + leftChild.deepCopy().asInstanceOf[RegressionNode], + rightChild.deepCopy().asInstanceOf[RegressionNode], + split, impurityStats) + } +} + + /** * Version of a node used in learning. This uses vars so that we can modify nodes as we split the * tree by adding children, etc. @@ -265,30 +390,52 @@ private[tree] class LearningNode( var isLeaf: Boolean, var stats: ImpurityStats) extends Serializable { - def toNode: Node = toNode(prune = true) + def toNode(isClassification: Boolean): Node = toNode(isClassification, prune = true) + + def toClassificationNode(prune: Boolean = true): ClassificationNode = { + toNode(true, prune).asInstanceOf[ClassificationNode] + } + + def toRegressionNode(prune: Boolean = true): RegressionNode = { + toNode(false, prune).asInstanceOf[RegressionNode] + } /** * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. */ - def toNode(prune: Boolean = true): Node = { + def toNode(isClassification: Boolean, prune: Boolean): Node = { if (!leftChild.isEmpty || !rightChild.isEmpty) { assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null, "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") - (leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match { + (leftChild.get.toNode(isClassification, prune), + rightChild.get.toNode(isClassification, prune)) match { case (l: LeafNode, r: LeafNode) if prune && l.prediction == r.prediction => - new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator) + if (isClassification) { + new ClassificationLeafNode(l.prediction, stats.impurity, stats.impurityCalculator) + } else { + new RegressionLeafNode(l.prediction, stats.impurity, stats.impurityCalculator) + } case (l, r) => - new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, - l, r, split.get, stats.impurityCalculator) + if (isClassification) { + new ClassificationInternalNode(stats.impurityCalculator.predict, stats.impurity, + stats.gain, l.asInstanceOf[ClassificationNode], r.asInstanceOf[ClassificationNode], + split.get, stats.impurityCalculator) + } else { + new RegressionInternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, + l.asInstanceOf[RegressionNode], r.asInstanceOf[RegressionNode], + split.get, stats.impurityCalculator) + } } } else { - if (stats.valid) { - new LeafNode(stats.impurityCalculator.predict, stats.impurity, + // Here we want to keep same behavior with the old mllib.DecisionTreeModel + val impurity = if (stats.valid) stats.impurity else -1.0 + if (isClassification) { + new ClassificationLeafNode(stats.impurityCalculator.predict, impurity, stats.impurityCalculator) } else { - // Here we want to keep same behavior with the old mllib.DecisionTreeModel - new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) + new RegressionLeafNode(stats.impurityCalculator.predict, impurity, + stats.impurityCalculator) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 16f32d76b9984..056a94b351f79 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -224,23 +224,23 @@ private[spark] object RandomForest extends Logging { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures, - strategy.getNumClasses) + new DecisionTreeClassificationModel(uid, rootNode.toClassificationNode(prune), + numFeatures, strategy.getNumClasses) } } else { topNodes.map { rootNode => - new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures) + new DecisionTreeRegressionModel(uid, rootNode.toRegressionNode(prune), numFeatures) } } case None => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures, + new DecisionTreeClassificationModel(rootNode.toClassificationNode(prune), numFeatures, strategy.getNumClasses) } } else { topNodes.map(rootNode => - new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures)) + new DecisionTreeRegressionModel(rootNode.toRegressionNode(prune), numFeatures)) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 4aa4c3617e7fd..f027b14f1d476 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -219,8 +219,10 @@ private[ml] object TreeEnsembleModel { importances.changeValue(feature, scaledGain, _ + scaledGain) computeFeatureImportance(n.leftChild, importances) computeFeatureImportance(n.rightChild, importances) - case n: LeafNode => + case _: LeafNode => // do nothing + case _ => + throw new IllegalArgumentException(s"Unknown node type: ${node.getClass.toString}") } } @@ -317,6 +319,8 @@ private[ml] object DecisionTreeModelReadWrite { (Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats, -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))), id) + case _ => + throw new IllegalArgumentException(s"Unknown node type: ${node.getClass.toString}") } } @@ -327,7 +331,7 @@ private[ml] object DecisionTreeModelReadWrite { def loadTreeNodes( path: String, metadata: DefaultParamsReader.Metadata, - sparkSession: SparkSession): Node = { + sparkSession: SparkSession, isClassification: Boolean): Node = { import sparkSession.implicits._ implicit val format = DefaultFormats @@ -339,7 +343,7 @@ private[ml] object DecisionTreeModelReadWrite { val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath).as[NodeData] - buildTreeFromNodes(data.collect(), impurityType) + buildTreeFromNodes(data.collect(), impurityType, isClassification) } /** @@ -348,7 +352,8 @@ private[ml] object DecisionTreeModelReadWrite { * @param impurityType Impurity type for this tree * @return Root node of reconstructed tree */ - def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = { + def buildTreeFromNodes(data: Array[NodeData], impurityType: String, + isClassification: Boolean): Node = { // Load all nodes, sorted by ID. val nodes = data.sortBy(_.id) // Sanity checks; could remove @@ -364,10 +369,21 @@ private[ml] object DecisionTreeModelReadWrite { val node = if (n.leftChild != -1) { val leftChild = finalNodes(n.leftChild) val rightChild = finalNodes(n.rightChild) - new InternalNode(n.prediction, n.impurity, n.gain, leftChild, rightChild, - n.split.getSplit, impurityStats) + if (isClassification) { + new ClassificationInternalNode(n.prediction, n.impurity, n.gain, + leftChild.asInstanceOf[ClassificationNode], rightChild.asInstanceOf[ClassificationNode], + n.split.getSplit, impurityStats) + } else { + new RegressionInternalNode(n.prediction, n.impurity, n.gain, + leftChild.asInstanceOf[RegressionNode], rightChild.asInstanceOf[RegressionNode], + n.split.getSplit, impurityStats) + } } else { - new LeafNode(n.prediction, n.impurity, impurityStats) + if (isClassification) { + new ClassificationLeafNode(n.prediction, n.impurity, impurityStats) + } else { + new RegressionLeafNode(n.prediction, n.impurity, impurityStats) + } } finalNodes(n.id) = node } @@ -421,7 +437,8 @@ private[ml] object EnsembleModelReadWrite { path: String, sql: SparkSession, className: String, - treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = { + treeClassName: String, + isClassification: Boolean): (Metadata, Array[(Metadata, Node)], Array[Double]) = { import sql.implicits._ implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className) @@ -449,7 +466,8 @@ private[ml] object EnsembleModelReadWrite { val rootNodesRDD: RDD[(Int, Node)] = nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map { case (treeID: Int, nodeData: Iterable[NodeData]) => - treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType) + treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes( + nodeData.toArray, impurityType, isClassification) } val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect() (metadata, treesMetadata.zip(rootNodes), treesWeights) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 2930f4900d50e..d3dbb4e754d3d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.ClassificationLeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -61,7 +61,8 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new DecisionTreeClassifier) - val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2) + val model = new DecisionTreeClassificationModel("dtc", + new ClassificationLeafNode(0.0, 0.0, null), 1, 2) ParamsSuite.checkParams(model) } @@ -375,6 +376,32 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { testDefaultReadWrite(model) } + + test("label/impurity stats") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) + val rdd = sc.parallelize(arr) + val df = TreeTests.setMetadata(rdd, Map.empty[Int, Int], 2) + val dt1 = new DecisionTreeClassifier() + .setImpurity("entropy") + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val model1 = dt1.fit(df) + + val rootNode1 = model1.rootNode + assert(Array(rootNode1.getLabelCount(0), rootNode1.getLabelCount(1)) === Array(2.0, 1.0)) + + val dt2 = new DecisionTreeClassifier() + .setImpurity("gini") + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val model2 = dt2.fit(df) + + val rootNode2 = model2.rootNode + assert(Array(rootNode2.getLabelCount(0), rootNode2.getLabelCount(1)) === Array(2.0, 1.0)) + } } private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 57796069f6052..f0ee5496f9d1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.RegressionLeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ @@ -69,7 +69,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new GBTClassifier) val model = new GBTClassificationModel("gbtc", - Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null), 1)), + Array(new DecisionTreeRegressionModel("dtr", new RegressionLeafNode(0.0, 0.0, null), 1)), Array(1.0), 1, 2) ParamsSuite.checkParams(model) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index ba4a9cf082785..3062aa9f3d274 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.ClassificationLeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -71,7 +71,8 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2)), 2, 2) + Array(new DecisionTreeClassificationModel("dtc", + new ClassificationLeafNode(0.0, 0.0, null), 1, 2)), 2, 2) ParamsSuite.checkParams(model) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 29a438396516b..9ae27339b11d5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -191,6 +191,20 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { TreeTests.allParamSettings ++ Map("maxDepth" -> 0), TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData) } + + test("label/impurity stats") { + val categoricalFeatures = Map(0 -> 2, 1 -> 2) + val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) + val dtr = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(8) + val model = dtr.fit(df) + val statInfo = model.rootNode + + assert(statInfo.getCount == 1000.0 && statInfo.getSum == 600.0 + && statInfo.getSumOfSquares == 600.0) + } } private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 743dacf146fe7..4dbbd75d2466d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -340,8 +340,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.stats.impurity > 0.0) // set impurity and predict for child nodes - assert(topNode.leftChild.get.toNode.prediction === 0.0) - assert(topNode.rightChild.get.toNode.prediction === 1.0) + assert(topNode.leftChild.get.toNode(isClassification = true).prediction === 0.0) + assert(topNode.rightChild.get.toNode(isClassification = true).prediction === 1.0) assert(topNode.leftChild.get.stats.impurity === 0.0) assert(topNode.rightChild.get.stats.impurity === 0.0) } @@ -382,8 +382,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.stats.impurity > 0.0) // set impurity and predict for child nodes - assert(topNode.leftChild.get.toNode.prediction === 0.0) - assert(topNode.rightChild.get.toNode.prediction === 1.0) + assert(topNode.leftChild.get.toNode(isClassification = true).prediction === 0.0) + assert(topNode.rightChild.get.toNode(isClassification = true).prediction === 1.0) assert(topNode.leftChild.get.stats.impurity === 0.0) assert(topNode.rightChild.get.stats.impurity === 0.0) } @@ -582,18 +582,18 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { left right */ val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0)) - val left = new LeafNode(0.0, leftImp.calculate(), leftImp) + val left = new ClassificationLeafNode(0.0, leftImp.calculate(), leftImp) val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0)) - val right = new LeafNode(2.0, rightImp.calculate(), rightImp) + val right = new ClassificationLeafNode(2.0, rightImp.calculate(), rightImp) - val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5)) + val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5), true) val parentImp = parent.impurityStats val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0)) - val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp) + val left2 = new ClassificationLeafNode(0.0, left2Imp.calculate(), left2Imp) - val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0)) + val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0), true) val grandImp = grandParent.impurityStats // Test feature importance computed at different subtrees. @@ -618,8 +618,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // Forest consisting of (full tree) + (internal node with 2 leafs) val trees = Array(parent, grandParent).map { root => - new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3) - .asInstanceOf[DecisionTreeModel] + new DecisionTreeClassificationModel(root.asInstanceOf[ClassificationNode], + numFeatures = 2, numClasses = 3).asInstanceOf[DecisionTreeModel] } val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2) val tree2norm = feature0importance + feature1importance diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index b6894b30b0c2b..3f03d909d4a4c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -159,7 +159,7 @@ private[ml] object TreeTests extends SparkFunSuite { * @param split Split for parent node * @return Parent node with children attached */ - def buildParentNode(left: Node, right: Node, split: Split): Node = { + def buildParentNode(left: Node, right: Node, split: Split, isClassification: Boolean): Node = { val leftImp = left.impurityStats val rightImp = right.impurityStats val parentImp = leftImp.copy.add(rightImp) @@ -168,7 +168,15 @@ private[ml] object TreeTests extends SparkFunSuite { val gain = parentImp.calculate() - (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate()) val pred = parentImp.predict - new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp) + if (isClassification) { + new ClassificationInternalNode(pred, parentImp.calculate(), gain, + left.asInstanceOf[ClassificationNode], right.asInstanceOf[ClassificationNode], + split, parentImp) + } else { + new RegressionInternalNode(pred, parentImp.calculate(), gain, + left.asInstanceOf[RegressionNode], right.asInstanceOf[RegressionNode], + split, parentImp) + } } /** diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 1b6d1dec69d49..b37b4d51775e8 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -55,7 +55,14 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numRddBlocksById"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.memUsedByRdd"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.cacheSize"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel"), + + // [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.LeafNode"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.Node"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this") ) // Exclude rules for 2.3.x From 61b724724cc4a18818774ecaaa5a45b70fdb8dae Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 9 Apr 2018 14:07:33 -0700 Subject: [PATCH 589/774] [INFRA] Close stale PRs. Closes #20957 Closes #20792 From f94f3624ea81053653a06560808cb71f510c6828 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Mon, 9 Apr 2018 21:07:28 -0700 Subject: [PATCH 590/774] [SPARK-23947][SQL] Add hashUTF8String convenience method to hasher classes ## What changes were proposed in this pull request? Add `hashUTF8String()` to the hasher classes to allow Spark SQL codegen to generate cleaner code for hashing `UTF8String`s. No change in behavior otherwise. Although with the introduction of SPARK-10399, the code size for hashing `UTF8String` is already smaller, it's still good to extract a separate function in the hasher classes so that the generated code can stay clean. ## How was this patch tested? Existing tests. Author: Kris Mok Closes #21016 from rednaxelafx/hashutf8. --- .../apache/spark/sql/catalyst/expressions/HiveHasher.java | 5 +++++ .../java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java | 7 ++++++- .../org/apache/spark/sql/catalyst/expressions/XXH64.java | 5 +++++ .../org/apache/spark/sql/catalyst/expressions/hash.scala | 6 ++---- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java index c34e36903a93e..62b75ae8aa01d 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; /** * Simulates Hive's hashing function from Hive v1.2.1 @@ -51,4 +52,8 @@ public static int hashUnsafeBytesBlock(MemoryBlock mb) { public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) { return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes)); } + + public static int hashUTF8String(UTF8String str) { + return hashUnsafeBytesBlock(str.getMemoryBlock()); + } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index f372b19fac119..aff6e93d647fe 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -20,6 +20,7 @@ import com.google.common.primitives.Ints; import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; /** * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. @@ -82,6 +83,10 @@ public static int hashUnsafeBytesBlock(MemoryBlock base, int seed) { return fmix(h1, lengthInBytes); } + public static int hashUTF8String(UTF8String str, int seed) { + return hashUnsafeBytesBlock(str.getMemoryBlock(), seed); + } + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); } @@ -91,7 +96,7 @@ public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, } public static int hashUnsafeBytes2Block(MemoryBlock base, int seed) { - // This is compatible with original and another implementations. + // This is compatible with original and other implementations. // Use this method for new components after Spark 2.3. int lengthInBytes = Ints.checkedCast(base.size()); assert (lengthInBytes >= 0) : "lengthInBytes cannot be negative"; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java index fe727f6011cbf..8e9c0a2e9dc81 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; // scalastyle: off /** @@ -107,6 +108,10 @@ public static long hashUnsafeBytesBlock(MemoryBlock mb, long seed) { return fmix(hash); } + public static long hashUTF8String(UTF8String str, long seed) { + return hashUnsafeBytesBlock(str.getMemoryBlock(), seed); + } + public static long hashUnsafeBytes(Object base, long offset, int length, long seed) { return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, length), seed); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index df29c38d64d3d..ef790338bdd27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -361,8 +361,7 @@ abstract class HashExpression[E] extends Expression { } protected def genHashString(input: String, result: String): String = { - val mb = s"$input.getMemoryBlock()" - s"$result = $hasherClassName.hashUnsafeBytesBlock($mb, $result);" + s"$result = $hasherClassName.hashUTF8String($input, $result);" } protected def genHashForMap( @@ -725,8 +724,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """ override protected def genHashString(input: String, result: String): String = { - val mb = s"$input.getMemoryBlock()" - s"$result = $hasherClassName.hashUnsafeBytesBlock($mb);" + s"$result = $hasherClassName.hashUTF8String($input);" } override protected def genHashForArray( From 64988841540464e261b0cbaede43058e7bd36261 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 9 Apr 2018 21:49:49 -0700 Subject: [PATCH 591/774] [SPARK-23898][SQL] Simplify add & subtract code generation ## What changes were proposed in this pull request? Code generation for the `Add` and `Subtract` expressions was not done using the `BinaryArithmetic.doCodeGen` method because these expressions also support `CalendarInterval`. This leads to a bit of duplication. This PR gets rid of that duplication by adding `calendarIntervalMethod` to `BinaryArithmetic` and doing the code generation for `CalendarInterval` in `BinaryArithmetic` instead. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #21005 from hvanhovell/SPARK-23898. --- .../sql/catalyst/expressions/arithmetic.scala | 50 ++++++++----------- 1 file changed, 20 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 478ff3a7c1011..defd6f3cd8849 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -43,7 +43,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression private lazy val numeric = TypeUtils.getNumeric(dataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") + case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => { val originValue = ctx.freshName("origin") // codegen would fail to compile if we just write (-($c)) @@ -52,7 +52,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression ${CodeGenerator.javaType(dt)} $originValue = (${CodeGenerator.javaType(dt)})($eval); ${ev.value} = (${CodeGenerator.javaType(dt)})(-($originValue)); """}) - case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") + case _: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") } protected override def nullSafeEval(input: Any): Any = { @@ -104,7 +104,7 @@ case class Abs(child: Expression) private lazy val numeric = TypeUtils.getNumeric(dataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => + case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.abs()") case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dt)})(java.lang.Math.abs($c))") @@ -117,15 +117,21 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { override def dataType: DataType = left.dataType - override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess + override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess /** Name of the function for this expression on a [[Decimal]] type. */ def decimalMethod: String = sys.error("BinaryArithmetics must override either decimalMethod or genCode") + /** Name of the function for this expression on a [[CalendarInterval]] type. */ + def calendarIntervalMethod: String = + sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode") + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => + case _: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") + case CalendarIntervalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$calendarIntervalMethod($eval2)") // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => defineCodeGen(ctx, ev, @@ -152,6 +158,10 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" + override def decimalMethod: String = "$plus" + + override def calendarIntervalMethod: String = "add" + private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -161,18 +171,6 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { numeric.plus(input1, input2) } } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)") - case ByteType | ShortType => - defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") - case CalendarIntervalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)") - case _ => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") - } } @ExpressionDescription( @@ -188,6 +186,10 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti override def symbol: String = "-" + override def decimalMethod: String = "$minus" + + override def calendarIntervalMethod: String = "subtract" + private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -197,18 +199,6 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti numeric.minus(input1, input2) } } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)") - case ByteType | ShortType => - defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") - case CalendarIntervalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)") - case _ => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") - } } @ExpressionDescription( @@ -416,7 +406,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "pmod" - protected def checkTypesInternal(t: DataType) = + protected def checkTypesInternal(t: DataType): TypeCheckResult = TypeUtils.checkForNumericExpr(t, "pmod") override def inputType: AbstractDataType = NumericType From 95034af69623bb8be5b9f5eabf50980bdeca48e6 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Tue, 10 Apr 2018 08:51:35 -0500 Subject: [PATCH 592/774] [SPARK-23841][ML] NodeIdCache should unpersist the last cached nodeIdsForInstances ## What changes were proposed in this pull request? unpersist the last cached nodeIdsForInstances in `deleteAllCheckpoints` ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #20956 from zhengruifeng/NodeIdCache_cleanup. --- .../scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala index a7c5f489dea86..5b14a63ada4ef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -95,7 +95,7 @@ private[spark] class NodeIdCache( splits: Array[Array[Split]]): Unit = { if (prevNodeIdsForInstances != null) { // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() + prevNodeIdsForInstances.unpersist(false) } prevNodeIdsForInstances = nodeIdsForInstances @@ -166,9 +166,13 @@ private[spark] class NodeIdCache( } } } + if (nodeIdsForInstances != null) { + // Unpersist current one if one exists. + nodeIdsForInstances.unpersist(false) + } if (prevNodeIdsForInstances != null) { // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() + prevNodeIdsForInstances.unpersist(false) } } } From 3323b156f9c0beb0b3c2b724a6faddc6ffdfe99a Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 10 Apr 2018 17:32:00 +0200 Subject: [PATCH 593/774] [SPARK-23864][SQL] Add unsafe object writing to UnsafeWriter ## What changes were proposed in this pull request? This PR moves writing of `UnsafeRow`, `UnsafeArrayData` & `UnsafeMapData` out of the `GenerateUnsafeProjection`/`InterpretedUnsafeProjection` classes into the `UnsafeWriter` interface. This cleans up the code a little bit, and it should also result in less byte code for the code generated path. ## How was this patch tested? Existing tests Author: Herman van Hovell Closes #20986 from hvanhovell/SPARK-23864. --- .../expressions/codegen/UnsafeWriter.java | 72 ++-- .../InterpretedUnsafeProjection.scala | 46 +-- .../codegen/GenerateUnsafeProjection.scala | 322 ++++++++---------- .../spark/sql/types/UserDefinedType.scala | 10 + 4 files changed, 204 insertions(+), 246 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index de0eb6dbb76be..2781655002000 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -16,6 +16,9 @@ */ package org.apache.spark.sql.catalyst.expressions.codegen; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.expressions.UnsafeMapData; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; @@ -103,21 +106,7 @@ protected final void zeroOutPaddingBytes(int numBytes) { public abstract void write(int ordinal, Decimal input, int precision, int scale); public final void write(int ordinal, UTF8String input) { - final int numBytes = input.numBytes(); - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - input.writeToMemory(getBuffer(), cursor()); - - setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. - increaseCursor(roundedSize); + writeUnalignedBytes(ordinal, input.getBaseObject(), input.getBaseOffset(), input.numBytes()); } public final void write(int ordinal, byte[] input) { @@ -125,20 +114,19 @@ public final void write(int ordinal, byte[] input) { } public final void write(int ordinal, byte[] input, int offset, int numBytes) { - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); + writeUnalignedBytes(ordinal, input, Platform.BYTE_ARRAY_OFFSET + offset, numBytes); + } - // grow the global buffer before writing data. + private void writeUnalignedBytes( + int ordinal, + Object baseObject, + long baseOffset, + int numBytes) { + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); grow(roundedSize); - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - Platform.copyMemory( - input, Platform.BYTE_ARRAY_OFFSET + offset, getBuffer(), cursor(), numBytes); - + Platform.copyMemory(baseObject, baseOffset, getBuffer(), cursor(), numBytes); setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. increaseCursor(roundedSize); } @@ -156,6 +144,40 @@ public final void write(int ordinal, CalendarInterval input) { increaseCursor(16); } + public final void write(int ordinal, UnsafeRow row) { + writeAlignedBytes(ordinal, row.getBaseObject(), row.getBaseOffset(), row.getSizeInBytes()); + } + + public final void write(int ordinal, UnsafeMapData map) { + writeAlignedBytes(ordinal, map.getBaseObject(), map.getBaseOffset(), map.getSizeInBytes()); + } + + public final void write(UnsafeArrayData array) { + // Unsafe arrays both can be written as a regular array field or as part of a map. This makes + // updating the offset and size dependent on the code path, this is why we currently do not + // provide an method for writing unsafe arrays that also updates the size and offset. + int numBytes = array.getSizeInBytes(); + grow(numBytes); + Platform.copyMemory( + array.getBaseObject(), + array.getBaseOffset(), + getBuffer(), + cursor(), + numBytes); + increaseCursor(numBytes); + } + + private void writeAlignedBytes( + int ordinal, + Object baseObject, + long baseOffset, + int numBytes) { + grow(numBytes); + Platform.copyMemory(baseObject, baseOffset, getBuffer(), cursor(), numBytes); + setOffsetAndSize(ordinal, numBytes); + increaseCursor(numBytes); + } + protected final void writeBoolean(long offset, boolean value) { Platform.putBoolean(getBuffer(), offset, value); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index b31466f5c92d1..6d69d69b1c802 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -173,21 +173,17 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { val rowWriter = new UnsafeRowWriter(writer, numFields) val structWriter = generateStructWriter(rowWriter, fields) (v, i) => { - val previousCursor = writer.cursor() v.getStruct(i, fields.length) match { case row: UnsafeRow => - writeUnsafeData( - rowWriter, - row.getBaseObject, - row.getBaseOffset, - row.getSizeInBytes) + writer.write(i, row) case row => + val previousCursor = writer.cursor() // Nested struct. We don't know where this will start because a row can be // variable length, so we need to update the offsets and zero out the bit mask. rowWriter.resetRowWriter() structWriter.apply(row) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } - writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case ArrayType(elementType, containsNull) => @@ -214,15 +210,12 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { valueType, valueContainsNull) (v, i) => { - val previousCursor = writer.cursor() v.getMap(i) match { case map: UnsafeMapData => - writeUnsafeData( - valueArrayWriter, - map.getBaseObject, - map.getBaseOffset, - map.getSizeInBytes) + writer.write(i, map) case map => + val previousCursor = writer.cursor() + // preserve 8 bytes to write the key array numBytes later. valueArrayWriter.grow(8) valueArrayWriter.increaseCursor(8) @@ -237,8 +230,8 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { // Write the values. writeArray(valueArrayWriter, valueWriter, map.valueArray()) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } - writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case udt: UserDefinedType[_] => @@ -318,11 +311,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { elementWriter: (SpecializedGetters, Int) => Unit, array: ArrayData): Unit = array match { case unsafe: UnsafeArrayData => - writeUnsafeData( - arrayWriter, - unsafe.getBaseObject, - unsafe.getBaseOffset, - unsafe.getSizeInBytes) + arrayWriter.write(unsafe) case _ => val numElements = array.numElements() arrayWriter.initialize(numElements) @@ -332,23 +321,4 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { i += 1 } } - - /** - * Write an opaque block of data to the buffer. This is used to copy - * [[UnsafeRow]], [[UnsafeArrayData]] and [[UnsafeMapData]] objects. - */ - private def writeUnsafeData( - writer: UnsafeWriter, - baseObject: AnyRef, - baseOffset: Long, - sizeInBytes: Int) : Unit = { - writer.grow(sizeInBytes) - Platform.copyMemory( - baseObject, - baseOffset, - writer.getBuffer, - writer.cursor, - sizeInBytes) - writer.increaseCursor(sizeInBytes) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 4a4d76313a543..2fb441ac4500e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -32,14 +32,13 @@ import org.apache.spark.sql.types._ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] { /** Returns true iff we support this data type. */ - def canSupport(dataType: DataType): Boolean = dataType match { + def canSupport(dataType: DataType): Boolean = UserDefinedType.sqlType(dataType) match { case NullType => true - case t: AtomicType => true + case _: AtomicType => true case _: CalendarIntervalType => true case t: StructType => t.forall(field => canSupport(field.dataType)) case t: ArrayType if canSupport(t.elementType) => true case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true - case udt: UserDefinedType[_] => canSupport(udt.sqlType) case _ => false } @@ -47,6 +46,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private def writeStructToBuffer( ctx: CodegenContext, input: String, + index: String, fieldTypes: Seq[DataType], rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. @@ -60,15 +60,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val rowWriterClass = classOf[UnsafeRowWriter].getName val structRowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});") - + val previousCursor = ctx.freshName("previousCursor") s""" - final InternalRow $tmpInput = $input; - if ($tmpInput instanceof UnsafeRow) { - ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", structRowWriter)} - } else { - ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)} - } - """ + |final InternalRow $tmpInput = $input; + |if ($tmpInput instanceof UnsafeRow) { + | $rowWriter.write($index, (UnsafeRow) $tmpInput); + |} else { + | // Remember the current cursor so that we can calculate how many bytes are + | // written later. + | final int $previousCursor = $rowWriter.cursor(); + | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)} + | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); + |} + """.stripMargin } private def writeExpressionsToBuffer( @@ -95,10 +99,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val writeFields = inputs.zip(inputTypes).zipWithIndex.map { case ((input, dataType), index) => - val dt = dataType match { - case udt: UserDefinedType[_] => udt.sqlType - case other => other - } + val dt = UserDefinedType.sqlType(dataType) val setNull = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => @@ -106,58 +107,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});" case _ => s"$rowWriter.setNullAt($index);" } - val previousCursor = ctx.freshName("previousCursor") - - val writeField = dt match { - case t: StructType => - s""" - // Remember the current cursor so that we can calculate how many bytes are - // written later. - final int $previousCursor = $rowWriter.cursor(); - ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), rowWriter)} - $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case a @ ArrayType(et, _) => - s""" - // Remember the current cursor so that we can calculate how many bytes are - // written later. - final int $previousCursor = $rowWriter.cursor(); - ${writeArrayToBuffer(ctx, input.value, et, rowWriter)} - $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case m @ MapType(kt, vt, _) => - s""" - // Remember the current cursor so that we can calculate how many bytes are - // written later. - final int $previousCursor = $rowWriter.cursor(); - ${writeMapToBuffer(ctx, input.value, kt, vt, rowWriter)} - $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case t: DecimalType => - s"$rowWriter.write($index, ${input.value}, ${t.precision}, ${t.scale});" - - case NullType => "" - - case _ => s"$rowWriter.write($index, ${input.value});" - } + val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter) if (input.isNull == "false") { s""" - ${input.code} - ${writeField.trim} - """ + |${input.code} + |${writeField.trim} + """.stripMargin } else { s""" - ${input.code} - if (${input.isNull}) { - ${setNull.trim} - } else { - ${writeField.trim} - } - """ + |${input.code} + |if (${input.isNull}) { + | ${setNull.trim} + |} else { + | ${writeField.trim} + |} + """.stripMargin } } @@ -171,11 +136,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro funcName = "writeFields", arguments = Seq("InternalRow" -> row)) } - s""" - $resetWriter - $writeFieldsCode - """.trim + |$resetWriter + |$writeFieldsCode + """.stripMargin } // TODO: if the nullability of array element is correct, we can use it to save null check. @@ -189,10 +153,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") - val et = elementType match { - case udt: UserDefinedType[_] => udt.sqlType - case other => other - } + val et = UserDefinedType.sqlType(elementType) val jt = CodeGenerator.javaType(et) @@ -205,106 +166,100 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", v => s"$v = new $arrayWriterClass($rowWriter, $elementOrOffsetSize);") - val previousCursor = ctx.freshName("previousCursor") val element = CodeGenerator.getValue(tmpInput, et, index) - val writeElement = et match { - case t: StructType => - s""" - final int $previousCursor = $arrayWriter.cursor(); - ${writeStructToBuffer(ctx, element, t.map(_.dataType), arrayWriter)} - $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case a @ ArrayType(et, _) => - s""" - final int $previousCursor = $arrayWriter.cursor(); - ${writeArrayToBuffer(ctx, element, et, arrayWriter)} - $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case m @ MapType(kt, vt, _) => - s""" - final int $previousCursor = $arrayWriter.cursor(); - ${writeMapToBuffer(ctx, element, kt, vt, arrayWriter)} - $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case t: DecimalType => - s"$arrayWriter.write($index, $element, ${t.precision}, ${t.scale});" - - case NullType => "" - - case _ => s"$arrayWriter.write($index, $element);" - } - val primitiveTypeName = - if (CodeGenerator.isPrimitiveType(jt)) CodeGenerator.primitiveTypeName(et) else "" s""" - final ArrayData $tmpInput = $input; - if ($tmpInput instanceof UnsafeArrayData) { - ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", arrayWriter)} - } else { - final int $numElements = $tmpInput.numElements(); - $arrayWriter.initialize($numElements); - - for (int $index = 0; $index < $numElements; $index++) { - if ($tmpInput.isNullAt($index)) { - $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); - } else { - $writeElement - } - } - } - """ + |final ArrayData $tmpInput = $input; + |if ($tmpInput instanceof UnsafeArrayData) { + | $rowWriter.write((UnsafeArrayData) $tmpInput); + |} else { + | final int $numElements = $tmpInput.numElements(); + | $arrayWriter.initialize($numElements); + | + | for (int $index = 0; $index < $numElements; $index++) { + | if ($tmpInput.isNullAt($index)) { + | $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); + | } else { + | ${writeElement(ctx, element, index, et, arrayWriter)} + | } + | } + |} + """.stripMargin } // TODO: if the nullability of value element is correct, we can use it to save null check. private def writeMapToBuffer( ctx: CodegenContext, input: String, + index: String, keyType: DataType, valueType: DataType, rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val tmpCursor = ctx.freshName("tmpCursor") + val previousCursor = ctx.freshName("previousCursor") // Writes out unsafe map according to the format described in `UnsafeMapData`. s""" - final MapData $tmpInput = $input; - if ($tmpInput instanceof UnsafeMapData) { - ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", rowWriter)} - } else { - // preserve 8 bytes to write the key array numBytes later. - $rowWriter.grow(8); - $rowWriter.increaseCursor(8); + |final MapData $tmpInput = $input; + |if ($tmpInput instanceof UnsafeMapData) { + | $rowWriter.write($index, (UnsafeMapData) $tmpInput); + |} else { + | // Remember the current cursor so that we can calculate how many bytes are + | // written later. + | final int $previousCursor = $rowWriter.cursor(); + | + | // preserve 8 bytes to write the key array numBytes later. + | $rowWriter.grow(8); + | $rowWriter.increaseCursor(8); + | + | // Remember the current cursor so that we can write numBytes of key array later. + | final int $tmpCursor = $rowWriter.cursor(); + | + | ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} + | + | // Write the numBytes of key array into the first 8 bytes. + | Platform.putLong( + | $rowWriter.getBuffer(), + | $tmpCursor - 8, + | $rowWriter.cursor() - $tmpCursor); + | + | ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} + | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); + |} + """.stripMargin + } - // Remember the current cursor so that we can write numBytes of key array later. - final int $tmpCursor = $rowWriter.cursor(); + private def writeElement( + ctx: CodegenContext, + input: String, + index: String, + dt: DataType, + writer: String): String = dt match { + case t: StructType => + writeStructToBuffer(ctx, input, index, t.map(_.dataType), writer) + + case ArrayType(et, _) => + val previousCursor = ctx.freshName("previousCursor") + s""" + |// Remember the current cursor so that we can calculate how many bytes are + |// written later. + |final int $previousCursor = $writer.cursor(); + |${writeArrayToBuffer(ctx, input, et, writer)} + |$writer.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); + """.stripMargin - ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} - // Write the numBytes of key array into the first 8 bytes. - Platform.putLong($rowWriter.getBuffer(), $tmpCursor - 8, $rowWriter.cursor() - $tmpCursor); + case MapType(kt, vt, _) => + writeMapToBuffer(ctx, input, index, kt, vt, writer) - ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} - } - """ - } + case DecimalType.Fixed(precision, scale) => + s"$writer.write($index, $input, $precision, $scale);" - /** - * If the input is already in unsafe format, we don't need to go through all elements/fields, - * we can directly write it. - */ - private def writeUnsafeData(ctx: CodegenContext, input: String, rowWriter: String) = { - val sizeInBytes = ctx.freshName("sizeInBytes") - s""" - final int $sizeInBytes = $input.getSizeInBytes(); - // grow the global buffer before writing data. - $rowWriter.grow($sizeInBytes); - $input.writeToMemory($rowWriter.getBuffer(), $rowWriter.cursor()); - $rowWriter.increaseCursor($sizeInBytes); - """ + case NullType => "" + + case _ => s"$writer.write($index, $input);" } def createCode( @@ -332,10 +287,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val code = s""" - $rowWriter.reset(); - $evalSubexpr - $writeExpressions - """ + |$rowWriter.reset(); + |$evalSubexpr + |$writeExpressions + """.stripMargin // `rowWriter` is declared as a class field, so we can access it directly in methods. ExprCode(code, FalseLiteral, StatementValue(s"$rowWriter.getRow()", "UnsafeRow", canDirectAccess = true)) @@ -363,38 +318,39 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val ctx = newCodeGenContext() val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) - val codeBody = s""" - public java.lang.Object generate(Object[] references) { - return new SpecificUnsafeProjection(references); - } - - class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} { - - private Object[] references; - ${ctx.declareMutableStates()} - - public SpecificUnsafeProjection(Object[] references) { - this.references = references; - ${ctx.initMutableStates()} - } - - public void initialize(int partitionIndex) { - ${ctx.initPartition()} - } - - // Scala.Function1 need this - public java.lang.Object apply(java.lang.Object row) { - return apply((InternalRow) row); - } - - public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { - ${eval.code.trim} - return ${eval.value}; - } - - ${ctx.declareAddedFunctions()} - } - """ + val codeBody = + s""" + |public java.lang.Object generate(Object[] references) { + | return new SpecificUnsafeProjection(references); + |} + | + |class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} { + | + | private Object[] references; + | ${ctx.declareMutableStates()} + | + | public SpecificUnsafeProjection(Object[] references) { + | this.references = references; + | ${ctx.initMutableStates()} + | } + | + | public void initialize(int partitionIndex) { + | ${ctx.initPartition()} + | } + | + | // Scala.Function1 need this + | public java.lang.Object apply(java.lang.Object row) { + | return apply((InternalRow) row); + | } + | + | public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { + | ${eval.code.trim} + | return ${eval.value}; + | } + | + | ${ctx.declareAddedFunctions()} + |} + """.stripMargin val code = CodeFormatter.stripOverlappingComments( new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 5a944e763e099..6af16e2dba105 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -97,6 +97,16 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa override def catalogString: String = sqlType.simpleString } +private[spark] object UserDefinedType { + /** + * Get the sqlType of a (potential) [[UserDefinedType]]. + */ + def sqlType(dt: DataType): DataType = dt match { + case udt: UserDefinedType[_] => udt.sqlType + case _ => dt + } +} + /** * The user defined type in Python. * From e179658914963de472120a81621396706584c949 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 10 Apr 2018 09:33:09 -0700 Subject: [PATCH 594/774] [SPARK-19724][SQL][FOLLOW-UP] Check location of managed table when ignoreIfExists is true ## What changes were proposed in this pull request? In the PR #20886, I mistakenly check the table location only when `ignoreIfExists` is false, which was following the original deprecated PR. That was wrong. When `ignoreIfExists` is true and the target table doesn't exist, we should also check the table location. In other word, **`ignoreIfExists` has nothing to do with table location validation**. This is a follow-up PR to fix the mistake. ## How was this patch tested? Add one unit test. Author: Gengliang Wang Closes #21001 from gengliangwang/SPARK-19724-followup. --- .../spark/sql/catalyst/catalog/SessionCatalog.scala | 11 +++++++++-- .../execution/command/createDataSourceTables.scala | 2 +- .../apache/spark/sql/execution/command/DDLSuite.scala | 9 +++++++++ .../spark/sql/hive/execution/HiveDDLSuite.scala | 2 +- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 52ed89ef8d781..c390337c03ff5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -286,7 +286,10 @@ class SessionCatalog( * Create a metastore table in the database specified in `tableDefinition`. * If no such database is specified, create it in the current database. */ - def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { + def createTable( + tableDefinition: CatalogTable, + ignoreIfExists: Boolean, + validateLocation: Boolean = true): Unit = { val db = formatDatabaseName(tableDefinition.identifier.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableDefinition.identifier.table) val tableIdentifier = TableIdentifier(table, Some(db)) @@ -305,7 +308,11 @@ class SessionCatalog( } requireDbExists(db) - if (!ignoreIfExists) { + if (tableExists(newTableDefinition.identifier)) { + if (!ignoreIfExists) { + throw new TableAlreadyExistsException(db = db, table = table) + } + } else if (validateLocation) { validateTableLocation(newTableDefinition) } externalCatalog.createTable(newTableDefinition, ignoreIfExists) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index f7c3e9b019258..f6ef433f2ce15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -182,7 +182,7 @@ case class CreateDataSourceTableAsSelectCommand( // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). schema = result.schema) // Table location is already validated. No need to check it again during table creation. - sessionState.catalog.createTable(newTable, ignoreIfExists = true) + sessionState.catalog.createTable(newTable, ignoreIfExists = false, validateLocation = false) result match { case fs: HadoopFsRelation if table.partitionColumnNames.nonEmpty && diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 4304d0b6f6b16..cbd7f9d6f67be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -425,6 +425,15 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql(s"CREATE TABLE tab1 (col1 int, col2 string) USING ${dataSource}") }.getMessage assert(ex.contains(exMsgWithDefaultDB)) + + // Always check location of managed table, with or without (IF NOT EXISTS) + withTable("tab2") { + sql(s"CREATE TABLE tab2 (col1 int, col2 string) USING ${dataSource}") + ex = intercept[AnalysisException] { + sql(s"CREATE TABLE IF NOT EXISTS tab1 LIKE tab2") + }.getMessage + assert(ex.contains(exMsgWithDefaultDB)) + } } finally { waitForTasksToFinish() Utils.deleteRecursively(tableLoc) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index db76ec9d084cb..c85db78c732de 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1461,7 +1461,7 @@ class HiveDDLSuite assert(e2.getMessage.contains(forbiddenPrefix + "foo")) val e3 = intercept[AnalysisException] { - sql(s"CREATE TABLE tbl (a INT) TBLPROPERTIES ('${forbiddenPrefix}foo'='anything')") + sql(s"CREATE TABLE tbl2 (a INT) TBLPROPERTIES ('${forbiddenPrefix}foo'='anything')") } assert(e3.getMessage.contains(forbiddenPrefix + "foo")) } From adb222b957f327a69929b8f16fa5ebc071fa99e3 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 10 Apr 2018 11:18:14 -0700 Subject: [PATCH 595/774] [SPARK-23751][ML][PYSPARK] Kolmogorov-Smirnoff test Python API in pyspark.ml ## What changes were proposed in this pull request? Kolmogorov-Smirnoff test Python API in `pyspark.ml` **Note** API with `CDF` is a little difficult to support in python. We can add it in following PR. ## How was this patch tested? doctest Author: WeichenXu Closes #20904 from WeichenXu123/ks-test-py. --- .../spark/ml/stat/KolmogorovSmirnovTest.scala | 29 +-- python/pyspark/ml/stat.py | 181 ++++++++++++------ 2 files changed, 138 insertions(+), 72 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala index c62d7463288f7..af8ff64d33ffe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala @@ -24,7 +24,7 @@ import org.apache.spark.api.java.function.Function import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.stat.{Statistics => OldStatistics} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.col /** @@ -59,7 +59,7 @@ object KolmogorovSmirnovTest { * distribution of the sample data and the theoretical distribution we can provide a test for the * the null hypothesis that the sample data comes from that theoretical distribution. * - * @param dataset a `DataFrame` containing the sample of data to test + * @param dataset A `Dataset` or a `DataFrame` containing the sample of data to test * @param sampleCol Name of sample column in dataset, of any numerical type * @param cdf a `Double => Double` function to calculate the theoretical CDF at a given value * @return DataFrame containing the test result for the input sampled data. @@ -68,10 +68,10 @@ object KolmogorovSmirnovTest { * - `statistic: Double` */ @Since("2.4.0") - def test(dataset: DataFrame, sampleCol: String, cdf: Double => Double): DataFrame = { + def test(dataset: Dataset[_], sampleCol: String, cdf: Double => Double): DataFrame = { val spark = dataset.sparkSession - val rdd = getSampleRDD(dataset, sampleCol) + val rdd = getSampleRDD(dataset.toDF(), sampleCol) val testResult = OldStatistics.kolmogorovSmirnovTest(rdd, cdf) spark.createDataFrame(Seq(KolmogorovSmirnovTestResult( testResult.pValue, testResult.statistic))) @@ -81,10 +81,11 @@ object KolmogorovSmirnovTest { * Java-friendly version of `test(dataset: DataFrame, sampleCol: String, cdf: Double => Double)` */ @Since("2.4.0") - def test(dataset: DataFrame, sampleCol: String, - cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = { - val f: Double => Double = x => cdf.call(x) - test(dataset, sampleCol, f) + def test( + dataset: Dataset[_], + sampleCol: String, + cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = { + test(dataset, sampleCol, (x: Double) => cdf.call(x)) } /** @@ -92,10 +93,11 @@ object KolmogorovSmirnovTest { * distribution equality. Currently supports the normal distribution, taking as parameters * the mean and standard deviation. * - * @param dataset a `DataFrame` containing the sample of data to test + * @param dataset A `Dataset` or a `DataFrame` containing the sample of data to test * @param sampleCol Name of sample column in dataset, of any numerical type * @param distName a `String` name for a theoretical distribution, currently only support "norm". - * @param params `Double*` specifying the parameters to be used for the theoretical distribution + * @param params `Double*` specifying the parameters to be used for the theoretical distribution. + * For "norm" distribution, the parameters includes mean and variance. * @return DataFrame containing the test result for the input sampled data. * This DataFrame will contain a single Row with the following fields: * - `pValue: Double` @@ -103,10 +105,13 @@ object KolmogorovSmirnovTest { */ @Since("2.4.0") @varargs - def test(dataset: DataFrame, sampleCol: String, distName: String, params: Double*): DataFrame = { + def test( + dataset: Dataset[_], + sampleCol: String, distName: String, + params: Double*): DataFrame = { val spark = dataset.sparkSession - val rdd = getSampleRDD(dataset, sampleCol) + val rdd = getSampleRDD(dataset.toDF(), sampleCol) val testResult = OldStatistics.kolmogorovSmirnovTest(rdd, distName, params: _*) spark.createDataFrame(Seq(KolmogorovSmirnovTestResult( testResult.pValue, testResult.statistic))) diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index 0eeb5e528434a..93d0f4fd9148f 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -32,32 +32,6 @@ class ChiSquareTest(object): The null hypothesis is that the occurrence of the outcomes is statistically independent. - :param dataset: - DataFrame of categorical labels and categorical features. - Real-valued features will be treated as categorical for each distinct value. - :param featuresCol: - Name of features column in dataset, of type `Vector` (`VectorUDT`). - :param labelCol: - Name of label column in dataset, of any numerical type. - :return: - DataFrame containing the test result for every feature against the label. - This DataFrame will contain a single Row with the following fields: - - `pValues: Vector` - - `degreesOfFreedom: Array[Int]` - - `statistics: Vector` - Each of these fields has one value per feature. - - >>> from pyspark.ml.linalg import Vectors - >>> from pyspark.ml.stat import ChiSquareTest - >>> dataset = [[0, Vectors.dense([0, 0, 1])], - ... [0, Vectors.dense([1, 0, 1])], - ... [1, Vectors.dense([2, 1, 1])], - ... [1, Vectors.dense([3, 1, 1])]] - >>> dataset = spark.createDataFrame(dataset, ["label", "features"]) - >>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label') - >>> chiSqResult.select("degreesOfFreedom").collect()[0] - Row(degreesOfFreedom=[3, 1, 0]) - .. versionadded:: 2.2.0 """ @@ -66,6 +40,32 @@ class ChiSquareTest(object): def test(dataset, featuresCol, labelCol): """ Perform a Pearson's independence test using dataset. + + :param dataset: + DataFrame of categorical labels and categorical features. + Real-valued features will be treated as categorical for each distinct value. + :param featuresCol: + Name of features column in dataset, of type `Vector` (`VectorUDT`). + :param labelCol: + Name of label column in dataset, of any numerical type. + :return: + DataFrame containing the test result for every feature against the label. + This DataFrame will contain a single Row with the following fields: + - `pValues: Vector` + - `degreesOfFreedom: Array[Int]` + - `statistics: Vector` + Each of these fields has one value per feature. + + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.ml.stat import ChiSquareTest + >>> dataset = [[0, Vectors.dense([0, 0, 1])], + ... [0, Vectors.dense([1, 0, 1])], + ... [1, Vectors.dense([2, 1, 1])], + ... [1, Vectors.dense([3, 1, 1])]] + >>> dataset = spark.createDataFrame(dataset, ["label", "features"]) + >>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label') + >>> chiSqResult.select("degreesOfFreedom").collect()[0] + Row(degreesOfFreedom=[3, 1, 0]) """ sc = SparkContext._active_spark_context javaTestObj = _jvm().org.apache.spark.ml.stat.ChiSquareTest @@ -85,40 +85,6 @@ class Correlation(object): which is fairly costly. Cache the input Dataset before calling corr with `method = 'spearman'` to avoid recomputing the common lineage. - :param dataset: - A dataset or a dataframe. - :param column: - The name of the column of vectors for which the correlation coefficient needs - to be computed. This must be a column of the dataset, and it must contain - Vector objects. - :param method: - String specifying the method to use for computing correlation. - Supported: `pearson` (default), `spearman`. - :return: - A dataframe that contains the correlation matrix of the column of vectors. This - dataframe contains a single row and a single column of name - '$METHODNAME($COLUMN)'. - - >>> from pyspark.ml.linalg import Vectors - >>> from pyspark.ml.stat import Correlation - >>> dataset = [[Vectors.dense([1, 0, 0, -2])], - ... [Vectors.dense([4, 5, 0, 3])], - ... [Vectors.dense([6, 7, 0, 8])], - ... [Vectors.dense([9, 0, 0, 1])]] - >>> dataset = spark.createDataFrame(dataset, ['features']) - >>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()[0][0] - >>> print(str(pearsonCorr).replace('nan', 'NaN')) - DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...], - [ 0.0556..., 1. , NaN, 0.9135...], - [ NaN, NaN, 1. , NaN], - [ 0.4004..., 0.9135..., NaN, 1. ]]) - >>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()[0][0] - >>> print(str(spearmanCorr).replace('nan', 'NaN')) - DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ], - [ 0.1054..., 1. , NaN, 0.9486... ], - [ NaN, NaN, 1. , NaN], - [ 0.4 , 0.9486... , NaN, 1. ]]) - .. versionadded:: 2.2.0 """ @@ -127,6 +93,40 @@ class Correlation(object): def corr(dataset, column, method="pearson"): """ Compute the correlation matrix with specified method using dataset. + + :param dataset: + A Dataset or a DataFrame. + :param column: + The name of the column of vectors for which the correlation coefficient needs + to be computed. This must be a column of the dataset, and it must contain + Vector objects. + :param method: + String specifying the method to use for computing correlation. + Supported: `pearson` (default), `spearman`. + :return: + A DataFrame that contains the correlation matrix of the column of vectors. This + DataFrame contains a single row and a single column of name + '$METHODNAME($COLUMN)'. + + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.ml.stat import Correlation + >>> dataset = [[Vectors.dense([1, 0, 0, -2])], + ... [Vectors.dense([4, 5, 0, 3])], + ... [Vectors.dense([6, 7, 0, 8])], + ... [Vectors.dense([9, 0, 0, 1])]] + >>> dataset = spark.createDataFrame(dataset, ['features']) + >>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()[0][0] + >>> print(str(pearsonCorr).replace('nan', 'NaN')) + DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...], + [ 0.0556..., 1. , NaN, 0.9135...], + [ NaN, NaN, 1. , NaN], + [ 0.4004..., 0.9135..., NaN, 1. ]]) + >>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()[0][0] + >>> print(str(spearmanCorr).replace('nan', 'NaN')) + DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ], + [ 0.1054..., 1. , NaN, 0.9486... ], + [ NaN, NaN, 1. , NaN], + [ 0.4 , 0.9486... , NaN, 1. ]]) """ sc = SparkContext._active_spark_context javaCorrObj = _jvm().org.apache.spark.ml.stat.Correlation @@ -134,6 +134,67 @@ def corr(dataset, column, method="pearson"): return _java2py(sc, javaCorrObj.corr(*args)) +class KolmogorovSmirnovTest(object): + """ + .. note:: Experimental + + Conduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a continuous + distribution. + + By comparing the largest difference between the empirical cumulative + distribution of the sample data and the theoretical distribution we can provide a test for the + the null hypothesis that the sample data comes from that theoretical distribution. + + .. versionadded:: 2.4.0 + + """ + @staticmethod + @since("2.4.0") + def test(dataset, sampleCol, distName, *params): + """ + Conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability distribution + equality. Currently supports the normal distribution, taking as parameters the mean and + standard deviation. + + :param dataset: + a Dataset or a DataFrame containing the sample of data to test. + :param sampleCol: + Name of sample column in dataset, of any numerical type. + :param distName: + a `string` name for a theoretical distribution, currently only support "norm". + :param params: + a list of `Double` values specifying the parameters to be used for the theoretical + distribution. For "norm" distribution, the parameters includes mean and variance. + :return: + A DataFrame that contains the Kolmogorov-Smirnov test result for the input sampled data. + This DataFrame will contain a single Row with the following fields: + - `pValue: Double` + - `statistic: Double` + + >>> from pyspark.ml.stat import KolmogorovSmirnovTest + >>> dataset = [[-1.0], [0.0], [1.0]] + >>> dataset = spark.createDataFrame(dataset, ['sample']) + >>> ksResult = KolmogorovSmirnovTest.test(dataset, 'sample', 'norm', 0.0, 1.0).first() + >>> round(ksResult.pValue, 3) + 1.0 + >>> round(ksResult.statistic, 3) + 0.175 + >>> dataset = [[2.0], [3.0], [4.0]] + >>> dataset = spark.createDataFrame(dataset, ['sample']) + >>> ksResult = KolmogorovSmirnovTest.test(dataset, 'sample', 'norm', 3.0, 1.0).first() + >>> round(ksResult.pValue, 3) + 1.0 + >>> round(ksResult.statistic, 3) + 0.175 + """ + sc = SparkContext._active_spark_context + javaTestObj = _jvm().org.apache.spark.ml.stat.KolmogorovSmirnovTest + dataset = _py2java(sc, dataset) + params = [float(param) for param in params] + return _java2py(sc, javaTestObj.test(dataset, sampleCol, distName, + _jvm().PythonUtils.toSeq(params))) + + if __name__ == "__main__": import doctest import pyspark.ml.stat From 4f1e8b9bb7d795d4ca3d5cd5dcc0f9419e52dfae Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 10 Apr 2018 15:41:45 -0700 Subject: [PATCH 596/774] [SPARK-23871][ML][PYTHON] add python api for VectorAssembler handleInvalid ## What changes were proposed in this pull request? add python api for VectorAssembler handleInvalid ## How was this patch tested? Add doctest Author: Huaxin Gao Closes #21003 from huaxingao/spark-23871. --- .../spark/ml/feature/VectorAssembler.scala | 12 +++--- python/pyspark/ml/feature.py | 42 ++++++++++++++++--- 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 6bf4aa38b1fcb..4061154b39c14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -71,12 +71,12 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) */ @Since("2.4.0") override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", - """Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with - |invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the - |output). Column lengths are taken from the size of ML Attribute Group, which can be set using - |`VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred - |from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'. - |""".stripMargin.replaceAll("\n", " "), + """Param for how to handle invalid data (NULL and NaN values). Options are 'skip' (filter out + |rows with invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN + |in the output). Column lengths are taken from the size of ML Attribute Group, which can be + |set using `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also + |be inferred from first rows of the data since it is safe to do so but only in case of 'error' + |or 'skip'.""".stripMargin.replaceAll("\n", " "), ParamValidators.inArray(VectorAssembler.supportedHandleInvalids)) setDefault(handleInvalid, VectorAssembler.ERROR_INVALID) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 5a3e0dd655150..cdda30cfab482 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2701,7 +2701,8 @@ def setParams(self, inputCol=None, outputCol=None): @inherit_doc -class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadable, JavaMLWritable): +class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, HasHandleInvalid, JavaMLReadable, + JavaMLWritable): """ A feature transformer that merges multiple columns into a vector column. @@ -2719,25 +2720,56 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadabl >>> loadedAssembler = VectorAssembler.load(vectorAssemblerPath) >>> loadedAssembler.transform(df).head().freqs == vecAssembler.transform(df).head().freqs True + >>> dfWithNullsAndNaNs = spark.createDataFrame( + ... [(1.0, 2.0, None), (3.0, float("nan"), 4.0), (5.0, 6.0, 7.0)], ["a", "b", "c"]) + >>> vecAssembler2 = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features", + ... handleInvalid="keep") + >>> vecAssembler2.transform(dfWithNullsAndNaNs).show() + +---+---+----+-------------+ + | a| b| c| features| + +---+---+----+-------------+ + |1.0|2.0|null|[1.0,2.0,NaN]| + |3.0|NaN| 4.0|[3.0,NaN,4.0]| + |5.0|6.0| 7.0|[5.0,6.0,7.0]| + +---+---+----+-------------+ + ... + >>> vecAssembler2.setParams(handleInvalid="skip").transform(dfWithNullsAndNaNs).show() + +---+---+---+-------------+ + | a| b| c| features| + +---+---+---+-------------+ + |5.0|6.0|7.0|[5.0,6.0,7.0]| + +---+---+---+-------------+ + ... .. versionadded:: 1.4.0 """ + handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data (NULL " + + "and NaN values). Options are 'skip' (filter out rows with invalid " + + "data), 'error' (throw an error), or 'keep' (return relevant number " + + "of NaN in the output). Column lengths are taken from the size of ML " + + "Attribute Group, which can be set using `VectorSizeHint` in a " + + "pipeline before `VectorAssembler`. Column lengths can also be " + + "inferred from first rows of the data since it is safe to do so but " + + "only in case of 'error' or 'skip').", + typeConverter=TypeConverters.toString) + @keyword_only - def __init__(self, inputCols=None, outputCol=None): + def __init__(self, inputCols=None, outputCol=None, handleInvalid="error"): """ - __init__(self, inputCols=None, outputCol=None) + __init__(self, inputCols=None, outputCol=None, handleInvalid="error") """ super(VectorAssembler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid) + self._setDefault(handleInvalid="error") kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.4.0") - def setParams(self, inputCols=None, outputCol=None): + def setParams(self, inputCols=None, outputCol=None, handleInvalid="error"): """ - setParams(self, inputCols=None, outputCol=None) + setParams(self, inputCols=None, outputCol=None, handleInvalid="error") Sets params for this VectorAssembler. """ kwargs = self._input_kwargs From 7c7570d466a8ded51e580eb6a28583bd9a9c5337 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Tue, 10 Apr 2018 17:26:06 -0700 Subject: [PATCH 597/774] [SPARK-23944][ML] Add the set method for the two LSHModel ## What changes were proposed in this pull request? Add two set method for LSHModel in LSH.scala, BucketedRandomProjectionLSH.scala, and MinHashLSH.scala ## How was this patch tested? New test for the param setup was added into - BucketedRandomProjectionLSHSuite.scala - MinHashLSHSuite.scala Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21015 from ludatabricks/SPARK-23944. --- .../spark/ml/feature/BucketedRandomProjectionLSH.scala | 8 ++++++++ .../src/main/scala/org/apache/spark/ml/feature/LSH.scala | 6 ++++++ .../scala/org/apache/spark/ml/feature/MinHashLSH.scala | 8 ++++++++ .../ml/feature/BucketedRandomProjectionLSHSuite.scala | 8 ++++++++ .../org/apache/spark/ml/feature/MinHashLSHSuite.scala | 8 ++++++++ 5 files changed, 38 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index 36a46ca6ff4b7..41eaaf9679914 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -73,6 +73,14 @@ class BucketedRandomProjectionLSHModel private[ml]( private[ml] val randUnitVectors: Array[Vector]) extends LSHModel[BucketedRandomProjectionLSHModel] with BucketedRandomProjectionLSHParams { + /** @group setParam */ + @Since("2.4.0") + override def setInputCol(value: String): this.type = super.set(inputCol, value) + + /** @group setParam */ + @Since("2.4.0") + override def setOutputCol(value: String): this.type = super.set(outputCol, value) + @Since("2.1.0") override protected[ml] val hashFunction: Vector => Array[Vector] = { key: Vector => { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala index 1c9f47a0b201d..a70931f783f45 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala @@ -65,6 +65,12 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] extends Model[T] with LSHParams with MLWritable { self: T => + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + /** * The hash function of LSH, mapping an input feature vector to multiple hash vectors. * @return The mapping of LSH function. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index 145422a059196..556848e45532d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -51,6 +51,14 @@ class MinHashLSHModel private[ml]( private[ml] val randCoefficients: Array[(Int, Int)]) extends LSHModel[MinHashLSHModel] { + /** @group setParam */ + @Since("2.4.0") + override def setInputCol(value: String): this.type = super.set(inputCol, value) + + /** @group setParam */ + @Since("2.4.0") + override def setOutputCol(value: String): this.type = super.set(outputCol, value) + @Since("2.1.0") override protected[ml] val hashFunction: Vector => Array[Vector] = { elems: Vector => { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala index ed9a39d8d1512..9b823259b1deb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala @@ -48,6 +48,14 @@ class BucketedRandomProjectionLSHSuite extends MLTest with DefaultReadWriteTest ParamsSuite.checkParams(model) } + test("setters") { + val model = new BucketedRandomProjectionLSHModel("brp", Array(Vectors.dense(0.0, 1.0))) + .setInputCol("testkeys") + .setOutputCol("testvalues") + assert(model.getInputCol === "testkeys") + assert(model.getOutputCol === "testvalues") + } + test("BucketedRandomProjectionLSH: default params") { val brp = new BucketedRandomProjectionLSH assert(brp.getNumHashTables === 1.0) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala index 96df68dbdf053..3da0fb7da01ae 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala @@ -43,6 +43,14 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa ParamsSuite.checkParams(model) } + test("setters") { + val model = new MinHashLSHModel("mh", randCoefficients = Array((1, 0))) + .setInputCol("testkeys") + .setOutputCol("testvalues") + assert(model.getInputCol === "testkeys") + assert(model.getOutputCol === "testvalues") + } + test("MinHashLSH: default params") { val rp = new MinHashLSH assert(rp.getNumHashTables === 1.0) From c7622befdadfea725797d76e820e3dfc76fec927 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 11 Apr 2018 19:42:09 +0800 Subject: [PATCH 598/774] [SPARK-23847][FOLLOWUP][PYTHON][SQL] Actually test [desc|acs]_nulls_[first|last] functions in PySpark ## What changes were proposed in this pull request? There was a mistake in `tests.py` missing `assertEquals`. ## How was this patch tested? Fixed tests. Author: hyukjinkwon Closes #21035 from HyukjinKwon/SPARK-23847. --- python/pyspark/sql/tests.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index dd04ffb4ed393..96c2a776a5049 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2991,19 +2991,23 @@ def test_create_dateframe_from_pandas_with_dst(self): os.environ['TZ'] = orig_env_tz time.tzset() - def test_2_4_functions(self): + def test_sort_with_nulls_order(self): from pyspark.sql import functions df = self.spark.createDataFrame( [('Tom', 80), (None, 60), ('Alice', 50)], ["name", "height"]) - df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect() - [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')] - df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect() - [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)] - df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect() - [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')] - df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect() - [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)] + self.assertEquals( + df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect(), + [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')]) + self.assertEquals( + df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect(), + [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)]) + self.assertEquals( + df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect(), + [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')]) + self.assertEquals( + df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(), + [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)]) class HiveSparkSubmitTests(SparkSubmitTests): From 87611bba222a95158fc5b638a566bdf47346da8e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 11 Apr 2018 19:44:01 +0800 Subject: [PATCH 599/774] [MINOR][DOCS] Fix R documentation generation instruction for roxygen2 ## What changes were proposed in this pull request? This PR proposes to fix `roxygen2` to `5.0.1` in `docs/README.md` for SparkR documentation generation. If I use higher version and creates the doc, it shows the diff below. Not a big deal but it bothered me. ```diff diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 855eb5bf77f..159fca61e06 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION -57,6 +57,6 Collate: 'types.R' 'utils.R' 'window.R' -RoxygenNote: 5.0.1 +RoxygenNote: 6.0.1 VignetteBuilder: knitr NeedsCompilation: no ``` ## How was this patch tested? Manually tested. I met this every time I set the new environment for Spark dev but I have kept forgetting to fix it. Author: hyukjinkwon Closes #21020 from HyukjinKwon/minor-r-doc. --- docs/README.md | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/docs/README.md b/docs/README.md index 9eac4ba35c458..dbea4d64c4298 100644 --- a/docs/README.md +++ b/docs/README.md @@ -22,10 +22,13 @@ $ sudo gem install jekyll jekyll-redirect-from pygments.rb $ sudo pip install Pygments # Following is needed only for generating API docs $ sudo pip install sphinx pypandoc mkdocs -$ sudo Rscript -e 'install.packages(c("knitr", "devtools", "roxygen2", "testthat", "rmarkdown"), repos="http://cran.stat.ucla.edu/")' +$ sudo Rscript -e 'install.packages(c("knitr", "devtools", "testthat", "rmarkdown"), repos="http://cran.stat.ucla.edu/")' +$ sudo Rscript -e 'devtools::install_version("roxygen2", version = "5.0.1", repos="http://cran.stat.ucla.edu/")' ``` -(Note: If you are on a system with both Ruby 1.9 and Ruby 2.0 you may need to replace gem with gem2.0) +Note: If you are on a system with both Ruby 1.9 and Ruby 2.0 you may need to replace gem with gem2.0. + +Note: Other versions of roxygen2 might work in SparkR documentation generation but `RoxygenNote` field in `$SPARK_HOME/R/pkg/DESCRIPTION` is 5.0.1, which is updated if the version is mismatched. ## Generating the Documentation HTML @@ -62,12 +65,12 @@ $ PRODUCTION=1 jekyll build ## API Docs (Scaladoc, Javadoc, Sphinx, roxygen2, MkDocs) -You can build just the Spark scaladoc and javadoc by running `build/sbt unidoc` from the `SPARK_HOME` directory. +You can build just the Spark scaladoc and javadoc by running `build/sbt unidoc` from the `$SPARK_HOME` directory. Similarly, you can build just the PySpark docs by running `make html` from the -`SPARK_HOME/python/docs` directory. Documentation is only generated for classes that are listed as -public in `__init__.py`. The SparkR docs can be built by running `SPARK_HOME/R/create-docs.sh`, and -the SQL docs can be built by running `SPARK_HOME/sql/create-docs.sh` +`$SPARK_HOME/python/docs` directory. Documentation is only generated for classes that are listed as +public in `__init__.py`. The SparkR docs can be built by running `$SPARK_HOME/R/create-docs.sh`, and +the SQL docs can be built by running `$SPARK_HOME/sql/create-docs.sh` after [building Spark](https://github.com/apache/spark#building-spark) first. When you run `jekyll build` in the `docs` directory, it will also copy over the scaladoc and javadoc for the various From c604d659e19c1b2704cdf8c8ea97edaf50d8cb6b Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 11 Apr 2018 20:11:03 +0800 Subject: [PATCH 600/774] [SPARK-23951][SQL] Use actual java class instead of string representation. ## What changes were proposed in this pull request? This PR slightly refactors the newly added `ExprValue` API by quite a bit. The following changes are introduced: 1. `ExprValue` now uses the actual class instead of the class name as its type. This should give some more flexibility with generating code in the future. 2. Renamed `StatementValue` to `SimpleExprValue`. The statement concept is broader then an expression (untyped and it cannot be on the right hand side of an assignment), and this was not really what we were using it for. I have added a top level `JavaCode` trait that can be used in the future to reinstate (no pun intended) a statement a-like code fragment. 3. Added factory methods to the `JavaCode` companion object to make it slightly less verbose to create `JavaCode`/`ExprValue` objects. This is also what makes the diff quite large. 4. Added one more factory method to `ExprCode` to make it easier to create code-less expressions. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #21026 from hvanhovell/SPARK-23951. --- .../sql/catalyst/expressions/Expression.scala | 10 +- .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 35 +++- .../expressions/codegen/ExprValue.scala | 76 -------- .../codegen/GenerateMutableProjection.scala | 36 ++-- .../codegen/GenerateSafeProjection.scala | 25 +-- .../codegen/GenerateUnsafeProjection.scala | 11 +- .../expressions/codegen/javaCode.scala | 166 ++++++++++++++++++ .../expressions/complexTypeCreator.scala | 2 +- .../expressions/conditionalExpressions.scala | 5 +- .../expressions/datetimeExpressions.scala | 3 +- .../sql/catalyst/expressions/literals.scala | 27 +-- .../spark/sql/catalyst/expressions/misc.scala | 2 +- .../expressions/nullExpressions.scala | 20 +-- .../expressions/objects/objects.scala | 7 +- .../expressions/CodeGenerationSuite.scala | 5 +- .../expressions/codegen/ExprValueSuite.scala | 14 +- .../sql/execution/ColumnarBatchScan.scala | 6 +- .../spark/sql/execution/ExpandExec.scala | 8 +- .../spark/sql/execution/GenerateExec.scala | 15 +- .../sql/execution/WholeStageCodegenExec.scala | 11 +- .../aggregate/HashAggregateExec.scala | 6 +- .../aggregate/HashMapGenerator.scala | 8 +- .../execution/basicPhysicalOperators.scala | 2 +- .../joins/BroadcastHashJoinExec.scala | 9 +- .../execution/joins/SortMergeJoinExec.scala | 12 +- 26 files changed, 315 insertions(+), 212 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 7a5e49cb5206b..97dff6ae88299 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -104,9 +104,9 @@ abstract class Expression extends TreeNode[Expression] { }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val eval = doGenCode(ctx, ExprCode("", - VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), - VariableValue(value, CodeGenerator.javaType(dataType)))) + val eval = doGenCode(ctx, ExprCode( + JavaCode.isNullVariable(isNull), + JavaCode.variable(value, dataType))) reduceCodeSize(ctx, eval) if (eval.code.nonEmpty) { // Add `this` in the comment. @@ -123,7 +123,7 @@ abstract class Expression extends TreeNode[Expression] { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull - eval.isNull = GlobalValue(globalIsNull, CodeGenerator.JAVA_BOOLEAN) + eval.isNull = JavaCode.isNullGlobal(globalIsNull) s"$globalIsNull = $localIsNull;" } else { "" @@ -142,7 +142,7 @@ abstract class Expression extends TreeNode[Expression] { |} """.stripMargin) - eval.value = VariableValue(newValue, javaType) + eval.value = JavaCode.variable(newValue, dataType) eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index defd6f3cd8849..9212c3de1f814 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -591,8 +591,7 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), - CodeGenerator.JAVA_BOOLEAN) + ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -671,8 +670,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), - CodeGenerator.JAVA_BOOLEAN) + ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)) val evals = evalChildren.map(eval => s""" |${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c9c60ef1be640..0abfc9fa4c465 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -59,10 +59,12 @@ import org.apache.spark.util.{ParentClassLoader, Utils} case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue) object ExprCode { + def apply(isNull: ExprValue, value: ExprValue): ExprCode = { + ExprCode(code = "", isNull, value) + } + def forNullValue(dataType: DataType): ExprCode = { - val defaultValueLiteral = CodeGenerator.defaultValue(dataType, typedNull = true) - ExprCode(code = "", isNull = TrueLiteral, - value = LiteralValue(defaultValueLiteral, CodeGenerator.javaType(dataType))) + ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) } def forNonNullValue(value: ExprValue): ExprCode = { @@ -331,7 +333,7 @@ class CodegenContext { case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" case _ => s"$value = $initCode;" } - ExprCode(code, FalseLiteral, GlobalValue(value, javaType(dataType))) + ExprCode(code, FalseLiteral, JavaCode.global(value, dataType)) } def declareMutableStates(): String = { @@ -1004,8 +1006,9 @@ class CodegenContext { // at least two nodes) as the cost of doing it is expected to be low. subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - val state = SubExprEliminationState(GlobalValue(isNull, JAVA_BOOLEAN), - GlobalValue(value, javaType(expr.dataType))) + val state = SubExprEliminationState( + JavaCode.isNullGlobal(isNull), + JavaCode.global(value, expr.dataType)) subExprEliminationExprs ++= e.map(_ -> state).toMap } } @@ -1479,6 +1482,26 @@ object CodeGenerator extends Logging { case _ => "Object" } + def javaClass(dt: DataType): Class[_] = dt match { + case BooleanType => java.lang.Boolean.TYPE + case ByteType => java.lang.Byte.TYPE + case ShortType => java.lang.Short.TYPE + case IntegerType | DateType => java.lang.Integer.TYPE + case LongType | TimestampType => java.lang.Long.TYPE + case FloatType => java.lang.Float.TYPE + case DoubleType => java.lang.Double.TYPE + case _: DecimalType => classOf[Decimal] + case BinaryType => classOf[Array[Byte]] + case StringType => classOf[UTF8String] + case CalendarIntervalType => classOf[CalendarInterval] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayData] + case _: MapType => classOf[MapData] + case udt: UserDefinedType[_] => javaClass(udt.sqlType) + case ObjectType(cls) => cls + case _ => classOf[Object] + } + /** * Returns the boxed type in Java. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala deleted file mode 100644 index df5f1c58b1b2d..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.codegen - -import scala.language.implicitConversions - -import org.apache.spark.sql.types.DataType - -// An abstraction that represents the evaluation result of [[ExprCode]]. -abstract class ExprValue { - - val javaType: String - - // Whether we can directly access the evaluation value anywhere. - // For example, a variable created outside a method can not be accessed inside the method. - // For such cases, we may need to pass the evaluation as parameter. - val canDirectAccess: Boolean - - def isPrimitive: Boolean = CodeGenerator.isPrimitiveType(javaType) -} - -object ExprValue { - implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString -} - -// A literal evaluation of [[ExprCode]]. -class LiteralValue(val value: String, val javaType: String) extends ExprValue { - override def toString: String = value - override val canDirectAccess: Boolean = true -} - -object LiteralValue { - def apply(value: String, javaType: String): LiteralValue = new LiteralValue(value, javaType) - def unapply(literal: LiteralValue): Option[(String, String)] = - Some((literal.value, literal.javaType)) -} - -// A variable evaluation of [[ExprCode]]. -case class VariableValue( - val variableName: String, - val javaType: String) extends ExprValue { - override def toString: String = variableName - override val canDirectAccess: Boolean = false -} - -// A statement evaluation of [[ExprCode]]. -case class StatementValue( - val statement: String, - val javaType: String, - val canDirectAccess: Boolean = false) extends ExprValue { - override def toString: String = statement -} - -// A global variable evaluation of [[ExprCode]]. -case class GlobalValue(val value: String, val javaType: String) extends ExprValue { - override def toString: String = value - override val canDirectAccess: Boolean = true -} - -case object TrueLiteral extends LiteralValue("true", "boolean") -case object FalseLiteral extends LiteralValue("false", "boolean") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 3ae0b54c754cf..33d14329ec95c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -52,43 +52,45 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP expressions: Seq[Expression], useSubexprElimination: Boolean): MutableProjection = { val ctx = newCodeGenContext() - val (validExpr, index) = expressions.zipWithIndex.filter { + val validExpr = expressions.zipWithIndex.filter { case (NoOp, _) => false case _ => true - }.unzip - val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination) + } + val exprVals = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination) // 4-tuples: (code for projection, isNull variable name, value variable name, column index) - val projectionCodes: Seq[(String, ExprValue, String, Int)] = exprVals.zip(index).map { - case (ev, i) => - val e = expressions(i) - val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value") - if (e.nullable) { + val projectionCodes: Seq[(String, String)] = validExpr.zip(exprVals).map { + case ((e, i), ev) => + val value = JavaCode.global( + ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value"), + e.dataType) + val (code, isNull) = if (e.nullable) { val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "isNull") (s""" |${ev.code} |$isNull = ${ev.isNull}; |$value = ${ev.value}; - """.stripMargin, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), value, i) + """.stripMargin, JavaCode.isNullGlobal(isNull)) } else { (s""" |${ev.code} |$value = ${ev.value}; - """.stripMargin, ev.isNull, value, i) + """.stripMargin, FalseLiteral) } + val update = CodeGenerator.updateColumn( + "mutableRow", + e.dataType, + i, + ExprCode(isNull, value), + e.nullable) + (code, update) } // Evaluate all the subexpressions. val evalSubexpr = ctx.subexprFunctions.mkString("\n") - val updates = validExpr.zip(projectionCodes).map { - case (e, (_, isNull, value, i)) => - val ev = ExprCode("", isNull, GlobalValue(value, CodeGenerator.javaType(e.dataType))) - CodeGenerator.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) - } - val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1)) - val allUpdates = ctx.splitExpressionsWithCurrentInputs(updates) + val allUpdates = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._2)) val codeBody = s""" public java.lang.Object generate(Object[] references) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index a30a0b22cd305..01c350e9dbf69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst.expressions.codegen import scala.annotation.tailrec +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ /** @@ -53,9 +54,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => - val converter = convertToSafe(ctx, - StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString), - CodeGenerator.javaType(dt)), dt) + val converter = convertToSafe( + ctx, + JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt), + dt) s""" if (!$tmpInput.isNullAt($i)) { ${converter.code} @@ -76,7 +78,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] |final InternalRow $output = new $rowClass($values); """.stripMargin - ExprCode(code, FalseLiteral, VariableValue(output, "InternalRow")) + ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[InternalRow])) } private def createCodeForArray( @@ -91,9 +93,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val index = ctx.freshName("index") val arrayClass = classOf[GenericArrayData].getName - val elementConverter = convertToSafe(ctx, - StatementValue(CodeGenerator.getValue(tmpInput, elementType, index), - CodeGenerator.javaType(elementType)), elementType) + val elementConverter = convertToSafe( + ctx, + JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType), + elementType) val code = s""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); @@ -107,7 +110,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final ArrayData $output = new $arrayClass($values); """ - ExprCode(code, FalseLiteral, VariableValue(output, "ArrayData")) + ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[ArrayData])) } private def createCodeForMap( @@ -128,7 +131,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value}); """ - ExprCode(code, FalseLiteral, VariableValue(output, "MapData")) + ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[MapData])) } @tailrec @@ -140,7 +143,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) - case _ => ExprCode("", FalseLiteral, input) + case _ => ExprCode(FalseLiteral, input) } protected def create(expressions: Seq[Expression]): Projection = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 2fb441ac4500e..01b4d6c4529bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -52,9 +52,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)", CodeGenerator.JAVA_BOOLEAN), - StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString), - CodeGenerator.javaType(dt))) + ExprCode( + JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)"), + JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt)) } val rowWriterClass = classOf[UnsafeRowWriter].getName @@ -109,7 +109,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter) - if (input.isNull == "false") { + if (input.isNull == FalseLiteral) { s""" |${input.code} |${writeField.trim} @@ -292,8 +292,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro |$writeExpressions """.stripMargin // `rowWriter` is declared as a class field, so we can access it directly in methods. - ExprCode(code, FalseLiteral, StatementValue(s"$rowWriter.getRow()", "UnsafeRow", - canDirectAccess = true)) + ExprCode(code, FalseLiteral, JavaCode.expression(s"$rowWriter.getRow()", classOf[UnsafeRow])) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala new file mode 100644 index 0000000000000..74ff018488863 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import java.lang.{Boolean => JBool} + +import scala.language.{existentials, implicitConversions} + +import org.apache.spark.sql.types.{BooleanType, DataType} + +/** + * Trait representing an opaque fragments of java code. + */ +trait JavaCode { + def code: String + override def toString: String = code +} + +/** + * Utility functions for creating [[JavaCode]] fragments. + */ +object JavaCode { + /** + * Create a java literal. + */ + def literal(v: String, dataType: DataType): LiteralValue = dataType match { + case BooleanType if v == "true" => TrueLiteral + case BooleanType if v == "false" => FalseLiteral + case _ => new LiteralValue(v, CodeGenerator.javaClass(dataType)) + } + + /** + * Create a default literal. This is null for reference types, false for boolean types and + * -1 for other primitive types. + */ + def defaultLiteral(dataType: DataType): LiteralValue = { + new LiteralValue( + CodeGenerator.defaultValue(dataType, typedNull = true), + CodeGenerator.javaClass(dataType)) + } + + /** + * Create a local java variable. + */ + def variable(name: String, dataType: DataType): VariableValue = { + variable(name, CodeGenerator.javaClass(dataType)) + } + + /** + * Create a local java variable. + */ + def variable(name: String, javaClass: Class[_]): VariableValue = { + VariableValue(name, javaClass) + } + + /** + * Create a local isNull variable. + */ + def isNullVariable(name: String): VariableValue = variable(name, BooleanType) + + /** + * Create a global java variable. + */ + def global(name: String, dataType: DataType): GlobalValue = { + global(name, CodeGenerator.javaClass(dataType)) + } + + /** + * Create a global java variable. + */ + def global(name: String, javaClass: Class[_]): GlobalValue = { + GlobalValue(name, javaClass) + } + + /** + * Create a global isNull variable. + */ + def isNullGlobal(name: String): GlobalValue = global(name, BooleanType) + + /** + * Create an expression fragment. + */ + def expression(code: String, dataType: DataType): SimpleExprValue = { + expression(code, CodeGenerator.javaClass(dataType)) + } + + /** + * Create an expression fragment. + */ + def expression(code: String, javaClass: Class[_]): SimpleExprValue = { + SimpleExprValue(code, javaClass) + } + + /** + * Create a isNull expression fragment. + */ + def isNullExpression(code: String): SimpleExprValue = { + expression(code, BooleanType) + } +} + +/** + * A typed java fragment that must be a valid java expression. + */ +trait ExprValue extends JavaCode { + def javaType: Class[_] + def isPrimitive: Boolean = javaType.isPrimitive +} + +object ExprValue { + implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString +} + + +/** + * A java expression fragment. + */ +case class SimpleExprValue(expr: String, javaType: Class[_]) extends ExprValue { + override def code: String = s"($expr)" +} + +/** + * A local variable java expression. + */ +case class VariableValue(variableName: String, javaType: Class[_]) extends ExprValue { + override def code: String = variableName +} + +/** + * A global variable java expression. + */ +case class GlobalValue(value: String, javaType: Class[_]) extends ExprValue { + override def code: String = value +} + +/** + * A literal java expression. + */ +class LiteralValue(val value: String, val javaType: Class[_]) extends ExprValue with Serializable { + override def code: String = value + + override def equals(arg: Any): Boolean = arg match { + case l: LiteralValue => l.javaType == javaType && l.value == value + case _ => false + } + + override def hashCode(): Int = value.hashCode() * 31 + javaType.hashCode() +} + +case object TrueLiteral extends LiteralValue("true", JBool.TYPE) +case object FalseLiteral extends LiteralValue("false", JBool.TYPE) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 49a8d12057188..67876a8565488 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -64,7 +64,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( code = preprocess + assigns + postprocess, - value = VariableValue(arrayData, CodeGenerator.javaType(dataType)), + value = JavaCode.variable(arrayData, dataType), isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 409c0b6b79b81..205d77f6a9acf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -191,8 +191,9 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - ev.value = GlobalValue(ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value), - CodeGenerator.javaType(dataType)) + ev.value = JavaCode.global( + ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value), + dataType) // these blocks are meant to be inside a // do { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 49dd988b4b53c..32fdb13afbbfa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -813,8 +813,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ val df = classOf[DateFormat].getName if (format.foldable) { if (formatter == null) { - ExprCode("", TrueLiteral, LiteralValue("(UTF8String) null", - CodeGenerator.javaType(dataType))) + ExprCode.forNullValue(StringType) } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val t = left.genCode(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 742a650eb445d..246025b82d59e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -281,38 +281,41 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { if (value == null) { ExprCode.forNullValue(dataType) } else { + def toExprCode(code: String): ExprCode = { + ExprCode.forNonNullValue(JavaCode.literal(code, dataType)) + } dataType match { case BooleanType | IntegerType | DateType => - ExprCode.forNonNullValue(LiteralValue(value.toString, javaType)) + toExprCode(value.toString) case FloatType => value.asInstanceOf[Float] match { case v if v.isNaN => - ExprCode.forNonNullValue(LiteralValue("Float.NaN", javaType)) + toExprCode("Float.NaN") case Float.PositiveInfinity => - ExprCode.forNonNullValue(LiteralValue("Float.POSITIVE_INFINITY", javaType)) + toExprCode("Float.POSITIVE_INFINITY") case Float.NegativeInfinity => - ExprCode.forNonNullValue(LiteralValue("Float.NEGATIVE_INFINITY", javaType)) + toExprCode("Float.NEGATIVE_INFINITY") case _ => - ExprCode.forNonNullValue(LiteralValue(s"${value}F", javaType)) + toExprCode(s"${value}F") } case DoubleType => value.asInstanceOf[Double] match { case v if v.isNaN => - ExprCode.forNonNullValue(LiteralValue("Double.NaN", javaType)) + toExprCode("Double.NaN") case Double.PositiveInfinity => - ExprCode.forNonNullValue(LiteralValue("Double.POSITIVE_INFINITY", javaType)) + toExprCode("Double.POSITIVE_INFINITY") case Double.NegativeInfinity => - ExprCode.forNonNullValue(LiteralValue("Double.NEGATIVE_INFINITY", javaType)) + toExprCode("Double.NEGATIVE_INFINITY") case _ => - ExprCode.forNonNullValue(LiteralValue(s"${value}D", javaType)) + toExprCode(s"${value}D") } case ByteType | ShortType => - ExprCode.forNonNullValue(LiteralValue(s"($javaType)$value", javaType)) + ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value", dataType)) case TimestampType | LongType => - ExprCode.forNonNullValue(LiteralValue(s"${value}L", javaType)) + toExprCode(s"${value}L") case _ => val constRef = ctx.addReferenceObj("literal", value, javaType) - ExprCode.forNonNullValue(GlobalValue(constRef, javaType)) + ExprCode.forNonNullValue(JavaCode.global(constRef, dataType)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 7081a5e096d56..7eda65a867028 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -92,7 +92,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); |}""".stripMargin, isNull = TrueLiteral, - value = LiteralValue("null", CodeGenerator.javaType(dataType))) + value = JavaCode.defaultLiteral(dataType)) } override def sql: String = s"assert_true(${child.sql})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 55b6e346be82a..0787342bce6bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,8 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), - CodeGenerator.JAVA_BOOLEAN) + ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -321,12 +320,7 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - val value = if (eval.isNull.isInstanceOf[LiteralValue]) { - LiteralValue(eval.isNull, CodeGenerator.JAVA_BOOLEAN) - } else { - VariableValue(eval.isNull, CodeGenerator.JAVA_BOOLEAN) - } - ExprCode(code = eval.code, isNull = FalseLiteral, value = value) + ExprCode(code = eval.code, isNull = FalseLiteral, value = eval.isNull) } override def sql: String = s"(${child.sql} IS NULL)" @@ -352,12 +346,10 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - val value = if (eval.isNull == TrueLiteral) { - FalseLiteral - } else if (eval.isNull == FalseLiteral) { - TrueLiteral - } else { - StatementValue(s"(!(${eval.isNull}))", CodeGenerator.javaType(dataType)) + val value = eval.isNull match { + case TrueLiteral => FalseLiteral + case FalseLiteral => TrueLiteral + case v => JavaCode.isNullExpression(s"!$v") } ExprCode(code = eval.code, isNull = FalseLiteral, value = value) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index b2cca3178cd2a..50e90ca550807 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -65,7 +65,7 @@ trait InvokeLike extends Expression with NonSQLExpression { val resultIsNull = if (needNullCheck) { val resultIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "resultIsNull") - GlobalValue(resultIsNull, CodeGenerator.JAVA_BOOLEAN) + JavaCode.isNullGlobal(resultIsNull) } else { FalseLiteral } @@ -569,12 +569,11 @@ case class LambdaVariable( override def genCode(ctx: CodegenContext): ExprCode = { val isNullValue = if (nullable) { - VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN) + JavaCode.isNullVariable(isNull) } else { FalseLiteral } - ExprCode(code = "", value = VariableValue(value, CodeGenerator.javaType(dataType)), - isNull = isNullValue) + ExprCode(value = JavaCode.variable(value, dataType), isNull = isNullValue) } // This won't be called as `genCode` is overrided, just overriding it to make diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 8e83b35c3809c..f7c023111ff59 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -448,8 +448,9 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val ref = BoundReference(0, IntegerType, true) val add1 = Add(ref, ref) val add2 = Add(add1, add1) - val dummy = SubExprEliminationState(VariableValue("dummy", "boolean"), - VariableValue("dummy", "boolean")) + val dummy = SubExprEliminationState( + JavaCode.variable("dummy", BooleanType), + JavaCode.variable("dummy", BooleanType)) // raw testing of basic functionality { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala index c8f4cff7db48d..378b8bc055e34 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.BooleanType class ExprValueSuite extends SparkFunSuite { @@ -31,16 +32,7 @@ class ExprValueSuite extends SparkFunSuite { assert(trueLit.isPrimitive) assert(falseLit.isPrimitive) - trueLit match { - case LiteralValue(value, javaType) => - assert(value == "true" && javaType == "boolean") - case _ => fail() - } - - falseLit match { - case LiteralValue(value, javaType) => - assert(value == "false" && javaType == "boolean") - case _ => fail() - } + assert(trueLit === JavaCode.literal("true", BooleanType)) + assert(falseLit === JavaCode.literal("false", BooleanType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 434214a10e1e3..fc3dbc1c5591b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, VariableValue} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -52,7 +52,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val javaType = CodeGenerator.javaType(dataType) val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal) val isNullVar = if (nullable) { - VariableValue(ctx.freshName("isNull"), CodeGenerator.JAVA_BOOLEAN) + JavaCode.isNullVariable(ctx.freshName("isNull")) } else { FalseLiteral } @@ -66,7 +66,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } else { s"$javaType $valueVar = $value;" }).trim - ExprCode(code, isNullVar, VariableValue(valueVar, javaType)) + ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 0d9a62cace62a..e4812f3d338fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, VariableValue} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -157,8 +157,10 @@ case class ExpandExec( |${CodeGenerator.javaType(firstExpr.dataType)} $value = | ${CodeGenerator.defaultValue(firstExpr.dataType)}; """.stripMargin - ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), - VariableValue(value, CodeGenerator.javaType(firstExpr.dataType))) + ExprCode( + code, + JavaCode.isNullVariable(isNull), + JavaCode.variable(value, firstExpr.dataType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 85c5ebfdaa689..f40c50df74ccb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} +import org.apache.spark.sql.types._ /** * For lazy computing, be sure the generator.terminate() called in the very last @@ -170,10 +170,11 @@ case class GenerateExec( // Add position val position = if (e.position) { if (outer) { - Seq(ExprCode("", StatementValue(s"$index == -1", CodeGenerator.JAVA_BOOLEAN), - VariableValue(index, CodeGenerator.JAVA_INT))) + Seq(ExprCode( + JavaCode.isNullExpression(s"$index == -1"), + JavaCode.variable(index, IntegerType))) } else { - Seq(ExprCode("", FalseLiteral, VariableValue(index, CodeGenerator.JAVA_INT))) + Seq(ExprCode(FalseLiteral, JavaCode.variable(index, IntegerType))) } } else { Seq.empty @@ -316,11 +317,9 @@ case class GenerateExec( |boolean $isNull = ${checks.mkString(" || ")}; |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin - ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), - VariableValue(value, javaType)) + ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, dt)) } else { - ExprCode(s"$javaType $value = $getter;", FalseLiteral, - VariableValue(value, javaType)) + ExprCode(s"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 805ff3cf001ba..828b51fa199de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -111,7 +111,7 @@ trait CodegenSupport extends SparkPlan { private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = { if (row != null) { - ExprCode("", FalseLiteral, VariableValue(row, "UnsafeRow")) + ExprCode.forNonNullValue(JavaCode.variable(row, classOf[UnsafeRow])) } else { if (colVars.nonEmpty) { val colExprs = output.zipWithIndex.map { case (attr, i) => @@ -128,8 +128,8 @@ trait CodegenSupport extends SparkPlan { """.stripMargin.trim ExprCode(code, FalseLiteral, ev.value) } else { - // There is no columns - ExprCode("", FalseLiteral, VariableValue("unsafeRow", "UnsafeRow")) + // There are no columns + ExprCode.forNonNullValue(JavaCode.variable("unsafeRow", classOf[UnsafeRow])) } } } @@ -246,11 +246,10 @@ trait CodegenSupport extends SparkPlan { val isNull = ctx.freshName(s"exprIsNull_$i") arguments += ev.isNull parameters += s"boolean $isNull" - VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN) + JavaCode.isNullVariable(isNull) } - paramVars += ExprCode("", paramIsNull, - VariableValue(paramName, CodeGenerator.javaType(attributes(i).dataType))) + paramVars += ExprCode(paramIsNull, JavaCode.variable(paramName, attributes(i).dataType)) } (arguments, parameters, paramVars) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8f7f10243d4cc..a5dc6ebf2b0f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -194,8 +194,10 @@ case class HashAggregateExec( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), - GlobalValue(value, CodeGenerator.javaType(e.dataType))) + ExprCode( + ev.code + initVars, + JavaCode.isNullGlobal(isNull), + JavaCode.global(value, e.dataType)) } val initBufVar = evaluateVariables(bufVars) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 4978954271311..de2d630de3fdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, GlobalValue} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ /** @@ -54,8 +54,10 @@ abstract class HashMapGenerator( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), - GlobalValue(value, CodeGenerator.javaType(e.dataType))) + ExprCode( + ev.code + initVars, + JavaCode.isNullGlobal(isNull), + JavaCode.global(value, e.dataType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index cab7081400ce9..1edfdc888afd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -368,7 +368,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number") val value = ctx.freshName("value") - val ev = ExprCode("", FalseLiteral, VariableValue(value, CodeGenerator.JAVA_LONG)) + val ev = ExprCode.forNonNullValue(JavaCode.variable(value, LongType)) val BigInt = classOf[java.math.BigInteger].getName // Inline mutable state since not many Range operations in a task diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index fa62a32d51f3e..6fa716d9fadee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.{BooleanType, LongType} import org.apache.spark.util.TaskCompletionListener /** @@ -192,8 +192,7 @@ case class BroadcastHashJoinExec( | $value = ${ev.value}; |} """.stripMargin - ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), - VariableValue(value, CodeGenerator.javaType(a.dataType))) + ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)) } } } @@ -488,8 +487,8 @@ case class BroadcastHashJoinExec( s"$existsVar = true;" } - val resultVar = input ++ Seq(ExprCode("", FalseLiteral, - VariableValue(existsVar, CodeGenerator.JAVA_BOOLEAN))) + val resultVar = input ++ Seq(ExprCode.forNonNullValue( + JavaCode.variable(existsVar, BooleanType))) if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index b61acb8d4fda9..d8261f0f33b61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -22,11 +22,10 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, VariableValue} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, -ExternalAppendOnlyUnsafeRowArray, RowIterator, SparkPlan} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.collection.BitSet @@ -531,13 +530,12 @@ case class SortMergeJoinExec( |boolean $isNull = false; |$javaType $value = $defaultValue; """.stripMargin - (ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), - VariableValue(value, CodeGenerator.javaType(a.dataType))), leftVarsDecl) + (ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)), + leftVarsDecl) } else { val code = s"$value = $valueCode;" val leftVarsDecl = s"""$javaType $value = $defaultValue;""" - (ExprCode(code, FalseLiteral, - VariableValue(value, CodeGenerator.javaType(a.dataType))), leftVarsDecl) + (ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), leftVarsDecl) } }.unzip } From 271c891b91917d660d1f6b995de397c47c7a6058 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Wed, 11 Apr 2018 21:52:48 +0800 Subject: [PATCH 601/774] [SPARK-23960][SQL][MINOR] Mark HashAggregateExec.bufVars as transient ## What changes were proposed in this pull request? Mark `HashAggregateExec.bufVars` as transient to avoid it from being serialized. Also manually null out this field at the end of `doProduceWithoutKeys()` to shorten its lifecycle, because it'll no longer be used after that. ## How was this patch tested? Existing tests. Author: Kris Mok Closes #21039 from rednaxelafx/codegen-improve. --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index a5dc6ebf2b0f2..965950ed94fe8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -174,8 +174,8 @@ case class HashAggregateExec( } } - // The variables used as aggregation buffer. Only used for aggregation without keys. - private var bufVars: Seq[ExprCode] = _ + // The variables used as aggregation buffer. Only used in codegen for aggregation without keys. + @transient private var bufVars: Seq[ExprCode] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") @@ -238,6 +238,8 @@ case class HashAggregateExec( | } """.stripMargin) + bufVars = null // explicitly null this field out to allow the referent to be GC'd sooner + val numOutput = metricTerm(ctx, "numOutputRows") val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") From 653fe02415a537299e15f92b56045569864b6183 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 11 Apr 2018 09:49:25 -0500 Subject: [PATCH 602/774] [SPARK-6951][CORE] Speed up parsing of event logs during listing. This change introduces two optimizations to help speed up generation of listing data when parsing events logs. The first one allows the parser to be stopped when enough data to create the listing entry has been read. This is currently the start event plus environment info, to capture UI ACLs. If the end event is needed, the code will skip to the end of the log to try to find that information, instead of parsing the whole log file. Unfortunately this works better with uncompressed logs. Skipping bytes on compressed logs only saves the work of parsing lines and some events, so not a lot of gains are observed. The second optimization deals with in-progress logs. It works in two ways: first, it completely avoids parsing the rest of the log for these apps when enough listing data is read. This, unlike the above, also speeds things up for compressed logs, since only the very beginning of the log has to be read. On top of that, the code that decides whether to re-parse logs to get updated listing data will ignore in-progress applications until they've completed. Both optimizations can be disabled but are enabled by default. I tested this on some fake event logs to see the effect. I created 500 logs of about 60M each (so ~30G uncompressed; each log was 1.7M when compressed with zstd). Below, C = completed, IP = in-progress, the size means the amount of data re-parsed at the end of logs when necessary. ``` none/C none/IP zstd/C zstd/IP On / 16k 2s 2s 22s 2s On / 1m 3s 2s 24s 2s Off 1.1m 1.1m 26s 24s ``` This was with 4 threads on a single local SSD. As expected from the previous explanations, there are considerable gains for in-progress logs, and for uncompressed logs, but not so much when looking at the full compressed log. As a side note, I removed the custom code to get the scan time by creating a file on HDFS; since file mod times are not used to detect changed logs anymore, local time is enough for the current use of the SHS. Author: Marcelo Vanzin Closes #20952 from vanzin/SPARK-6951. --- .../deploy/history/FsHistoryProvider.scala | 251 ++++++++++++------ .../apache/spark/deploy/history/config.scala | 15 ++ .../spark/scheduler/ReplayListenerBus.scala | 11 + .../org/apache/spark/util/ListenerBus.scala | 5 +- .../history/FsHistoryProviderSuite.scala | 78 ++++-- 5 files changed, 264 insertions(+), 96 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index ace6d9e00c838..56db9359e033f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -18,12 +18,13 @@ package org.apache.spark.deploy.history import java.io.{File, FileNotFoundException, IOException} -import java.util.{Date, ServiceLoader, UUID} +import java.util.{Date, ServiceLoader} import java.util.concurrent.{ExecutorService, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.io.Source import scala.util.Try import scala.xml.Node @@ -58,10 +59,10 @@ import org.apache.spark.util.kvstore._ * * == How new and updated attempts are detected == * - * - New attempts are detected in [[checkForLogs]]: the log dir is scanned, and any - * entries in the log dir whose modification time is greater than the last scan time - * are considered new or updated. These are replayed to create a new attempt info entry - * and update or create a matching application info element in the list of applications. + * - New attempts are detected in [[checkForLogs]]: the log dir is scanned, and any entries in the + * log dir whose size changed since the last scan time are considered new or updated. These are + * replayed to create a new attempt info entry and update or create a matching application info + * element in the list of applications. * - Updated attempts are also found in [[checkForLogs]] -- if the attempt's log file has grown, the * attempt is replaced by another one with a larger log size. * @@ -125,6 +126,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private val pendingReplayTasksCount = new java.util.concurrent.atomic.AtomicInteger(0) private val storePath = conf.get(LOCAL_STORE_DIR).map(new File(_)) + private val fastInProgressParsing = conf.get(FAST_IN_PROGRESS_PARSING) // Visible for testing. private[history] val listing: KVStore = storePath.map { path => @@ -402,13 +404,13 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) */ private[history] def checkForLogs(): Unit = { try { - val newLastScanTime = getNewLastScanTime() + val newLastScanTime = clock.getTimeMillis() logDebug(s"Scanning $logDir with lastScanTime==$lastScanTime") val updated = Option(fs.listStatus(new Path(logDir))).map(_.toSeq).getOrElse(Nil) .filter { entry => !entry.isDirectory() && - // FsHistoryProvider generates a hidden file which can't be read. Accidentally + // FsHistoryProvider used to generate a hidden file which can't be read. Accidentally // reading a garbage file is safe, but we would log an error which can be scary to // the end-user. !entry.getPath().getName().startsWith(".") && @@ -417,15 +419,24 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) .filter { entry => try { val info = listing.read(classOf[LogInfo], entry.getPath().toString()) - if (info.fileSize < entry.getLen()) { - // Log size has changed, it should be parsed. - true - } else { + + if (info.appId.isDefined) { // If the SHS view has a valid application, update the time the file was last seen so - // that the entry is not deleted from the SHS listing. - if (info.appId.isDefined) { - listing.write(info.copy(lastProcessed = newLastScanTime)) + // that the entry is not deleted from the SHS listing. Also update the file size, in + // case the code below decides we don't need to parse the log. + listing.write(info.copy(lastProcessed = newLastScanTime, fileSize = entry.getLen())) + } + + if (info.fileSize < entry.getLen()) { + if (info.appId.isDefined && fastInProgressParsing) { + // When fast in-progress parsing is on, we don't need to re-parse when the + // size changes, but we do need to invalidate any existing UIs. + invalidateUI(info.appId.get, info.attemptId) + false + } else { + true } + } else { false } } catch { @@ -449,7 +460,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val tasks = updated.map { entry => try { replayExecutor.submit(new Runnable { - override def run(): Unit = mergeApplicationListing(entry, newLastScanTime) + override def run(): Unit = mergeApplicationListing(entry, newLastScanTime, true) }) } catch { // let the iteration over the updated entries break, since an exception on @@ -542,25 +553,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - private[history] def getNewLastScanTime(): Long = { - val fileName = "." + UUID.randomUUID().toString - val path = new Path(logDir, fileName) - val fos = fs.create(path) - - try { - fos.close() - fs.getFileStatus(path).getModificationTime - } catch { - case e: Exception => - logError("Exception encountered when attempting to update last scan time", e) - lastScanTime.get() - } finally { - if (!fs.delete(path, true)) { - logWarning(s"Error deleting ${path}") - } - } - } - override def writeEventLogs( appId: String, attemptId: Option[String], @@ -607,7 +599,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Replay the given log file, saving the application in the listing db. */ - protected def mergeApplicationListing(fileStatus: FileStatus, scanTime: Long): Unit = { + protected def mergeApplicationListing( + fileStatus: FileStatus, + scanTime: Long, + enableOptimizations: Boolean): Unit = { val eventsFilter: ReplayEventsFilter = { eventString => eventString.startsWith(APPL_START_EVENT_PREFIX) || eventString.startsWith(APPL_END_EVENT_PREFIX) || @@ -616,32 +611,118 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } val logPath = fileStatus.getPath() + val appCompleted = isCompleted(logPath.getName()) + val reparseChunkSize = conf.get(END_EVENT_REPARSE_CHUNK_SIZE) + + // Enable halt support in listener if: + // - app in progress && fast parsing enabled + // - skipping to end event is enabled (regardless of in-progress state) + val shouldHalt = enableOptimizations && + ((!appCompleted && fastInProgressParsing) || reparseChunkSize > 0) + val bus = new ReplayListenerBus() - val listener = new AppListingListener(fileStatus, clock) + val listener = new AppListingListener(fileStatus, clock, shouldHalt) bus.addListener(listener) - replay(fileStatus, bus, eventsFilter = eventsFilter) - - val (appId, attemptId) = listener.applicationInfo match { - case Some(app) => - // Invalidate the existing UI for the reloaded app attempt, if any. See LoadedAppUI for a - // discussion on the UI lifecycle. - synchronized { - activeUIs.get((app.info.id, app.attempts.head.info.attemptId)).foreach { ui => - ui.invalidate() - ui.ui.store.close() + + logInfo(s"Parsing $logPath for listing data...") + Utils.tryWithResource(EventLoggingListener.openEventLog(logPath, fs)) { in => + bus.replay(in, logPath.toString, !appCompleted, eventsFilter) + } + + // If enabled above, the listing listener will halt parsing when there's enough information to + // create a listing entry. When the app is completed, or fast parsing is disabled, we still need + // to replay until the end of the log file to try to find the app end event. Instead of reading + // and parsing line by line, this code skips bytes from the underlying stream so that it is + // positioned somewhere close to the end of the log file. + // + // Because the application end event is written while some Spark subsystems such as the + // scheduler are still active, there is no guarantee that the end event will be the last + // in the log. So, to be safe, the code uses a configurable chunk to be re-parsed at + // the end of the file, and retries parsing the whole log later if the needed data is + // still not found. + // + // Note that skipping bytes in compressed files is still not cheap, but there are still some + // minor gains over the normal log parsing done by the replay bus. + // + // This code re-opens the file so that it knows where it's skipping to. This isn't as cheap as + // just skipping from the current position, but there isn't a a good way to detect what the + // current position is, since the replay listener bus buffers data internally. + val lookForEndEvent = shouldHalt && (appCompleted || !fastInProgressParsing) + if (lookForEndEvent && listener.applicationInfo.isDefined) { + Utils.tryWithResource(EventLoggingListener.openEventLog(logPath, fs)) { in => + val target = fileStatus.getLen() - reparseChunkSize + if (target > 0) { + logInfo(s"Looking for end event; skipping $target bytes from $logPath...") + var skipped = 0L + while (skipped < target) { + skipped += in.skip(target - skipped) } } + val source = Source.fromInputStream(in).getLines() + + // Because skipping may leave the stream in the middle of a line, read the next line + // before replaying. + if (target > 0) { + source.next() + } + + bus.replay(source, logPath.toString, !appCompleted, eventsFilter) + } + } + + logInfo(s"Finished parsing $logPath") + + listener.applicationInfo match { + case Some(app) if !lookForEndEvent || app.attempts.head.info.completed => + // In this case, we either didn't care about the end event, or we found it. So the + // listing data is good. + invalidateUI(app.info.id, app.attempts.head.info.attemptId) addListing(app) - (Some(app.info.id), app.attempts.head.info.attemptId) + listing.write(LogInfo(logPath.toString(), scanTime, Some(app.info.id), + app.attempts.head.info.attemptId, fileStatus.getLen())) + + // For a finished log, remove the corresponding "in progress" entry from the listing DB if + // the file is really gone. + if (appCompleted) { + val inProgressLog = logPath.toString() + EventLoggingListener.IN_PROGRESS + try { + // Fetch the entry first to avoid an RPC when it's already removed. + listing.read(classOf[LogInfo], inProgressLog) + if (!fs.isFile(new Path(inProgressLog))) { + listing.delete(classOf[LogInfo], inProgressLog) + } + } catch { + case _: NoSuchElementException => + } + } + + case Some(_) => + // In this case, the attempt is still not marked as finished but was expected to. This can + // mean the end event is before the configured threshold, so call the method again to + // re-parse the whole log. + logInfo(s"Reparsing $logPath since end event was not found.") + mergeApplicationListing(fileStatus, scanTime, false) case _ => // If the app hasn't written down its app ID to the logs, still record the entry in the // listing db, with an empty ID. This will make the log eligible for deletion if the app // does not make progress after the configured max log age. - (None, None) + listing.write(LogInfo(logPath.toString(), scanTime, None, None, fileStatus.getLen())) + } + } + + /** + * Invalidate an existing UI for a given app attempt. See LoadedAppUI for a discussion on the + * UI lifecycle. + */ + private def invalidateUI(appId: String, attemptId: Option[String]): Unit = { + synchronized { + activeUIs.get((appId, attemptId)).foreach { ui => + ui.invalidate() + ui.ui.store.close() + } } - listing.write(LogInfo(logPath.toString(), scanTime, appId, attemptId, fileStatus.getLen())) } /** @@ -696,29 +777,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - /** - * Replays the events in the specified log file on the supplied `ReplayListenerBus`. - * `ReplayEventsFilter` determines what events are replayed. - */ - private def replay( - eventLog: FileStatus, - bus: ReplayListenerBus, - eventsFilter: ReplayEventsFilter = SELECT_ALL_FILTER): Unit = { - val logPath = eventLog.getPath() - val isCompleted = !logPath.getName().endsWith(EventLoggingListener.IN_PROGRESS) - logInfo(s"Replaying log path: $logPath") - // Note that the eventLog may have *increased* in size since when we grabbed the filestatus, - // and when we read the file here. That is OK -- it may result in an unnecessary refresh - // when there is no update, but will not result in missing an update. We *must* prevent - // an error the other way -- if we report a size bigger (ie later) than the file that is - // actually read, we may never refresh the app. FileStatus is guaranteed to be static - // after it's created, so we get a file size that is no bigger than what is actually read. - Utils.tryWithResource(EventLoggingListener.openEventLog(logPath, fs)) { in => - bus.replay(in, logPath.toString, !isCompleted, eventsFilter) - logInfo(s"Finished parsing $logPath") - } - } - /** * Rebuilds the application state store from its event log. */ @@ -741,8 +799,13 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } replayBus.addListener(listener) try { - replay(eventLog, replayBus) + val path = eventLog.getPath() + logInfo(s"Parsing $path to re-build UI...") + Utils.tryWithResource(EventLoggingListener.openEventLog(path, fs)) { in => + replayBus.replay(in, path.toString(), maybeTruncated = !isCompleted(path.toString())) + } trackingStore.close(false) + logInfo(s"Finished parsing $path") } catch { case e: Exception => Utils.tryLogNonFatalError { @@ -881,6 +944,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + private def isCompleted(name: String): Boolean = { + !name.endsWith(EventLoggingListener.IN_PROGRESS) + } + } private[history] object FsHistoryProvider { @@ -945,11 +1012,17 @@ private[history] class ApplicationInfoWrapper( } -private[history] class AppListingListener(log: FileStatus, clock: Clock) extends SparkListener { +private[history] class AppListingListener( + log: FileStatus, + clock: Clock, + haltEnabled: Boolean) extends SparkListener { private val app = new MutableApplicationInfo() private val attempt = new MutableAttemptInfo(log.getPath().getName(), log.getLen()) + private var gotEnvUpdate = false + private var halted = false + override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { app.id = event.appId.orNull app.name = event.appName @@ -958,6 +1031,8 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends attempt.startTime = new Date(event.time) attempt.lastUpdated = new Date(clock.getTimeMillis()) attempt.sparkUser = event.sparkUser + + checkProgress() } override def onApplicationEnd(event: SparkListenerApplicationEnd): Unit = { @@ -968,11 +1043,18 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends } override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = { - val allProperties = event.environmentDetails("Spark Properties").toMap - attempt.viewAcls = allProperties.get("spark.ui.view.acls") - attempt.adminAcls = allProperties.get("spark.admin.acls") - attempt.viewAclsGroups = allProperties.get("spark.ui.view.acls.groups") - attempt.adminAclsGroups = allProperties.get("spark.admin.acls.groups") + // Only parse the first env update, since any future changes don't have any effect on + // the ACLs set for the UI. + if (!gotEnvUpdate) { + val allProperties = event.environmentDetails("Spark Properties").toMap + attempt.viewAcls = allProperties.get("spark.ui.view.acls") + attempt.adminAcls = allProperties.get("spark.admin.acls") + attempt.viewAclsGroups = allProperties.get("spark.ui.view.acls.groups") + attempt.adminAclsGroups = allProperties.get("spark.admin.acls.groups") + + gotEnvUpdate = true + checkProgress() + } } override def onOtherEvent(event: SparkListenerEvent): Unit = event match { @@ -989,6 +1071,17 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends } } + /** + * Throws a halt exception to stop replay if enough data to create the app listing has been + * read. + */ + private def checkProgress(): Unit = { + if (haltEnabled && !halted && app.id != null && gotEnvUpdate) { + halted = true + throw new HaltReplayException() + } + } + private class MutableApplicationInfo { var id: String = null var name: String = null diff --git a/core/src/main/scala/org/apache/spark/deploy/history/config.scala b/core/src/main/scala/org/apache/spark/deploy/history/config.scala index efdbf672bb52f..25ba9edb9e014 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/config.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/config.scala @@ -49,4 +49,19 @@ private[spark] object config { .intConf .createWithDefault(18080) + val FAST_IN_PROGRESS_PARSING = + ConfigBuilder("spark.history.fs.inProgressOptimization.enabled") + .doc("Enable optimized handling of in-progress logs. This option may leave finished " + + "applications that fail to rename their event logs listed as in-progress.") + .booleanConf + .createWithDefault(true) + + val END_EVENT_REPARSE_CHUNK_SIZE = + ConfigBuilder("spark.history.fs.endEventReparseChunkSize") + .doc("How many bytes to parse at the end of log files looking for the end event. " + + "This is used to speed up generation of application listings by skipping unnecessary " + + "parts of event log files. It can be disabled by setting this config to 0.") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("1m") + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index c9cd662f5709d..226c23733c870 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -115,6 +115,8 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { } } } catch { + case e: HaltReplayException => + // Just stop replay. case _: EOFException if maybeTruncated => case ioe: IOException => throw ioe @@ -124,8 +126,17 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { } } + override protected def isIgnorableException(e: Throwable): Boolean = { + e.isInstanceOf[HaltReplayException] + } + } +/** + * Exception that can be thrown by listeners to halt replay. This is handled by ReplayListenerBus + * only, and will cause errors if thrown when using other bus implementations. + */ +private[spark] class HaltReplayException extends RuntimeException private[spark] object ReplayListenerBus { diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index 76a56298aaebc..b25a731401f23 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -81,7 +81,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { try { doPostEvent(listener, event) } catch { - case NonFatal(e) => + case NonFatal(e) if !isIgnorableException(e) => logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) } finally { if (maybeTimerContext != null) { @@ -97,6 +97,9 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { */ protected def doPostEvent(listener: L, event: E): Unit + /** Allows bus implementations to prevent error logging for certain exceptions. */ + protected def isIgnorableException(e: Throwable): Boolean = false + private[spark] def findListenersByClass[T <: L : ClassTag](): Seq[T] = { val c = implicitly[ClassTag[T]].runtimeClass listeners.asScala.filter(_.getClass == c).map(_.asInstanceOf[T]).toSeq diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 0ba57bf4563c1..77b239489d489 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hdfs.DistributedFileSystem import org.json4s.jackson.JsonMethods._ import org.mockito.Matchers.any -import org.mockito.Mockito.{doReturn, mock, spy, verify} +import org.mockito.Mockito.{mock, spy, verify} import org.scalatest.BeforeAndAfter import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ @@ -151,8 +151,9 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc var mergeApplicationListingCall = 0 override protected def mergeApplicationListing( fileStatus: FileStatus, - lastSeen: Long): Unit = { - super.mergeApplicationListing(fileStatus, lastSeen) + lastSeen: Long, + enableSkipToEnd: Boolean): Unit = { + super.mergeApplicationListing(fileStatus, lastSeen, enableSkipToEnd) mergeApplicationListingCall += 1 } } @@ -256,14 +257,13 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc ) updateAndCheck(provider) { list => - list should not be (null) list.size should be (1) list.head.attempts.size should be (3) list.head.attempts.head.attemptId should be (Some("attempt3")) } val app2Attempt1 = newLogFile("app2", Some("attempt1"), inProgress = false) - writeFile(attempt1, true, None, + writeFile(app2Attempt1, true, None, SparkListenerApplicationStart("app2", Some("app2"), 5L, "test", Some("attempt1")), SparkListenerApplicationEnd(6L) ) @@ -649,8 +649,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc // Add more info to the app log, and trigger the provider to update things. writeFile(appLog, true, None, SparkListenerApplicationStart(appId, Some(appId), 1L, "test", None), - SparkListenerJobStart(0, 1L, Nil, null), - SparkListenerApplicationEnd(5L) + SparkListenerJobStart(0, 1L, Nil, null) ) provider.checkForLogs() @@ -668,11 +667,12 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc test("clean up stale app information") { val storeDir = Utils.createTempDir() val conf = createTestConf().set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) - val provider = spy(new FsHistoryProvider(conf)) + val clock = new ManualClock() + val provider = spy(new FsHistoryProvider(conf, clock)) val appId = "new1" // Write logs for two app attempts. - doReturn(1L).when(provider).getNewLastScanTime() + clock.advance(1) val attempt1 = newLogFile(appId, Some("1"), inProgress = false) writeFile(attempt1, true, None, SparkListenerApplicationStart(appId, Some(appId), 1L, "test", Some("1")), @@ -697,7 +697,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc // Delete the underlying log file for attempt 1 and rescan. The UI should go away, but since // attempt 2 still exists, listing data should be there. - doReturn(2L).when(provider).getNewLastScanTime() + clock.advance(1) attempt1.delete() updateAndCheck(provider) { list => assert(list.size === 1) @@ -708,7 +708,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc assert(provider.getAppUI(appId, None) === None) // Delete the second attempt's log file. Now everything should go away. - doReturn(3L).when(provider).getNewLastScanTime() + clock.advance(1) attempt2.delete() updateAndCheck(provider) { list => assert(list.isEmpty) @@ -718,9 +718,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc test("SPARK-21571: clean up removes invalid history files") { val clock = new ManualClock() val conf = createTestConf().set(MAX_LOG_AGE_S.key, s"2d") - val provider = new FsHistoryProvider(conf, clock) { - override def getNewLastScanTime(): Long = clock.getTimeMillis() - } + val provider = new FsHistoryProvider(conf, clock) // Create 0-byte size inprogress and complete files var logCount = 0 @@ -772,6 +770,54 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc assert(new File(testDir.toURI).listFiles().size === validLogCount) } + test("always find end event for finished apps") { + // Create a log file where the end event is before the configure chunk to be reparsed at + // the end of the file. The correct listing should still be generated. + val log = newLogFile("end-event-test", None, inProgress = false) + writeFile(log, true, None, + Seq( + SparkListenerApplicationStart("end-event-test", Some("end-event-test"), 1L, "test", None), + SparkListenerEnvironmentUpdate(Map( + "Spark Properties" -> Seq.empty, + "JVM Information" -> Seq.empty, + "System Properties" -> Seq.empty, + "Classpath Entries" -> Seq.empty + )), + SparkListenerApplicationEnd(5L) + ) ++ (1 to 1000).map { i => SparkListenerJobStart(i, i, Nil) }: _*) + + val conf = createTestConf().set(END_EVENT_REPARSE_CHUNK_SIZE.key, s"1k") + val provider = new FsHistoryProvider(conf) + updateAndCheck(provider) { list => + assert(list.size === 1) + assert(list(0).attempts.size === 1) + assert(list(0).attempts(0).completed) + } + } + + test("parse event logs with optimizations off") { + val conf = createTestConf() + .set(END_EVENT_REPARSE_CHUNK_SIZE, 0L) + .set(FAST_IN_PROGRESS_PARSING, false) + val provider = new FsHistoryProvider(conf) + + val complete = newLogFile("complete", None, inProgress = false) + writeFile(complete, true, None, + SparkListenerApplicationStart("complete", Some("complete"), 1L, "test", None), + SparkListenerApplicationEnd(5L) + ) + + val incomplete = newLogFile("incomplete", None, inProgress = true) + writeFile(incomplete, true, None, + SparkListenerApplicationStart("incomplete", Some("incomplete"), 1L, "test", None) + ) + + updateAndCheck(provider) { list => + list.size should be (2) + list.count(_.attempts.head.completed) should be (1) + } + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: @@ -815,7 +861,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc private def createTestConf(inMemory: Boolean = false): SparkConf = { val conf = new SparkConf() - .set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) + .set(EVENT_LOG_DIR, testDir.getAbsolutePath()) + .set(FAST_IN_PROGRESS_PARSING, true) if (!inMemory) { conf.set(LOCAL_STORE_DIR, Utils.createTempDir().getAbsolutePath()) @@ -848,4 +895,3 @@ class TestGroupsMappingProvider extends GroupMappingServiceProvider { mappings.get(username).map(Set(_)).getOrElse(Set.empty) } } - From 3cb82047f2f51af553df09b9323796af507d36f8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 11 Apr 2018 10:13:44 -0500 Subject: [PATCH 603/774] [SPARK-22941][CORE] Do not exit JVM when submit fails with in-process launcher. The current in-process launcher implementation just calls the SparkSubmit object, which, in case of errors, will more often than not exit the JVM. This is not desirable since this launcher is meant to be used inside other applications, and that would kill the application. The change turns SparkSubmit into a class, and abstracts aways some of the functionality used to print error messages and abort the submission process. The default implementation uses the logging system for messages, and throws exceptions for errors. As part of that I also moved some code that doesn't really belong in SparkSubmit to a better location. The command line invocation of spark-submit now uses a special implementation of the SparkSubmit class that overrides those behaviors to do what is expected from the command line version (print to the terminal, exit the JVM, etc). A lot of the changes are to replace calls to methods such as "printErrorAndExit" with the new API. As part of adding tests for this, I had to fix some small things in the launcher option parser so that things like "--version" can work when used in the launcher library. There is still code that prints directly to the terminal, like all the Ivy-related code in SparkSubmitUtils, and other areas where some re-factoring would help, like the CommandLineUtils class, but I chose to leave those alone to keep this change more focused. Aside from existing and added unit tests, I ran command line tools with a bunch of different arguments to make sure messages and errors behave like before. Author: Marcelo Vanzin Closes #20925 from vanzin/SPARK-22941. --- .../apache/spark/deploy/DependencyUtils.scala | 30 +- .../org/apache/spark/deploy/SparkSubmit.scala | 318 +++++++++--------- .../spark/deploy/SparkSubmitArguments.scala | 90 +++-- .../spark/deploy/worker/DriverWrapper.scala | 4 +- .../apache/spark/util/CommandLineUtils.scala | 18 +- .../spark/launcher/SparkLauncherSuite.java | 37 +- .../spark/deploy/SparkSubmitSuite.scala | 69 ++-- .../rest/StandaloneRestSubmitSuite.scala | 2 +- .../spark/launcher/AbstractLauncher.java | 6 +- .../spark/launcher/InProcessLauncher.java | 14 +- .../launcher/SparkSubmitCommandBuilder.java | 82 +++-- project/MimaExcludes.scala | 7 +- .../deploy/mesos/MesosClusterDispatcher.scala | 10 +- .../MesosClusterDispatcherArguments.scala | 6 +- 14 files changed, 401 insertions(+), 292 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala index fac834a70b893..178bdcfccb603 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -25,9 +25,10 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.{SecurityManager, SparkConf, SparkException} +import org.apache.spark.internal.Logging import org.apache.spark.util.{MutableURLClassLoader, Utils} -private[deploy] object DependencyUtils { +private[deploy] object DependencyUtils extends Logging { def resolveMavenDependencies( packagesExclusions: String, @@ -75,7 +76,7 @@ private[deploy] object DependencyUtils { def addJarsToClassPath(jars: String, loader: MutableURLClassLoader): Unit = { if (jars != null) { for (jar <- jars.split(",")) { - SparkSubmit.addJarToClasspath(jar, loader) + addJarToClasspath(jar, loader) } } } @@ -151,6 +152,31 @@ private[deploy] object DependencyUtils { }.mkString(",") } + def addJarToClasspath(localJar: String, loader: MutableURLClassLoader): Unit = { + val uri = Utils.resolveURI(localJar) + uri.getScheme match { + case "file" | "local" => + val file = new File(uri.getPath) + if (file.exists()) { + loader.addURL(file.toURI.toURL) + } else { + logWarning(s"Local jar $file does not exist, skipping.") + } + case _ => + logWarning(s"Skip remote jar $uri.") + } + } + + /** + * Merge a sequence of comma-separated file lists, some of which may be null to indicate + * no files, into a single comma-separated string. + */ + def mergeFileLists(lists: String*): String = { + val merged = lists.filterNot(StringUtils.isBlank) + .flatMap(Utils.stringToSeq) + if (merged.nonEmpty) merged.mkString(",") else null + } + private def splitOnFragment(path: String): (URI, Option[String]) = { val uri = Utils.resolveURI(path) val withoutFragment = new URI(uri.getScheme, uri.getSchemeSpecificPart, null) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index eddbedeb1024d..427c797755b84 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -58,7 +58,7 @@ import org.apache.spark.util._ */ private[deploy] object SparkSubmitAction extends Enumeration { type SparkSubmitAction = Value - val SUBMIT, KILL, REQUEST_STATUS = Value + val SUBMIT, KILL, REQUEST_STATUS, PRINT_VERSION = Value } /** @@ -67,78 +67,32 @@ private[deploy] object SparkSubmitAction extends Enumeration { * This program handles setting up the classpath with relevant Spark dependencies and provides * a layer over the different cluster managers and deploy modes that Spark supports. */ -object SparkSubmit extends CommandLineUtils with Logging { +private[spark] class SparkSubmit extends Logging { import DependencyUtils._ + import SparkSubmit._ - // Cluster managers - private val YARN = 1 - private val STANDALONE = 2 - private val MESOS = 4 - private val LOCAL = 8 - private val KUBERNETES = 16 - private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL | KUBERNETES - - // Deploy modes - private val CLIENT = 1 - private val CLUSTER = 2 - private val ALL_DEPLOY_MODES = CLIENT | CLUSTER - - // Special primary resource names that represent shells rather than application jars. - private val SPARK_SHELL = "spark-shell" - private val PYSPARK_SHELL = "pyspark-shell" - private val SPARKR_SHELL = "sparkr-shell" - private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip" - private val R_PACKAGE_ARCHIVE = "rpkg.zip" - - private val CLASS_NOT_FOUND_EXIT_STATUS = 101 - - // Following constants are visible for testing. - private[deploy] val YARN_CLUSTER_SUBMIT_CLASS = - "org.apache.spark.deploy.yarn.YarnClusterApplication" - private[deploy] val REST_CLUSTER_SUBMIT_CLASS = classOf[RestSubmissionClientApp].getName() - private[deploy] val STANDALONE_CLUSTER_SUBMIT_CLASS = classOf[ClientApp].getName() - private[deploy] val KUBERNETES_CLUSTER_SUBMIT_CLASS = - "org.apache.spark.deploy.k8s.submit.KubernetesClientApplication" - - // scalastyle:off println - private[spark] def printVersionAndExit(): Unit = { - printStream.println("""Welcome to - ____ __ - / __/__ ___ _____/ /__ - _\ \/ _ \/ _ `/ __/ '_/ - /___/ .__/\_,_/_/ /_/\_\ version %s - /_/ - """.format(SPARK_VERSION)) - printStream.println("Using Scala %s, %s, %s".format( - Properties.versionString, Properties.javaVmName, Properties.javaVersion)) - printStream.println("Branch %s".format(SPARK_BRANCH)) - printStream.println("Compiled by user %s on %s".format(SPARK_BUILD_USER, SPARK_BUILD_DATE)) - printStream.println("Revision %s".format(SPARK_REVISION)) - printStream.println("Url %s".format(SPARK_REPO_URL)) - printStream.println("Type --help for more information.") - exitFn(0) - } - // scalastyle:on println - - override def main(args: Array[String]): Unit = { + def doSubmit(args: Array[String]): Unit = { // Initialize logging if it hasn't been done yet. Keep track of whether logging needs to // be reset before the application starts. val uninitLog = initializeLogIfNecessary(true, silent = true) - val appArgs = new SparkSubmitArguments(args) + val appArgs = parseArguments(args) if (appArgs.verbose) { - // scalastyle:off println - printStream.println(appArgs) - // scalastyle:on println + logInfo(appArgs.toString) } appArgs.action match { case SparkSubmitAction.SUBMIT => submit(appArgs, uninitLog) case SparkSubmitAction.KILL => kill(appArgs) case SparkSubmitAction.REQUEST_STATUS => requestStatus(appArgs) + case SparkSubmitAction.PRINT_VERSION => printVersion() } } + protected def parseArguments(args: Array[String]): SparkSubmitArguments = { + new SparkSubmitArguments(args) + } + /** * Kill an existing submission using the REST protocol. Standalone and Mesos cluster mode only. */ @@ -156,6 +110,24 @@ object SparkSubmit extends CommandLineUtils with Logging { .requestSubmissionStatus(args.submissionToRequestStatusFor) } + /** Print version information to the log. */ + private def printVersion(): Unit = { + logInfo("""Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ version %s + /_/ + """.format(SPARK_VERSION)) + logInfo("Using Scala %s, %s, %s".format( + Properties.versionString, Properties.javaVmName, Properties.javaVersion)) + logInfo(s"Branch $SPARK_BRANCH") + logInfo(s"Compiled by user $SPARK_BUILD_USER on $SPARK_BUILD_DATE") + logInfo(s"Revision $SPARK_REVISION") + logInfo(s"Url $SPARK_REPO_URL") + logInfo("Type --help for more information.") + } + /** * Submit the application using the provided parameters. * @@ -185,10 +157,7 @@ object SparkSubmit extends CommandLineUtils with Logging { // makes the message printed to the output by the JVM not very helpful. Instead, // detect exceptions with empty stack traces here, and treat them differently. if (e.getStackTrace().length == 0) { - // scalastyle:off println - printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") - // scalastyle:on println - exitFn(1) + error(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") } else { throw e } @@ -210,14 +179,11 @@ object SparkSubmit extends CommandLineUtils with Logging { // to use the legacy gateway if the master endpoint turns out to be not a REST server. if (args.isStandaloneCluster && args.useRest) { try { - // scalastyle:off println - printStream.println("Running Spark using the REST application submission protocol.") - // scalastyle:on println - doRunMain() + logInfo("Running Spark using the REST application submission protocol.") } catch { // Fail over to use the legacy submission gateway case e: SubmitRestConnectionException => - printWarning(s"Master endpoint ${args.master} was not a REST server. " + + logWarning(s"Master endpoint ${args.master} was not a REST server. " + "Falling back to legacy submission gateway instead.") args.useRest = false submit(args, false) @@ -245,19 +211,6 @@ object SparkSubmit extends CommandLineUtils with Logging { args: SparkSubmitArguments, conf: Option[HadoopConfiguration] = None) : (Seq[String], Seq[String], SparkConf, String) = { - try { - doPrepareSubmitEnvironment(args, conf) - } catch { - case e: SparkException => - printErrorAndExit(e.getMessage) - throw e - } - } - - private def doPrepareSubmitEnvironment( - args: SparkSubmitArguments, - conf: Option[HadoopConfiguration] = None) - : (Seq[String], Seq[String], SparkConf, String) = { // Return values val childArgs = new ArrayBuffer[String]() val childClasspath = new ArrayBuffer[String]() @@ -268,7 +221,7 @@ object SparkSubmit extends CommandLineUtils with Logging { val clusterManager: Int = args.master match { case "yarn" => YARN case "yarn-client" | "yarn-cluster" => - printWarning(s"Master ${args.master} is deprecated since 2.0." + + logWarning(s"Master ${args.master} is deprecated since 2.0." + " Please use master \"yarn\" with specified deploy mode instead.") YARN case m if m.startsWith("spark") => STANDALONE @@ -276,7 +229,7 @@ object SparkSubmit extends CommandLineUtils with Logging { case m if m.startsWith("k8s") => KUBERNETES case m if m.startsWith("local") => LOCAL case _ => - printErrorAndExit("Master must either be yarn or start with spark, mesos, k8s, or local") + error("Master must either be yarn or start with spark, mesos, k8s, or local") -1 } @@ -284,7 +237,9 @@ object SparkSubmit extends CommandLineUtils with Logging { var deployMode: Int = args.deployMode match { case "client" | null => CLIENT case "cluster" => CLUSTER - case _ => printErrorAndExit("Deploy mode must be either client or cluster"); -1 + case _ => + error("Deploy mode must be either client or cluster") + -1 } // Because the deprecated way of specifying "yarn-cluster" and "yarn-client" encapsulate both @@ -296,16 +251,16 @@ object SparkSubmit extends CommandLineUtils with Logging { deployMode = CLUSTER args.master = "yarn" case ("yarn-cluster", "client") => - printErrorAndExit("Client deploy mode is not compatible with master \"yarn-cluster\"") + error("Client deploy mode is not compatible with master \"yarn-cluster\"") case ("yarn-client", "cluster") => - printErrorAndExit("Cluster deploy mode is not compatible with master \"yarn-client\"") + error("Cluster deploy mode is not compatible with master \"yarn-client\"") case (_, mode) => args.master = "yarn" } // Make sure YARN is included in our build if we're trying to use it if (!Utils.classIsLoadable(YARN_CLUSTER_SUBMIT_CLASS) && !Utils.isTesting) { - printErrorAndExit( + error( "Could not load YARN classes. " + "This copy of Spark may not have been compiled with YARN support.") } @@ -315,7 +270,7 @@ object SparkSubmit extends CommandLineUtils with Logging { args.master = Utils.checkAndGetK8sMasterUrl(args.master) // Make sure KUBERNETES is included in our build if we're trying to use it if (!Utils.classIsLoadable(KUBERNETES_CLUSTER_SUBMIT_CLASS) && !Utils.isTesting) { - printErrorAndExit( + error( "Could not load KUBERNETES classes. " + "This copy of Spark may not have been compiled with KUBERNETES support.") } @@ -324,23 +279,23 @@ object SparkSubmit extends CommandLineUtils with Logging { // Fail fast, the following modes are not supported or applicable (clusterManager, deployMode) match { case (STANDALONE, CLUSTER) if args.isPython => - printErrorAndExit("Cluster deploy mode is currently not supported for python " + + error("Cluster deploy mode is currently not supported for python " + "applications on standalone clusters.") case (STANDALONE, CLUSTER) if args.isR => - printErrorAndExit("Cluster deploy mode is currently not supported for R " + + error("Cluster deploy mode is currently not supported for R " + "applications on standalone clusters.") case (KUBERNETES, _) if args.isPython => - printErrorAndExit("Python applications are currently not supported for Kubernetes.") + error("Python applications are currently not supported for Kubernetes.") case (KUBERNETES, _) if args.isR => - printErrorAndExit("R applications are currently not supported for Kubernetes.") + error("R applications are currently not supported for Kubernetes.") case (LOCAL, CLUSTER) => - printErrorAndExit("Cluster deploy mode is not compatible with master \"local\"") + error("Cluster deploy mode is not compatible with master \"local\"") case (_, CLUSTER) if isShell(args.primaryResource) => - printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") + error("Cluster deploy mode is not applicable to Spark shells.") case (_, CLUSTER) if isSqlShell(args.mainClass) => - printErrorAndExit("Cluster deploy mode is not applicable to Spark SQL shell.") + error("Cluster deploy mode is not applicable to Spark SQL shell.") case (_, CLUSTER) if isThriftServer(args.mainClass) => - printErrorAndExit("Cluster deploy mode is not applicable to Spark Thrift server.") + error("Cluster deploy mode is not applicable to Spark Thrift server.") case _ => } @@ -493,11 +448,11 @@ object SparkSubmit extends CommandLineUtils with Logging { if (args.isR && clusterManager == YARN) { val sparkRPackagePath = RUtils.localSparkRPackagePath if (sparkRPackagePath.isEmpty) { - printErrorAndExit("SPARK_HOME does not exist for R application in YARN mode.") + error("SPARK_HOME does not exist for R application in YARN mode.") } val sparkRPackageFile = new File(sparkRPackagePath.get, SPARKR_PACKAGE_ARCHIVE) if (!sparkRPackageFile.exists()) { - printErrorAndExit(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.") + error(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.") } val sparkRPackageURI = Utils.resolveURI(sparkRPackageFile.getAbsolutePath).toString @@ -510,7 +465,7 @@ object SparkSubmit extends CommandLineUtils with Logging { val rPackageFile = RPackageUtils.zipRLibraries(new File(RUtils.rPackages.get), R_PACKAGE_ARCHIVE) if (!rPackageFile.exists()) { - printErrorAndExit("Failed to zip all the built R packages.") + error("Failed to zip all the built R packages.") } val rPackageURI = Utils.resolveURI(rPackageFile.getAbsolutePath).toString @@ -521,12 +476,12 @@ object SparkSubmit extends CommandLineUtils with Logging { // TODO: Support distributing R packages with standalone cluster if (args.isR && clusterManager == STANDALONE && !RUtils.rPackages.isEmpty) { - printErrorAndExit("Distributing R packages with standalone cluster is not supported.") + error("Distributing R packages with standalone cluster is not supported.") } // TODO: Support distributing R packages with mesos cluster if (args.isR && clusterManager == MESOS && !RUtils.rPackages.isEmpty) { - printErrorAndExit("Distributing R packages with mesos cluster is not supported.") + error("Distributing R packages with mesos cluster is not supported.") } // If we're running an R app, set the main class to our specific R runner @@ -799,9 +754,7 @@ object SparkSubmit extends CommandLineUtils with Logging { private def setRMPrincipal(sparkConf: SparkConf): Unit = { val shortUserName = UserGroupInformation.getCurrentUser.getShortUserName val key = s"spark.hadoop.${YarnConfiguration.RM_PRINCIPAL}" - // scalastyle:off println - printStream.println(s"Setting ${key} to ${shortUserName}") - // scalastyle:off println + logInfo(s"Setting ${key} to ${shortUserName}") sparkConf.set(key, shortUserName) } @@ -817,16 +770,14 @@ object SparkSubmit extends CommandLineUtils with Logging { sparkConf: SparkConf, childMainClass: String, verbose: Boolean): Unit = { - // scalastyle:off println if (verbose) { - printStream.println(s"Main class:\n$childMainClass") - printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") + logInfo(s"Main class:\n$childMainClass") + logInfo(s"Arguments:\n${childArgs.mkString("\n")}") // sysProps may contain sensitive information, so redact before printing - printStream.println(s"Spark config:\n${Utils.redact(sparkConf.getAll.toMap).mkString("\n")}") - printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}") - printStream.println("\n") + logInfo(s"Spark config:\n${Utils.redact(sparkConf.getAll.toMap).mkString("\n")}") + logInfo(s"Classpath elements:\n${childClasspath.mkString("\n")}") + logInfo("\n") } - // scalastyle:on println val loader = if (sparkConf.get(DRIVER_USER_CLASS_PATH_FIRST)) { @@ -848,23 +799,19 @@ object SparkSubmit extends CommandLineUtils with Logging { mainClass = Utils.classForName(childMainClass) } catch { case e: ClassNotFoundException => - e.printStackTrace(printStream) + logWarning(s"Failed to load $childMainClass.", e) if (childMainClass.contains("thriftserver")) { - // scalastyle:off println - printStream.println(s"Failed to load main class $childMainClass.") - printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") - // scalastyle:on println + logInfo(s"Failed to load main class $childMainClass.") + logInfo("You need to build Spark with -Phive and -Phive-thriftserver.") } - System.exit(CLASS_NOT_FOUND_EXIT_STATUS) + throw new SparkUserAppException(CLASS_NOT_FOUND_EXIT_STATUS) case e: NoClassDefFoundError => - e.printStackTrace(printStream) + logWarning(s"Failed to load $childMainClass: ${e.getMessage()}") if (e.getMessage.contains("org/apache/hadoop/hive")) { - // scalastyle:off println - printStream.println(s"Failed to load hive class.") - printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") - // scalastyle:on println + logInfo(s"Failed to load hive class.") + logInfo("You need to build Spark with -Phive and -Phive-thriftserver.") } - System.exit(CLASS_NOT_FOUND_EXIT_STATUS) + throw new SparkUserAppException(CLASS_NOT_FOUND_EXIT_STATUS) } val app: SparkApplication = if (classOf[SparkApplication].isAssignableFrom(mainClass)) { @@ -872,7 +819,7 @@ object SparkSubmit extends CommandLineUtils with Logging { } else { // SPARK-4170 if (classOf[scala.App].isAssignableFrom(mainClass)) { - printWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.") + logWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.") } new JavaMainApplication(mainClass) } @@ -891,29 +838,90 @@ object SparkSubmit extends CommandLineUtils with Logging { app.start(childArgs.toArray, sparkConf) } catch { case t: Throwable => - findCause(t) match { - case SparkUserAppException(exitCode) => - System.exit(exitCode) - - case t: Throwable => - throw t - } + throw findCause(t) } } - private[deploy] def addJarToClasspath(localJar: String, loader: MutableURLClassLoader) { - val uri = Utils.resolveURI(localJar) - uri.getScheme match { - case "file" | "local" => - val file = new File(uri.getPath) - if (file.exists()) { - loader.addURL(file.toURI.toURL) - } else { - printWarning(s"Local jar $file does not exist, skipping.") + /** Throw a SparkException with the given error message. */ + private def error(msg: String): Unit = throw new SparkException(msg) + +} + + +/** + * This entry point is used by the launcher library to start in-process Spark applications. + */ +private[spark] object InProcessSparkSubmit { + + def main(args: Array[String]): Unit = { + val submit = new SparkSubmit() + submit.doSubmit(args) + } + +} + +object SparkSubmit extends CommandLineUtils with Logging { + + // Cluster managers + private val YARN = 1 + private val STANDALONE = 2 + private val MESOS = 4 + private val LOCAL = 8 + private val KUBERNETES = 16 + private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL | KUBERNETES + + // Deploy modes + private val CLIENT = 1 + private val CLUSTER = 2 + private val ALL_DEPLOY_MODES = CLIENT | CLUSTER + + // Special primary resource names that represent shells rather than application jars. + private val SPARK_SHELL = "spark-shell" + private val PYSPARK_SHELL = "pyspark-shell" + private val SPARKR_SHELL = "sparkr-shell" + private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip" + private val R_PACKAGE_ARCHIVE = "rpkg.zip" + + private val CLASS_NOT_FOUND_EXIT_STATUS = 101 + + // Following constants are visible for testing. + private[deploy] val YARN_CLUSTER_SUBMIT_CLASS = + "org.apache.spark.deploy.yarn.YarnClusterApplication" + private[deploy] val REST_CLUSTER_SUBMIT_CLASS = classOf[RestSubmissionClientApp].getName() + private[deploy] val STANDALONE_CLUSTER_SUBMIT_CLASS = classOf[ClientApp].getName() + private[deploy] val KUBERNETES_CLUSTER_SUBMIT_CLASS = + "org.apache.spark.deploy.k8s.submit.KubernetesClientApplication" + + override def main(args: Array[String]): Unit = { + val submit = new SparkSubmit() { + self => + + override protected def parseArguments(args: Array[String]): SparkSubmitArguments = { + new SparkSubmitArguments(args) { + override protected def logInfo(msg: => String): Unit = self.logInfo(msg) + + override protected def logWarning(msg: => String): Unit = self.logWarning(msg) } - case _ => - printWarning(s"Skip remote jar $uri.") + } + + override protected def logInfo(msg: => String): Unit = printMessage(msg) + + override protected def logWarning(msg: => String): Unit = printMessage(s"Warning: $msg") + + override def doSubmit(args: Array[String]): Unit = { + try { + super.doSubmit(args) + } catch { + case e: SparkUserAppException => + exitFn(e.exitCode) + case e: SparkException => + printErrorAndExit(e.getMessage()) + } + } + } + + submit.doSubmit(args) } /** @@ -962,17 +970,6 @@ object SparkSubmit extends CommandLineUtils with Logging { res == SparkLauncher.NO_RESOURCE } - /** - * Merge a sequence of comma-separated file lists, some of which may be null to indicate - * no files, into a single comma-separated string. - */ - private[deploy] def mergeFileLists(lists: String*): String = { - val merged = lists.filterNot(StringUtils.isBlank) - .flatMap(_.split(",")) - .mkString(",") - if (merged == "") null else merged - } - } /** Provides utility functions to be used inside SparkSubmit. */ @@ -1000,12 +997,12 @@ private[spark] object SparkSubmitUtils { override def toString: String = s"$groupId:$artifactId:$version" } -/** - * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided - * in the format `groupId:artifactId:version` or `groupId/artifactId:version`. - * @param coordinates Comma-delimited string of maven coordinates - * @return Sequence of Maven coordinates - */ + /** + * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided + * in the format `groupId:artifactId:version` or `groupId/artifactId:version`. + * @param coordinates Comma-delimited string of maven coordinates + * @return Sequence of Maven coordinates + */ def extractMavenCoordinates(coordinates: String): Seq[MavenCoordinate] = { coordinates.split(",").map { p => val splits = p.replace("/", ":").split(":") @@ -1304,6 +1301,13 @@ private[spark] object SparkSubmitUtils { rule } + def parseSparkConfProperty(pair: String): (String, String) = { + pair.split("=", 2).toSeq match { + case Seq(k, v) => (k, v) + case _ => throw new SparkException(s"Spark config without '=': $pair") + } + } + } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 8e7070593687b..0733fdb72cafb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -29,7 +29,9 @@ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.io.Source import scala.util.Try +import org.apache.spark.{SparkException, SparkUserAppException} import org.apache.spark.deploy.SparkSubmitAction._ +import org.apache.spark.internal.Logging import org.apache.spark.launcher.SparkSubmitArgumentsParser import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.Utils @@ -40,7 +42,7 @@ import org.apache.spark.util.Utils * The env argument is used for testing. */ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, String] = sys.env) - extends SparkSubmitArgumentsParser { + extends SparkSubmitArgumentsParser with Logging { var master: String = null var deployMode: String = null var executorMemory: String = null @@ -85,8 +87,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S /** Default properties present in the currently defined defaults file. */ lazy val defaultSparkProperties: HashMap[String, String] = { val defaultProperties = new HashMap[String, String]() - // scalastyle:off println - if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") + if (verbose) { + logInfo(s"Using properties file: $propertiesFile") + } Option(propertiesFile).foreach { filename => val properties = Utils.getPropertiesFromFile(filename) properties.foreach { case (k, v) => @@ -95,21 +98,16 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S // Property files may contain sensitive information, so redact before printing if (verbose) { Utils.redact(properties).foreach { case (k, v) => - SparkSubmit.printStream.println(s"Adding default property: $k=$v") + logInfo(s"Adding default property: $k=$v") } } } - // scalastyle:on println defaultProperties } // Set parameters from command line arguments - try { - parse(args.asJava) - } catch { - case e: IllegalArgumentException => - SparkSubmit.printErrorAndExit(e.getMessage()) - } + parse(args.asJava) + // Populate `sparkProperties` map from properties file mergeDefaultSparkProperties() // Remove keys that don't start with "spark." from `sparkProperties`. @@ -141,7 +139,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S sparkProperties.foreach { case (k, v) => if (!k.startsWith("spark.")) { sparkProperties -= k - SparkSubmit.printWarning(s"Ignoring non-spark config property: $k=$v") + logWarning(s"Ignoring non-spark config property: $k=$v") } } } @@ -215,10 +213,10 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } } catch { case _: Exception => - SparkSubmit.printErrorAndExit(s"Cannot load main class from JAR $primaryResource") + error(s"Cannot load main class from JAR $primaryResource") } case _ => - SparkSubmit.printErrorAndExit( + error( s"Cannot load main class from JAR $primaryResource with URI $uriScheme. " + "Please specify a class through --class.") } @@ -248,6 +246,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case SUBMIT => validateSubmitArguments() case KILL => validateKillArguments() case REQUEST_STATUS => validateStatusRequestArguments() + case PRINT_VERSION => } } @@ -256,62 +255,61 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S printUsageAndExit(-1) } if (primaryResource == null) { - SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python or R file)") + error("Must specify a primary resource (JAR or Python or R file)") } if (mainClass == null && SparkSubmit.isUserJar(primaryResource)) { - SparkSubmit.printErrorAndExit("No main class set in JAR; please specify one with --class") + error("No main class set in JAR; please specify one with --class") } if (driverMemory != null && Try(JavaUtils.byteStringAsBytes(driverMemory)).getOrElse(-1L) <= 0) { - SparkSubmit.printErrorAndExit("Driver Memory must be a positive number") + error("Driver memory must be a positive number") } if (executorMemory != null && Try(JavaUtils.byteStringAsBytes(executorMemory)).getOrElse(-1L) <= 0) { - SparkSubmit.printErrorAndExit("Executor Memory cores must be a positive number") + error("Executor memory must be a positive number") } if (executorCores != null && Try(executorCores.toInt).getOrElse(-1) <= 0) { - SparkSubmit.printErrorAndExit("Executor cores must be a positive number") + error("Executor cores must be a positive number") } if (totalExecutorCores != null && Try(totalExecutorCores.toInt).getOrElse(-1) <= 0) { - SparkSubmit.printErrorAndExit("Total executor cores must be a positive number") + error("Total executor cores must be a positive number") } if (numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) { - SparkSubmit.printErrorAndExit("Number of executors must be a positive number") + error("Number of executors must be a positive number") } if (pyFiles != null && !isPython) { - SparkSubmit.printErrorAndExit("--py-files given but primary resource is not a Python script") + error("--py-files given but primary resource is not a Python script") } if (master.startsWith("yarn")) { val hasHadoopEnv = env.contains("HADOOP_CONF_DIR") || env.contains("YARN_CONF_DIR") if (!hasHadoopEnv && !Utils.isTesting) { - throw new Exception(s"When running with master '$master' " + + error(s"When running with master '$master' " + "either HADOOP_CONF_DIR or YARN_CONF_DIR must be set in the environment.") } } if (proxyUser != null && principal != null) { - SparkSubmit.printErrorAndExit("Only one of --proxy-user or --principal can be provided.") + error("Only one of --proxy-user or --principal can be provided.") } } private def validateKillArguments(): Unit = { if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { - SparkSubmit.printErrorAndExit( - "Killing submissions is only supported in standalone or Mesos mode!") + error("Killing submissions is only supported in standalone or Mesos mode!") } if (submissionToKill == null) { - SparkSubmit.printErrorAndExit("Please specify a submission to kill.") + error("Please specify a submission to kill.") } } private def validateStatusRequestArguments(): Unit = { if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { - SparkSubmit.printErrorAndExit( + error( "Requesting submission statuses is only supported in standalone or Mesos mode!") } if (submissionToRequestStatusFor == null) { - SparkSubmit.printErrorAndExit("Please specify a submission to request status for.") + error("Please specify a submission to request status for.") } } @@ -368,7 +366,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case DEPLOY_MODE => if (value != "client" && value != "cluster") { - SparkSubmit.printErrorAndExit("--deploy-mode must be either \"client\" or \"cluster\"") + error("--deploy-mode must be either \"client\" or \"cluster\"") } deployMode = value @@ -405,14 +403,14 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case KILL_SUBMISSION => submissionToKill = value if (action != null) { - SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $KILL.") + error(s"Action cannot be both $action and $KILL.") } action = KILL case STATUS => submissionToRequestStatusFor = value if (action != null) { - SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $REQUEST_STATUS.") + error(s"Action cannot be both $action and $REQUEST_STATUS.") } action = REQUEST_STATUS @@ -444,7 +442,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S repositories = value case CONF => - val (confName, confValue) = SparkSubmit.parseSparkConfProperty(value) + val (confName, confValue) = SparkSubmitUtils.parseSparkConfProperty(value) sparkProperties(confName) = confValue case PROXY_USER => @@ -463,15 +461,15 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S verbose = true case VERSION => - SparkSubmit.printVersionAndExit() + action = SparkSubmitAction.PRINT_VERSION case USAGE_ERROR => printUsageAndExit(1) case _ => - throw new IllegalArgumentException(s"Unexpected argument '$opt'.") + error(s"Unexpected argument '$opt'.") } - true + action != SparkSubmitAction.PRINT_VERSION } /** @@ -482,7 +480,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S */ override protected def handleUnknown(opt: String): Boolean = { if (opt.startsWith("-")) { - SparkSubmit.printErrorAndExit(s"Unrecognized option '$opt'.") + error(s"Unrecognized option '$opt'.") } primaryResource = @@ -501,20 +499,18 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = { - // scalastyle:off println - val outStream = SparkSubmit.printStream if (unknownParam != null) { - outStream.println("Unknown/unsupported param " + unknownParam) + logInfo("Unknown/unsupported param " + unknownParam) } val command = sys.env.get("_SPARK_CMD_USAGE").getOrElse( """Usage: spark-submit [options] [app arguments] |Usage: spark-submit --kill [submission ID] --master [spark://...] |Usage: spark-submit --status [submission ID] --master [spark://...] |Usage: spark-submit run-example [options] example-class [example args]""".stripMargin) - outStream.println(command) + logInfo(command) val mem_mb = Utils.DEFAULT_DRIVER_MEM_MB - outStream.println( + logInfo( s""" |Options: | --master MASTER_URL spark://host:port, mesos://host:port, yarn, @@ -596,12 +592,11 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S ) if (SparkSubmit.isSqlShell(mainClass)) { - outStream.println("CLI options:") - outStream.println(getSqlShellOptions()) + logInfo("CLI options:") + logInfo(getSqlShellOptions()) } - // scalastyle:on println - SparkSubmit.exitFn(exitCode) + throw new SparkUserAppException(exitCode) } /** @@ -655,4 +650,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S System.setErr(currentErr) } } + + private def error(msg: String): Unit = throw new SparkException(msg) + } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index 3f71237164a15..8d6a2b80ef5f2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -25,7 +25,7 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.{DependencyUtils, SparkHadoopUtil, SparkSubmit} import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEnv -import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} +import org.apache.spark.util._ /** * Utility object for launching driver programs such that they share fate with the Worker process. @@ -93,7 +93,7 @@ object DriverWrapper extends Logging { val jars = { val jarsProp = sys.props.get("spark.jars").orNull if (!StringUtils.isBlank(resolvedMavenCoordinates)) { - SparkSubmit.mergeFileLists(jarsProp, resolvedMavenCoordinates) + DependencyUtils.mergeFileLists(jarsProp, resolvedMavenCoordinates) } else { jarsProp } diff --git a/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala b/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala index d73901686b705..4b6602b50aa1c 100644 --- a/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala @@ -33,24 +33,14 @@ private[spark] trait CommandLineUtils { private[spark] var printStream: PrintStream = System.err // scalastyle:off println - - private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) + private[spark] def printMessage(str: String): Unit = printStream.println(str) + // scalastyle:on println private[spark] def printErrorAndExit(str: String): Unit = { - printStream.println("Error: " + str) - printStream.println("Run with --help for usage help or --verbose for debug output") + printMessage("Error: " + str) + printMessage("Run with --help for usage help or --verbose for debug output") exitFn(1) } - // scalastyle:on println - - private[spark] def parseSparkConfProperty(pair: String): (String, String) = { - pair.split("=", 2).toSeq match { - case Seq(k, v) => (k, v) - case _ => printErrorAndExit(s"Spark config without '=': $pair") - throw new SparkException(s"Spark config without '=': $pair") - } - } - def main(args: Array[String]): Unit } diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 2225591a4ff75..6a1a38c1a54f4 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -109,7 +109,7 @@ public void testChildProcLauncher() throws Exception { .addSparkArg(opts.CONF, String.format("%s=-Dfoo=ShouldBeOverriddenBelow", SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS)) .setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, - "-Dfoo=bar -Dtest.appender=childproc") + "-Dfoo=bar -Dtest.appender=console") .setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path")) .addSparkArg(opts.CLASS, "ShouldBeOverriddenBelow") .setMainClass(SparkLauncherTestApp.class.getName()) @@ -192,6 +192,41 @@ private void inProcessLauncherTestImpl() throws Exception { } } + @Test + public void testInProcessLauncherDoesNotKillJvm() throws Exception { + SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); + List wrongArgs = Arrays.asList( + new String[] { "--unknown" }, + new String[] { opts.DEPLOY_MODE, "invalid" }); + + for (String[] args : wrongArgs) { + InProcessLauncher launcher = new InProcessLauncher() + .setAppResource(SparkLauncher.NO_RESOURCE); + switch (args.length) { + case 2: + launcher.addSparkArg(args[0], args[1]); + break; + + case 1: + launcher.addSparkArg(args[0]); + break; + + default: + fail("FIXME: invalid test."); + } + + SparkAppHandle handle = launcher.startApplication(); + waitFor(handle); + assertEquals(SparkAppHandle.State.FAILED, handle.getState()); + } + + // Run --version, which is useless as a use case, but should succeed and not exit the JVM. + // The expected state is "LOST" since "--version" doesn't report state back to the handle. + SparkAppHandle handle = new InProcessLauncher().addSparkArg(opts.VERSION).startApplication(); + waitFor(handle); + assertEquals(SparkAppHandle.State.LOST, handle.getState()); + } + public static class SparkLauncherTestApp { public static void main(String[] args) throws Exception { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 0d7c342a5eacd..7451e07b25a1f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -42,6 +42,7 @@ import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.launcher.SparkLauncher import org.apache.spark.scheduler.EventLoggingListener import org.apache.spark.util.{CommandLineUtils, ResetSystemProperties, Utils} @@ -109,6 +110,8 @@ class SparkSubmitSuite private val emptyIvySettings = File.createTempFile("ivy", ".xml") FileUtils.write(emptyIvySettings, "", StandardCharsets.UTF_8) + private val submit = new SparkSubmit() + override def beforeEach() { super.beforeEach() } @@ -128,13 +131,16 @@ class SparkSubmitSuite } test("handle binary specified but not class") { - testPrematureExit(Array("foo.jar"), "No main class") + val jar = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) + testPrematureExit(Array(jar.toString()), "No main class") } test("handles arguments with --key=val") { val clArgs = Seq( "--jars=one.jar,two.jar,three.jar", - "--name=myApp") + "--name=myApp", + "--class=org.FooBar", + SparkLauncher.NO_RESOURCE) val appArgs = new SparkSubmitArguments(clArgs) appArgs.jars should include regex (".*one.jar,.*two.jar,.*three.jar") appArgs.name should be ("myApp") @@ -182,7 +188,7 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, conf, _) = prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) appArgs.deployMode should be ("client") conf.get("spark.submit.deployMode") should be ("client") @@ -192,11 +198,11 @@ class SparkSubmitSuite "--master", "yarn", "--deploy-mode", "cluster", "--conf", "spark.submit.deployMode=client", - "-class", "org.SomeClass", + "--class", "org.SomeClass", "thejar.jar" ) val appArgs1 = new SparkSubmitArguments(clArgs1) - val (_, _, conf1, _) = prepareSubmitEnvironment(appArgs1) + val (_, _, conf1, _) = submit.prepareSubmitEnvironment(appArgs1) appArgs1.deployMode should be ("cluster") conf1.get("spark.submit.deployMode") should be ("cluster") @@ -210,7 +216,7 @@ class SparkSubmitSuite val appArgs2 = new SparkSubmitArguments(clArgs2) appArgs2.deployMode should be (null) - val (_, _, conf2, _) = prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) appArgs2.deployMode should be ("client") conf2.get("spark.submit.deployMode") should be ("client") } @@ -233,7 +239,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") childArgsStr should include ("--class org.SomeClass") childArgsStr should include ("--arg arg1 --arg arg2") @@ -276,7 +282,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (4) @@ -322,7 +328,7 @@ class SparkSubmitSuite "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) appArgs.useRest = useRest - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") if (useRest) { childArgsStr should endWith ("thejar.jar org.SomeClass arg1 arg2") @@ -359,7 +365,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) @@ -381,7 +387,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) @@ -403,7 +409,7 @@ class SparkSubmitSuite "/home/thejar.jar", "arg1") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) val childArgsMap = childArgs.grouped(2).map(a => a(0) -> a(1)).toMap childArgsMap.get("--primary-java-resource") should be (Some("file:/home/thejar.jar")) @@ -428,7 +434,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (_, _, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) conf.get("spark.executor.memory") should be ("5g") conf.get("spark.master") should be ("yarn") conf.get("spark.submit.deployMode") should be ("cluster") @@ -441,12 +447,12 @@ class SparkSubmitSuite val clArgs1 = Seq("--class", "org.apache.spark.repl.Main", "spark-shell") val appArgs1 = new SparkSubmitArguments(clArgs1) - val (_, _, conf1, _) = prepareSubmitEnvironment(appArgs1) + val (_, _, conf1, _) = submit.prepareSubmitEnvironment(appArgs1) conf1.get(UI_SHOW_CONSOLE_PROGRESS) should be (true) val clArgs2 = Seq("--class", "org.SomeClass", "thejar.jar") val appArgs2 = new SparkSubmitArguments(clArgs2) - val (_, _, conf2, _) = prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) assert(!conf2.contains(UI_SHOW_CONSOLE_PROGRESS)) } @@ -625,7 +631,7 @@ class SparkSubmitSuite "--files", files, "thejar.jar") val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) appArgs.jars should be (Utils.resolveURIs(jars)) appArgs.files should be (Utils.resolveURIs(files)) conf.get("spark.jars") should be (Utils.resolveURIs(jars + ",thejar.jar")) @@ -640,7 +646,7 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) appArgs2.files should be (Utils.resolveURIs(files)) appArgs2.archives should fullyMatch regex ("file:/archive1,file:.*#archive3") conf2.get("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) @@ -656,7 +662,7 @@ class SparkSubmitSuite "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val (_, _, conf3, _) = SparkSubmit.prepareSubmitEnvironment(appArgs3) + val (_, _, conf3, _) = submit.prepareSubmitEnvironment(appArgs3) appArgs3.pyFiles should be (Utils.resolveURIs(pyFiles)) conf3.get("spark.submit.pyFiles") should be ( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) @@ -708,7 +714,7 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) conf.get("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar")) conf.get("spark.files") should be(Utils.resolveURIs(files)) @@ -725,7 +731,7 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) conf2.get("spark.yarn.dist.files") should be(Utils.resolveURIs(files)) conf2.get("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives)) @@ -740,7 +746,7 @@ class SparkSubmitSuite "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val (_, _, conf3, _) = SparkSubmit.prepareSubmitEnvironment(appArgs3) + val (_, _, conf3, _) = submit.prepareSubmitEnvironment(appArgs3) conf3.get("spark.submit.pyFiles") should be( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) @@ -757,7 +763,7 @@ class SparkSubmitSuite "hdfs:///tmp/mister.py" ) val appArgs4 = new SparkSubmitArguments(clArgs4) - val (_, _, conf4, _) = SparkSubmit.prepareSubmitEnvironment(appArgs4) + val (_, _, conf4, _) = submit.prepareSubmitEnvironment(appArgs4) // Should not format python path for yarn cluster mode conf4.get("spark.submit.pyFiles") should be(Utils.resolveURIs(remotePyFiles)) } @@ -778,17 +784,17 @@ class SparkSubmitSuite } test("SPARK_CONF_DIR overrides spark-defaults.conf") { - forConfDir(Map("spark.executor.memory" -> "2.3g")) { path => + forConfDir(Map("spark.executor.memory" -> "3g")) { path => val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val args = Seq( "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", "--master", "local", unusedJar.toString) - val appArgs = new SparkSubmitArguments(args, Map("SPARK_CONF_DIR" -> path)) + val appArgs = new SparkSubmitArguments(args, env = Map("SPARK_CONF_DIR" -> path)) assert(appArgs.propertiesFile != null) assert(appArgs.propertiesFile.startsWith(path)) - appArgs.executorMemory should be ("2.3g") + appArgs.executorMemory should be ("3g") } } @@ -809,6 +815,9 @@ class SparkSubmitSuite val archive1 = File.createTempFile("archive1", ".zip", tmpArchiveDir) val archive2 = File.createTempFile("archive2", ".zip", tmpArchiveDir) + val tempPyFile = File.createTempFile("tmpApp", ".py") + tempPyFile.deleteOnExit() + val args = Seq( "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), "--name", "testApp", @@ -818,10 +827,10 @@ class SparkSubmitSuite "--files", s"${tmpFileDir.getAbsolutePath}/tmpFile*", "--py-files", s"${tmpPyFileDir.getAbsolutePath}/tmpPy*", "--archives", s"${tmpArchiveDir.getAbsolutePath}/*.zip", - jar2.toString) + tempPyFile.toURI().toString()) val appArgs = new SparkSubmitArguments(args) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) conf.get("spark.yarn.dist.jars").split(",").toSet should be (Set(jar1.toURI.toString, jar2.toURI.toString)) conf.get("spark.yarn.dist.files").split(",").toSet should be @@ -947,7 +956,7 @@ class SparkSubmitSuite ) val appArgs = new SparkSubmitArguments(args) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf)) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs, conf = Some(hadoopConf)) // All the resources should still be remote paths, so that YARN client will not upload again. conf.get("spark.yarn.dist.jars") should be (tmpJarPath) @@ -1007,7 +1016,7 @@ class SparkSubmitSuite ) ++ forceDownloadArgs ++ Seq(s"s3a://$mainResource") val appArgs = new SparkSubmitArguments(args) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf)) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs, conf = Some(hadoopConf)) val jars = conf.get("spark.yarn.dist.jars").split(",").toSet @@ -1058,7 +1067,7 @@ class SparkSubmitSuite "hello") val exception = intercept[SparkException] { - SparkSubmit.main(args) + submit.doSubmit(args) } assert(exception.getMessage() === "hello") diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index e505bc018857d..54c168a8218f3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -445,7 +445,7 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { "--class", mainClass, mainJar) ++ appArgs val args = new SparkSubmitArguments(commandLineArgs) - val (_, _, sparkConf, _) = SparkSubmit.prepareSubmitEnvironment(args) + val (_, _, sparkConf, _) = new SparkSubmit().prepareSubmitEnvironment(args) new RestSubmissionClient("spark://host:port").constructSubmitRequest( mainJar, mainClass, appArgs, sparkConf.getAll.toMap, Map.empty) } diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java index 44e69fc45dffa..4e02843480e8f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java @@ -139,7 +139,7 @@ public T setMainClass(String mainClass) { public T addSparkArg(String arg) { SparkSubmitOptionParser validator = new ArgumentValidator(false); validator.parse(Arrays.asList(arg)); - builder.sparkArgs.add(arg); + builder.userArgs.add(arg); return self(); } @@ -187,8 +187,8 @@ public T addSparkArg(String name, String value) { } } else { validator.parse(Arrays.asList(name, value)); - builder.sparkArgs.add(name); - builder.sparkArgs.add(value); + builder.userArgs.add(name); + builder.userArgs.add(value); } return self(); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java index 6d726b4a69a86..688e1f763c205 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java @@ -89,10 +89,18 @@ Method findSparkSubmit() throws IOException { } Class sparkSubmit; + // SPARK-22941: first try the new SparkSubmit interface that has better error handling, + // but fall back to the old interface in case someone is mixing & matching launcher and + // Spark versions. try { - sparkSubmit = cl.loadClass("org.apache.spark.deploy.SparkSubmit"); - } catch (Exception e) { - throw new IOException("Cannot find SparkSubmit; make sure necessary jars are available.", e); + sparkSubmit = cl.loadClass("org.apache.spark.deploy.InProcessSparkSubmit"); + } catch (Exception e1) { + try { + sparkSubmit = cl.loadClass("org.apache.spark.deploy.SparkSubmit"); + } catch (Exception e2) { + throw new IOException("Cannot find SparkSubmit; make sure necessary jars are available.", + e2); + } } Method main; diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index e0ef22d7d5058..5cb6457bf5c21 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -88,8 +88,9 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { SparkLauncher.NO_RESOURCE); } - final List sparkArgs; - private final boolean isAppResourceReq; + final List userArgs; + private final List parsedArgs; + private final boolean requiresAppResource; private final boolean isExample; /** @@ -99,17 +100,27 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { */ private boolean allowsMixedArguments; + /** + * This constructor is used when creating a user-configurable launcher. It allows the + * spark-submit argument list to be modified after creation. + */ SparkSubmitCommandBuilder() { - this.sparkArgs = new ArrayList<>(); - this.isAppResourceReq = true; + this.requiresAppResource = true; this.isExample = false; + this.parsedArgs = new ArrayList<>(); + this.userArgs = new ArrayList<>(); } + /** + * This constructor is used when invoking spark-submit; it parses and validates arguments + * provided by the user on the command line. + */ SparkSubmitCommandBuilder(List args) { this.allowsMixedArguments = false; - this.sparkArgs = new ArrayList<>(); + this.parsedArgs = new ArrayList<>(); boolean isExample = false; List submitArgs = args; + this.userArgs = Collections.emptyList(); if (args.size() > 0) { switch (args.get(0)) { @@ -131,21 +142,21 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { } this.isExample = isExample; - OptionParser parser = new OptionParser(); + OptionParser parser = new OptionParser(true); parser.parse(submitArgs); - this.isAppResourceReq = parser.isAppResourceReq; - } else { + this.requiresAppResource = parser.requiresAppResource; + } else { this.isExample = isExample; - this.isAppResourceReq = false; + this.requiresAppResource = false; } } @Override public List buildCommand(Map env) throws IOException, IllegalArgumentException { - if (PYSPARK_SHELL.equals(appResource) && isAppResourceReq) { + if (PYSPARK_SHELL.equals(appResource) && requiresAppResource) { return buildPySparkShellCommand(env); - } else if (SPARKR_SHELL.equals(appResource) && isAppResourceReq) { + } else if (SPARKR_SHELL.equals(appResource) && requiresAppResource) { return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); @@ -154,9 +165,19 @@ public List buildCommand(Map env) List buildSparkSubmitArgs() { List args = new ArrayList<>(); - SparkSubmitOptionParser parser = new SparkSubmitOptionParser(); + OptionParser parser = new OptionParser(false); + final boolean requiresAppResource; + + // If the user args array is not empty, we need to parse it to detect exactly what + // the user is trying to run, so that checks below are correct. + if (!userArgs.isEmpty()) { + parser.parse(userArgs); + requiresAppResource = parser.requiresAppResource; + } else { + requiresAppResource = this.requiresAppResource; + } - if (!allowsMixedArguments && isAppResourceReq) { + if (!allowsMixedArguments && requiresAppResource) { checkArgument(appResource != null, "Missing application resource."); } @@ -208,15 +229,16 @@ List buildSparkSubmitArgs() { args.add(join(",", pyFiles)); } - if (isAppResourceReq) { - checkArgument(!isExample || mainClass != null, "Missing example class name."); + if (isExample) { + checkArgument(mainClass != null, "Missing example class name."); } + if (mainClass != null) { args.add(parser.CLASS); args.add(mainClass); } - args.addAll(sparkArgs); + args.addAll(parsedArgs); if (appResource != null) { args.add(appResource); } @@ -399,7 +421,12 @@ private List findExamplesJars() { private class OptionParser extends SparkSubmitOptionParser { - boolean isAppResourceReq = true; + boolean requiresAppResource = true; + private final boolean errorOnUnknownArgs; + + OptionParser(boolean errorOnUnknownArgs) { + this.errorOnUnknownArgs = errorOnUnknownArgs; + } @Override protected boolean handle(String opt, String value) { @@ -443,23 +470,23 @@ protected boolean handle(String opt, String value) { break; case KILL_SUBMISSION: case STATUS: - isAppResourceReq = false; - sparkArgs.add(opt); - sparkArgs.add(value); + requiresAppResource = false; + parsedArgs.add(opt); + parsedArgs.add(value); break; case HELP: case USAGE_ERROR: - isAppResourceReq = false; - sparkArgs.add(opt); + requiresAppResource = false; + parsedArgs.add(opt); break; case VERSION: - isAppResourceReq = false; - sparkArgs.add(opt); + requiresAppResource = false; + parsedArgs.add(opt); break; default: - sparkArgs.add(opt); + parsedArgs.add(opt); if (value != null) { - sparkArgs.add(value); + parsedArgs.add(value); } break; } @@ -483,12 +510,13 @@ protected boolean handleUnknown(String opt) { mainClass = className; appResource = SparkLauncher.NO_RESOURCE; return false; - } else { + } else if (errorOnUnknownArgs) { checkArgument(!opt.startsWith("-"), "Unrecognized option: %s", opt); checkState(appResource == null, "Found unrecognized argument but resource is already set."); appResource = opt; return false; } + return true; } @Override diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b37b4d51775e8..a87fa68422c34 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,12 +36,17 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-22941][core] Do not exit JVM when submit fails with in-process launcher. + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printWarning"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.parseSparkConfProperty"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printVersionAndExit"), + // [SPARK-23412][ML] Add cosine distance measure to BisectingKmeans ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.org$apache$spark$ml$param$shared$HasDistanceMeasure$_setter_$distanceMeasure_="), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.getDistanceMeasure"), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.distanceMeasure"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.BisectingKMeansModel#SaveLoadV1_0.load"), - + // [SPARK-20659] Remove StorageStatus, or make it private ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOffHeapStorageMemory"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOffHeapStorageMemory"), diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index aa378c9d340f1..ccf33e8d4283c 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.mesos import java.util.concurrent.CountDownLatch -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.mesos.config._ import org.apache.spark.deploy.mesos.ui.MesosClusterUI import org.apache.spark.deploy.rest.mesos.MesosRestServer @@ -100,7 +100,13 @@ private[mesos] object MesosClusterDispatcher Thread.setDefaultUncaughtExceptionHandler(new SparkUncaughtExceptionHandler) Utils.initDaemon(log) val conf = new SparkConf - val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf) + val dispatcherArgs = try { + new MesosClusterDispatcherArguments(args, conf) + } catch { + case e: SparkException => + printErrorAndExit(e.getMessage()) + null + } conf.setMaster(dispatcherArgs.masterUrl) conf.setAppName(dispatcherArgs.name) dispatcherArgs.zookeeperUrl.foreach { z => diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala index 096bb4e1af688..267a4283db9e6 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -21,6 +21,7 @@ import scala.annotation.tailrec import scala.collection.mutable import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkSubmitUtils import org.apache.spark.util.{IntParam, Utils} private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: SparkConf) { @@ -95,9 +96,8 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: parse(tail) case ("--conf") :: value :: tail => - val pair = MesosClusterDispatcher. - parseSparkConfProperty(value) - confProperties(pair._1) = pair._2 + val (k, v) = SparkSubmitUtils.parseSparkConfProperty(value) + confProperties(k) = v parse(tail) case ("--help") :: tail => From 75a183071c4ed2e407c930edfdf721779662b3ee Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 11 Apr 2018 09:59:38 -0700 Subject: [PATCH 604/774] [SPARK-22883] ML test for StructuredStreaming: spark.ml.feature, I-M ## What changes were proposed in this pull request? Adds structured streaming tests using testTransformer for these suites: * IDF * Imputer * Interaction * MaxAbsScaler * MinHashLSH * MinMaxScaler * NGram ## How was this patch tested? It is a bunch of tests! Author: Joseph K. Bradley Closes #20964 from jkbradley/SPARK-22883-part2. --- .../apache/spark/ml/feature/IDFSuite.scala | 14 +++--- .../spark/ml/feature/ImputerSuite.scala | 31 ++++++++++--- .../spark/ml/feature/InteractionSuite.scala | 46 ++++++++++--------- .../spark/ml/feature/MaxAbsScalerSuite.scala | 14 +++--- .../spark/ml/feature/MinHashLSHSuite.scala | 25 ++++++++-- .../spark/ml/feature/MinMaxScalerSuite.scala | 14 +++--- .../apache/spark/ml/feature/NGramSuite.scala | 2 +- 7 files changed, 89 insertions(+), 57 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index 005edf73d29be..cdd62be43b54c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel} import org.apache.spark.mllib.linalg.VectorImplicits._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row -class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class IDFSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -57,7 +55,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead Vectors.dense(0.0, 1.0, 2.0, 3.0), Vectors.sparse(numOfFeatures, Array(1), Array(1.0)) ) - val numOfData = data.size + val numOfData = data.length val idf = Vectors.dense(Array(0, 3, 1, 2).map { x => math.log((numOfData + 1.0) / (x + 1.0)) }) @@ -72,7 +70,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead MLTestingUtils.checkCopyAndUids(idfEst, idfModel) - idfModel.transform(df).select("idfValue", "expected").collect().foreach { + testTransformer[(Vector, Vector)](df, idfModel, "idfValue", "expected") { case Row(x: Vector, y: Vector) => assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } @@ -85,7 +83,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead Vectors.dense(0.0, 1.0, 2.0, 3.0), Vectors.sparse(numOfFeatures, Array(1), Array(1.0)) ) - val numOfData = data.size + val numOfData = data.length val idf = Vectors.dense(Array(0, 3, 1, 2).map { x => if (x > 0) math.log((numOfData + 1.0) / (x + 1.0)) else 0 }) @@ -99,7 +97,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead .setMinDocFreq(1) .fit(df) - idfModel.transform(df).select("idfValue", "expected").collect().foreach { + testTransformer[(Vector, Vector)](df, idfModel, "idfValue", "expected") { case Row(x: Vector, y: Vector) => assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index c08b35b419266..75f63a623e6d8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -16,13 +16,12 @@ */ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.SparkException +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class ImputerSuite extends MLTest with DefaultReadWriteTest { test("Imputer for Double with default missing Value NaN") { val df = spark.createDataFrame( Seq( @@ -76,6 +75,28 @@ class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with Default ImputerSuite.iterateStrategyTest(imputer, df) } + test("Imputer should work with Structured Streaming") { + val localSpark = spark + import localSpark.implicits._ + val df = Seq[(java.lang.Double, Double)]( + (4.0, 4.0), + (10.0, 10.0), + (10.0, 10.0), + (Double.NaN, 8.0), + (null, 8.0) + ).toDF("value", "expected_mean_value") + val imputer = new Imputer() + .setInputCols(Array("value")) + .setOutputCols(Array("out")) + .setStrategy("mean") + val model = imputer.fit(df) + testTransformer[(java.lang.Double, Double)](df, model, "expected_mean_value", "out") { + case Row(exp: java.lang.Double, out: Double) => + assert((exp.isNaN && out.isNaN) || (exp == out), + s"Imputed values differ. Expected: $exp, actual: $out") + } + } + test("Imputer throws exception when surrogate cannot be computed") { val df = spark.createDataFrame( Seq( (0, Double.NaN, 1.0, 1.0), @@ -164,8 +185,6 @@ object ImputerSuite { * @param df DataFrame with columns "id", "value", "expected_mean", "expected_median" */ def iterateStrategyTest(imputer: Imputer, df: DataFrame): Unit = { - val inputCols = imputer.getInputCols - Seq("mean", "median").foreach { strategy => imputer.setStrategy(strategy) val model = imputer.fit(df) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala index 54f059e5f143e..eea31fc7ae3f2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -19,15 +19,15 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col -class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class InteractionSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -63,9 +63,9 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def test("numeric interaction") { val data = Seq( - (2, Vectors.dense(3.0, 4.0)), - (1, Vectors.dense(1.0, 5.0)) - ).toDF("a", "b") + (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0)) + ).toDF("a", "b", "expected") val groupAttr = new AttributeGroup( "b", Array[Attribute]( @@ -73,14 +73,15 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def NumericAttribute.defaultAttr.withName("bar"))) val df = data.select( col("a").as("a", NumericAttribute.defaultAttr.toMetadata()), - col("b").as("b", groupAttr.toMetadata())) + col("b").as("b", groupAttr.toMetadata()), + col("expected")) val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + testTransformer[(Int, Vector, Vector)](df, trans, "features", "expected") { + case Row(features: Vector, expected: Vector) => + assert(features === expected) + } + val res = trans.transform(df) - val expected = Seq( - (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), - (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0)) - ).toDF("a", "b", "features") - assert(res.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(res.schema("features")) val expectedAttrs = new AttributeGroup( "features", @@ -92,9 +93,9 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def test("nominal interaction") { val data = Seq( - (2, Vectors.dense(3.0, 4.0)), - (1, Vectors.dense(1.0, 5.0)) - ).toDF("a", "b") + (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0)) + ).toDF("a", "b", "expected") val groupAttr = new AttributeGroup( "b", Array[Attribute]( @@ -103,14 +104,15 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def val df = data.select( col("a").as( "a", NominalAttribute.defaultAttr.withValues(Array("up", "down", "left")).toMetadata()), - col("b").as("b", groupAttr.toMetadata())) + col("b").as("b", groupAttr.toMetadata()), + col("expected")) val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + testTransformer[(Int, Vector, Vector)](df, trans, "features", "expected") { + case Row(features: Vector, expected: Vector) => + assert(features === expected) + } + val res = trans.transform(df) - val expected = Seq( - (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), - (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0)) - ).toDF("a", "b", "features") - assert(res.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(res.schema("features")) val expectedAttrs = new AttributeGroup( "features", diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala index 918da4f9388d4..8dd0f0cb91e37 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala @@ -14,15 +14,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.sql.Row -class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class MaxAbsScalerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -45,9 +44,10 @@ class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De .setOutputCol("scaled") val model = scaler.fit(df) - model.transform(df).select("expected", "scaled").collect() - .foreach { case Row(vector1: Vector, vector2: Vector) => - assert(vector1.equals(vector2), s"MaxAbsScaler ut error: $vector2 should be $vector1") + testTransformer[(Vector, Vector)](df, model, "expected", "scaled") { + case Row(expectedVec: Vector, actualVec: Vector) => + assert(expectedVec === actualVec, + s"MaxAbsScaler error: Expected $expectedVec but computed $actualVec") } MLTestingUtils.checkCopyAndUids(scaler, model) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala index 3da0fb7da01ae..1c2956cb82908 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Dataset +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.sql.{Dataset, Row} -class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + +class MinHashLSHSuite extends MLTest with DefaultReadWriteTest { @transient var dataset: Dataset[_] = _ @@ -175,4 +174,20 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa assert(precision == 1.0) assert(recall >= 0.7) } + + test("MinHashLSHModel.transform should work with Structured Streaming") { + val localSpark = spark + import localSpark.implicits._ + + val model = new MinHashLSHModel("mh", randCoefficients = Array((1, 0))) + model.set(model.inputCol, "keys") + testTransformer[Tuple1[Vector]](dataset.toDF(), model, "keys", model.getOutputCol) { + case Row(_: Vector, output: Seq[_]) => + assert(output.length === model.randCoefficients.length) + // no AND-amplification yet: SPARK-18450, so each hash output is of length 1 + output.foreach { + case hashOutput: Vector => assert(hashOutput.size === 1) + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index 51db74eb739ca..2d965f2ca2c54 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.sql.Row -class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class MinMaxScalerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -48,9 +46,9 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De .setMax(5) val model = scaler.fit(df) - model.transform(df).select("expected", "scaled").collect() - .foreach { case Row(vector1: Vector, vector2: Vector) => - assert(vector1.equals(vector2), "Transformed vector is different with expected.") + testTransformer[(Vector, Vector)](df, model, "expected", "scaled") { + case Row(vector1: Vector, vector2: Vector) => + assert(vector1 === vector2, "Transformed vector is different with expected.") } MLTestingUtils.checkCopyAndUids(scaler, model) @@ -114,7 +112,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De val model = scaler.fit(df) model.transform(df).select("expected", "scaled").collect() .foreach { case Row(vector1: Vector, vector2: Vector) => - assert(vector1.equals(vector2), "Transformed vector is different with expected.") + assert(vector1 === vector2, "Transformed vector is different with expected.") } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index e5956ee9942aa..201a335e0d7be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -84,7 +84,7 @@ class NGramSuite extends MLTest with DefaultReadWriteTest { def testNGram(t: NGram, dataFrame: DataFrame): Unit = { testTransformer[(Seq[String], Seq[String])](dataFrame, t, "nGrams", "wantedNGrams") { - case Row(actualNGrams : Seq[String], wantedNGrams: Seq[String]) => + case Row(actualNGrams : Seq[_], wantedNGrams: Seq[_]) => assert(actualNGrams === wantedNGrams) } } From 9d960de0814a1128318676cc2e91f447cdf0137f Mon Sep 17 00:00:00 2001 From: JBauerKogentix <37910022+JBauerKogentix@users.noreply.github.com> Date: Wed, 11 Apr 2018 15:52:13 -0700 Subject: [PATCH 605/774] typo rawPredicition changed to rawPrediction MultilayerPerceptronClassifier had 4 occurrences ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: JBauerKogentix <37910022+JBauerKogentix@users.noreply.github.com> Closes #21030 from JBauerKogentix/patch-1. --- python/pyspark/ml/classification.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index fbbe3d0307c81..ec17653a1adf9 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1543,12 +1543,12 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, solver="l-bfgs", initialWeights=None, probabilityCol="probability", - rawPredicitionCol="rawPrediction"): + rawPredictionCol="rawPrediction"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \ solver="l-bfgs", initialWeights=None, probabilityCol="probability", \ - rawPredicitionCol="rawPrediction") + rawPredictionCol="rawPrediction") """ super(MultilayerPerceptronClassifier, self).__init__() self._java_obj = self._new_java_obj( @@ -1562,12 +1562,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, solver="l-bfgs", initialWeights=None, probabilityCol="probability", - rawPredicitionCol="rawPrediction"): + rawPredictionCol="rawPrediction"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \ solver="l-bfgs", initialWeights=None, probabilityCol="probability", \ - rawPredicitionCol="rawPrediction"): + rawPredictionCol="rawPrediction"): Sets params for MultilayerPerceptronClassifier. """ kwargs = self._input_kwargs From e904dfaf0d16f9fa0cc4d2f46a3dec1b1d77de75 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 11 Apr 2018 17:04:34 -0700 Subject: [PATCH 606/774] Revert "[SPARK-23960][SQL][MINOR] Mark HashAggregateExec.bufVars as transient" This reverts commit 271c891b91917d660d1f6b995de397c47c7a6058. --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 965950ed94fe8..a5dc6ebf2b0f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -174,8 +174,8 @@ case class HashAggregateExec( } } - // The variables used as aggregation buffer. Only used in codegen for aggregation without keys. - @transient private var bufVars: Seq[ExprCode] = _ + // The variables used as aggregation buffer. Only used for aggregation without keys. + private var bufVars: Seq[ExprCode] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") @@ -238,8 +238,6 @@ case class HashAggregateExec( | } """.stripMargin) - bufVars = null // explicitly null this field out to allow the referent to be GC'd sooner - val numOutput = metricTerm(ctx, "numOutputRows") val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") From 6a2289ecf020a99cd9b3bcea7da5e78fb4e0303a Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 12 Apr 2018 15:58:04 +0800 Subject: [PATCH 607/774] [SPARK-23962][SQL][TEST] Fix race in currentExecutionIds(). SQLMetricsTestUtils.currentExecutionIds() was racing with the listener bus, which lead to some flaky tests. We should wait till the listener bus is empty. I tested by adding some Thread.sleep()s in SQLAppStatusListener, which reproduced the exceptions I saw on Jenkins. With this change, they went away. Author: Imran Rashid Closes #21041 from squito/SPARK-23962. --- .../apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 534d8bb629b8c..dcc540fc4f109 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -34,6 +34,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils { import testImplicits._ protected def currentExecutionIds(): Set[Long] = { + spark.sparkContext.listenerBus.waitUntilEmpty(10000) statusStore.executionsList.map(_.executionId).toSet } From 0b19122d434e39eb117ccc3174a0688c9c874d48 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 12 Apr 2018 22:21:30 +0800 Subject: [PATCH 608/774] [SPARK-23762][SQL] UTF8StringBuffer uses MemoryBlock ## What changes were proposed in this pull request? This PR tries to use `MemoryBlock` in `UTF8StringBuffer`. In general, there are two advantages to use `MemoryBlock`. 1. Has clean API calls rather than using a Java array or `PlatformMemory` 2. Improve runtime performance of memory access instead of using `Object`. ## How was this patch tested? Added `UTF8StringBufferSuite` Author: Kazuaki Ishizaki Closes #20871 from kiszk/SPARK-23762. --- .../codegen/UTF8StringBuilder.java | 35 +++++++--------- .../codegen/UTF8StringBuilderSuite.scala | 42 +++++++++++++++++++ 2 files changed, 56 insertions(+), 21 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java index f0f66bae245fd..f8000d78cd1b6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java @@ -19,6 +19,8 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.UTF8String; /** @@ -29,43 +31,34 @@ public class UTF8StringBuilder { private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; - private byte[] buffer; - private int cursor = Platform.BYTE_ARRAY_OFFSET; + private ByteArrayMemoryBlock buffer; + private int length = 0; public UTF8StringBuilder() { // Since initial buffer size is 16 in `StringBuilder`, we set the same size here - this.buffer = new byte[16]; + this.buffer = new ByteArrayMemoryBlock(16); } // Grows the buffer by at least `neededSize` private void grow(int neededSize) { - if (neededSize > ARRAY_MAX - totalSize()) { + if (neededSize > ARRAY_MAX - length) { throw new UnsupportedOperationException( "Cannot grow internal buffer by size " + neededSize + " because the size after growing " + "exceeds size limitation " + ARRAY_MAX); } - final int length = totalSize() + neededSize; - if (buffer.length < length) { - int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX; - final byte[] tmp = new byte[newLength]; - Platform.copyMemory( - buffer, - Platform.BYTE_ARRAY_OFFSET, - tmp, - Platform.BYTE_ARRAY_OFFSET, - totalSize()); + final int requestedSize = length + neededSize; + if (buffer.size() < requestedSize) { + int newLength = requestedSize < ARRAY_MAX / 2 ? requestedSize * 2 : ARRAY_MAX; + final ByteArrayMemoryBlock tmp = new ByteArrayMemoryBlock(newLength); + MemoryBlock.copyMemory(buffer, tmp, length); buffer = tmp; } } - private int totalSize() { - return cursor - Platform.BYTE_ARRAY_OFFSET; - } - public void append(UTF8String value) { grow(value.numBytes()); - value.writeToMemory(buffer, cursor); - cursor += value.numBytes(); + value.writeToMemory(buffer.getByteArray(), length + Platform.BYTE_ARRAY_OFFSET); + length += value.numBytes(); } public void append(String value) { @@ -73,6 +66,6 @@ public void append(String value) { } public UTF8String build() { - return UTF8String.fromBytes(buffer, 0, totalSize()); + return UTF8String.fromBytes(buffer.getByteArray(), 0, length); } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala new file mode 100644 index 0000000000000..1b25a4b191f86 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.unsafe.types.UTF8String + +class UTF8StringBuilderSuite extends SparkFunSuite { + + test("basic test") { + val sb = new UTF8StringBuilder() + assert(sb.build() === UTF8String.EMPTY_UTF8) + + sb.append("") + assert(sb.build() === UTF8String.EMPTY_UTF8) + + sb.append("abcd") + assert(sb.build() === UTF8String.fromString("abcd")) + + sb.append(UTF8String.fromString("1234")) + assert(sb.build() === UTF8String.fromString("abcd1234")) + + // expect to grow an internal buffer + sb.append(UTF8String.fromString("efgijk567890")) + assert(sb.build() === UTF8String.fromString("abcd1234efgijk567890")) + } +} From 0f93b91a71444a1a938acfd8ea2191c54fb0187c Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 12 Apr 2018 15:47:42 -0600 Subject: [PATCH 609/774] [SPARK-23751][FOLLOW-UP] fix build for scala-2.12 ## What changes were proposed in this pull request? fix build for scala-2.12 ## How was this patch tested? Manual. Author: WeichenXu Closes #21051 from WeichenXu123/fix_build212. --- .../scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala index af8ff64d33ffe..adf8145726711 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala @@ -85,7 +85,7 @@ object KolmogorovSmirnovTest { dataset: Dataset[_], sampleCol: String, cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = { - test(dataset, sampleCol, (x: Double) => cdf.call(x)) + test(dataset, sampleCol, (x: Double) => cdf.call(x).toDouble) } /** From 682002b6da844ed11324ee5ff4d00fc0294c0b31 Mon Sep 17 00:00:00 2001 From: Patrick Pisciuneri Date: Fri, 13 Apr 2018 09:45:27 +0800 Subject: [PATCH 610/774] [SPARK-23867][SCHEDULER] use droppedCount in logWarning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Get the count of dropped events for output in log message. ## How was this patch tested? The fix is pretty trivial, but `./dev/run-tests` were run and were successful. Please review http://spark.apache.org/contributing.html before opening a pull request. vanzin cloud-fan The contribution is my original work and I license the work to the project under the project’s open source license. Author: Patrick Pisciuneri Closes #20977 from phpisciuneri/fix-log-warning. --- .../main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala index 7e14938acd8e0..c1fedd63f6a90 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala @@ -166,7 +166,7 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi val prevLastReportTimestamp = lastReportTimestamp lastReportTimestamp = System.currentTimeMillis() val previous = new java.util.Date(prevLastReportTimestamp) - logWarning(s"Dropped $droppedEvents events from $name since $previous.") + logWarning(s"Dropped $droppedCount events from $name since $previous.") } } } From 14291b061b9b40eadbf4ed442f9a5021b8e09597 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 12 Apr 2018 20:00:25 -0700 Subject: [PATCH 611/774] [SPARK-23748][SS] Fix SS continuous process doesn't support SubqueryAlias issue ## What changes were proposed in this pull request? Current SS continuous doesn't support processing on temp table or `df.as("xxx")`, SS will throw an exception as LogicalPlan not supported, details described in [here](https://issues.apache.org/jira/browse/SPARK-23748). So here propose to add this support. ## How was this patch tested? new UT. Author: jerryshao Closes #21017 from jerryshao/SPARK-23748. --- .../UnsupportedOperationChecker.scala | 2 +- .../continuous/ContinuousSuite.scala | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index b55043c270644..ff9d6d7a7dded 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -345,7 +345,7 @@ object UnsupportedOperationChecker { plan.foreachUp { implicit subPlan => subPlan match { case (_: Project | _: Filter | _: MapElements | _: MapPartitions | - _: DeserializeToObject | _: SerializeFromObject) => + _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias) => case node if node.nodeName == "StreamingRelationV2" => case node => throwError(s"Continuous processing does not support ${node.nodeName} operations.") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index f5884b9c8de12..ef74efef156d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -171,6 +171,25 @@ class ContinuousSuite extends ContinuousSuiteBase { "Continuous processing does not support current time operations.")) } + test("subquery alias") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .createOrReplaceTempView("rate") + val test = spark.sql("select value from rate where value > 5") + + testStream(test, useV2Sink = true)( + StartStream(longContinuousTrigger), + AwaitEpoch(0), + Execute(waitForRateSourceTriggers(_, 2)), + IncrementEpoch(), + Execute(waitForRateSourceTriggers(_, 4)), + IncrementEpoch(), + CheckAnswerRowsContains(scala.Range(6, 20).map(Row(_)))) + } + test("repeatedly restart") { val df = spark.readStream .format("rate") From ab7b961a4fe96ca02b8352d16b0fa80c972b67fc Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 13 Apr 2018 11:28:13 +0800 Subject: [PATCH 612/774] [SPARK-23942][PYTHON][SQL] Makes collect in PySpark as action for a query executor listener ## What changes were proposed in this pull request? This PR proposes to add `collect` to a query executor as an action. Seems `collect` / `collect` with Arrow are not recognised via `QueryExecutionListener` as an action. For example, if we have a custom listener as below: ```scala package org.apache.spark.sql import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.util.QueryExecutionListener class TestQueryExecutionListener extends QueryExecutionListener with Logging { override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { logError("Look at me! I'm 'onSuccess'") } override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { } } ``` and set `spark.sql.queryExecutionListeners` to `org.apache.spark.sql.TestQueryExecutionListener` Other operations in PySpark or Scala side seems fine: ```python >>> sql("SELECT * FROM range(1)").show() ``` ``` 18/04/09 17:02:04 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' +---+ | id| +---+ | 0| +---+ ``` ```scala scala> sql("SELECT * FROM range(1)").collect() ``` ``` 18/04/09 16:58:41 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' res1: Array[org.apache.spark.sql.Row] = Array([0]) ``` but .. **Before** ```python >>> sql("SELECT * FROM range(1)").collect() ``` ``` [Row(id=0)] ``` ```python >>> spark.conf.set("spark.sql.execution.arrow.enabled", "true") >>> sql("SELECT * FROM range(1)").toPandas() ``` ``` id 0 0 ``` **After** ```python >>> sql("SELECT * FROM range(1)").collect() ``` ``` 18/04/09 16:57:58 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' [Row(id=0)] ``` ```python >>> spark.conf.set("spark.sql.execution.arrow.enabled", "true") >>> sql("SELECT * FROM range(1)").toPandas() ``` ``` 18/04/09 17:53:26 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' id 0 0 ``` ## How was this patch tested? I have manually tested as described above and unit test was added. Author: hyukjinkwon Closes #21007 from HyukjinKwon/SPARK-23942. --- python/pyspark/sql/tests.py | 87 ++++++++++++++++--- .../scala/org/apache/spark/sql/Dataset.scala | 20 +++-- .../sql/TestQueryExecutionListener.scala | 44 ++++++++++ 3 files changed, 134 insertions(+), 17 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 96c2a776a5049..4e99c8e3c6b10 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -186,16 +186,12 @@ def __init__(self, key, value): self.value = value -class ReusedSQLTestCase(ReusedPySparkTestCase): - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.spark = SparkSession(cls.sc) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - cls.spark.stop() +class SQLTestUtils(object): + """ + This util assumes the instance of this to have 'spark' attribute, having a spark session. + It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the + the implementation of this class has 'spark' attribute. + """ @contextmanager def sql_conf(self, pairs): @@ -204,6 +200,7 @@ def sql_conf(self, pairs): `value` to the configuration `key` and then restores it back when it exits. """ assert isinstance(pairs, dict), "pairs should be a dictionary." + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." keys = pairs.keys() new_values = pairs.values() @@ -219,6 +216,18 @@ def sql_conf(self, pairs): else: self.spark.conf.set(key, old_value) + +class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() + def assertPandasEqual(self, expected, result): msg = ("DataFrames are not equal: " + "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + @@ -3066,6 +3075,64 @@ def test_sparksession_with_stopped_sparkcontext(self): sc.stop() +class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): + # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is + # static and immutable. This can't be set or unset, for example, via `spark.conf`. + + @classmethod + def setUpClass(cls): + import glob + from pyspark.find_spark_home import _find_spark_home + + SPARK_HOME = _find_spark_home() + filename_pattern = ( + "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" + "TestQueryExecutionListener.class") + if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): + raise unittest.SkipTest( + "'org.apache.spark.sql.TestQueryExecutionListener' is not " + "available. Will skip the related tests.") + + # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. + cls.spark = SparkSession.builder \ + .master("local[4]") \ + .appName(cls.__name__) \ + .config( + "spark.sql.queryExecutionListeners", + "org.apache.spark.sql.TestQueryExecutionListener") \ + .getOrCreate() + + @classmethod + def tearDownClass(cls): + cls.spark.stop() + + def tearDown(self): + self.spark._jvm.OnSuccessCall.clear() + + def test_query_execution_listener_on_collect(self): + self.assertFalse( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should not be called before 'collect'") + self.spark.sql("SELECT * FROM range(1)").collect() + self.assertTrue( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should be called after 'collect'") + + @unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) + def test_query_execution_listener_on_collect_with_arrow(self): + with self.sql_conf({"spark.sql.execution.arrow.enabled": True}): + self.assertFalse( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should not be " + "called before 'toPandas'") + self.spark.sql("SELECT * FROM range(1)").toPandas() + self.assertTrue( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should be called after 'toPandas'") + + class SparkSessionTests(PySparkTestCase): # This test is separate because it's closely related with session's start and stop. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0aee1d7be5788..917168162b236 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3189,10 +3189,10 @@ class Dataset[T] private[sql]( private[sql] def collectToPython(): Int = { EvaluatePython.registerPicklers() - withNewExecutionId { + withAction("collectToPython", queryExecution) { plan => val toJava: (Any) => Any = EvaluatePython.toJava(_, schema) - val iter = new SerDeUtil.AutoBatchedPickler( - queryExecution.executedPlan.executeCollect().iterator.map(toJava)) + val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( + plan.executeCollect().iterator.map(toJava)) PythonRDD.serveIterator(iter, "serve-DataFrame") } } @@ -3201,8 +3201,9 @@ class Dataset[T] private[sql]( * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. */ private[sql] def collectAsArrowToPython(): Int = { - withNewExecutionId { - val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable) + withAction("collectAsArrowToPython", queryExecution) { plan => + val iter: Iterator[Array[Byte]] = + toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) PythonRDD.serveIterator(iter, "serve-Arrow") } } @@ -3311,14 +3312,19 @@ class Dataset[T] private[sql]( } /** Convert to an RDD of ArrowPayload byte arrays */ - private[sql] def toArrowPayload: RDD[ArrowPayload] = { + private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = { val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone - queryExecution.toRdd.mapPartitionsInternal { iter => + plan.execute().mapPartitionsInternal { iter => val context = TaskContext.get() ArrowConverters.toPayloadIterator( iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context) } } + + // This is only used in tests, for now. + private[sql] def toArrowPayload: RDD[ArrowPayload] = { + toArrowPayload(queryExecution.executedPlan) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala new file mode 100644 index 0000000000000..d2a6358ee822b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.util.QueryExecutionListener + + +class TestQueryExecutionListener extends QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + OnSuccessCall.isOnSuccessCalled.set(true) + } + + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { } +} + +/** + * This has a variable to check if `onSuccess` is actually called or not. Currently, this is for + * the test case in PySpark. See SPARK-23942. + */ +object OnSuccessCall { + val isOnSuccessCalled = new AtomicBoolean(false) + + def isCalled(): Boolean = isOnSuccessCalled.get() + + def clear(): Unit = isOnSuccessCalled.set(false) +} From 1018be44d6c52cf18e14d84160850063f0e60a1d Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 12 Apr 2018 22:30:59 -0700 Subject: [PATCH 613/774] [SPARK-23971] Should not leak Spark sessions across test suites ## What changes were proposed in this pull request? Many suites currently leak Spark sessions (sometimes with stopped SparkContexts) via the thread-local active Spark session and default Spark session. We should attempt to clean these up and detect when this happens to improve the reproducibility of tests. ## How was this patch tested? Existing tests Author: Eric Liang Closes #21058 from ericl/clear-session. --- .../org/apache/spark/SharedSparkSession.java | 9 ++++++-- .../org/apache/spark/sql/SparkSession.scala | 23 +++++++++++++++++-- .../apache/spark/sql/SessionStateSuite.scala | 2 ++ .../spark/sql/test/SharedSparkSession.scala | 22 ++++++++++++++---- 4 files changed, 47 insertions(+), 9 deletions(-) diff --git a/mllib/src/test/java/org/apache/spark/SharedSparkSession.java b/mllib/src/test/java/org/apache/spark/SharedSparkSession.java index 43779878890db..35a250955b282 100644 --- a/mllib/src/test/java/org/apache/spark/SharedSparkSession.java +++ b/mllib/src/test/java/org/apache/spark/SharedSparkSession.java @@ -42,7 +42,12 @@ public void setUp() throws IOException { @After public void tearDown() { - spark.stop(); - spark = null; + try { + spark.stop(); + spark = null; + } finally { + SparkSession.clearDefaultSession(); + SparkSession.clearActiveSession(); + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index b107492fbb330..c502e583a55c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ExecutionListenerManager -import org.apache.spark.util.Utils +import org.apache.spark.util.{CallSite, Utils} /** @@ -81,6 +81,9 @@ class SparkSession private( @transient private[sql] val extensions: SparkSessionExtensions) extends Serializable with Closeable with Logging { self => + // The call site where this SparkSession was constructed. + private val creationSite: CallSite = Utils.getCallSite() + private[sql] def this(sc: SparkContext) { this(sc, None, None, new SparkSessionExtensions) } @@ -763,7 +766,7 @@ class SparkSession private( @InterfaceStability.Stable -object SparkSession { +object SparkSession extends Logging { /** * Builder for [[SparkSession]]. @@ -1090,4 +1093,20 @@ object SparkSession { } } + private[spark] def cleanupAnyExistingSession(): Unit = { + val session = getActiveSession.orElse(getDefaultSession) + if (session.isDefined) { + logWarning( + s"""An existing Spark session exists as the active or default session. + |This probably means another suite leaked it. Attempting to stop it before continuing. + |This existing Spark session was created at: + | + |${session.get.creationSite.longForm} + | + """.stripMargin) + session.get.stop() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index 4efae4c46c2e1..7d1366092d1e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -44,6 +44,8 @@ class SessionStateSuite extends SparkFunSuite { if (activeSession != null) { activeSession.stop() activeSession = null + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() } super.afterAll() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index e758c865b908f..8968dbf36d507 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -60,6 +60,7 @@ trait SharedSparkSession protected implicit def sqlContext: SQLContext = _spark.sqlContext protected def createSparkSession: TestSparkSession = { + SparkSession.cleanupAnyExistingSession() new TestSparkSession(sparkConf) } @@ -92,11 +93,22 @@ trait SharedSparkSession * Stop the underlying [[org.apache.spark.SparkContext]], if any. */ protected override def afterAll(): Unit = { - super.afterAll() - if (_spark != null) { - _spark.sessionState.catalog.reset() - _spark.stop() - _spark = null + try { + super.afterAll() + } finally { + try { + if (_spark != null) { + try { + _spark.sessionState.catalog.reset() + } finally { + _spark.stop() + _spark = null + } + } + } finally { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } } } From 4b07036799b01894826b47c73142fe282c607a57 Mon Sep 17 00:00:00 2001 From: Fangshi Li Date: Fri, 13 Apr 2018 13:46:34 +0800 Subject: [PATCH 614/774] [SPARK-23815][CORE] Spark writer dynamic partition overwrite mode may fail to write output on multi level partition ## What changes were proposed in this pull request? Spark introduced new writer mode to overwrite only related partitions in SPARK-20236. While we are using this feature in our production cluster, we found a bug when writing multi-level partitions on HDFS. A simple test case to reproduce this issue: val df = Seq(("1","2","3")).toDF("col1", "col2","col3") df.write.partitionBy("col1","col2").mode("overwrite").save("/my/hdfs/location") If HDFS location "/my/hdfs/location" does not exist, there will be no output. This seems to be caused by the job commit change in SPARK-20236 in HadoopMapReduceCommitProtocol. In the commit job process, the output has been written into staging dir /my/hdfs/location/.spark-staging.xxx/col1=1/col2=2, and then the code calls fs.rename to rename /my/hdfs/location/.spark-staging.xxx/col1=1/col2=2 to /my/hdfs/location/col1=1/col2=2. However, in our case the operation will fail on HDFS because /my/hdfs/location/col1=1 does not exists. HDFS rename can not create directory for more than one level. This does not happen in the new unit test added with SPARK-20236 which uses local file system. We are proposing a fix. When cleaning current partition dir /my/hdfs/location/col1=1/col2=2 before the rename op, if the delete op fails (because /my/hdfs/location/col1=1/col2=2 may not exist), we call mkdirs op to create the parent dir /my/hdfs/location/col1=1 (if the parent dir does not exist) so the following rename op can succeed. Reference: in official HDFS document(https://hadoop.apache.org/docs/stable/hadoop-project-dist/hadoop-common/filesystem/filesystem.html), the rename command has precondition "dest must be root, or have a parent that exists" ## How was this patch tested? We have tested this patch on our production cluster and it fixed the problem Author: Fangshi Li Closes #20931 from fangshil/master. --- .../internal/io/HadoopMapReduceCommitProtocol.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 6d20ef1f98a3c..3e60c50ada59b 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -186,7 +186,17 @@ class HadoopMapReduceCommitProtocol( logDebug(s"Clean up default partition directories for overwriting: $partitionPaths") for (part <- partitionPaths) { val finalPartPath = new Path(path, part) - fs.delete(finalPartPath, true) + if (!fs.delete(finalPartPath, true) && !fs.exists(finalPartPath.getParent)) { + // According to the official hadoop FileSystem API spec, delete op should assume + // the destination is no longer present regardless of return value, thus we do not + // need to double check if finalPartPath exists before rename. + // Also in our case, based on the spec, delete returns false only when finalPartPath + // does not exist. When this happens, we need to take action if parent of finalPartPath + // also does not exist(e.g. the scenario described on SPARK-23815), because + // FileSystem API spec on rename op says the rename dest(finalPartPath) must have + // a parent that exists, otherwise we may get unexpected result on the rename. + fs.mkdirs(finalPartPath.getParent) + } fs.rename(new Path(stagingDir, part), finalPartPath) } } From 0323e61465ee747c9a57a70e9d6108876499546e Mon Sep 17 00:00:00 2001 From: yucai Date: Fri, 13 Apr 2018 00:00:04 -0700 Subject: [PATCH 615/774] [SPARK-23905][SQL] Add UDF weekday ## What changes were proposed in this pull request? Add UDF weekday ## How was this patch tested? A new test Author: yucai Closes #21009 from yucai/SPARK-23905. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/datetimeExpressions.scala | 55 +++++++++++++++---- .../expressions/DateExpressionsSuite.scala | 11 ++++ .../resources/sql-tests/inputs/datetime.sql | 2 + .../sql-tests/results/datetime.sql.out | 9 ++- 5 files changed, 67 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 747016beb06e7..131b958239e41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -395,6 +395,7 @@ object FunctionRegistry { expression[TruncTimestamp]("date_trunc"), expression[UnixTimestamp]("unix_timestamp"), expression[DayOfWeek]("dayofweek"), + expression[WeekDay]("weekday"), expression[WeekOfYear]("weekofyear"), expression[Year]("year"), expression[TimeWindow]("window"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 32fdb13afbbfa..b9b2cd5bdb9f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -426,36 +426,71 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa """, since = "2.3.0") // scalastyle:on line.size.limit -case class DayOfWeek(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class DayOfWeek(child: Expression) extends DayWeek { - override def inputTypes: Seq[AbstractDataType] = Seq(DateType) - - override def dataType: DataType = IntegerType + override protected def nullSafeEval(date: Any): Any = { + cal.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) + cal.get(Calendar.DAY_OF_WEEK) + } - @transient private lazy val c = { - Calendar.getInstance(DateTimeUtils.getTimeZone("UTC")) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, time => { + val cal = classOf[Calendar].getName + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val c = "calDayOfWeek" + ctx.addImmutableStateIfNotExists(cal, c, + v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""") + s""" + $c.setTimeInMillis($time * 1000L * 3600L * 24L); + ${ev.value} = $c.get($cal.DAY_OF_WEEK); + """ + }) } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(date) - Returns the day of the week for date/timestamp (0 = Monday, 1 = Tuesday, ..., 6 = Sunday).", + examples = """ + Examples: + > SELECT _FUNC_('2009-07-30'); + 3 + """, + since = "2.4.0") +// scalastyle:on line.size.limit +case class WeekDay(child: Expression) extends DayWeek { override protected def nullSafeEval(date: Any): Any = { - c.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) - c.get(Calendar.DAY_OF_WEEK) + cal.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) + (cal.get(Calendar.DAY_OF_WEEK) + 5 ) % 7 } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - val c = "calDayOfWeek" + val c = "calWeekDay" ctx.addImmutableStateIfNotExists(cal, c, v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""") s""" $c.setTimeInMillis($time * 1000L * 3600L * 24L); - ${ev.value} = $c.get($cal.DAY_OF_WEEK); + ${ev.value} = ($c.get($cal.DAY_OF_WEEK) + 5) % 7; """ }) } } +abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + @transient protected lazy val cal: Calendar = { + Calendar.getInstance(DateTimeUtils.getTimeZone("UTC")) + } +} + // scalastyle:off line.size.limit @ExpressionDescription( usage = "_FUNC_(date) - Returns the week of the year of the given date. A week is considered to start on a Monday and week 1 is the first week with >3 days.", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 786266a2c13c0..080ec487cfa6a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -211,6 +211,17 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(DayOfWeek, DateType) } + test("WeekDay") { + checkEvaluation(WeekDay(Literal.create(null, DateType)), null) + checkEvaluation(WeekDay(Literal(d)), 2) + checkEvaluation(WeekDay(Cast(Literal(sdfDate.format(d)), DateType, gmtId)), 2) + checkEvaluation(WeekDay(Cast(Literal(ts), DateType, gmtId)), 4) + checkEvaluation(WeekDay(Cast(Literal("2011-05-06"), DateType, gmtId)), 4) + checkEvaluation(WeekDay(Literal(new Date(sdf.parse("2017-05-27 13:10:15").getTime))), 5) + checkEvaluation(WeekDay(Literal(new Date(sdf.parse("1582-10-15 13:10:15").getTime))), 4) + checkConsistencyBetweenInterpretedAndCodegen(WeekDay, DateType) + } + test("WeekOfYear") { checkEvaluation(WeekOfYear(Literal.create(null, DateType)), null) checkEvaluation(WeekOfYear(Literal(d)), 15) diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index adea2bfa82cd3..547c2bef02b24 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -25,3 +25,5 @@ create temporary view ttf2 as select * from values select current_date = current_date(), current_timestamp = current_timestamp(), a, b from ttf2; select a, b from ttf2 order by a, current_date; + +select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15'); diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index bbb6851e69c7e..4e1cfa6e48c1c 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 9 +-- Number of queries: 10 -- !query 0 @@ -81,3 +81,10 @@ struct -- !query 8 output 1 2 2 3 + +-- !query 9 +select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15') +-- !query 3 schema +struct +-- !query 3 output +5 3 5 NULL 4 From a83ae0d9bc1b8f4909b9338370efe4020079bea7 Mon Sep 17 00:00:00 2001 From: mcheah Date: Fri, 13 Apr 2018 08:43:58 -0700 Subject: [PATCH 616/774] [SPARK-22839][K8S] Refactor to unify driver and executor pod builder APIs ## What changes were proposed in this pull request? Breaks down the construction of driver pods and executor pods in a way that uses a common abstraction for both spark-submit creating the driver and KubernetesClusterSchedulerBackend creating the executor. Encourages more code reuse and is more legible than the older approach. The high-level design is discussed in more detail on the JIRA ticket. This pull request is the implementation of that design with some minor changes in the implementation details. No user-facing behavior should break as a result of this change. ## How was this patch tested? Migrated all unit tests from the old submission steps architecture to the new architecture. Integration tests should not have to change and pass given that this shouldn't change any outward behavior. Author: mcheah Closes #20910 from mccheah/spark-22839-incremental. --- .../org/apache/spark/deploy/k8s/Config.scala | 2 +- .../spark/deploy/k8s/KubernetesConf.scala | 184 +++++++++++++ .../deploy/k8s/KubernetesDriverSpec.scala} | 25 +- .../spark/deploy/k8s/KubernetesUtils.scala | 11 - .../deploy/k8s/MountSecretsBootstrap.scala | 72 ----- ...ConfigurationStep.scala => SparkPod.scala} | 24 +- .../k8s/features/BasicDriverFeatureStep.scala | 136 ++++++++++ .../features/BasicExecutorFeatureStep.scala | 179 +++++++++++++ ...iverKubernetesCredentialsFeatureStep.scala | 216 +++++++++++++++ .../features/DriverServiceFeatureStep.scala | 97 +++++++ .../KubernetesFeatureConfigStep.scala | 71 +++++ .../features/MountSecretsFeatureStep.scala | 62 +++++ .../k8s/submit/DriverConfigOrchestrator.scala | 145 ----------- .../submit/KubernetesClientApplication.scala | 80 +++--- .../k8s/submit/KubernetesDriverBuilder.scala | 56 ++++ .../k8s/submit/KubernetesDriverSpec.scala | 47 ---- .../steps/BasicDriverConfigurationStep.scala | 163 ------------ .../steps/DependencyResolutionStep.scala | 61 ----- .../DriverKubernetesCredentialsStep.scala | 245 ------------------ .../submit/steps/DriverMountSecretsStep.scala | 38 --- .../steps/DriverServiceBootstrapStep.scala | 104 -------- .../cluster/k8s/ExecutorPodFactory.scala | 227 ---------------- .../k8s/KubernetesClusterManager.scala | 12 +- .../KubernetesClusterSchedulerBackend.scala | 20 +- .../k8s/KubernetesExecutorBuilder.scala | 41 +++ .../deploy/k8s/KubernetesConfSuite.scala | 175 +++++++++++++ .../BasicDriverFeatureStepSuite.scala | 153 +++++++++++ .../BasicExecutorFeatureStepSuite.scala | 179 +++++++++++++ ...bernetesCredentialsFeatureStepSuite.scala} | 101 +++++--- .../DriverServiceFeatureStepSuite.scala | 227 ++++++++++++++++ .../KubernetesFeaturesTestUtils.scala | 61 +++++ .../MountSecretsFeatureStepSuite.scala} | 29 ++- .../spark/deploy/k8s/submit/ClientSuite.scala | 216 +++++++-------- .../DriverConfigOrchestratorSuite.scala | 131 ---------- .../submit/KubernetesDriverBuilderSuite.scala | 102 ++++++++ .../BasicDriverConfigurationStepSuite.scala | 122 --------- .../steps/DependencyResolutionStepSuite.scala | 69 ----- .../DriverServiceBootstrapStepSuite.scala | 180 ------------- .../cluster/k8s/ExecutorPodFactorySuite.scala | 195 -------------- ...bernetesClusterSchedulerBackendSuite.scala | 37 ++- .../k8s/KubernetesExecutorBuilderSuite.scala | 75 ++++++ 41 files changed, 2289 insertions(+), 2081 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala rename resource-managers/kubernetes/core/src/{test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala => main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala} (57%) delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala rename resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/{submit/steps/DriverConfigurationStep.scala => SparkPod.scala} (64%) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/{submit/steps/DriverKubernetesCredentialsStepSuite.scala => features/DriverKubernetesCredentialsFeatureStepSuite.scala} (67%) create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/{submit/steps/DriverMountSecretsStepSuite.scala => features/MountSecretsFeatureStepSuite.scala} (64%) delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 82f6c714f3555..4086970ffb256 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -167,5 +167,5 @@ private[spark] object Config extends Logging { val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets." - val KUBERNETES_DRIVER_ENV_KEY = "spark.kubernetes.driverEnv." + val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala new file mode 100644 index 0000000000000..77b634ddfabcc --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{LocalObjectReference, LocalObjectReferenceBuilder, Pod} + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.{JavaMainAppResource, MainAppResource} +import org.apache.spark.internal.config.ConfigEntry + +private[spark] sealed trait KubernetesRoleSpecificConf + +/* + * Structure containing metadata for Kubernetes logic that builds a Spark driver. + */ +private[spark] case class KubernetesDriverSpecificConf( + mainAppResource: Option[MainAppResource], + mainClass: String, + appName: String, + appArgs: Seq[String]) extends KubernetesRoleSpecificConf + +/* + * Structure containing metadata for Kubernetes logic that builds a Spark executor. + */ +private[spark] case class KubernetesExecutorSpecificConf( + executorId: String, + driverPod: Pod) + extends KubernetesRoleSpecificConf + +/** + * Structure containing metadata for Kubernetes logic to build Spark pods. + */ +private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( + sparkConf: SparkConf, + roleSpecificConf: T, + appResourceNamePrefix: String, + appId: String, + roleLabels: Map[String, String], + roleAnnotations: Map[String, String], + roleSecretNamesToMountPaths: Map[String, String], + roleEnvs: Map[String, String]) { + + def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE) + + def sparkJars(): Seq[String] = sparkConf + .getOption("spark.jars") + .map(str => str.split(",").toSeq) + .getOrElse(Seq.empty[String]) + + def sparkFiles(): Seq[String] = sparkConf + .getOption("spark.files") + .map(str => str.split(",").toSeq) + .getOrElse(Seq.empty[String]) + + def imagePullPolicy(): String = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) + + def imagePullSecrets(): Seq[LocalObjectReference] = { + sparkConf + .get(IMAGE_PULL_SECRETS) + .map(_.split(",")) + .getOrElse(Array.empty[String]) + .map(_.trim) + .map { secret => + new LocalObjectReferenceBuilder().withName(secret).build() + } + } + + def nodeSelector(): Map[String, String] = + KubernetesUtils.parsePrefixedKeyValuePairs(sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) + + def get[T](config: ConfigEntry[T]): T = sparkConf.get(config) + + def get(conf: String): String = sparkConf.get(conf) + + def get(conf: String, defaultValue: String): String = sparkConf.get(conf, defaultValue) + + def getOption(key: String): Option[String] = sparkConf.getOption(key) +} + +private[spark] object KubernetesConf { + def createDriverConf( + sparkConf: SparkConf, + appName: String, + appResourceNamePrefix: String, + appId: String, + mainAppResource: Option[MainAppResource], + mainClass: String, + appArgs: Array[String]): KubernetesConf[KubernetesDriverSpecificConf] = { + val sparkConfWithMainAppJar = sparkConf.clone() + mainAppResource.foreach { + case JavaMainAppResource(res) => + val previousJars = sparkConf + .getOption("spark.jars") + .map(_.split(",")) + .getOrElse(Array.empty) + if (!previousJars.contains(res)) { + sparkConfWithMainAppJar.setJars(previousJars ++ Seq(res)) + } + } + + val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_LABEL_PREFIX) + require(!driverCustomLabels.contains(SPARK_APP_ID_LABEL), "Label with key " + + s"$SPARK_APP_ID_LABEL is not allowed as it is reserved for Spark bookkeeping " + + "operations.") + require(!driverCustomLabels.contains(SPARK_ROLE_LABEL), "Label with key " + + s"$SPARK_ROLE_LABEL is not allowed as it is reserved for Spark bookkeeping " + + "operations.") + val driverLabels = driverCustomLabels ++ Map( + SPARK_APP_ID_LABEL -> appId, + SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) + val driverAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) + val driverSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_SECRETS_PREFIX) + val driverEnvs = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_ENV_PREFIX) + + KubernetesConf( + sparkConfWithMainAppJar, + KubernetesDriverSpecificConf(mainAppResource, mainClass, appName, appArgs), + appResourceNamePrefix, + appId, + driverLabels, + driverAnnotations, + driverSecretNamesToMountPaths, + driverEnvs) + } + + def createExecutorConf( + sparkConf: SparkConf, + executorId: String, + appId: String, + driverPod: Pod): KubernetesConf[KubernetesExecutorSpecificConf] = { + val executorCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_LABEL_PREFIX) + require( + !executorCustomLabels.contains(SPARK_APP_ID_LABEL), + s"Custom executor labels cannot contain $SPARK_APP_ID_LABEL as it is reserved for Spark.") + require( + !executorCustomLabels.contains(SPARK_EXECUTOR_ID_LABEL), + s"Custom executor labels cannot contain $SPARK_EXECUTOR_ID_LABEL as it is reserved for" + + " Spark.") + require( + !executorCustomLabels.contains(SPARK_ROLE_LABEL), + s"Custom executor labels cannot contain $SPARK_ROLE_LABEL as it is reserved for Spark.") + val executorLabels = Map( + SPARK_EXECUTOR_ID_LABEL -> executorId, + SPARK_APP_ID_LABEL -> appId, + SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ + executorCustomLabels + val executorAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) + val executorSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) + val executorEnv = sparkConf.getExecutorEnv.toMap + + KubernetesConf( + sparkConf.clone(), + KubernetesExecutorSpecificConf(executorId, driverPod), + sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX), + appId, + executorLabels, + executorAnnotations, + executorSecrets, + executorEnv) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala similarity index 57% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala index cf41b22e241af..0c5ae022f4070 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala @@ -16,21 +16,16 @@ */ package org.apache.spark.deploy.k8s -import io.fabric8.kubernetes.api.model.LocalObjectReference +import io.fabric8.kubernetes.api.model.HasMetadata -import org.apache.spark.SparkFunSuite - -class KubernetesUtilsTest extends SparkFunSuite { - - test("testParseImagePullSecrets") { - val noSecrets = KubernetesUtils.parseImagePullSecrets(None) - assert(noSecrets === Nil) - - val oneSecret = KubernetesUtils.parseImagePullSecrets(Some("imagePullSecret")) - assert(oneSecret === new LocalObjectReference("imagePullSecret") :: Nil) - - val commaSeparatedSecrets = KubernetesUtils.parseImagePullSecrets(Some("s1, s2 , s3,s4")) - assert(commaSeparatedSecrets.map(_.getName) === "s1" :: "s2" :: "s3" :: "s4" :: Nil) - } +private[spark] case class KubernetesDriverSpec( + pod: SparkPod, + driverKubernetesResources: Seq[HasMetadata], + systemProperties: Map[String, String]) +private[spark] object KubernetesDriverSpec { + def initialSpec(initialProps: Map[String, String]): KubernetesDriverSpec = KubernetesDriverSpec( + SparkPod.initialPod(), + Seq.empty, + initialProps) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 5b2bb819cdb14..ee629068ad90d 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -37,17 +37,6 @@ private[spark] object KubernetesUtils { sparkConf.getAllWithPrefix(prefix).toMap } - /** - * Parses comma-separated list of imagePullSecrets into K8s-understandable format - */ - def parseImagePullSecrets(imagePullSecrets: Option[String]): List[LocalObjectReference] = { - imagePullSecrets match { - case Some(secretsCommaSeparated) => - secretsCommaSeparated.split(',').map(_.trim).map(new LocalObjectReference(_)).toList - case None => Nil - } - } - def requireNandDefined(opt1: Option[_], opt2: Option[_], errMessage: String): Unit = { opt1.foreach { _ => require(opt2.isEmpty, errMessage) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala deleted file mode 100644 index c35e7db51d407..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBuilder} - -/** - * Bootstraps a driver or executor container or an init-container with needed secrets mounted. - */ -private[spark] class MountSecretsBootstrap(secretNamesToMountPaths: Map[String, String]) { - - /** - * Add new secret volumes for the secrets specified in secretNamesToMountPaths into the given pod. - * - * @param pod the pod into which the secret volumes are being added. - * @return the updated pod with the secret volumes added. - */ - def addSecretVolumes(pod: Pod): Pod = { - var podBuilder = new PodBuilder(pod) - secretNamesToMountPaths.keys.foreach { name => - podBuilder = podBuilder - .editOrNewSpec() - .addNewVolume() - .withName(secretVolumeName(name)) - .withNewSecret() - .withSecretName(name) - .endSecret() - .endVolume() - .endSpec() - } - - podBuilder.build() - } - - /** - * Mounts Kubernetes secret volumes of the secrets specified in secretNamesToMountPaths into the - * given container. - * - * @param container the container into which the secret volumes are being mounted. - * @return the updated container with the secrets mounted. - */ - def mountSecrets(container: Container): Container = { - var containerBuilder = new ContainerBuilder(container) - secretNamesToMountPaths.foreach { case (name, path) => - containerBuilder = containerBuilder - .addNewVolumeMount() - .withName(secretVolumeName(name)) - .withMountPath(path) - .endVolumeMount() - } - - containerBuilder.build() - } - - private def secretVolumeName(secretName: String): String = { - secretName + "-volume" - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala similarity index 64% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala index 17614e040e587..345dd117fd35f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala @@ -14,17 +14,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit.steps +package org.apache.spark.deploy.k8s -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBuilder} -/** - * Represents a step in configuring the Spark driver pod. - */ -private[spark] trait DriverConfigurationStep { +private[spark] case class SparkPod(pod: Pod, container: Container) - /** - * Apply some transformation to the previous state of the driver to add a new feature to it. - */ - def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec +private[spark] object SparkPod { + def initialPod(): SparkPod = { + SparkPod( + new PodBuilder() + .withNewMetadata() + .endMetadata() + .withNewSpec() + .endSpec() + .build(), + new ContainerBuilder().build()) + } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala new file mode 100644 index 0000000000000..07bdccbe0479d --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} + +import org.apache.spark.SparkException +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.config._ +import org.apache.spark.launcher.SparkLauncher + +private[spark] class BasicDriverFeatureStep( + conf: KubernetesConf[KubernetesDriverSpecificConf]) + extends KubernetesFeatureConfigStep { + + private val driverPodName = conf + .get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse(s"${conf.appResourceNamePrefix}-driver") + + private val driverContainerImage = conf + .get(DRIVER_CONTAINER_IMAGE) + .getOrElse(throw new SparkException("Must specify the driver container image")) + + // CPU settings + private val driverCpuCores = conf.get("spark.driver.cores", "1") + private val driverLimitCores = conf.get(KUBERNETES_DRIVER_LIMIT_CORES) + + // Memory settings + private val driverMemoryMiB = conf.get(DRIVER_MEMORY) + private val memoryOverheadMiB = conf + .get(DRIVER_MEMORY_OVERHEAD) + .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) + private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB + + override def configurePod(pod: SparkPod): SparkPod = { + val driverCustomEnvs = conf.roleEnvs + .toSeq + .map { env => + new EnvVarBuilder() + .withName(env._1) + .withValue(env._2) + .build() + } + + val driverCpuQuantity = new QuantityBuilder(false) + .withAmount(driverCpuCores) + .build() + val driverMemoryQuantity = new QuantityBuilder(false) + .withAmount(s"${driverMemoryWithOverheadMiB}Mi") + .build() + val maybeCpuLimitQuantity = driverLimitCores.map { limitCores => + ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) + } + + val driverContainer = new ContainerBuilder(pod.container) + .withName(DRIVER_CONTAINER_NAME) + .withImage(driverContainerImage) + .withImagePullPolicy(conf.imagePullPolicy()) + .addAllToEnv(driverCustomEnvs.asJava) + .addNewEnv() + .withName(ENV_DRIVER_BIND_ADDRESS) + .withValueFrom(new EnvVarSourceBuilder() + .withNewFieldRef("v1", "status.podIP") + .build()) + .endEnv() + .withNewResources() + .addToRequests("cpu", driverCpuQuantity) + .addToLimits(maybeCpuLimitQuantity.toMap.asJava) + .addToRequests("memory", driverMemoryQuantity) + .addToLimits("memory", driverMemoryQuantity) + .endResources() + .addToArgs("driver") + .addToArgs("--properties-file", SPARK_CONF_PATH) + .addToArgs("--class", conf.roleSpecificConf.mainClass) + // The user application jar is merged into the spark.jars list and managed through that + // property, so there is no need to reference it explicitly here. + .addToArgs(SparkLauncher.NO_RESOURCE) + .addToArgs(conf.roleSpecificConf.appArgs: _*) + .build() + + val driverPod = new PodBuilder(pod.pod) + .editOrNewMetadata() + .withName(driverPodName) + .addToLabels(conf.roleLabels.asJava) + .addToAnnotations(conf.roleAnnotations.asJava) + .endMetadata() + .withNewSpec() + .withRestartPolicy("Never") + .withNodeSelector(conf.nodeSelector().asJava) + .addToImagePullSecrets(conf.imagePullSecrets(): _*) + .endSpec() + .build() + SparkPod(driverPod, driverContainer) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = { + val additionalProps = mutable.Map( + KUBERNETES_DRIVER_POD_NAME.key -> driverPodName, + "spark.app.id" -> conf.appId, + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> conf.appResourceNamePrefix, + KUBERNETES_DRIVER_SUBMIT_CHECK.key -> "true") + + val resolvedSparkJars = KubernetesUtils.resolveFileUrisAndPath( + conf.sparkJars()) + val resolvedSparkFiles = KubernetesUtils.resolveFileUrisAndPath( + conf.sparkFiles()) + if (resolvedSparkJars.nonEmpty) { + additionalProps.put("spark.jars", resolvedSparkJars.mkString(",")) + } + if (resolvedSparkFiles.nonEmpty) { + additionalProps.put("spark.files", resolvedSparkFiles.mkString(",")) + } + additionalProps.toMap + } + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala new file mode 100644 index 0000000000000..d22097587aafe --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, ContainerPortBuilder, EnvVar, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} + +import org.apache.spark.SparkException +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} +import org.apache.spark.rpc.RpcEndpointAddress +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.util.Utils + +private[spark] class BasicExecutorFeatureStep( + kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]) + extends KubernetesFeatureConfigStep { + + // Consider moving some of these fields to KubernetesConf or KubernetesExecutorSpecificConf + private val executorExtraClasspath = kubernetesConf.get(EXECUTOR_CLASS_PATH) + private val executorContainerImage = kubernetesConf + .get(EXECUTOR_CONTAINER_IMAGE) + .getOrElse(throw new SparkException("Must specify the executor container image")) + private val blockManagerPort = kubernetesConf + .sparkConf + .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) + + private val executorPodNamePrefix = kubernetesConf.appResourceNamePrefix + + private val driverUrl = RpcEndpointAddress( + kubernetesConf.get("spark.driver.host"), + kubernetesConf.sparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString + private val executorMemoryMiB = kubernetesConf.get(EXECUTOR_MEMORY) + private val executorMemoryString = kubernetesConf.get( + EXECUTOR_MEMORY.key, EXECUTOR_MEMORY.defaultValueString) + + private val memoryOverheadMiB = kubernetesConf + .get(EXECUTOR_MEMORY_OVERHEAD) + .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt, + MEMORY_OVERHEAD_MIN_MIB)) + private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB + + private val executorCores = kubernetesConf.sparkConf.getInt("spark.executor.cores", 1) + private val executorCoresRequest = + if (kubernetesConf.sparkConf.contains(KUBERNETES_EXECUTOR_REQUEST_CORES)) { + kubernetesConf.get(KUBERNETES_EXECUTOR_REQUEST_CORES).get + } else { + executorCores.toString + } + private val executorLimitCores = kubernetesConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) + + override def configurePod(pod: SparkPod): SparkPod = { + val name = s"$executorPodNamePrefix-exec-${kubernetesConf.roleSpecificConf.executorId}" + + // hostname must be no longer than 63 characters, so take the last 63 characters of the pod + // name as the hostname. This preserves uniqueness since the end of name contains + // executorId + val hostname = name.substring(Math.max(0, name.length - 63)) + val executorMemoryQuantity = new QuantityBuilder(false) + .withAmount(s"${executorMemoryWithOverhead}Mi") + .build() + val executorCpuQuantity = new QuantityBuilder(false) + .withAmount(executorCoresRequest) + .build() + val executorExtraClasspathEnv = executorExtraClasspath.map { cp => + new EnvVarBuilder() + .withName(ENV_CLASSPATH) + .withValue(cp) + .build() + } + val executorExtraJavaOptionsEnv = kubernetesConf + .get(EXECUTOR_JAVA_OPTIONS) + .map { opts => + val delimitedOpts = Utils.splitCommandString(opts) + delimitedOpts.zipWithIndex.map { + case (opt, index) => + new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() + } + }.getOrElse(Seq.empty[EnvVar]) + val executorEnv = (Seq( + (ENV_DRIVER_URL, driverUrl), + (ENV_EXECUTOR_CORES, executorCores.toString), + (ENV_EXECUTOR_MEMORY, executorMemoryString), + (ENV_APPLICATION_ID, kubernetesConf.appId), + // This is to set the SPARK_CONF_DIR to be /opt/spark/conf + (ENV_SPARK_CONF_DIR, SPARK_CONF_DIR_INTERNAL), + (ENV_EXECUTOR_ID, kubernetesConf.roleSpecificConf.executorId)) ++ + kubernetesConf.roleEnvs) + .map(env => new EnvVarBuilder() + .withName(env._1) + .withValue(env._2) + .build() + ) ++ Seq( + new EnvVarBuilder() + .withName(ENV_EXECUTOR_POD_IP) + .withValueFrom(new EnvVarSourceBuilder() + .withNewFieldRef("v1", "status.podIP") + .build()) + .build() + ) ++ executorExtraJavaOptionsEnv ++ executorExtraClasspathEnv.toSeq + val requiredPorts = Seq( + (BLOCK_MANAGER_PORT_NAME, blockManagerPort)) + .map { case (name, port) => + new ContainerPortBuilder() + .withName(name) + .withContainerPort(port) + .build() + } + + val executorContainer = new ContainerBuilder(pod.container) + .withName("executor") + .withImage(executorContainerImage) + .withImagePullPolicy(kubernetesConf.imagePullPolicy()) + .withNewResources() + .addToRequests("memory", executorMemoryQuantity) + .addToLimits("memory", executorMemoryQuantity) + .addToRequests("cpu", executorCpuQuantity) + .endResources() + .addAllToEnv(executorEnv.asJava) + .withPorts(requiredPorts.asJava) + .addToArgs("executor") + .build() + val containerWithLimitCores = executorLimitCores.map { limitCores => + val executorCpuLimitQuantity = new QuantityBuilder(false) + .withAmount(limitCores) + .build() + new ContainerBuilder(executorContainer) + .editResources() + .addToLimits("cpu", executorCpuLimitQuantity) + .endResources() + .build() + }.getOrElse(executorContainer) + val driverPod = kubernetesConf.roleSpecificConf.driverPod + val executorPod = new PodBuilder(pod.pod) + .editOrNewMetadata() + .withName(name) + .withLabels(kubernetesConf.roleLabels.asJava) + .withAnnotations(kubernetesConf.roleAnnotations.asJava) + .withOwnerReferences() + .addNewOwnerReference() + .withController(true) + .withApiVersion(driverPod.getApiVersion) + .withKind(driverPod.getKind) + .withName(driverPod.getMetadata.getName) + .withUid(driverPod.getMetadata.getUid) + .endOwnerReference() + .endMetadata() + .editOrNewSpec() + .withHostname(hostname) + .withRestartPolicy("Never") + .withNodeSelector(kubernetesConf.nodeSelector().asJava) + .addToImagePullSecrets(kubernetesConf.imagePullSecrets(): _*) + .endSpec() + .build() + SparkPod(executorPod, containerWithLimitCores) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala new file mode 100644 index 0000000000000..ff5ad6673b309 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import java.io.File +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ + +import com.google.common.io.{BaseEncoding, Files} +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, Secret, SecretBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +private[spark] class DriverKubernetesCredentialsFeatureStep(kubernetesConf: KubernetesConf[_]) + extends KubernetesFeatureConfigStep { + // TODO clean up this class, and credentials in general. See also SparkKubernetesClientFactory. + // We should use a struct to hold all creds-related fields. A lot of the code is very repetitive. + + private val maybeMountedOAuthTokenFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX") + private val maybeMountedClientKeyFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX") + private val maybeMountedClientCertFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX") + private val maybeMountedCaCertFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX") + private val driverServiceAccount = kubernetesConf.get(KUBERNETES_SERVICE_ACCOUNT_NAME) + + private val oauthTokenBase64 = kubernetesConf + .getOption(s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX") + .map { token => + BaseEncoding.base64().encode(token.getBytes(StandardCharsets.UTF_8)) + } + + private val caCertDataBase64 = safeFileConfToBase64( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", + "Driver CA cert file") + private val clientKeyDataBase64 = safeFileConfToBase64( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX", + "Driver client key file") + private val clientCertDataBase64 = safeFileConfToBase64( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX", + "Driver client cert file") + + // TODO decide whether or not to apply this step entirely in the caller, i.e. the builder. + private val shouldMountSecret = oauthTokenBase64.isDefined || + caCertDataBase64.isDefined || + clientKeyDataBase64.isDefined || + clientCertDataBase64.isDefined + + private val driverCredentialsSecretName = + s"${kubernetesConf.appResourceNamePrefix}-kubernetes-credentials" + + override def configurePod(pod: SparkPod): SparkPod = { + if (!shouldMountSecret) { + pod.copy( + pod = driverServiceAccount.map { account => + new PodBuilder(pod.pod) + .editOrNewSpec() + .withServiceAccount(account) + .withServiceAccountName(account) + .endSpec() + .build() + }.getOrElse(pod.pod)) + } else { + val driverPodWithMountedKubernetesCredentials = + new PodBuilder(pod.pod) + .editOrNewSpec() + .addNewVolume() + .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) + .withNewSecret().withSecretName(driverCredentialsSecretName).endSecret() + .endVolume() + .endSpec() + .build() + + val driverContainerWithMountedSecretVolume = + new ContainerBuilder(pod.container) + .addNewVolumeMount() + .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) + .withMountPath(DRIVER_CREDENTIALS_SECRETS_BASE_DIR) + .endVolumeMount() + .build() + SparkPod(driverPodWithMountedKubernetesCredentials, driverContainerWithMountedSecretVolume) + } + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = { + val resolvedMountedOAuthTokenFile = resolveSecretLocation( + maybeMountedOAuthTokenFile, + oauthTokenBase64, + DRIVER_CREDENTIALS_OAUTH_TOKEN_PATH) + val resolvedMountedClientKeyFile = resolveSecretLocation( + maybeMountedClientKeyFile, + clientKeyDataBase64, + DRIVER_CREDENTIALS_CLIENT_KEY_PATH) + val resolvedMountedClientCertFile = resolveSecretLocation( + maybeMountedClientCertFile, + clientCertDataBase64, + DRIVER_CREDENTIALS_CLIENT_CERT_PATH) + val resolvedMountedCaCertFile = resolveSecretLocation( + maybeMountedCaCertFile, + caCertDataBase64, + DRIVER_CREDENTIALS_CA_CERT_PATH) + + val redactedTokens = kubernetesConf.sparkConf.getAll + .filter(_._1.endsWith(OAUTH_TOKEN_CONF_SUFFIX)) + .toMap + .mapValues( _ => "") + redactedTokens ++ + resolvedMountedCaCertFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) ++ + resolvedMountedClientKeyFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) ++ + resolvedMountedClientCertFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) ++ + resolvedMountedOAuthTokenFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) + } + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { + if (shouldMountSecret) { + Seq(createCredentialsSecret()) + } else { + Seq.empty + } + } + + private def safeFileConfToBase64(conf: String, fileType: String): Option[String] = { + kubernetesConf.getOption(conf) + .map(new File(_)) + .map { file => + require(file.isFile, String.format("%s provided at %s does not exist or is not a file.", + fileType, file.getAbsolutePath)) + BaseEncoding.base64().encode(Files.toByteArray(file)) + } + } + + /** + * Resolve a Kubernetes secret data entry from an optional client credential used by the + * driver to talk to the Kubernetes API server. + * + * @param userSpecifiedCredential the optional user-specified client credential. + * @param secretName name of the Kubernetes secret storing the client credential. + * @return a secret data entry in the form of a map from the secret name to the secret data, + * which may be empty if the user-specified credential is empty. + */ + private def resolveSecretData( + userSpecifiedCredential: Option[String], + secretName: String): Map[String, String] = { + userSpecifiedCredential.map { valueBase64 => + Map(secretName -> valueBase64) + }.getOrElse(Map.empty[String, String]) + } + + private def resolveSecretLocation( + mountedUserSpecified: Option[String], + valueMountedFromSubmitter: Option[String], + mountedCanonicalLocation: String): Option[String] = { + mountedUserSpecified.orElse(valueMountedFromSubmitter.map { _ => + mountedCanonicalLocation + }) + } + + private def createCredentialsSecret(): Secret = { + val allSecretData = + resolveSecretData( + clientKeyDataBase64, + DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME) ++ + resolveSecretData( + clientCertDataBase64, + DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME) ++ + resolveSecretData( + caCertDataBase64, + DRIVER_CREDENTIALS_CA_CERT_SECRET_NAME) ++ + resolveSecretData( + oauthTokenBase64, + DRIVER_CREDENTIALS_OAUTH_TOKEN_SECRET_NAME) + + new SecretBuilder() + .withNewMetadata() + .withName(driverCredentialsSecretName) + .endMetadata() + .withData(allSecretData.asJava) + .build() + } + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala new file mode 100644 index 0000000000000..f2d7bbd08f305 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{HasMetadata, ServiceBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.{Clock, SystemClock} + +private[spark] class DriverServiceFeatureStep( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf], + clock: Clock = new SystemClock) + extends KubernetesFeatureConfigStep with Logging { + import DriverServiceFeatureStep._ + + require(kubernetesConf.getOption(DRIVER_BIND_ADDRESS_KEY).isEmpty, + s"$DRIVER_BIND_ADDRESS_KEY is not supported in Kubernetes mode, as the driver's bind " + + "address is managed and set to the driver pod's IP address.") + require(kubernetesConf.getOption(DRIVER_HOST_KEY).isEmpty, + s"$DRIVER_HOST_KEY is not supported in Kubernetes mode, as the driver's hostname will be " + + "managed via a Kubernetes service.") + + private val preferredServiceName = s"${kubernetesConf.appResourceNamePrefix}$DRIVER_SVC_POSTFIX" + private val resolvedServiceName = if (preferredServiceName.length <= MAX_SERVICE_NAME_LENGTH) { + preferredServiceName + } else { + val randomServiceId = clock.getTimeMillis() + val shorterServiceName = s"spark-$randomServiceId$DRIVER_SVC_POSTFIX" + logWarning(s"Driver's hostname would preferably be $preferredServiceName, but this is " + + s"too long (must be <= $MAX_SERVICE_NAME_LENGTH characters). Falling back to use " + + s"$shorterServiceName as the driver service's name.") + shorterServiceName + } + + private val driverPort = kubernetesConf.sparkConf.getInt( + "spark.driver.port", DEFAULT_DRIVER_PORT) + private val driverBlockManagerPort = kubernetesConf.sparkConf.getInt( + org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key, DEFAULT_BLOCKMANAGER_PORT) + + override def configurePod(pod: SparkPod): SparkPod = pod + + override def getAdditionalPodSystemProperties(): Map[String, String] = { + val driverHostname = s"$resolvedServiceName.${kubernetesConf.namespace()}.svc" + Map(DRIVER_HOST_KEY -> driverHostname, + "spark.driver.port" -> driverPort.toString, + org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key -> + driverBlockManagerPort.toString) + } + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { + val driverService = new ServiceBuilder() + .withNewMetadata() + .withName(resolvedServiceName) + .endMetadata() + .withNewSpec() + .withClusterIP("None") + .withSelector(kubernetesConf.roleLabels.asJava) + .addNewPort() + .withName(DRIVER_PORT_NAME) + .withPort(driverPort) + .withNewTargetPort(driverPort) + .endPort() + .addNewPort() + .withName(BLOCK_MANAGER_PORT_NAME) + .withPort(driverBlockManagerPort) + .withNewTargetPort(driverBlockManagerPort) + .endPort() + .endSpec() + .build() + Seq(driverService) + } +} + +private[spark] object DriverServiceFeatureStep { + val DRIVER_BIND_ADDRESS_KEY = org.apache.spark.internal.config.DRIVER_BIND_ADDRESS.key + val DRIVER_HOST_KEY = org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key + val DRIVER_SVC_POSTFIX = "-driver-svc" + val MAX_SERVICE_NAME_LENGTH = 63 +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala new file mode 100644 index 0000000000000..4c1be3bb13293 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.HasMetadata + +import org.apache.spark.deploy.k8s.SparkPod + +/** + * A collection of functions that together represent a "feature" in pods that are launched for + * Spark drivers and executors. + */ +private[spark] trait KubernetesFeatureConfigStep { + + /** + * Apply modifications on the given pod in accordance to this feature. This can include attaching + * volumes, adding environment variables, and adding labels/annotations. + *

    + * Note that we should return a SparkPod that keeps all of the properties of the passed SparkPod + * object. So this is correct: + *

    +   * {@code val configuredPod = new PodBuilder(pod.pod)
    +   *     .editSpec()
    +   *     ...
    +   *     .build()
    +   *   val configuredContainer = new ContainerBuilder(pod.container)
    +   *     ...
    +   *     .build()
    +   *   SparkPod(configuredPod, configuredContainer)
    +   *  }
    +   * 
    + * This is incorrect: + *
    +   * {@code val configuredPod = new PodBuilder() // Loses the original state
    +   *     .editSpec()
    +   *     ...
    +   *     .build()
    +   *   val configuredContainer = new ContainerBuilder() // Loses the original state
    +   *     ...
    +   *     .build()
    +   *   SparkPod(configuredPod, configuredContainer)
    +   *  }
    +   * 
    + */ + def configurePod(pod: SparkPod): SparkPod + + /** + * Return any system properties that should be set on the JVM in accordance to this feature. + */ + def getAdditionalPodSystemProperties(): Map[String, String] + + /** + * Return any additional Kubernetes resources that should be added to support this feature. Only + * applicable when creating the driver in cluster mode. + */ + def getAdditionalKubernetesResources(): Seq[HasMetadata] +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala new file mode 100644 index 0000000000000..97fa9499b2edb --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, VolumeBuilder, VolumeMountBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesRoleSpecificConf, SparkPod} + +private[spark] class MountSecretsFeatureStep( + kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val addedVolumes = kubernetesConf + .roleSecretNamesToMountPaths + .keys + .map(secretName => + new VolumeBuilder() + .withName(secretVolumeName(secretName)) + .withNewSecret() + .withSecretName(secretName) + .endSecret() + .build()) + val podWithVolumes = new PodBuilder(pod.pod) + .editOrNewSpec() + .addToVolumes(addedVolumes.toSeq: _*) + .endSpec() + .build() + val addedVolumeMounts = kubernetesConf + .roleSecretNamesToMountPaths + .map { + case (secretName, mountPath) => + new VolumeMountBuilder() + .withName(secretVolumeName(secretName)) + .withMountPath(mountPath) + .build() + } + val containerWithMounts = new ContainerBuilder(pod.container) + .addToVolumeMounts(addedVolumeMounts.toSeq: _*) + .build() + SparkPod(podWithVolumes, containerWithMounts) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty + + private def secretVolumeName(secretName: String): String = s"$secretName-volume" +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala deleted file mode 100644 index b4d3f04a1bc32..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.steps._ -import org.apache.spark.launcher.SparkLauncher -import org.apache.spark.util.SystemClock -import org.apache.spark.util.Utils - -/** - * Figures out and returns the complete ordered list of needed DriverConfigurationSteps to - * configure the Spark driver pod. The returned steps will be applied one by one in the given - * order to produce a final KubernetesDriverSpec that is used in KubernetesClientApplication - * to construct and create the driver pod. - */ -private[spark] class DriverConfigOrchestrator( - kubernetesAppId: String, - kubernetesResourceNamePrefix: String, - mainAppResource: Option[MainAppResource], - appName: String, - mainClass: String, - appArgs: Array[String], - sparkConf: SparkConf) { - - // The resource name prefix is derived from the Spark application name, making it easy to connect - // the names of the Kubernetes resources from e.g. kubectl or the Kubernetes dashboard to the - // application the user submitted. - - private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) - - def getAllConfigurationSteps: Seq[DriverConfigurationStep] = { - val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_DRIVER_LABEL_PREFIX) - require(!driverCustomLabels.contains(SPARK_APP_ID_LABEL), "Label with key " + - s"$SPARK_APP_ID_LABEL is not allowed as it is reserved for Spark bookkeeping " + - "operations.") - require(!driverCustomLabels.contains(SPARK_ROLE_LABEL), "Label with key " + - s"$SPARK_ROLE_LABEL is not allowed as it is reserved for Spark bookkeeping " + - "operations.") - - val secretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_DRIVER_SECRETS_PREFIX) - - val allDriverLabels = driverCustomLabels ++ Map( - SPARK_APP_ID_LABEL -> kubernetesAppId, - SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) - - val initialSubmissionStep = new BasicDriverConfigurationStep( - kubernetesAppId, - kubernetesResourceNamePrefix, - allDriverLabels, - imagePullPolicy, - appName, - mainClass, - appArgs, - sparkConf) - - val serviceBootstrapStep = new DriverServiceBootstrapStep( - kubernetesResourceNamePrefix, - allDriverLabels, - sparkConf, - new SystemClock) - - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - sparkConf, kubernetesResourceNamePrefix) - - val additionalMainAppJar = if (mainAppResource.nonEmpty) { - val mayBeResource = mainAppResource.get match { - case JavaMainAppResource(resource) if resource != SparkLauncher.NO_RESOURCE => - Some(resource) - case _ => None - } - mayBeResource - } else { - None - } - - val sparkJars = sparkConf.getOption("spark.jars") - .map(_.split(",")) - .getOrElse(Array.empty[String]) ++ - additionalMainAppJar.toSeq - val sparkFiles = sparkConf.getOption("spark.files") - .map(_.split(",")) - .getOrElse(Array.empty[String]) - - // TODO(SPARK-23153): remove once submission client local dependencies are supported. - if (existSubmissionLocalFiles(sparkJars) || existSubmissionLocalFiles(sparkFiles)) { - throw new SparkException("The Kubernetes mode does not yet support referencing application " + - "dependencies in the local file system.") - } - - val dependencyResolutionStep = if (sparkJars.nonEmpty || sparkFiles.nonEmpty) { - Seq(new DependencyResolutionStep( - sparkJars, - sparkFiles)) - } else { - Nil - } - - val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { - Seq(new DriverMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) - } else { - Nil - } - - Seq( - initialSubmissionStep, - serviceBootstrapStep, - kubernetesCredentialsStep) ++ - dependencyResolutionStep ++ - mountSecretsStep - } - - private def existSubmissionLocalFiles(files: Seq[String]): Boolean = { - files.exists { uri => - Utils.resolveURI(uri).getScheme == "file" - } - } - - private def existNonContainerLocalFiles(files: Seq[String]): Boolean = { - files.exists { uri => - Utils.resolveURI(uri).getScheme != "local" - } - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index e16d1add600b2..a97f5650fb869 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -27,12 +27,10 @@ import scala.util.control.NonFatal import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkApplication +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory -import org.apache.spark.deploy.k8s.submit.steps.DriverConfigurationStep import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.ConfigBuilder import org.apache.spark.util.Utils /** @@ -43,9 +41,9 @@ import org.apache.spark.util.Utils * @param driverArgs arguments to the driver */ private[spark] case class ClientArguments( - mainAppResource: Option[MainAppResource], - mainClass: String, - driverArgs: Array[String]) + mainAppResource: Option[MainAppResource], + mainClass: String, + driverArgs: Array[String]) private[spark] object ClientArguments { @@ -80,8 +78,9 @@ private[spark] object ClientArguments { * watcher that monitors and logs the application status. Waits for the application to terminate if * spark.kubernetes.submission.waitAppCompletion is true. * - * @param submissionSteps steps that collectively configure the driver - * @param sparkConf the submission client Spark configuration + * @param builder Responsible for building the base driver pod based on a composition of + * implemented features. + * @param kubernetesConf application configuration * @param kubernetesClient the client to talk to the Kubernetes API server * @param waitForAppCompletion a flag indicating whether the client should wait for the application * to complete @@ -89,31 +88,21 @@ private[spark] object ClientArguments { * @param watcher a watcher that monitors and logs the application status */ private[spark] class Client( - submissionSteps: Seq[DriverConfigurationStep], - sparkConf: SparkConf, + builder: KubernetesDriverBuilder, + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf], kubernetesClient: KubernetesClient, waitForAppCompletion: Boolean, appName: String, watcher: LoggingPodStatusWatcher, kubernetesResourceNamePrefix: String) extends Logging { - /** - * Run command that initializes a DriverSpec that will be updated after each - * DriverConfigurationStep in the sequence that is passed in. The final KubernetesDriverSpec - * will be used to build the Driver Container, Driver Pod, and Kubernetes Resources - */ def run(): Unit = { - var currentDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf) - // submissionSteps contain steps necessary to take, to resolve varying - // client arguments that are passed in, created by orchestrator - for (nextStep <- submissionSteps) { - currentDriverSpec = nextStep.configureDriver(currentDriverSpec) - } + val resolvedDriverSpec = builder.buildFromFeatures(kubernetesConf) val configMapName = s"$kubernetesResourceNamePrefix-driver-conf-map" - val configMap = buildConfigMap(configMapName, currentDriverSpec.driverSparkConf) + val configMap = buildConfigMap(configMapName, resolvedDriverSpec.systemProperties) // The include of the ENV_VAR for "SPARK_CONF_DIR" is to allow for the // Spark command builder to pickup on the Java Options present in the ConfigMap - val resolvedDriverContainer = new ContainerBuilder(currentDriverSpec.driverContainer) + val resolvedDriverContainer = new ContainerBuilder(resolvedDriverSpec.pod.container) .addNewEnv() .withName(ENV_SPARK_CONF_DIR) .withValue(SPARK_CONF_DIR_INTERNAL) @@ -123,7 +112,7 @@ private[spark] class Client( .withMountPath(SPARK_CONF_DIR_INTERNAL) .endVolumeMount() .build() - val resolvedDriverPod = new PodBuilder(currentDriverSpec.driverPod) + val resolvedDriverPod = new PodBuilder(resolvedDriverSpec.pod.pod) .editSpec() .addToContainers(resolvedDriverContainer) .addNewVolume() @@ -141,12 +130,10 @@ private[spark] class Client( .watch(watcher)) { _ => val createdDriverPod = kubernetesClient.pods().create(resolvedDriverPod) try { - if (currentDriverSpec.otherKubernetesResources.nonEmpty) { - val otherKubernetesResources = - currentDriverSpec.otherKubernetesResources ++ Seq(configMap) - addDriverOwnerReference(createdDriverPod, otherKubernetesResources) - kubernetesClient.resourceList(otherKubernetesResources: _*).createOrReplace() - } + val otherKubernetesResources = + resolvedDriverSpec.driverKubernetesResources ++ Seq(configMap) + addDriverOwnerReference(createdDriverPod, otherKubernetesResources) + kubernetesClient.resourceList(otherKubernetesResources: _*).createOrReplace() } catch { case NonFatal(e) => kubernetesClient.pods().delete(createdDriverPod) @@ -180,20 +167,17 @@ private[spark] class Client( } // Build a Config Map that will house spark conf properties in a single file for spark-submit - private def buildConfigMap(configMapName: String, conf: SparkConf): ConfigMap = { + private def buildConfigMap(configMapName: String, conf: Map[String, String]): ConfigMap = { val properties = new Properties() - conf.getAll.foreach { case (k, v) => + conf.foreach { case (k, v) => properties.setProperty(k, v) } val propertiesWriter = new StringWriter() properties.store(propertiesWriter, s"Java properties built from Kubernetes config map with name: $configMapName") - - val namespace = conf.get(KUBERNETES_NAMESPACE) new ConfigMapBuilder() .withNewMetadata() .withName(configMapName) - .withNamespace(namespace) .endMetadata() .addToData(SPARK_CONF_FILE_NAME, propertiesWriter.toString) .build() @@ -211,7 +195,7 @@ private[spark] class KubernetesClientApplication extends SparkApplication { } private def run(clientArguments: ClientArguments, sparkConf: SparkConf): Unit = { - val namespace = sparkConf.get(KUBERNETES_NAMESPACE) + val appName = sparkConf.getOption("spark.app.name").getOrElse("spark") // For constructing the app ID, we can't use the Spark application name, as the app ID is going // to be added as a label to group resources belonging to the same application. Label values are // considerably restrictive, e.g. must be no longer than 63 characters in length. So we generate @@ -219,10 +203,19 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val kubernetesAppId = s"spark-${UUID.randomUUID().toString.replaceAll("-", "")}" val launchTime = System.currentTimeMillis() val waitForAppCompletion = sparkConf.get(WAIT_FOR_APP_COMPLETION) - val appName = sparkConf.getOption("spark.app.name").getOrElse("spark") val kubernetesResourceNamePrefix = { s"$appName-$launchTime".toLowerCase.replaceAll("\\.", "-") } + val kubernetesConf = KubernetesConf.createDriverConf( + sparkConf, + appName, + kubernetesResourceNamePrefix, + kubernetesAppId, + clientArguments.mainAppResource, + clientArguments.mainClass, + clientArguments.driverArgs) + val builder = new KubernetesDriverBuilder + val namespace = kubernetesConf.namespace() // The master URL has been checked for validity already in SparkSubmit. // We just need to get rid of the "k8s://" prefix here. val master = sparkConf.get("spark.master").substring("k8s://".length) @@ -230,15 +223,6 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val watcher = new LoggingPodStatusWatcherImpl(kubernetesAppId, loggingInterval) - val orchestrator = new DriverConfigOrchestrator( - kubernetesAppId, - kubernetesResourceNamePrefix, - clientArguments.mainAppResource, - appName, - clientArguments.mainClass, - clientArguments.driverArgs, - sparkConf) - Utils.tryWithResource(SparkKubernetesClientFactory.createKubernetesClient( master, Some(namespace), @@ -247,8 +231,8 @@ private[spark] class KubernetesClientApplication extends SparkApplication { None, None)) { kubernetesClient => val client = new Client( - orchestrator.getAllConfigurationSteps, - sparkConf, + builder, + kubernetesConf, kubernetesClient, waitForAppCompletion, appName, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala new file mode 100644 index 0000000000000..c7579ed8cb689 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, MountSecretsFeatureStep} + +private[spark] class KubernetesDriverBuilder( + provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep = + new BasicDriverFeatureStep(_), + provideCredentialsStep: (KubernetesConf[KubernetesDriverSpecificConf]) + => DriverKubernetesCredentialsFeatureStep = + new DriverKubernetesCredentialsFeatureStep(_), + provideServiceStep: (KubernetesConf[KubernetesDriverSpecificConf]) => DriverServiceFeatureStep = + new DriverServiceFeatureStep(_), + provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => MountSecretsFeatureStep) = + new MountSecretsFeatureStep(_)) { + + def buildFromFeatures( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]): KubernetesDriverSpec = { + val baseFeatures = Seq( + provideBasicStep(kubernetesConf), + provideCredentialsStep(kubernetesConf), + provideServiceStep(kubernetesConf)) + val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) + } else baseFeatures + + var spec = KubernetesDriverSpec.initialSpec(kubernetesConf.sparkConf.getAll.toMap) + for (feature <- allFeatures) { + val configuredPod = feature.configurePod(spec.pod) + val addedSystemProperties = feature.getAdditionalPodSystemProperties() + val addedResources = feature.getAdditionalKubernetesResources() + spec = KubernetesDriverSpec( + configuredPod, + spec.driverKubernetesResources ++ addedResources, + spec.systemProperties ++ addedSystemProperties) + } + spec + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala deleted file mode 100644 index db13f09387ef9..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit - -import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, HasMetadata, Pod, PodBuilder} - -import org.apache.spark.SparkConf - -/** - * Represents the components and characteristics of a Spark driver. The driver can be considered - * as being comprised of the driver pod itself, any other Kubernetes resources that the driver - * pod depends on, and the SparkConf that should be supplied to the Spark application. The driver - * container should be operated on via the specific field of this case class as opposed to trying - * to edit the container directly on the pod. The driver container should be attached at the - * end of executing all submission steps. - */ -private[spark] case class KubernetesDriverSpec( - driverPod: Pod, - driverContainer: Container, - otherKubernetesResources: Seq[HasMetadata], - driverSparkConf: SparkConf) - -private[spark] object KubernetesDriverSpec { - def initialSpec(initialSparkConf: SparkConf): KubernetesDriverSpec = { - KubernetesDriverSpec( - // Set new metadata and a new spec so that submission steps can use - // PodBuilder#editMetadata() and/or PodBuilder#editSpec() safely. - new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build(), - new ContainerBuilder().build(), - Seq.empty[HasMetadata], - initialSparkConf.clone()) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala deleted file mode 100644 index fcb1db8008053..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ /dev/null @@ -1,163 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model._ - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.KubernetesUtils -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.internal.config.{DRIVER_CLASS_PATH, DRIVER_MEMORY, DRIVER_MEMORY_OVERHEAD} -import org.apache.spark.launcher.SparkLauncher - -/** - * Performs basic configuration for the driver pod. - */ -private[spark] class BasicDriverConfigurationStep( - kubernetesAppId: String, - resourceNamePrefix: String, - driverLabels: Map[String, String], - imagePullPolicy: String, - appName: String, - mainClass: String, - appArgs: Array[String], - sparkConf: SparkConf) extends DriverConfigurationStep { - - private val driverPodName = sparkConf - .get(KUBERNETES_DRIVER_POD_NAME) - .getOrElse(s"$resourceNamePrefix-driver") - - private val driverExtraClasspath = sparkConf.get(DRIVER_CLASS_PATH) - - private val driverContainerImage = sparkConf - .get(DRIVER_CONTAINER_IMAGE) - .getOrElse(throw new SparkException("Must specify the driver container image")) - - private val imagePullSecrets = sparkConf.get(IMAGE_PULL_SECRETS) - - // CPU settings - private val driverCpuCores = sparkConf.getOption("spark.driver.cores").getOrElse("1") - private val driverLimitCores = sparkConf.get(KUBERNETES_DRIVER_LIMIT_CORES) - - // Memory settings - private val driverMemoryMiB = sparkConf.get(DRIVER_MEMORY) - private val memoryOverheadMiB = sparkConf - .get(DRIVER_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) - private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val driverExtraClasspathEnv = driverExtraClasspath.map { classPath => - new EnvVarBuilder() - .withName(ENV_CLASSPATH) - .withValue(classPath) - .build() - } - - val driverCustomAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) - require(!driverCustomAnnotations.contains(SPARK_APP_NAME_ANNOTATION), - s"Annotation with key $SPARK_APP_NAME_ANNOTATION is not allowed as it is reserved for" + - " Spark bookkeeping operations.") - - val driverCustomEnvs = sparkConf.getAllWithPrefix(KUBERNETES_DRIVER_ENV_KEY).toSeq - .map { env => - new EnvVarBuilder() - .withName(env._1) - .withValue(env._2) - .build() - } - - val driverAnnotations = driverCustomAnnotations ++ Map(SPARK_APP_NAME_ANNOTATION -> appName) - - val nodeSelector = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) - - val driverCpuQuantity = new QuantityBuilder(false) - .withAmount(driverCpuCores) - .build() - val driverMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${driverMemoryWithOverheadMiB}Mi") - .build() - val maybeCpuLimitQuantity = driverLimitCores.map { limitCores => - ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) - } - - val driverContainerWithoutArgs = new ContainerBuilder(driverSpec.driverContainer) - .withName(DRIVER_CONTAINER_NAME) - .withImage(driverContainerImage) - .withImagePullPolicy(imagePullPolicy) - .addAllToEnv(driverCustomEnvs.asJava) - .addToEnv(driverExtraClasspathEnv.toSeq: _*) - .addNewEnv() - .withName(ENV_DRIVER_BIND_ADDRESS) - .withValueFrom(new EnvVarSourceBuilder() - .withNewFieldRef("v1", "status.podIP") - .build()) - .endEnv() - .withNewResources() - .addToRequests("cpu", driverCpuQuantity) - .addToRequests("memory", driverMemoryQuantity) - .addToLimits("memory", driverMemoryQuantity) - .addToLimits(maybeCpuLimitQuantity.toMap.asJava) - .endResources() - .addToArgs("driver") - .addToArgs("--properties-file", SPARK_CONF_PATH) - .addToArgs("--class", mainClass) - // The user application jar is merged into the spark.jars list and managed through that - // property, so there is no need to reference it explicitly here. - .addToArgs(SparkLauncher.NO_RESOURCE) - - val driverContainer = appArgs.toList match { - case "" :: Nil | Nil => driverContainerWithoutArgs.build() - case _ => driverContainerWithoutArgs.addToArgs(appArgs: _*).build() - } - - val parsedImagePullSecrets = KubernetesUtils.parseImagePullSecrets(imagePullSecrets) - - val baseDriverPod = new PodBuilder(driverSpec.driverPod) - .editOrNewMetadata() - .withName(driverPodName) - .addToLabels(driverLabels.asJava) - .addToAnnotations(driverAnnotations.asJava) - .endMetadata() - .withNewSpec() - .withRestartPolicy("Never") - .withNodeSelector(nodeSelector.asJava) - .withImagePullSecrets(parsedImagePullSecrets.asJava) - .endSpec() - .build() - - val resolvedSparkConf = driverSpec.driverSparkConf.clone() - .setIfMissing(KUBERNETES_DRIVER_POD_NAME, driverPodName) - .set("spark.app.id", kubernetesAppId) - .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, resourceNamePrefix) - // to set the config variables to allow client-mode spark-submit from driver - .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) - - driverSpec.copy( - driverPod = baseDriverPod, - driverSparkConf = resolvedSparkConf, - driverContainer = driverContainer) - } - -} - diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala deleted file mode 100644 index 43de329f239ad..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.File - -import io.fabric8.kubernetes.api.model.ContainerBuilder - -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.KubernetesUtils -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -/** - * Step that configures the classpath, spark.jars, and spark.files for the driver given that the - * user may provide remote files or files with local:// schemes. - */ -private[spark] class DependencyResolutionStep( - sparkJars: Seq[String], - sparkFiles: Seq[String]) extends DriverConfigurationStep { - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val resolvedSparkJars = KubernetesUtils.resolveFileUrisAndPath(sparkJars) - val resolvedSparkFiles = KubernetesUtils.resolveFileUrisAndPath(sparkFiles) - - val sparkConf = driverSpec.driverSparkConf.clone() - if (resolvedSparkJars.nonEmpty) { - sparkConf.set("spark.jars", resolvedSparkJars.mkString(",")) - } - if (resolvedSparkFiles.nonEmpty) { - sparkConf.set("spark.files", resolvedSparkFiles.mkString(",")) - } - val resolvedDriverContainer = if (resolvedSparkJars.nonEmpty) { - new ContainerBuilder(driverSpec.driverContainer) - .addNewEnv() - .withName(ENV_MOUNTED_CLASSPATH) - .withValue(resolvedSparkJars.mkString(File.pathSeparator)) - .endEnv() - .build() - } else { - driverSpec.driverContainer - } - - driverSpec.copy( - driverContainer = resolvedDriverContainer, - driverSparkConf = sparkConf) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala deleted file mode 100644 index 2424e63999a82..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala +++ /dev/null @@ -1,245 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.File -import java.nio.charset.StandardCharsets - -import scala.collection.JavaConverters._ -import scala.language.implicitConversions - -import com.google.common.io.{BaseEncoding, Files} -import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder, Secret, SecretBuilder} - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -/** - * Mounts Kubernetes credentials into the driver pod. The driver will use such mounted credentials - * to request executors. - */ -private[spark] class DriverKubernetesCredentialsStep( - submissionSparkConf: SparkConf, - kubernetesResourceNamePrefix: String) extends DriverConfigurationStep { - - private val maybeMountedOAuthTokenFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX") - private val maybeMountedClientKeyFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX") - private val maybeMountedClientCertFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX") - private val maybeMountedCaCertFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX") - private val driverServiceAccount = submissionSparkConf.get(KUBERNETES_SERVICE_ACCOUNT_NAME) - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val driverSparkConf = driverSpec.driverSparkConf.clone() - - val oauthTokenBase64 = submissionSparkConf - .getOption(s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX") - .map { token => - BaseEncoding.base64().encode(token.getBytes(StandardCharsets.UTF_8)) - } - val caCertDataBase64 = safeFileConfToBase64( - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", - "Driver CA cert file") - val clientKeyDataBase64 = safeFileConfToBase64( - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX", - "Driver client key file") - val clientCertDataBase64 = safeFileConfToBase64( - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX", - "Driver client cert file") - - val driverSparkConfWithCredentialsLocations = setDriverPodKubernetesCredentialLocations( - driverSparkConf, - oauthTokenBase64, - caCertDataBase64, - clientKeyDataBase64, - clientCertDataBase64) - - val kubernetesCredentialsSecret = createCredentialsSecret( - oauthTokenBase64, - caCertDataBase64, - clientKeyDataBase64, - clientCertDataBase64) - - val driverPodWithMountedKubernetesCredentials = kubernetesCredentialsSecret.map { secret => - new PodBuilder(driverSpec.driverPod) - .editOrNewSpec() - .addNewVolume() - .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) - .withNewSecret().withSecretName(secret.getMetadata.getName).endSecret() - .endVolume() - .endSpec() - .build() - }.getOrElse( - driverServiceAccount.map { account => - new PodBuilder(driverSpec.driverPod) - .editOrNewSpec() - .withServiceAccount(account) - .withServiceAccountName(account) - .endSpec() - .build() - }.getOrElse(driverSpec.driverPod) - ) - - val driverContainerWithMountedSecretVolume = kubernetesCredentialsSecret.map { _ => - new ContainerBuilder(driverSpec.driverContainer) - .addNewVolumeMount() - .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) - .withMountPath(DRIVER_CREDENTIALS_SECRETS_BASE_DIR) - .endVolumeMount() - .build() - }.getOrElse(driverSpec.driverContainer) - - driverSpec.copy( - driverPod = driverPodWithMountedKubernetesCredentials, - otherKubernetesResources = - driverSpec.otherKubernetesResources ++ kubernetesCredentialsSecret.toSeq, - driverSparkConf = driverSparkConfWithCredentialsLocations, - driverContainer = driverContainerWithMountedSecretVolume) - } - - private def createCredentialsSecret( - driverOAuthTokenBase64: Option[String], - driverCaCertDataBase64: Option[String], - driverClientKeyDataBase64: Option[String], - driverClientCertDataBase64: Option[String]): Option[Secret] = { - val allSecretData = - resolveSecretData( - driverClientKeyDataBase64, - DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME) ++ - resolveSecretData( - driverClientCertDataBase64, - DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME) ++ - resolveSecretData( - driverCaCertDataBase64, - DRIVER_CREDENTIALS_CA_CERT_SECRET_NAME) ++ - resolveSecretData( - driverOAuthTokenBase64, - DRIVER_CREDENTIALS_OAUTH_TOKEN_SECRET_NAME) - - if (allSecretData.isEmpty) { - None - } else { - Some(new SecretBuilder() - .withNewMetadata() - .withName(s"$kubernetesResourceNamePrefix-kubernetes-credentials") - .endMetadata() - .withData(allSecretData.asJava) - .build()) - } - } - - private def setDriverPodKubernetesCredentialLocations( - driverSparkConf: SparkConf, - driverOauthTokenBase64: Option[String], - driverCaCertDataBase64: Option[String], - driverClientKeyDataBase64: Option[String], - driverClientCertDataBase64: Option[String]): SparkConf = { - val resolvedMountedOAuthTokenFile = resolveSecretLocation( - maybeMountedOAuthTokenFile, - driverOauthTokenBase64, - DRIVER_CREDENTIALS_OAUTH_TOKEN_PATH) - val resolvedMountedClientKeyFile = resolveSecretLocation( - maybeMountedClientKeyFile, - driverClientKeyDataBase64, - DRIVER_CREDENTIALS_CLIENT_KEY_PATH) - val resolvedMountedClientCertFile = resolveSecretLocation( - maybeMountedClientCertFile, - driverClientCertDataBase64, - DRIVER_CREDENTIALS_CLIENT_CERT_PATH) - val resolvedMountedCaCertFile = resolveSecretLocation( - maybeMountedCaCertFile, - driverCaCertDataBase64, - DRIVER_CREDENTIALS_CA_CERT_PATH) - - val sparkConfWithCredentialLocations = driverSparkConf - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", - resolvedMountedCaCertFile) - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX", - resolvedMountedClientKeyFile) - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX", - resolvedMountedClientCertFile) - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX", - resolvedMountedOAuthTokenFile) - - // Redact all OAuth token values - sparkConfWithCredentialLocations - .getAll - .filter(_._1.endsWith(OAUTH_TOKEN_CONF_SUFFIX)).map(_._1) - .foreach { - sparkConfWithCredentialLocations.set(_, "") - } - sparkConfWithCredentialLocations - } - - private def safeFileConfToBase64(conf: String, fileType: String): Option[String] = { - submissionSparkConf.getOption(conf) - .map(new File(_)) - .map { file => - require(file.isFile, String.format("%s provided at %s does not exist or is not a file.", - fileType, file.getAbsolutePath)) - BaseEncoding.base64().encode(Files.toByteArray(file)) - } - } - - private def resolveSecretLocation( - mountedUserSpecified: Option[String], - valueMountedFromSubmitter: Option[String], - mountedCanonicalLocation: String): Option[String] = { - mountedUserSpecified.orElse(valueMountedFromSubmitter.map { _ => - mountedCanonicalLocation - }) - } - - /** - * Resolve a Kubernetes secret data entry from an optional client credential used by the - * driver to talk to the Kubernetes API server. - * - * @param userSpecifiedCredential the optional user-specified client credential. - * @param secretName name of the Kubernetes secret storing the client credential. - * @return a secret data entry in the form of a map from the secret name to the secret data, - * which may be empty if the user-specified credential is empty. - */ - private def resolveSecretData( - userSpecifiedCredential: Option[String], - secretName: String): Map[String, String] = { - userSpecifiedCredential.map { valueBase64 => - Map(secretName -> valueBase64) - }.getOrElse(Map.empty[String, String]) - } - - private implicit def augmentSparkConf(sparkConf: SparkConf): OptionSettableSparkConf = { - new OptionSettableSparkConf(sparkConf) - } -} - -private class OptionSettableSparkConf(sparkConf: SparkConf) { - def setOption(configEntry: String, option: Option[String]): SparkConf = { - option.foreach { opt => - sparkConf.set(configEntry, opt) - } - sparkConf - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala deleted file mode 100644 index 91e9a9f211335..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import org.apache.spark.deploy.k8s.MountSecretsBootstrap -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -/** - * A driver configuration step for mounting user-specified secrets onto user-specified paths. - * - * @param bootstrap a utility actually handling mounting of the secrets. - */ -private[spark] class DriverMountSecretsStep( - bootstrap: MountSecretsBootstrap) extends DriverConfigurationStep { - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val pod = bootstrap.addSecretVolumes(driverSpec.driverPod) - val container = bootstrap.mountSecrets(driverSpec.driverContainer) - driverSpec.copy( - driverPod = pod, - driverContainer = container - ) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala deleted file mode 100644 index 34af7cde6c1a9..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.ServiceBuilder - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.internal.Logging -import org.apache.spark.util.Clock - -/** - * Allows the driver to be reachable by executor pods through a headless service. The service's - * ports should correspond to the ports that the executor will reach the pod at for RPC. - */ -private[spark] class DriverServiceBootstrapStep( - resourceNamePrefix: String, - driverLabels: Map[String, String], - sparkConf: SparkConf, - clock: Clock) extends DriverConfigurationStep with Logging { - - import DriverServiceBootstrapStep._ - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - require(sparkConf.getOption(DRIVER_BIND_ADDRESS_KEY).isEmpty, - s"$DRIVER_BIND_ADDRESS_KEY is not supported in Kubernetes mode, as the driver's bind " + - "address is managed and set to the driver pod's IP address.") - require(sparkConf.getOption(DRIVER_HOST_KEY).isEmpty, - s"$DRIVER_HOST_KEY is not supported in Kubernetes mode, as the driver's hostname will be " + - "managed via a Kubernetes service.") - - val preferredServiceName = s"$resourceNamePrefix$DRIVER_SVC_POSTFIX" - val resolvedServiceName = if (preferredServiceName.length <= MAX_SERVICE_NAME_LENGTH) { - preferredServiceName - } else { - val randomServiceId = clock.getTimeMillis() - val shorterServiceName = s"spark-$randomServiceId$DRIVER_SVC_POSTFIX" - logWarning(s"Driver's hostname would preferably be $preferredServiceName, but this is " + - s"too long (must be <= $MAX_SERVICE_NAME_LENGTH characters). Falling back to use " + - s"$shorterServiceName as the driver service's name.") - shorterServiceName - } - - val driverPort = sparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT) - val driverBlockManagerPort = sparkConf.getInt( - org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key, DEFAULT_BLOCKMANAGER_PORT) - val driverService = new ServiceBuilder() - .withNewMetadata() - .withName(resolvedServiceName) - .endMetadata() - .withNewSpec() - .withClusterIP("None") - .withSelector(driverLabels.asJava) - .addNewPort() - .withName(DRIVER_PORT_NAME) - .withPort(driverPort) - .withNewTargetPort(driverPort) - .endPort() - .addNewPort() - .withName(BLOCK_MANAGER_PORT_NAME) - .withPort(driverBlockManagerPort) - .withNewTargetPort(driverBlockManagerPort) - .endPort() - .endSpec() - .build() - - val namespace = sparkConf.get(KUBERNETES_NAMESPACE) - val driverHostname = s"${driverService.getMetadata.getName}.$namespace.svc" - val resolvedSparkConf = driverSpec.driverSparkConf.clone() - .set(DRIVER_HOST_KEY, driverHostname) - .set("spark.driver.port", driverPort.toString) - .set( - org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, driverBlockManagerPort) - - driverSpec.copy( - driverSparkConf = resolvedSparkConf, - otherKubernetesResources = driverSpec.otherKubernetesResources ++ Seq(driverService)) - } -} - -private[spark] object DriverServiceBootstrapStep { - val DRIVER_BIND_ADDRESS_KEY = org.apache.spark.internal.config.DRIVER_BIND_ADDRESS.key - val DRIVER_HOST_KEY = org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key - val DRIVER_SVC_POSTFIX = "-driver-svc" - val MAX_SERVICE_NAME_LENGTH = 63 -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala deleted file mode 100644 index 8607d6fba3234..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ /dev/null @@ -1,227 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.scheduler.cluster.k8s - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model._ - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} -import org.apache.spark.util.Utils - -/** - * A factory class for bootstrapping and creating executor pods with the given bootstrapping - * components. - * - * @param sparkConf Spark configuration - * @param mountSecretsBootstrap an optional component for mounting user-specified secrets onto - * user-specified paths into the executor container - */ -private[spark] class ExecutorPodFactory( - sparkConf: SparkConf, - mountSecretsBootstrap: Option[MountSecretsBootstrap]) { - - private val executorExtraClasspath = sparkConf.get(EXECUTOR_CLASS_PATH) - - private val executorLabels = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_EXECUTOR_LABEL_PREFIX) - require( - !executorLabels.contains(SPARK_APP_ID_LABEL), - s"Custom executor labels cannot contain $SPARK_APP_ID_LABEL as it is reserved for Spark.") - require( - !executorLabels.contains(SPARK_EXECUTOR_ID_LABEL), - s"Custom executor labels cannot contain $SPARK_EXECUTOR_ID_LABEL as it is reserved for" + - " Spark.") - require( - !executorLabels.contains(SPARK_ROLE_LABEL), - s"Custom executor labels cannot contain $SPARK_ROLE_LABEL as it is reserved for Spark.") - - private val executorAnnotations = - KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) - private val nodeSelector = - KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_NODE_SELECTOR_PREFIX) - - private val executorContainerImage = sparkConf - .get(EXECUTOR_CONTAINER_IMAGE) - .getOrElse(throw new SparkException("Must specify the executor container image")) - private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) - private val imagePullSecrets = sparkConf.get(IMAGE_PULL_SECRETS) - private val blockManagerPort = sparkConf - .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) - - private val executorPodNamePrefix = sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX) - - private val executorMemoryMiB = sparkConf.get(EXECUTOR_MEMORY) - private val executorMemoryString = sparkConf.get( - EXECUTOR_MEMORY.key, EXECUTOR_MEMORY.defaultValueString) - - private val memoryOverheadMiB = sparkConf - .get(EXECUTOR_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt, - MEMORY_OVERHEAD_MIN_MIB)) - private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB - - private val executorCores = sparkConf.getInt("spark.executor.cores", 1) - private val executorCoresRequest = if (sparkConf.contains(KUBERNETES_EXECUTOR_REQUEST_CORES)) { - sparkConf.get(KUBERNETES_EXECUTOR_REQUEST_CORES).get - } else { - executorCores.toString - } - private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) - - /** - * Configure and construct an executor pod with the given parameters. - */ - def createExecutorPod( - executorId: String, - applicationId: String, - driverUrl: String, - executorEnvs: Seq[(String, String)], - driverPod: Pod, - nodeToLocalTaskCount: Map[String, Int]): Pod = { - val name = s"$executorPodNamePrefix-exec-$executorId" - - val parsedImagePullSecrets = KubernetesUtils.parseImagePullSecrets(imagePullSecrets) - - // hostname must be no longer than 63 characters, so take the last 63 characters of the pod - // name as the hostname. This preserves uniqueness since the end of name contains - // executorId - val hostname = name.substring(Math.max(0, name.length - 63)) - val resolvedExecutorLabels = Map( - SPARK_EXECUTOR_ID_LABEL -> executorId, - SPARK_APP_ID_LABEL -> applicationId, - SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ - executorLabels - val executorMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${executorMemoryWithOverhead}Mi") - .build() - val executorCpuQuantity = new QuantityBuilder(false) - .withAmount(executorCoresRequest) - .build() - val executorExtraClasspathEnv = executorExtraClasspath.map { cp => - new EnvVarBuilder() - .withName(ENV_CLASSPATH) - .withValue(cp) - .build() - } - val executorExtraJavaOptionsEnv = sparkConf - .get(EXECUTOR_JAVA_OPTIONS) - .map { opts => - val delimitedOpts = Utils.splitCommandString(opts) - delimitedOpts.zipWithIndex.map { - case (opt, index) => - new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() - } - }.getOrElse(Seq.empty[EnvVar]) - val executorEnv = (Seq( - (ENV_DRIVER_URL, driverUrl), - (ENV_EXECUTOR_CORES, executorCores.toString), - (ENV_EXECUTOR_MEMORY, executorMemoryString), - (ENV_APPLICATION_ID, applicationId), - // This is to set the SPARK_CONF_DIR to be /opt/spark/conf - (ENV_SPARK_CONF_DIR, SPARK_CONF_DIR_INTERNAL), - (ENV_EXECUTOR_ID, executorId)) ++ executorEnvs) - .map(env => new EnvVarBuilder() - .withName(env._1) - .withValue(env._2) - .build() - ) ++ Seq( - new EnvVarBuilder() - .withName(ENV_EXECUTOR_POD_IP) - .withValueFrom(new EnvVarSourceBuilder() - .withNewFieldRef("v1", "status.podIP") - .build()) - .build() - ) ++ executorExtraJavaOptionsEnv ++ executorExtraClasspathEnv.toSeq - val requiredPorts = Seq( - (BLOCK_MANAGER_PORT_NAME, blockManagerPort)) - .map { case (name, port) => - new ContainerPortBuilder() - .withName(name) - .withContainerPort(port) - .build() - } - - val executorContainer = new ContainerBuilder() - .withName("executor") - .withImage(executorContainerImage) - .withImagePullPolicy(imagePullPolicy) - .withNewResources() - .addToRequests("memory", executorMemoryQuantity) - .addToLimits("memory", executorMemoryQuantity) - .addToRequests("cpu", executorCpuQuantity) - .endResources() - .addAllToEnv(executorEnv.asJava) - .withPorts(requiredPorts.asJava) - .addToArgs("executor") - .build() - - val executorPod = new PodBuilder() - .withNewMetadata() - .withName(name) - .withLabels(resolvedExecutorLabels.asJava) - .withAnnotations(executorAnnotations.asJava) - .withOwnerReferences() - .addNewOwnerReference() - .withController(true) - .withApiVersion(driverPod.getApiVersion) - .withKind(driverPod.getKind) - .withName(driverPod.getMetadata.getName) - .withUid(driverPod.getMetadata.getUid) - .endOwnerReference() - .endMetadata() - .withNewSpec() - .withHostname(hostname) - .withRestartPolicy("Never") - .withNodeSelector(nodeSelector.asJava) - .withImagePullSecrets(parsedImagePullSecrets.asJava) - .endSpec() - .build() - - val containerWithLimitCores = executorLimitCores.map { limitCores => - val executorCpuLimitQuantity = new QuantityBuilder(false) - .withAmount(limitCores) - .build() - new ContainerBuilder(executorContainer) - .editResources() - .addToLimits("cpu", executorCpuLimitQuantity) - .endResources() - .build() - }.getOrElse(executorContainer) - - val (maybeSecretsMountedPod, maybeSecretsMountedContainer) = - mountSecretsBootstrap.map { bootstrap => - (bootstrap.addSecretVolumes(executorPod), bootstrap.mountSecrets(containerWithLimitCores)) - }.getOrElse((executorPod, containerWithLimitCores)) - - - new PodBuilder(maybeSecretsMountedPod) - .editSpec() - .addToContainers(maybeSecretsMountedContainer) - .endSpec() - .build() - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index ff5f6801da2a3..0ea80dfbc0d97 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -21,7 +21,7 @@ import java.io.File import io.fabric8.kubernetes.client.Config import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap, SparkKubernetesClientFactory} +import org.apache.spark.deploy.k8s.{KubernetesUtils, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging @@ -48,12 +48,6 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit scheduler: TaskScheduler): SchedulerBackend = { val executorSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( sc.conf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) - val mountSecretBootstrap = if (executorSecretNamesToMountPaths.nonEmpty) { - Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths)) - } else { - None - } - val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( KUBERNETES_MASTER_INTERNAL_URL, Some(sc.conf.get(KUBERNETES_NAMESPACE)), @@ -62,8 +56,6 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - val executorPodFactory = new ExecutorPodFactory(sc.conf, mountSecretBootstrap) - val allocatorExecutor = ThreadUtils .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( @@ -71,7 +63,7 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit new KubernetesClusterSchedulerBackend( scheduler.asInstanceOf[TaskSchedulerImpl], sc.env.rpcEnv, - executorPodFactory, + new KubernetesExecutorBuilder, kubernetesClient, allocatorExecutor, requestExecutorsService) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index 9de4b16c30d3c..d86664c81071b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -32,6 +32,7 @@ import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.SparkException import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.KubernetesConf import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv} import org.apache.spark.scheduler.{ExecutorExited, SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils} @@ -40,7 +41,7 @@ import org.apache.spark.util.Utils private[spark] class KubernetesClusterSchedulerBackend( scheduler: TaskSchedulerImpl, rpcEnv: RpcEnv, - executorPodFactory: ExecutorPodFactory, + executorBuilder: KubernetesExecutorBuilder, kubernetesClient: KubernetesClient, allocatorExecutor: ScheduledExecutorService, requestExecutorsService: ExecutorService) @@ -115,14 +116,19 @@ private[spark] class KubernetesClusterSchedulerBackend( for (_ <- 0 until math.min( currentTotalExpectedExecutors - runningExecutorsToPods.size, podAllocationSize)) { val executorId = EXECUTOR_ID_COUNTER.incrementAndGet().toString - val executorPod = executorPodFactory.createExecutorPod( + val executorConf = KubernetesConf.createExecutorConf( + conf, executorId, applicationId(), - driverUrl, - conf.getExecutorEnv, - driverPod, - currentNodeToLocalTaskCount) - executorsToAllocate(executorId) = executorPod + driverPod) + val executorPod = executorBuilder.buildFromFeatures(executorConf) + val podWithAttachedContainer = new PodBuilder(executorPod.pod) + .editOrNewSpec() + .addToContainers(executorPod.container) + .endSpec() + .build() + + executorsToAllocate(executorId) = podWithAttachedContainer logInfo( s"Requesting a new executor, total executors is now ${runningExecutorsToPods.size}") } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala new file mode 100644 index 0000000000000..22568fe7ea3be --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, MountSecretsFeatureStep} + +private[spark] class KubernetesExecutorBuilder( + provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep = + new BasicExecutorFeatureStep(_), + provideSecretsStep: + (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => MountSecretsFeatureStep = + new MountSecretsFeatureStep(_)) { + + def buildFromFeatures( + kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { + val baseFeatures = Seq(provideBasicStep(kubernetesConf)) + val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) + } else baseFeatures + var executorPod = SparkPod.initialPod() + for (feature <- allFeatures) { + executorPod = feature.configurePod(executorPod) + } + executorPod + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala new file mode 100644 index 0000000000000..f10202f7a3546 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{LocalObjectReferenceBuilder, PodBuilder} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.JavaMainAppResource + +class KubernetesConfSuite extends SparkFunSuite { + + private val APP_NAME = "test-app" + private val RESOURCE_NAME_PREFIX = "prefix" + private val APP_ID = "test-id" + private val MAIN_CLASS = "test-class" + private val APP_ARGS = Array("arg1", "arg2") + private val CUSTOM_LABELS = Map( + "customLabel1Key" -> "customLabel1Value", + "customLabel2Key" -> "customLabel2Value") + private val CUSTOM_ANNOTATIONS = Map( + "customAnnotation1Key" -> "customAnnotation1Value", + "customAnnotation2Key" -> "customAnnotation2Value") + private val SECRET_NAMES_TO_MOUNT_PATHS = Map( + "secret1" -> "/mnt/secrets/secret1", + "secret2" -> "/mnt/secrets/secret2") + private val CUSTOM_ENVS = Map( + "customEnvKey1" -> "customEnvValue1", + "customEnvKey2" -> "customEnvValue2") + private val DRIVER_POD = new PodBuilder().build() + private val EXECUTOR_ID = "executor-id" + + test("Basic driver translated fields.") { + val sparkConf = new SparkConf(false) + val conf = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + None, + MAIN_CLASS, + APP_ARGS) + assert(conf.appId === APP_ID) + assert(conf.sparkConf.getAll.toMap === sparkConf.getAll.toMap) + assert(conf.appResourceNamePrefix === RESOURCE_NAME_PREFIX) + assert(conf.roleSpecificConf.appName === APP_NAME) + assert(conf.roleSpecificConf.mainAppResource.isEmpty) + assert(conf.roleSpecificConf.mainClass === MAIN_CLASS) + assert(conf.roleSpecificConf.appArgs === APP_ARGS) + } + + test("Creating driver conf with and without the main app jar influences spark.jars") { + val sparkConf = new SparkConf(false) + .setJars(Seq("local:///opt/spark/jar1.jar")) + val mainAppJar = Some(JavaMainAppResource("local:///opt/spark/main.jar")) + val kubernetesConfWithMainJar = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + mainAppJar, + MAIN_CLASS, + APP_ARGS) + assert(kubernetesConfWithMainJar.sparkConf.get("spark.jars") + .split(",") + === Array("local:///opt/spark/jar1.jar", "local:///opt/spark/main.jar")) + val kubernetesConfWithoutMainJar = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + None, + MAIN_CLASS, + APP_ARGS) + assert(kubernetesConfWithoutMainJar.sparkConf.get("spark.jars").split(",") + === Array("local:///opt/spark/jar1.jar")) + } + + test("Resolve driver labels, annotations, secret mount paths, and envs.") { + val sparkConf = new SparkConf(false) + CUSTOM_LABELS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$key", value) + } + CUSTOM_ANNOTATIONS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$key", value) + } + SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$key", value) + } + CUSTOM_ENVS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_ENV_PREFIX$key", value) + } + + val conf = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + None, + MAIN_CLASS, + APP_ARGS) + assert(conf.roleLabels === Map( + SPARK_APP_ID_LABEL -> APP_ID, + SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) ++ + CUSTOM_LABELS) + assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) + assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + assert(conf.roleEnvs === CUSTOM_ENVS) + } + + test("Basic executor translated fields.") { + val conf = KubernetesConf.createExecutorConf( + new SparkConf(false), + EXECUTOR_ID, + APP_ID, + DRIVER_POD) + assert(conf.roleSpecificConf.executorId === EXECUTOR_ID) + assert(conf.roleSpecificConf.driverPod === DRIVER_POD) + } + + test("Image pull secrets.") { + val conf = KubernetesConf.createExecutorConf( + new SparkConf(false) + .set(IMAGE_PULL_SECRETS, "my-secret-1,my-secret-2 "), + EXECUTOR_ID, + APP_ID, + DRIVER_POD) + assert(conf.imagePullSecrets() === + Seq( + new LocalObjectReferenceBuilder().withName("my-secret-1").build(), + new LocalObjectReferenceBuilder().withName("my-secret-2").build())) + } + + test("Set executor labels, annotations, and secrets") { + val sparkConf = new SparkConf(false) + CUSTOM_LABELS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_LABEL_PREFIX$key", value) + } + CUSTOM_ANNOTATIONS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_ANNOTATION_PREFIX$key", value) + } + SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_SECRETS_PREFIX$key", value) + } + + val conf = KubernetesConf.createExecutorConf( + sparkConf, + EXECUTOR_ID, + APP_ID, + DRIVER_POD) + assert(conf.roleLabels === Map( + SPARK_EXECUTOR_ID_LABEL -> EXECUTOR_ID, + SPARK_APP_ID_LABEL -> APP_ID, + SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ CUSTOM_LABELS) + assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) + assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala new file mode 100644 index 0000000000000..eee85b8baa730 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.LocalObjectReferenceBuilder + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +class BasicDriverFeatureStepSuite extends SparkFunSuite { + + private val APP_ID = "spark-app-id" + private val RESOURCE_NAME_PREFIX = "spark" + private val DRIVER_LABELS = Map("labelkey" -> "labelvalue") + private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" + private val APP_NAME = "spark-test" + private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" + private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") + private val CUSTOM_ANNOTATION_KEY = "customAnnotation" + private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" + private val DRIVER_ANNOTATIONS = Map(CUSTOM_ANNOTATION_KEY -> CUSTOM_ANNOTATION_VALUE) + private val DRIVER_CUSTOM_ENV1 = "customDriverEnv1" + private val DRIVER_CUSTOM_ENV2 = "customDriverEnv2" + private val DRIVER_ENVS = Map( + DRIVER_CUSTOM_ENV1 -> DRIVER_CUSTOM_ENV1, + DRIVER_CUSTOM_ENV2 -> DRIVER_CUSTOM_ENV2) + private val TEST_IMAGE_PULL_SECRETS = Seq("my-secret-1", "my-secret-2") + private val TEST_IMAGE_PULL_SECRET_OBJECTS = + TEST_IMAGE_PULL_SECRETS.map { secret => + new LocalObjectReferenceBuilder().withName(secret).build() + } + + test("Check the pod respects all configurations from the user.") { + val sparkConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod") + .set("spark.driver.cores", "2") + .set(KUBERNETES_DRIVER_LIMIT_CORES, "4") + .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "256M") + .set(org.apache.spark.internal.config.DRIVER_MEMORY_OVERHEAD, 200L) + .set(CONTAINER_IMAGE, "spark-driver:latest") + .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, + APP_NAME, + MAIN_CLASS, + APP_ARGS), + RESOURCE_NAME_PREFIX, + APP_ID, + DRIVER_LABELS, + DRIVER_ANNOTATIONS, + Map.empty, + DRIVER_ENVS) + + val featureStep = new BasicDriverFeatureStep(kubernetesConf) + val basePod = SparkPod.initialPod() + val configuredPod = featureStep.configurePod(basePod) + + assert(configuredPod.container.getName === DRIVER_CONTAINER_NAME) + assert(configuredPod.container.getImage === "spark-driver:latest") + assert(configuredPod.container.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY) + + assert(configuredPod.container.getEnv.size === 3) + val envs = configuredPod.container + .getEnv + .asScala + .map(env => (env.getName, env.getValue)) + .toMap + assert(envs(DRIVER_CUSTOM_ENV1) === DRIVER_ENVS(DRIVER_CUSTOM_ENV1)) + assert(envs(DRIVER_CUSTOM_ENV2) === DRIVER_ENVS(DRIVER_CUSTOM_ENV2)) + + assert(configuredPod.pod.getSpec().getImagePullSecrets.asScala === + TEST_IMAGE_PULL_SECRET_OBJECTS) + + assert(configuredPod.container.getEnv.asScala.exists(envVar => + envVar.getName.equals(ENV_DRIVER_BIND_ADDRESS) && + envVar.getValueFrom.getFieldRef.getApiVersion.equals("v1") && + envVar.getValueFrom.getFieldRef.getFieldPath.equals("status.podIP"))) + + val resourceRequirements = configuredPod.container.getResources + val requests = resourceRequirements.getRequests.asScala + assert(requests("cpu").getAmount === "2") + assert(requests("memory").getAmount === "456Mi") + val limits = resourceRequirements.getLimits.asScala + assert(limits("memory").getAmount === "456Mi") + assert(limits("cpu").getAmount === "4") + + val driverPodMetadata = configuredPod.pod.getMetadata + assert(driverPodMetadata.getName === "spark-driver-pod") + assert(driverPodMetadata.getLabels.asScala === DRIVER_LABELS) + assert(driverPodMetadata.getAnnotations.asScala === DRIVER_ANNOTATIONS) + assert(configuredPod.pod.getSpec.getRestartPolicy === "Never") + + val expectedSparkConf = Map( + KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", + "spark.app.id" -> APP_ID, + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, + "spark.kubernetes.submitInDriver" -> "true") + assert(featureStep.getAdditionalPodSystemProperties() === expectedSparkConf) + } + + test("Additional system properties resolve jars and set cluster-mode confs.") { + val allJars = Seq("local:///opt/spark/jar1.jar", "hdfs:///opt/spark/jar2.jar") + val allFiles = Seq("https://localhost:9000/file1.txt", "local:///opt/spark/file2.txt") + val sparkConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod") + .setJars(allJars) + .set("spark.files", allFiles.mkString(",")) + .set(CONTAINER_IMAGE, "spark-driver:latest") + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, + APP_NAME, + MAIN_CLASS, + APP_ARGS), + RESOURCE_NAME_PREFIX, + APP_ID, + DRIVER_LABELS, + DRIVER_ANNOTATIONS, + Map.empty, + Map.empty) + val step = new BasicDriverFeatureStep(kubernetesConf) + val additionalProperties = step.getAdditionalPodSystemProperties() + val expectedSparkConf = Map( + KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", + "spark.app.id" -> APP_ID, + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, + "spark.kubernetes.submitInDriver" -> "true", + "spark.jars" -> "/opt/spark/jar1.jar,hdfs:///opt/spark/jar2.jar", + "spark.files" -> "https://localhost:9000/file1.txt,/opt/spark/file2.txt") + assert(additionalProperties === expectedSparkConf) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala new file mode 100644 index 0000000000000..a764f7630b5c8 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model._ +import org.mockito.MockitoAnnotations +import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.rpc.RpcEndpointAddress +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend + +class BasicExecutorFeatureStepSuite + extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { + + private val APP_ID = "app-id" + private val DRIVER_HOSTNAME = "localhost" + private val DRIVER_PORT = 7098 + private val DRIVER_ADDRESS = RpcEndpointAddress( + DRIVER_HOSTNAME, + DRIVER_PORT.toInt, + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + private val DRIVER_POD_NAME = "driver-pod" + + private val DRIVER_POD_UID = "driver-uid" + private val RESOURCE_NAME_PREFIX = "base" + private val EXECUTOR_IMAGE = "executor-image" + private val LABELS = Map("label1key" -> "label1value") + private val ANNOTATIONS = Map("annotation1key" -> "annotation1value") + private val TEST_IMAGE_PULL_SECRETS = Seq("my-1secret-1", "my-secret-2") + private val TEST_IMAGE_PULL_SECRET_OBJECTS = + TEST_IMAGE_PULL_SECRETS.map { secret => + new LocalObjectReferenceBuilder().withName(secret).build() + } + private val DRIVER_POD = new PodBuilder() + .withNewMetadata() + .withName(DRIVER_POD_NAME) + .withUid(DRIVER_POD_UID) + .endMetadata() + .withNewSpec() + .withNodeName("some-node") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.100") + .endStatus() + .build() + private var baseConf: SparkConf = _ + + before { + MockitoAnnotations.initMocks(this) + baseConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME) + .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, RESOURCE_NAME_PREFIX) + .set(CONTAINER_IMAGE, EXECUTOR_IMAGE) + .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) + .set("spark.driver.host", DRIVER_HOSTNAME) + .set("spark.driver.port", DRIVER_PORT.toString) + .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) + } + + test("basic executor pod has reasonable defaults") { + val step = new BasicExecutorFeatureStep( + KubernetesConf( + baseConf, + KubernetesExecutorSpecificConf("1", DRIVER_POD), + RESOURCE_NAME_PREFIX, + APP_ID, + LABELS, + ANNOTATIONS, + Map.empty, + Map.empty)) + val executor = step.configurePod(SparkPod.initialPod()) + + // The executor pod name and default labels. + assert(executor.pod.getMetadata.getName === s"$RESOURCE_NAME_PREFIX-exec-1") + assert(executor.pod.getMetadata.getLabels.asScala === LABELS) + assert(executor.pod.getSpec.getImagePullSecrets.asScala === TEST_IMAGE_PULL_SECRET_OBJECTS) + + // There is exactly 1 container with no volume mounts and default memory limits. + // Default memory limit is 1024M + 384M (minimum overhead constant). + assert(executor.container.getImage === EXECUTOR_IMAGE) + assert(executor.container.getVolumeMounts.isEmpty) + assert(executor.container.getResources.getLimits.size() === 1) + assert(executor.container.getResources + .getLimits.get("memory").getAmount === "1408Mi") + + // The pod has no node selector, volumes. + assert(executor.pod.getSpec.getNodeSelector.isEmpty) + assert(executor.pod.getSpec.getVolumes.isEmpty) + + checkEnv(executor, Map()) + checkOwnerReferences(executor.pod, DRIVER_POD_UID) + } + + test("executor pod hostnames get truncated to 63 characters") { + val conf = baseConf.clone() + val longPodNamePrefix = "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple" + + val step = new BasicExecutorFeatureStep( + KubernetesConf( + conf, + KubernetesExecutorSpecificConf("1", DRIVER_POD), + longPodNamePrefix, + APP_ID, + LABELS, + ANNOTATIONS, + Map.empty, + Map.empty)) + assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63) + } + + test("classpath and extra java options get translated into environment variables") { + val conf = baseConf.clone() + conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") + conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") + + val step = new BasicExecutorFeatureStep( + KubernetesConf( + conf, + KubernetesExecutorSpecificConf("1", DRIVER_POD), + RESOURCE_NAME_PREFIX, + APP_ID, + LABELS, + ANNOTATIONS, + Map.empty, + Map("qux" -> "quux"))) + val executor = step.configurePod(SparkPod.initialPod()) + + checkEnv(executor, + Map("SPARK_JAVA_OPT_0" -> "foo=bar", + ENV_CLASSPATH -> "bar=baz", + "qux" -> "quux")) + checkOwnerReferences(executor.pod, DRIVER_POD_UID) + } + + // There is always exactly one controller reference, and it points to the driver pod. + private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { + assert(executor.getMetadata.getOwnerReferences.size() === 1) + assert(executor.getMetadata.getOwnerReferences.get(0).getUid === driverPodUid) + assert(executor.getMetadata.getOwnerReferences.get(0).getController === true) + } + + // Check that the expected environment variables are present. + private def checkEnv(executorPod: SparkPod, additionalEnvVars: Map[String, String]): Unit = { + val defaultEnvs = Map( + ENV_EXECUTOR_ID -> "1", + ENV_DRIVER_URL -> DRIVER_ADDRESS.toString, + ENV_EXECUTOR_CORES -> "1", + ENV_EXECUTOR_MEMORY -> "1g", + ENV_APPLICATION_ID -> APP_ID, + ENV_SPARK_CONF_DIR -> SPARK_CONF_DIR_INTERNAL, + ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars + + assert(executorPod.container.getEnv.size() === defaultEnvs.size) + val mapEnvs = executorPod.container.getEnv.asScala.map { + x => (x.getName, x.getValue) + }.toMap + assert(defaultEnvs === mapEnvs) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala similarity index 67% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala index 64553d25883bb..9f817d3bfc79a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala @@ -14,34 +14,35 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit.steps +package org.apache.spark.deploy.k8s.features import java.io.File -import scala.collection.JavaConverters._ - import com.google.common.base.Charsets import com.google.common.io.{BaseEncoding, Files} import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, Secret} +import org.mockito.{Mock, MockitoAnnotations} import org.scalatest.BeforeAndAfter +import scala.collection.JavaConverters._ import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec import org.apache.spark.util.Utils -class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndAfter { +class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { private val KUBERNETES_RESOURCE_NAME_PREFIX = "spark" + private val APP_ID = "k8s-app" private var credentialsTempDirectory: File = _ - private val BASE_DRIVER_SPEC = new KubernetesDriverSpec( - driverPod = new PodBuilder().build(), - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) + private val BASE_DRIVER_POD = SparkPod.initialPod() + + @Mock + private var driverSpecificConf: KubernetesDriverSpecificConf = _ before { + MockitoAnnotations.initMocks(this) credentialsTempDirectory = Utils.createTempDir() } @@ -50,13 +51,19 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA } test("Don't set any credentials") { - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - new SparkConf(false), KUBERNETES_RESOURCE_NAME_PREFIX) - val preparedDriverSpec = kubernetesCredentialsStep.configureDriver(BASE_DRIVER_SPEC) - assert(preparedDriverSpec.driverPod === BASE_DRIVER_SPEC.driverPod) - assert(preparedDriverSpec.driverContainer === BASE_DRIVER_SPEC.driverContainer) - assert(preparedDriverSpec.otherKubernetesResources.isEmpty) - assert(preparedDriverSpec.driverSparkConf.getAll.isEmpty) + val kubernetesConf = KubernetesConf( + new SparkConf(false), + driverSpecificConf, + KUBERNETES_RESOURCE_NAME_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) + val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) + assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) + assert(kubernetesCredentialsStep.getAdditionalPodSystemProperties().isEmpty) + assert(kubernetesCredentialsStep.getAdditionalKubernetesResources().isEmpty) } test("Only set credentials that are manually mounted.") { @@ -73,14 +80,23 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA .set( s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", "/mnt/secrets/my-ca.pem") + val kubernetesConf = KubernetesConf( + submissionSparkConf, + driverSpecificConf, + KUBERNETES_RESOURCE_NAME_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - submissionSparkConf, KUBERNETES_RESOURCE_NAME_PREFIX) - val preparedDriverSpec = kubernetesCredentialsStep.configureDriver(BASE_DRIVER_SPEC) - assert(preparedDriverSpec.driverPod === BASE_DRIVER_SPEC.driverPod) - assert(preparedDriverSpec.driverContainer === BASE_DRIVER_SPEC.driverContainer) - assert(preparedDriverSpec.otherKubernetesResources.isEmpty) - assert(preparedDriverSpec.driverSparkConf.getAll.toMap === submissionSparkConf.getAll.toMap) + val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) + assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) + assert(kubernetesCredentialsStep.getAdditionalKubernetesResources().isEmpty) + val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() + resolvedProperties.foreach { case (propKey, propValue) => + assert(submissionSparkConf.get(propKey) === propValue) + } } test("Mount credentials from the submission client as a secret.") { @@ -100,10 +116,17 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA .set( s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", caCertFile.getAbsolutePath) - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - submissionSparkConf, KUBERNETES_RESOURCE_NAME_PREFIX) - val preparedDriverSpec = kubernetesCredentialsStep.configureDriver( - BASE_DRIVER_SPEC.copy(driverSparkConf = submissionSparkConf)) + val kubernetesConf = KubernetesConf( + submissionSparkConf, + driverSpecificConf, + KUBERNETES_RESOURCE_NAME_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) + val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) + val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() val expectedSparkConf = Map( s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX" -> "", s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX" -> @@ -113,16 +136,13 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" -> DRIVER_CREDENTIALS_CLIENT_CERT_PATH, s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" -> - DRIVER_CREDENTIALS_CA_CERT_PATH, - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX" -> - clientKeyFile.getAbsolutePath, - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" -> - clientCertFile.getAbsolutePath, - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" -> - caCertFile.getAbsolutePath) - assert(preparedDriverSpec.driverSparkConf.getAll.toMap === expectedSparkConf) - assert(preparedDriverSpec.otherKubernetesResources.size === 1) - val credentialsSecret = preparedDriverSpec.otherKubernetesResources.head.asInstanceOf[Secret] + DRIVER_CREDENTIALS_CA_CERT_PATH) + assert(resolvedProperties === expectedSparkConf) + assert(kubernetesCredentialsStep.getAdditionalKubernetesResources().size === 1) + val credentialsSecret = kubernetesCredentialsStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Secret] assert(credentialsSecret.getMetadata.getName === s"$KUBERNETES_RESOURCE_NAME_PREFIX-kubernetes-credentials") val decodedSecretData = credentialsSecret.getData.asScala.map { data => @@ -134,12 +154,13 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME -> "key", DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME -> "cert") assert(decodedSecretData === expectedSecretData) - val driverPodVolumes = preparedDriverSpec.driverPod.getSpec.getVolumes.asScala + val driverPod = kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) + val driverPodVolumes = driverPod.pod.getSpec.getVolumes.asScala assert(driverPodVolumes.size === 1) assert(driverPodVolumes.head.getName === DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) assert(driverPodVolumes.head.getSecret != null) assert(driverPodVolumes.head.getSecret.getSecretName === credentialsSecret.getMetadata.getName) - val driverContainerVolumeMount = preparedDriverSpec.driverContainer.getVolumeMounts.asScala + val driverContainerVolumeMount = driverPod.container.getVolumeMounts.asScala assert(driverContainerVolumeMount.size === 1) assert(driverContainerVolumeMount.head.getName === DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) assert(driverContainerVolumeMount.head.getMountPath === DRIVER_CREDENTIALS_SECRETS_BASE_DIR) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala new file mode 100644 index 0000000000000..c299d56865ec0 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.Service +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Mockito.when +import org.scalatest.BeforeAndAfter +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.util.Clock + +class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { + + private val SHORT_RESOURCE_NAME_PREFIX = + "a" * (DriverServiceFeatureStep.MAX_SERVICE_NAME_LENGTH - + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX.length) + + private val LONG_RESOURCE_NAME_PREFIX = + "a" * (DriverServiceFeatureStep.MAX_SERVICE_NAME_LENGTH - + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX.length + 1) + private val DRIVER_LABELS = Map( + "label1key" -> "label1value", + "label2key" -> "label2value") + + @Mock + private var clock: Clock = _ + + private var sparkConf: SparkConf = _ + + before { + MockitoAnnotations.initMocks(this) + sparkConf = new SparkConf(false) + } + + test("Headless service has a port for the driver RPC and the block manager.") { + sparkConf = sparkConf + .set("spark.driver.port", "9000") + .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080) + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + SHORT_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty)) + assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod()) + assert(configurationStep.getAdditionalKubernetesResources().size === 1) + assert(configurationStep.getAdditionalKubernetesResources().head.isInstanceOf[Service]) + val driverService = configurationStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Service] + verifyService( + 9000, + 8080, + s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}", + driverService) + } + + test("Hostname and ports are set according to the service name.") { + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf + .set("spark.driver.port", "9000") + .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080) + .set(KUBERNETES_NAMESPACE, "my-namespace"), + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + SHORT_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty)) + val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX + val expectedHostName = s"$expectedServiceName.my-namespace.svc" + val additionalProps = configurationStep.getAdditionalPodSystemProperties() + verifySparkConfHostNames(additionalProps, expectedHostName) + } + + test("Ports should resolve to defaults in SparkConf and in the service.") { + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + SHORT_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty)) + val resolvedService = configurationStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Service] + verifyService( + DEFAULT_DRIVER_PORT, + DEFAULT_BLOCKMANAGER_PORT, + s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}", + resolvedService) + val additionalProps = configurationStep.getAdditionalPodSystemProperties() + assert(additionalProps("spark.driver.port") === DEFAULT_DRIVER_PORT.toString) + assert(additionalProps(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key) + === DEFAULT_BLOCKMANAGER_PORT.toString) + } + + test("Long prefixes should switch to using a generated name.") { + when(clock.getTimeMillis()).thenReturn(10000) + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf.set(KUBERNETES_NAMESPACE, "my-namespace"), + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + LONG_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty), + clock) + val driverService = configurationStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Service] + val expectedServiceName = s"spark-10000${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}" + assert(driverService.getMetadata.getName === expectedServiceName) + val expectedHostName = s"$expectedServiceName.my-namespace.svc" + val additionalProps = configurationStep.getAdditionalPodSystemProperties() + verifySparkConfHostNames(additionalProps, expectedHostName) + } + + test("Disallow bind address and driver host to be set explicitly.") { + try { + new DriverServiceFeatureStep( + KubernetesConf( + sparkConf.set(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS, "host"), + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + LONG_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty), + clock) + fail("The driver bind address should not be allowed.") + } catch { + case e: Throwable => + assert(e.getMessage === + s"requirement failed: ${DriverServiceFeatureStep.DRIVER_BIND_ADDRESS_KEY} is" + + " not supported in Kubernetes mode, as the driver's bind address is managed" + + " and set to the driver pod's IP address.") + } + sparkConf.remove(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS) + sparkConf.set(org.apache.spark.internal.config.DRIVER_HOST_ADDRESS, "host") + try { + new DriverServiceFeatureStep( + KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + LONG_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty), + clock) + fail("The driver host address should not be allowed.") + } catch { + case e: Throwable => + assert(e.getMessage === + s"requirement failed: ${DriverServiceFeatureStep.DRIVER_HOST_KEY} is" + + " not supported in Kubernetes mode, as the driver's hostname will be managed via" + + " a Kubernetes service.") + } + } + + private def verifyService( + driverPort: Int, + blockManagerPort: Int, + expectedServiceName: String, + service: Service): Unit = { + assert(service.getMetadata.getName === expectedServiceName) + assert(service.getSpec.getClusterIP === "None") + assert(service.getSpec.getSelector.asScala === DRIVER_LABELS) + assert(service.getSpec.getPorts.size() === 2) + val driverServicePorts = service.getSpec.getPorts.asScala + assert(driverServicePorts.head.getName === DRIVER_PORT_NAME) + assert(driverServicePorts.head.getPort.intValue() === driverPort) + assert(driverServicePorts.head.getTargetPort.getIntVal === driverPort) + assert(driverServicePorts(1).getName === BLOCK_MANAGER_PORT_NAME) + assert(driverServicePorts(1).getPort.intValue() === blockManagerPort) + assert(driverServicePorts(1).getTargetPort.getIntVal === blockManagerPort) + } + + private def verifySparkConfHostNames( + driverSparkConf: Map[String, String], expectedHostName: String): Unit = { + assert(driverSparkConf( + org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key) === expectedHostName) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala new file mode 100644 index 0000000000000..27bff74ce38af --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.{HasMetadata, PodBuilder, SecretBuilder} +import org.mockito.Matchers +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import org.apache.spark.deploy.k8s.SparkPod + +object KubernetesFeaturesTestUtils { + + def getMockConfigStepForStepType[T <: KubernetesFeatureConfigStep]( + stepType: String, stepClass: Class[T]): T = { + val mockStep = mock(stepClass) + when(mockStep.getAdditionalKubernetesResources()).thenReturn( + getSecretsForStepType(stepType)) + + when(mockStep.getAdditionalPodSystemProperties()) + .thenReturn(Map(stepType -> stepType)) + when(mockStep.configurePod(Matchers.any(classOf[SparkPod]))) + .thenAnswer(new Answer[SparkPod]() { + override def answer(invocation: InvocationOnMock): SparkPod = { + val originalPod = invocation.getArgumentAt(0, classOf[SparkPod]) + val configuredPod = new PodBuilder(originalPod.pod) + .editOrNewMetadata() + .addToLabels(stepType, stepType) + .endMetadata() + .build() + SparkPod(configuredPod, originalPod.container) + } + }) + mockStep + } + + def getSecretsForStepType[T <: KubernetesFeatureConfigStep](stepType: String) + : Seq[HasMetadata] = { + Seq(new SecretBuilder() + .withNewMetadata() + .withName(stepType) + .endMetadata() + .build()) + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala similarity index 64% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala index 960d0bda1d011..9d02f56cc206d 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala @@ -14,29 +14,38 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit.steps +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.PodBuilder import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SecretVolumeUtils, SparkPod} -class DriverMountSecretsStepSuite extends SparkFunSuite { +class MountSecretsFeatureStepSuite extends SparkFunSuite { private val SECRET_FOO = "foo" private val SECRET_BAR = "bar" private val SECRET_MOUNT_PATH = "/etc/secrets/driver" test("mounts all given secrets") { - val baseDriverSpec = KubernetesDriverSpec.initialSpec(new SparkConf(false)) + val baseDriverPod = SparkPod.initialPod() val secretNamesToMountPaths = Map( SECRET_FOO -> SECRET_MOUNT_PATH, SECRET_BAR -> SECRET_MOUNT_PATH) + val sparkConf = new SparkConf(false) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesExecutorSpecificConf("1", new PodBuilder().build()), + "resource-name-prefix", + "app-id", + Map.empty, + Map.empty, + secretNamesToMountPaths, + Map.empty) - val mountSecretsBootstrap = new MountSecretsBootstrap(secretNamesToMountPaths) - val mountSecretsStep = new DriverMountSecretsStep(mountSecretsBootstrap) - val configuredDriverSpec = mountSecretsStep.configureDriver(baseDriverSpec) - val driverPodWithSecretsMounted = configuredDriverSpec.driverPod - val driverContainerWithSecretsMounted = configuredDriverSpec.driverContainer + val step = new MountSecretsFeatureStep(kubernetesConf) + val driverPodWithSecretsMounted = step.configurePod(baseDriverPod).pod + val driverContainerWithSecretsMounted = step.configurePod(baseDriverPod).container Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach { volumeName => assert(SecretVolumeUtils.podHasVolume(driverPodWithSecretsMounted, volumeName)) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index 6a501592f42a3..c1b203e03a357 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -16,22 +16,17 @@ */ package org.apache.spark.deploy.k8s.submit -import scala.collection.JavaConverters._ - -import com.google.common.collect.Iterables import io.fabric8.kubernetes.api.model._ import io.fabric8.kubernetes.client.{KubernetesClient, Watch} import io.fabric8.kubernetes.client.dsl.{MixedOperation, NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable, PodResource} import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} import org.mockito.Mockito.{doReturn, verify, when} -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfter import org.scalatest.mockito.MockitoSugar._ import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.steps.DriverConfigurationStep class ClientSuite extends SparkFunSuite with BeforeAndAfter { @@ -39,6 +34,74 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { private val DRIVER_POD_API_VERSION = "v1" private val DRIVER_POD_KIND = "pod" private val KUBERNETES_RESOURCE_PREFIX = "resource-example" + private val POD_NAME = "driver" + private val CONTAINER_NAME = "container" + private val APP_ID = "app-id" + private val APP_NAME = "app" + private val MAIN_CLASS = "main" + private val APP_ARGS = Seq("arg1", "arg2") + private val RESOLVED_JAVA_OPTIONS = Map( + "conf1key" -> "conf1value", + "conf2key" -> "conf2value") + private val BUILT_DRIVER_POD = + new PodBuilder() + .withNewMetadata() + .withName(POD_NAME) + .endMetadata() + .withNewSpec() + .withHostname("localhost") + .endSpec() + .build() + private val BUILT_DRIVER_CONTAINER = new ContainerBuilder().withName(CONTAINER_NAME).build() + private val ADDITIONAL_RESOURCES = Seq( + new SecretBuilder().withNewMetadata().withName("secret").endMetadata().build()) + + private val BUILT_KUBERNETES_SPEC = KubernetesDriverSpec( + SparkPod(BUILT_DRIVER_POD, BUILT_DRIVER_CONTAINER), + ADDITIONAL_RESOURCES, + RESOLVED_JAVA_OPTIONS) + + private val FULL_EXPECTED_CONTAINER = new ContainerBuilder(BUILT_DRIVER_CONTAINER) + .addNewEnv() + .withName(ENV_SPARK_CONF_DIR) + .withValue(SPARK_CONF_DIR_INTERNAL) + .endEnv() + .addNewVolumeMount() + .withName(SPARK_CONF_VOLUME) + .withMountPath(SPARK_CONF_DIR_INTERNAL) + .endVolumeMount() + .build() + private val FULL_EXPECTED_POD = new PodBuilder(BUILT_DRIVER_POD) + .editSpec() + .addToContainers(FULL_EXPECTED_CONTAINER) + .addNewVolume() + .withName(SPARK_CONF_VOLUME) + .withNewConfigMap().withName(s"$KUBERNETES_RESOURCE_PREFIX-driver-conf-map").endConfigMap() + .endVolume() + .endSpec() + .build() + + private val POD_WITH_OWNER_REFERENCE = new PodBuilder(FULL_EXPECTED_POD) + .editMetadata() + .withUid(DRIVER_POD_UID) + .endMetadata() + .withApiVersion(DRIVER_POD_API_VERSION) + .withKind(DRIVER_POD_KIND) + .build() + + private val ADDITIONAL_RESOURCES_WITH_OWNER_REFERENCES = ADDITIONAL_RESOURCES.map { secret => + new SecretBuilder(secret) + .editMetadata() + .addNewOwnerReference() + .withName(POD_NAME) + .withApiVersion(DRIVER_POD_API_VERSION) + .withKind(DRIVER_POD_KIND) + .withController(true) + .withUid(DRIVER_POD_UID) + .endOwnerReference() + .endMetadata() + .build() + } private type ResourceList = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ HasMetadata, Boolean] @@ -56,113 +119,86 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { @Mock private var loggingPodStatusWatcher: LoggingPodStatusWatcher = _ + @Mock + private var driverBuilder: KubernetesDriverBuilder = _ + @Mock private var resourceList: ResourceList = _ - private val submissionSteps = Seq(FirstTestConfigurationStep, SecondTestConfigurationStep) + private var kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf] = _ + + private var sparkConf: SparkConf = _ private var createdPodArgumentCaptor: ArgumentCaptor[Pod] = _ private var createdResourcesArgumentCaptor: ArgumentCaptor[HasMetadata] = _ - private var createdContainerArgumentCaptor: ArgumentCaptor[Container] = _ before { MockitoAnnotations.initMocks(this) + sparkConf = new SparkConf(false) + kubernetesConf = KubernetesConf[KubernetesDriverSpecificConf]( + sparkConf, + KubernetesDriverSpecificConf(None, MAIN_CLASS, APP_NAME, APP_ARGS), + KUBERNETES_RESOURCE_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) + when(driverBuilder.buildFromFeatures(kubernetesConf)).thenReturn(BUILT_KUBERNETES_SPEC) when(kubernetesClient.pods()).thenReturn(podOperations) - when(podOperations.withName(FirstTestConfigurationStep.podName)).thenReturn(namedPods) + when(podOperations.withName(POD_NAME)).thenReturn(namedPods) createdPodArgumentCaptor = ArgumentCaptor.forClass(classOf[Pod]) createdResourcesArgumentCaptor = ArgumentCaptor.forClass(classOf[HasMetadata]) - when(podOperations.create(createdPodArgumentCaptor.capture())).thenAnswer(new Answer[Pod] { - override def answer(invocation: InvocationOnMock): Pod = { - new PodBuilder(invocation.getArgumentAt(0, classOf[Pod])) - .editMetadata() - .withUid(DRIVER_POD_UID) - .endMetadata() - .withApiVersion(DRIVER_POD_API_VERSION) - .withKind(DRIVER_POD_KIND) - .build() - } - }) - when(podOperations.withName(FirstTestConfigurationStep.podName)).thenReturn(namedPods) + when(podOperations.create(FULL_EXPECTED_POD)).thenReturn(POD_WITH_OWNER_REFERENCE) when(namedPods.watch(loggingPodStatusWatcher)).thenReturn(mock[Watch]) doReturn(resourceList) .when(kubernetesClient) .resourceList(createdResourcesArgumentCaptor.capture()) } - test("The client should configure the pod with the submission steps.") { + test("The client should configure the pod using the builder.") { val submissionClient = new Client( - submissionSteps, - new SparkConf(false), + driverBuilder, + kubernetesConf, kubernetesClient, false, "spark", loggingPodStatusWatcher, KUBERNETES_RESOURCE_PREFIX) submissionClient.run() - val createdPod = createdPodArgumentCaptor.getValue - assert(createdPod.getMetadata.getName === FirstTestConfigurationStep.podName) - assert(createdPod.getMetadata.getLabels.asScala === - Map(FirstTestConfigurationStep.labelKey -> FirstTestConfigurationStep.labelValue)) - assert(createdPod.getMetadata.getAnnotations.asScala === - Map(SecondTestConfigurationStep.annotationKey -> - SecondTestConfigurationStep.annotationValue)) - assert(createdPod.getSpec.getContainers.size() === 1) - assert(createdPod.getSpec.getContainers.get(0).getName === - SecondTestConfigurationStep.containerName) + verify(podOperations).create(FULL_EXPECTED_POD) } test("The client should create Kubernetes resources") { - val EXAMPLE_JAVA_OPTS = "-XX:+HeapDumpOnOutOfMemoryError -XX:+PrintGCDetails" - val EXPECTED_JAVA_OPTS = "-XX\\:+HeapDumpOnOutOfMemoryError -XX\\:+PrintGCDetails" val submissionClient = new Client( - submissionSteps, - new SparkConf(false) - .set(org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS, EXAMPLE_JAVA_OPTS), + driverBuilder, + kubernetesConf, kubernetesClient, false, "spark", loggingPodStatusWatcher, KUBERNETES_RESOURCE_PREFIX) submissionClient.run() - val createdPod = createdPodArgumentCaptor.getValue val otherCreatedResources = createdResourcesArgumentCaptor.getAllValues assert(otherCreatedResources.size === 2) - val secrets = otherCreatedResources.toArray - .filter(_.isInstanceOf[Secret]).map(_.asInstanceOf[Secret]) + val secrets = otherCreatedResources.toArray.filter(_.isInstanceOf[Secret]).toSeq + assert(secrets === ADDITIONAL_RESOURCES_WITH_OWNER_REFERENCES) val configMaps = otherCreatedResources.toArray .filter(_.isInstanceOf[ConfigMap]).map(_.asInstanceOf[ConfigMap]) assert(secrets.nonEmpty) - val secret = secrets.head - assert(secret.getMetadata.getName === FirstTestConfigurationStep.secretName) - assert(secret.getData.asScala === - Map(FirstTestConfigurationStep.secretKey -> FirstTestConfigurationStep.secretData)) - val ownerReference = Iterables.getOnlyElement(secret.getMetadata.getOwnerReferences) - assert(ownerReference.getName === createdPod.getMetadata.getName) - assert(ownerReference.getKind === DRIVER_POD_KIND) - assert(ownerReference.getUid === DRIVER_POD_UID) - assert(ownerReference.getApiVersion === DRIVER_POD_API_VERSION) assert(configMaps.nonEmpty) val configMap = configMaps.head assert(configMap.getMetadata.getName === s"$KUBERNETES_RESOURCE_PREFIX-driver-conf-map") assert(configMap.getData.containsKey(SPARK_CONF_FILE_NAME)) - assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains(EXPECTED_JAVA_OPTS)) - assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains( - "spark.custom-conf=custom-conf-value")) - val driverContainer = Iterables.getOnlyElement(createdPod.getSpec.getContainers) - assert(driverContainer.getName === SecondTestConfigurationStep.containerName) - val driverEnv = driverContainer.getEnv.asScala.head - assert(driverEnv.getName === ENV_SPARK_CONF_DIR) - assert(driverEnv.getValue === SPARK_CONF_DIR_INTERNAL) - val driverMount = driverContainer.getVolumeMounts.asScala.head - assert(driverMount.getName === SPARK_CONF_VOLUME) - assert(driverMount.getMountPath === SPARK_CONF_DIR_INTERNAL) + assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains("conf1key=conf1value")) + assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains("conf2key=conf2value")) } test("Waiting for app completion should stall on the watcher") { val submissionClient = new Client( - submissionSteps, - new SparkConf(false), + driverBuilder, + kubernetesConf, kubernetesClient, true, "spark", @@ -171,56 +207,4 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { submissionClient.run() verify(loggingPodStatusWatcher).awaitCompletion() } - -} - -private object FirstTestConfigurationStep extends DriverConfigurationStep { - - val podName = "test-pod" - val secretName = "test-secret" - val labelKey = "first-submit" - val labelValue = "true" - val secretKey = "secretKey" - val secretData = "secretData" - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val modifiedPod = new PodBuilder(driverSpec.driverPod) - .editMetadata() - .withName(podName) - .addToLabels(labelKey, labelValue) - .endMetadata() - .build() - val additionalResource = new SecretBuilder() - .withNewMetadata() - .withName(secretName) - .endMetadata() - .addToData(secretKey, secretData) - .build() - driverSpec.copy( - driverPod = modifiedPod, - otherKubernetesResources = driverSpec.otherKubernetesResources ++ Seq(additionalResource)) - } -} - -private object SecondTestConfigurationStep extends DriverConfigurationStep { - val annotationKey = "second-submit" - val annotationValue = "submitted" - val sparkConfKey = "spark.custom-conf" - val sparkConfValue = "custom-conf-value" - val containerName = "driverContainer" - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val modifiedPod = new PodBuilder(driverSpec.driverPod) - .editMetadata() - .addToAnnotations(annotationKey, annotationValue) - .endMetadata() - .build() - val resolvedSparkConf = driverSpec.driverSparkConf.clone().set(sparkConfKey, sparkConfValue) - val modifiedContainer = new ContainerBuilder(driverSpec.driverContainer) - .withName(containerName) - .build() - driverSpec.copy( - driverPod = modifiedPod, - driverSparkConf = resolvedSparkConf, - driverContainer = modifiedContainer) - } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala deleted file mode 100644 index df34d2dbcb5be..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit - -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.submit.steps._ - -class DriverConfigOrchestratorSuite extends SparkFunSuite { - - private val DRIVER_IMAGE = "driver-image" - private val IC_IMAGE = "init-container-image" - private val APP_ID = "spark-app-id" - private val KUBERNETES_RESOURCE_PREFIX = "example-prefix" - private val APP_NAME = "spark" - private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" - private val APP_ARGS = Array("arg1", "arg2") - private val SECRET_FOO = "foo" - private val SECRET_BAR = "bar" - private val SECRET_MOUNT_PATH = "/etc/secrets/driver" - - test("Base submission steps with a main app resource.") { - val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE) - val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Some(mainAppResource), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep], - classOf[DependencyResolutionStep]) - } - - test("Base submission steps without a main app resource.") { - val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE) - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Option.empty, - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep]) - } - - test("Submission steps with driver secrets to mount") { - val sparkConf = new SparkConf(false) - .set(CONTAINER_IMAGE, DRIVER_IMAGE) - .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) - .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) - val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Some(mainAppResource), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep], - classOf[DependencyResolutionStep], - classOf[DriverMountSecretsStep]) - } - - test("Submission using client local dependencies") { - val sparkConf = new SparkConf(false) - .set(CONTAINER_IMAGE, DRIVER_IMAGE) - var orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Some(JavaMainAppResource("file:///var/apps/jars/main.jar")), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - assertThrows[SparkException] { - orchestrator.getAllConfigurationSteps - } - - sparkConf.set("spark.files", "/path/to/file1,/path/to/file2") - orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Some(JavaMainAppResource("local:///var/apps/jars/main.jar")), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - assertThrows[SparkException] { - orchestrator.getAllConfigurationSteps - } - } - - private def validateStepTypes( - orchestrator: DriverConfigOrchestrator, - types: Class[_ <: DriverConfigurationStep]*): Unit = { - val steps = orchestrator.getAllConfigurationSteps - assert(steps.size === types.size) - assert(steps.map(_.getClass) === types) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala new file mode 100644 index 0000000000000..161f9afe7bba9 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep} + +class KubernetesDriverBuilderSuite extends SparkFunSuite { + + private val BASIC_STEP_TYPE = "basic" + private val CREDENTIALS_STEP_TYPE = "credentials" + private val SERVICE_STEP_TYPE = "service" + private val SECRETS_STEP_TYPE = "mount-secrets" + + private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + BASIC_STEP_TYPE, classOf[BasicDriverFeatureStep]) + + private val credentialsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + CREDENTIALS_STEP_TYPE, classOf[DriverKubernetesCredentialsFeatureStep]) + + private val serviceStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + SERVICE_STEP_TYPE, classOf[DriverServiceFeatureStep]) + + private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + + private val builderUnderTest: KubernetesDriverBuilder = + new KubernetesDriverBuilder( + _ => basicFeatureStep, + _ => credentialsStep, + _ => serviceStep, + _ => secretsStep) + + test("Apply fundamental steps all the time.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + None, + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE) + } + + test("Apply secrets step if secrets are present.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + None, + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map("secret" -> "secretMountPath"), + Map.empty) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + SECRETS_STEP_TYPE) + } + + private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) + : Unit = { + assert(resolvedSpec.systemProperties.size === stepTypes.size) + stepTypes.foreach { stepType => + assert(resolvedSpec.pod.pod.getMetadata.getLabels.get(stepType) === stepType) + assert(resolvedSpec.driverKubernetesResources.containsSlice( + KubernetesFeaturesTestUtils.getSecretsForStepType(stepType))) + assert(resolvedSpec.systemProperties(stepType) === stepType) + } + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala deleted file mode 100644 index ee450fff8d376..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -class BasicDriverConfigurationStepSuite extends SparkFunSuite { - - private val APP_ID = "spark-app-id" - private val RESOURCE_NAME_PREFIX = "spark" - private val DRIVER_LABELS = Map("labelkey" -> "labelvalue") - private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" - private val APP_NAME = "spark-test" - private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" - private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") - private val CUSTOM_ANNOTATION_KEY = "customAnnotation" - private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" - private val DRIVER_CUSTOM_ENV_KEY1 = "customDriverEnv1" - private val DRIVER_CUSTOM_ENV_KEY2 = "customDriverEnv2" - - test("Set all possible configurations from the user.") { - val sparkConf = new SparkConf() - .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod") - .set(org.apache.spark.internal.config.DRIVER_CLASS_PATH, "/opt/spark/spark-examples.jar") - .set("spark.driver.cores", "2") - .set(KUBERNETES_DRIVER_LIMIT_CORES, "4") - .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "256M") - .set(org.apache.spark.internal.config.DRIVER_MEMORY_OVERHEAD, 200L) - .set(CONTAINER_IMAGE, "spark-driver:latest") - .set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$CUSTOM_ANNOTATION_KEY", CUSTOM_ANNOTATION_VALUE) - .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY1", "customDriverEnv1") - .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY2", "customDriverEnv2") - .set(IMAGE_PULL_SECRETS, "imagePullSecret1, imagePullSecret2") - - val submissionStep = new BasicDriverConfigurationStep( - APP_ID, - RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - CONTAINER_IMAGE_PULL_POLICY, - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - val basePod = new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build() - val baseDriverSpec = KubernetesDriverSpec( - driverPod = basePod, - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) - val preparedDriverSpec = submissionStep.configureDriver(baseDriverSpec) - - assert(preparedDriverSpec.driverContainer.getName === DRIVER_CONTAINER_NAME) - assert(preparedDriverSpec.driverContainer.getImage === "spark-driver:latest") - assert(preparedDriverSpec.driverContainer.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY) - - assert(preparedDriverSpec.driverContainer.getEnv.size === 4) - val envs = preparedDriverSpec.driverContainer - .getEnv - .asScala - .map(env => (env.getName, env.getValue)) - .toMap - assert(envs(ENV_CLASSPATH) === "/opt/spark/spark-examples.jar") - assert(envs(DRIVER_CUSTOM_ENV_KEY1) === "customDriverEnv1") - assert(envs(DRIVER_CUSTOM_ENV_KEY2) === "customDriverEnv2") - - assert(preparedDriverSpec.driverContainer.getEnv.asScala.exists(envVar => - envVar.getName.equals(ENV_DRIVER_BIND_ADDRESS) && - envVar.getValueFrom.getFieldRef.getApiVersion.equals("v1") && - envVar.getValueFrom.getFieldRef.getFieldPath.equals("status.podIP"))) - - val resourceRequirements = preparedDriverSpec.driverContainer.getResources - val requests = resourceRequirements.getRequests.asScala - assert(requests("cpu").getAmount === "2") - assert(requests("memory").getAmount === "456Mi") - val limits = resourceRequirements.getLimits.asScala - assert(limits("memory").getAmount === "456Mi") - assert(limits("cpu").getAmount === "4") - - val driverPodMetadata = preparedDriverSpec.driverPod.getMetadata - assert(driverPodMetadata.getName === "spark-driver-pod") - assert(driverPodMetadata.getLabels.asScala === DRIVER_LABELS) - val expectedAnnotations = Map( - CUSTOM_ANNOTATION_KEY -> CUSTOM_ANNOTATION_VALUE, - SPARK_APP_NAME_ANNOTATION -> APP_NAME) - assert(driverPodMetadata.getAnnotations.asScala === expectedAnnotations) - - val driverPodSpec = preparedDriverSpec.driverPod.getSpec - assert(driverPodSpec.getRestartPolicy === "Never") - assert(driverPodSpec.getImagePullSecrets.size() === 2) - assert(driverPodSpec.getImagePullSecrets.get(0).getName === "imagePullSecret1") - assert(driverPodSpec.getImagePullSecrets.get(1).getName === "imagePullSecret2") - - val resolvedSparkConf = preparedDriverSpec.driverSparkConf.getAll.toMap - val expectedSparkConf = Map( - KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", - "spark.app.id" -> APP_ID, - KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, - "spark.kubernetes.submitInDriver" -> "true") - assert(resolvedSparkConf === expectedSparkConf) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala deleted file mode 100644 index ca43fc97dc991..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.File - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -class DependencyResolutionStepSuite extends SparkFunSuite { - - private val SPARK_JARS = Seq( - "apps/jars/jar1.jar", - "local:///var/apps/jars/jar2.jar") - - private val SPARK_FILES = Seq( - "apps/files/file1.txt", - "local:///var/apps/files/file2.txt") - - test("Added dependencies should be resolved in Spark configuration and environment") { - val dependencyResolutionStep = new DependencyResolutionStep( - SPARK_JARS, - SPARK_FILES) - val driverPod = new PodBuilder().build() - val baseDriverSpec = KubernetesDriverSpec( - driverPod = driverPod, - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) - val preparedDriverSpec = dependencyResolutionStep.configureDriver(baseDriverSpec) - assert(preparedDriverSpec.driverPod === driverPod) - assert(preparedDriverSpec.otherKubernetesResources.isEmpty) - val resolvedSparkJars = preparedDriverSpec.driverSparkConf.get("spark.jars").split(",").toSet - val expectedResolvedSparkJars = Set( - "apps/jars/jar1.jar", - "/var/apps/jars/jar2.jar") - assert(resolvedSparkJars === expectedResolvedSparkJars) - val resolvedSparkFiles = preparedDriverSpec.driverSparkConf.get("spark.files").split(",").toSet - val expectedResolvedSparkFiles = Set( - "apps/files/file1.txt", - "/var/apps/files/file2.txt") - assert(resolvedSparkFiles === expectedResolvedSparkFiles) - val driverEnv = preparedDriverSpec.driverContainer.getEnv.asScala - assert(driverEnv.size === 1) - assert(driverEnv.head.getName === ENV_MOUNTED_CLASSPATH) - val resolvedDriverClasspath = driverEnv.head.getValue.split(File.pathSeparator).toSet - val expectedResolvedDriverClasspath = expectedResolvedSparkJars - assert(resolvedDriverClasspath === expectedResolvedDriverClasspath) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala deleted file mode 100644 index 78c8c3ba1afbd..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.Service -import org.mockito.{Mock, MockitoAnnotations} -import org.mockito.Mockito.when -import org.scalatest.BeforeAndAfter - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.util.Clock - -class DriverServiceBootstrapStepSuite extends SparkFunSuite with BeforeAndAfter { - - private val SHORT_RESOURCE_NAME_PREFIX = - "a" * (DriverServiceBootstrapStep.MAX_SERVICE_NAME_LENGTH - - DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX.length) - - private val LONG_RESOURCE_NAME_PREFIX = - "a" * (DriverServiceBootstrapStep.MAX_SERVICE_NAME_LENGTH - - DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX.length + 1) - private val DRIVER_LABELS = Map( - "label1key" -> "label1value", - "label2key" -> "label2value") - - @Mock - private var clock: Clock = _ - - private var sparkConf: SparkConf = _ - - before { - MockitoAnnotations.initMocks(this) - sparkConf = new SparkConf(false) - } - - test("Headless service has a port for the driver RPC and the block manager.") { - val configurationStep = new DriverServiceBootstrapStep( - SHORT_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf - .set("spark.driver.port", "9000") - .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080), - clock) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - assert(resolvedDriverSpec.otherKubernetesResources.size === 1) - assert(resolvedDriverSpec.otherKubernetesResources.head.isInstanceOf[Service]) - val driverService = resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service] - verifyService( - 9000, - 8080, - s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}", - driverService) - } - - test("Hostname and ports are set according to the service name.") { - val configurationStep = new DriverServiceBootstrapStep( - SHORT_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf - .set("spark.driver.port", "9000") - .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080) - .set(KUBERNETES_NAMESPACE, "my-namespace"), - clock) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + - DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX - val expectedHostName = s"$expectedServiceName.my-namespace.svc" - verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName) - } - - test("Ports should resolve to defaults in SparkConf and in the service.") { - val configurationStep = new DriverServiceBootstrapStep( - SHORT_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf, - clock) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - verifyService( - DEFAULT_DRIVER_PORT, - DEFAULT_BLOCKMANAGER_PORT, - s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}", - resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service]) - assert(resolvedDriverSpec.driverSparkConf.get("spark.driver.port") === - DEFAULT_DRIVER_PORT.toString) - assert(resolvedDriverSpec.driverSparkConf.get( - org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT) === DEFAULT_BLOCKMANAGER_PORT) - } - - test("Long prefixes should switch to using a generated name.") { - val configurationStep = new DriverServiceBootstrapStep( - LONG_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf.set(KUBERNETES_NAMESPACE, "my-namespace"), - clock) - when(clock.getTimeMillis()).thenReturn(10000) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - val driverService = resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service] - val expectedServiceName = s"spark-10000${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}" - assert(driverService.getMetadata.getName === expectedServiceName) - val expectedHostName = s"$expectedServiceName.my-namespace.svc" - verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName) - } - - test("Disallow bind address and driver host to be set explicitly.") { - val configurationStep = new DriverServiceBootstrapStep( - LONG_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf.set(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS, "host"), - clock) - try { - configurationStep.configureDriver(KubernetesDriverSpec.initialSpec(sparkConf)) - fail("The driver bind address should not be allowed.") - } catch { - case e: Throwable => - assert(e.getMessage === - s"requirement failed: ${DriverServiceBootstrapStep.DRIVER_BIND_ADDRESS_KEY} is" + - " not supported in Kubernetes mode, as the driver's bind address is managed" + - " and set to the driver pod's IP address.") - } - sparkConf.remove(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS) - sparkConf.set(org.apache.spark.internal.config.DRIVER_HOST_ADDRESS, "host") - try { - configurationStep.configureDriver(KubernetesDriverSpec.initialSpec(sparkConf)) - fail("The driver host address should not be allowed.") - } catch { - case e: Throwable => - assert(e.getMessage === - s"requirement failed: ${DriverServiceBootstrapStep.DRIVER_HOST_KEY} is" + - " not supported in Kubernetes mode, as the driver's hostname will be managed via" + - " a Kubernetes service.") - } - } - - private def verifyService( - driverPort: Int, - blockManagerPort: Int, - expectedServiceName: String, - service: Service): Unit = { - assert(service.getMetadata.getName === expectedServiceName) - assert(service.getSpec.getClusterIP === "None") - assert(service.getSpec.getSelector.asScala === DRIVER_LABELS) - assert(service.getSpec.getPorts.size() === 2) - val driverServicePorts = service.getSpec.getPorts.asScala - assert(driverServicePorts.head.getName === DRIVER_PORT_NAME) - assert(driverServicePorts.head.getPort.intValue() === driverPort) - assert(driverServicePorts.head.getTargetPort.getIntVal === driverPort) - assert(driverServicePorts(1).getName === BLOCK_MANAGER_PORT_NAME) - assert(driverServicePorts(1).getPort.intValue() === blockManagerPort) - assert(driverServicePorts(1).getTargetPort.getIntVal === blockManagerPort) - } - - private def verifySparkConfHostNames( - driverSparkConf: SparkConf, expectedHostName: String): Unit = { - assert(driverSparkConf.get( - org.apache.spark.internal.config.DRIVER_HOST_ADDRESS) === expectedHostName) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala deleted file mode 100644 index d73df20f0f956..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.scheduler.cluster.k8s - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model._ -import org.mockito.MockitoAnnotations -import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.MountSecretsBootstrap - -class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { - - private val driverPodName: String = "driver-pod" - private val driverPodUid: String = "driver-uid" - private val executorPrefix: String = "base" - private val executorImage: String = "executor-image" - private val imagePullSecrets: String = "imagePullSecret1, imagePullSecret2" - private val driverPod = new PodBuilder() - .withNewMetadata() - .withName(driverPodName) - .withUid(driverPodUid) - .endMetadata() - .withNewSpec() - .withNodeName("some-node") - .endSpec() - .withNewStatus() - .withHostIP("192.168.99.100") - .endStatus() - .build() - private var baseConf: SparkConf = _ - - before { - MockitoAnnotations.initMocks(this) - baseConf = new SparkConf() - .set(KUBERNETES_DRIVER_POD_NAME, driverPodName) - .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) - .set(CONTAINER_IMAGE, executorImage) - .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) - .set(IMAGE_PULL_SECRETS, imagePullSecrets) - } - - test("basic executor pod has reasonable defaults") { - val factory = new ExecutorPodFactory(baseConf, None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - // The executor pod name and default labels. - assert(executor.getMetadata.getName === s"$executorPrefix-exec-1") - assert(executor.getMetadata.getLabels.size() === 3) - assert(executor.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) === "1") - - // There is exactly 1 container with no volume mounts and default memory limits and requests. - // Default memory limit/request is 1024M + 384M (minimum overhead constant). - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getImage === executorImage) - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.isEmpty) - assert(executor.getSpec.getContainers.get(0).getResources.getLimits.size() === 1) - assert(executor.getSpec.getContainers.get(0).getResources - .getRequests.get("memory").getAmount === "1408Mi") - assert(executor.getSpec.getContainers.get(0).getResources - .getLimits.get("memory").getAmount === "1408Mi") - assert(executor.getSpec.getImagePullSecrets.size() === 2) - assert(executor.getSpec.getImagePullSecrets.get(0).getName === "imagePullSecret1") - assert(executor.getSpec.getImagePullSecrets.get(1).getName === "imagePullSecret2") - - // The pod has no node selector, volumes. - assert(executor.getSpec.getNodeSelector.isEmpty) - assert(executor.getSpec.getVolumes.isEmpty) - - checkEnv(executor, Map()) - checkOwnerReferences(executor, driverPodUid) - } - - test("executor core request specification") { - var factory = new ExecutorPodFactory(baseConf, None) - var executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount - === "1") - - val conf = baseConf.clone() - - conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "0.1") - factory = new ExecutorPodFactory(conf, None) - executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount - === "0.1") - - conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "100m") - factory = new ExecutorPodFactory(conf, None) - conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "100m") - executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount - === "100m") - } - - test("executor pod hostnames get truncated to 63 characters") { - val conf = baseConf.clone() - conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, - "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple") - - val factory = new ExecutorPodFactory(conf, None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - assert(executor.getSpec.getHostname.length === 63) - } - - test("classpath and extra java options get translated into environment variables") { - val conf = baseConf.clone() - conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") - conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") - - val factory = new ExecutorPodFactory(conf, None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)]("qux" -> "quux"), driverPod, Map[String, Int]()) - - checkEnv(executor, - Map("SPARK_JAVA_OPT_0" -> "foo=bar", - ENV_CLASSPATH -> "bar=baz", - "qux" -> "quux")) - checkOwnerReferences(executor, driverPodUid) - } - - test("executor secrets get mounted") { - val conf = baseConf.clone() - - val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1")) - val factory = new ExecutorPodFactory(conf, Some(secretsBootstrap)) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.size() === 1) - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.get(0).getName - === "secret1-volume") - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.get(0) - .getMountPath === "/var/secret1") - - // check volume mounted. - assert(executor.getSpec.getVolumes.size() === 1) - assert(executor.getSpec.getVolumes.get(0).getSecret.getSecretName === "secret1") - - checkOwnerReferences(executor, driverPodUid) - } - - // There is always exactly one controller reference, and it points to the driver pod. - private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { - assert(executor.getMetadata.getOwnerReferences.size() === 1) - assert(executor.getMetadata.getOwnerReferences.get(0).getUid === driverPodUid) - assert(executor.getMetadata.getOwnerReferences.get(0).getController === true) - } - - // Check that the expected environment variables are present. - private def checkEnv(executor: Pod, additionalEnvVars: Map[String, String]): Unit = { - val defaultEnvs = Map( - ENV_EXECUTOR_ID -> "1", - ENV_DRIVER_URL -> "dummy", - ENV_EXECUTOR_CORES -> "1", - ENV_EXECUTOR_MEMORY -> "1g", - ENV_APPLICATION_ID -> "dummy", - ENV_SPARK_CONF_DIR -> SPARK_CONF_DIR_INTERNAL, - ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars - - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getEnv.size() === defaultEnvs.size) - val mapEnvs = executor.getSpec.getContainers.get(0).getEnv.asScala.map { - x => (x.getName, x.getValue) - }.toMap - assert(defaultEnvs === mapEnvs) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index b2f26f205a329..96065e83f069c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.scheduler.cluster.k8s import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit} -import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder, PodList} +import io.fabric8.kubernetes.api.model.{ContainerBuilder, DoneablePod, Pod, PodBuilder, PodList} import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} import io.fabric8.kubernetes.client.Watcher.Action import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NonNamespaceOperation, PodResource} -import org.mockito.{AdditionalAnswers, ArgumentCaptor, Mock, MockitoAnnotations} +import org.hamcrest.{BaseMatcher, Description, Matcher} +import org.mockito.{AdditionalAnswers, ArgumentCaptor, Matchers, Mock, MockitoAnnotations} import org.mockito.Matchers.{any, eq => mockitoEq} import org.mockito.Mockito.{doNothing, never, times, verify, when} import org.scalatest.BeforeAndAfter @@ -31,6 +32,7 @@ import scala.collection.JavaConverters._ import scala.concurrent.Future import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.rpc._ @@ -47,8 +49,6 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private val SPARK_DRIVER_HOST = "localhost" private val SPARK_DRIVER_PORT = 7077 private val POD_ALLOCATION_INTERVAL = "1m" - private val DRIVER_URL = RpcEndpointAddress( - SPARK_DRIVER_HOST, SPARK_DRIVER_PORT, CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString private val FIRST_EXECUTOR_POD = new PodBuilder() .withNewMetadata() .withName("pod1") @@ -94,7 +94,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private var requestExecutorsService: ExecutorService = _ @Mock - private var executorPodFactory: ExecutorPodFactory = _ + private var executorBuilder: KubernetesExecutorBuilder = _ @Mock private var kubernetesClient: KubernetesClient = _ @@ -399,7 +399,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn new KubernetesClusterSchedulerBackend( taskSchedulerImpl, rpcEnv, - executorPodFactory, + executorBuilder, kubernetesClient, allocatorExecutor, requestExecutorsService) { @@ -428,13 +428,22 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) .endMetadata() .build() - when(executorPodFactory.createExecutorPod( - executorId.toString, - APP_ID, - DRIVER_URL, - sparkConf.getExecutorEnv, - driverPod, - Map.empty)).thenReturn(resolvedPod) - resolvedPod + val resolvedContainer = new ContainerBuilder().build() + when(executorBuilder.buildFromFeatures(Matchers.argThat( + new BaseMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { + override def matches(argument: scala.Any) + : Boolean = { + argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] && + argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] + .roleSpecificConf.executorId == executorId.toString + } + + override def describeTo(description: Description): Unit = {} + }))).thenReturn(SparkPod(resolvedPod, resolvedContainer)) + new PodBuilder(resolvedPod) + .editSpec() + .addToContainers(resolvedContainer) + .endSpec() + .build() } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala new file mode 100644 index 0000000000000..f5270623f8acc --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.PodBuilder + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep} + +class KubernetesExecutorBuilderSuite extends SparkFunSuite { + private val BASIC_STEP_TYPE = "basic" + private val SECRETS_STEP_TYPE = "mount-secrets" + + private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep]) + private val mountSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + + private val builderUnderTest = new KubernetesExecutorBuilder( + _ => basicFeatureStep, + _ => mountSecretsStep) + + test("Basic steps are consistently applied.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesExecutorSpecificConf( + "executor-id", new PodBuilder().build()), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty) + validateStepTypesApplied(builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE) + } + + test("Apply secrets step if secrets are present.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesExecutorSpecificConf( + "executor-id", new PodBuilder().build()), + "prefix", + "appId", + Map.empty, + Map.empty, + Map("secret" -> "secretMountPath"), + Map.empty) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + SECRETS_STEP_TYPE) + } + + private def validateStepTypesApplied(resolvedPod: SparkPod, stepTypes: String*): Unit = { + assert(resolvedPod.pod.getMetadata.getLabels.size === stepTypes.size) + stepTypes.foreach { stepType => + assert(resolvedPod.pod.getMetadata.getLabels.get(stepType) === stepType) + } + } +} From 4dfd746de3f4346ed0c2191f8523a7e6cc9f064d Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sat, 14 Apr 2018 00:22:38 +0800 Subject: [PATCH 617/774] [SPARK-23896][SQL] Improve PartitioningAwareFileIndex ## What changes were proposed in this pull request? Currently `PartitioningAwareFileIndex` accepts an optional parameter `userPartitionSchema`. If provided, it will combine the inferred partition schema with the parameter. However, 1. to get `userPartitionSchema`, we need to combine inferred partition schema with `userSpecifiedSchema` 2. to get the inferred partition schema, we have to create a temporary file index. Only after that, a final version of `PartitioningAwareFileIndex` can be created. This can be improved by passing `userSpecifiedSchema` to `PartitioningAwareFileIndex`. With the improvement, we can reduce redundant code and avoid parsing the file partition twice. ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21004 from gengliangwang/PartitioningAwareFileIndex. --- .../datasources/CatalogFileIndex.scala | 2 +- .../execution/datasources/DataSource.scala | 133 ++++++++---------- .../datasources/InMemoryFileIndex.scala | 8 +- .../PartitioningAwareFileIndex.scala | 54 ++++--- .../streaming/MetadataLogFileIndex.scala | 10 +- .../datasources/FileSourceStrategySuite.scala | 2 +- .../hive/PartitionedTablePerfStatsSuite.scala | 2 +- 7 files changed, 103 insertions(+), 108 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index 4046396d0e614..a66a07673e25f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -85,7 +85,7 @@ class CatalogFileIndex( sparkSession, new Path(baseLocation.get), fileStatusCache, partitionSpec, Option(timeNs)) } else { new InMemoryFileIndex( - sparkSession, rootPaths, table.storage.properties, partitionSchema = None) + sparkSession, rootPaths, table.storage.properties, userSpecifiedSchema = None) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b84ea769808f9..f16d824201e77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -23,7 +23,6 @@ import scala.collection.JavaConverters._ import scala.language.{existentials, implicitConversions} import scala.util.{Failure, Success, Try} -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil @@ -103,24 +102,6 @@ case class DataSource( bucket.sortColumnNames, "in the sort definition", equality) } - /** - * In the read path, only managed tables by Hive provide the partition columns properly when - * initializing this class. All other file based data sources will try to infer the partitioning, - * and then cast the inferred types to user specified dataTypes if the partition columns exist - * inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510, or - * inconsistent data types as reported in SPARK-21463. - * @param fileIndex A FileIndex that will perform partition inference - * @return The PartitionSchema resolved from inference and cast according to `userSpecifiedSchema` - */ - private def combineInferredAndUserSpecifiedPartitionSchema(fileIndex: FileIndex): StructType = { - val resolved = fileIndex.partitionSchema.map { partitionField => - // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred - userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( - partitionField) - } - StructType(resolved) - } - /** * Get the schema of the given FileFormat, if provided by `userSpecifiedSchema`, or try to infer * it. In the read path, only managed tables by Hive provide the partition columns properly when @@ -140,31 +121,26 @@ case class DataSource( * be any further inference in any triggers. * * @param format the file format object for this DataSource - * @param fileStatusCache the shared cache for file statuses to speed up listing + * @param fileIndex optional [[InMemoryFileIndex]] for getting partition schema and file list * @return A pair of the data schema (excluding partition columns) and the schema of the partition * columns. */ private def getOrInferFileFormatSchema( format: FileFormat, - fileStatusCache: FileStatusCache = NoopCache): (StructType, StructType) = { - // the operations below are expensive therefore try not to do them if we don't need to, e.g., + fileIndex: Option[InMemoryFileIndex] = None): (StructType, StructType) = { + // The operations below are expensive therefore try not to do them if we don't need to, e.g., // in streaming mode, we have already inferred and registered partition columns, we will // never have to materialize the lazy val below - lazy val tempFileIndex = { - val allPaths = caseInsensitiveOptions.get("path") ++ paths - val hadoopConf = sparkSession.sessionState.newHadoopConf() - val globbedPaths = allPaths.toSeq.flatMap { path => - val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(hadoopConf) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(fs, qualified) - }.toArray - new InMemoryFileIndex(sparkSession, globbedPaths, options, None, fileStatusCache) + lazy val tempFileIndex = fileIndex.getOrElse { + val globbedPaths = + checkAndGlobPathIfNecessary(checkEmptyGlobPath = false, checkFilesExist = false) + createInMemoryFileIndex(globbedPaths) } + val partitionSchema = if (partitionColumns.isEmpty) { // Try to infer partitioning, because no DataSource in the read path provides the partitioning // columns properly unless it is a Hive DataSource - combineInferredAndUserSpecifiedPartitionSchema(tempFileIndex) + tempFileIndex.partitionSchema } else { // maintain old behavior before SPARK-18510. If userSpecifiedSchema is empty used inferred // partitioning @@ -356,13 +332,7 @@ case class DataSource( caseInsensitiveOptions.get("path").toSeq ++ paths, sparkSession.sessionState.newHadoopConf()) => val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) - val tempFileCatalog = new MetadataLogFileIndex(sparkSession, basePath, None) - val fileCatalog = if (userSpecifiedSchema.nonEmpty) { - val partitionSchema = combineInferredAndUserSpecifiedPartitionSchema(tempFileCatalog) - new MetadataLogFileIndex(sparkSession, basePath, Option(partitionSchema)) - } else { - tempFileCatalog - } + val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath, userSpecifiedSchema) val dataSchema = userSpecifiedSchema.orElse { format.inferSchema( sparkSession, @@ -384,24 +354,23 @@ case class DataSource( // This is a non-streaming file based datasource. case (format: FileFormat, _) => - val allPaths = caseInsensitiveOptions.get("path") ++ paths - val hadoopConf = sparkSession.sessionState.newHadoopConf() - val globbedPaths = allPaths.flatMap( - DataSource.checkAndGlobPathIfNecessary(hadoopConf, _, checkFilesExist)).toArray - - val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) - val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format, fileStatusCache) - - val fileCatalog = if (sparkSession.sqlContext.conf.manageFilesourcePartitions && - catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog) { + val globbedPaths = + checkAndGlobPathIfNecessary(checkEmptyGlobPath = true, checkFilesExist = checkFilesExist) + val useCatalogFileIndex = sparkSession.sqlContext.conf.manageFilesourcePartitions && + catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog && + catalogTable.get.partitionColumnNames.nonEmpty + val (fileCatalog, dataSchema, partitionSchema) = if (useCatalogFileIndex) { val defaultTableSize = sparkSession.sessionState.conf.defaultSizeInBytes - new CatalogFileIndex( + val index = new CatalogFileIndex( sparkSession, catalogTable.get, catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(defaultTableSize)) + (index, catalogTable.get.dataSchema, catalogTable.get.partitionSchema) } else { - new InMemoryFileIndex( - sparkSession, globbedPaths, options, Some(partitionSchema), fileStatusCache) + val index = createInMemoryFileIndex(globbedPaths) + val (resultDataSchema, resultPartitionSchema) = + getOrInferFileFormatSchema(format, Some(index)) + (index, resultDataSchema, resultPartitionSchema) } HadoopFsRelation( @@ -552,6 +521,40 @@ case class DataSource( sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") } } + + /** Returns an [[InMemoryFileIndex]] that can be used to get partition schema and file list. */ + private def createInMemoryFileIndex(globbedPaths: Seq[Path]): InMemoryFileIndex = { + val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) + new InMemoryFileIndex( + sparkSession, globbedPaths, options, userSpecifiedSchema, fileStatusCache) + } + + /** + * Checks and returns files in all the paths. + */ + private def checkAndGlobPathIfNecessary( + checkEmptyGlobPath: Boolean, + checkFilesExist: Boolean): Seq[Path] = { + val allPaths = caseInsensitiveOptions.get("path") ++ paths + val hadoopConf = sparkSession.sessionState.newHadoopConf() + allPaths.flatMap { path => + val hdfsPath = new Path(path) + val fs = hdfsPath.getFileSystem(hadoopConf) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + val globPath = SparkHadoopUtil.get.globPathIfNecessary(fs, qualified) + + if (checkEmptyGlobPath && globPath.isEmpty) { + throw new AnalysisException(s"Path does not exist: $qualified") + } + + // Sufficient to check head of the globPath seq for non-glob scenario + // Don't need to check once again if files exist in streaming mode + if (checkFilesExist && !fs.exists(globPath.head)) { + throw new AnalysisException(s"Path does not exist: ${globPath.head}") + } + globPath + }.toSeq + } } object DataSource extends Logging { @@ -699,30 +702,6 @@ object DataSource extends Logging { locationUri = path.map(CatalogUtils.stringToURI), properties = optionsWithoutPath) } - /** - * If `path` is a file pattern, return all the files that match it. Otherwise, return itself. - * If `checkFilesExist` is `true`, also check the file existence. - */ - private def checkAndGlobPathIfNecessary( - hadoopConf: Configuration, - path: String, - checkFilesExist: Boolean): Seq[Path] = { - val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(hadoopConf) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - val globPath = SparkHadoopUtil.get.globPathIfNecessary(fs, qualified) - - if (globPath.isEmpty) { - throw new AnalysisException(s"Path does not exist: $qualified") - } - // Sufficient to check head of the globPath seq for non-glob scenario - // Don't need to check once again if files exist in streaming mode - if (checkFilesExist && !fs.exists(globPath.head)) { - throw new AnalysisException(s"Path does not exist: ${globPath.head}") - } - globPath - } - /** * Called before writing into a FileFormat based data source to make sure the * supplied schema is not empty. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 318ada0ceefc5..739d1f456e3ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -41,17 +41,17 @@ import org.apache.spark.util.SerializableConfiguration * @param rootPathsSpecified the list of root table paths to scan (some of which might be * filtered out later) * @param parameters as set of options to control discovery - * @param partitionSchema an optional partition schema that will be use to provide types for the - * discovered partitions + * @param userSpecifiedSchema an optional user specified schema that will be use to provide + * types for the discovered partitions */ class InMemoryFileIndex( sparkSession: SparkSession, rootPathsSpecified: Seq[Path], parameters: Map[String, String], - partitionSchema: Option[StructType], + userSpecifiedSchema: Option[StructType], fileStatusCache: FileStatusCache = NoopCache) extends PartitioningAwareFileIndex( - sparkSession, parameters, partitionSchema, fileStatusCache) { + sparkSession, parameters, userSpecifiedSchema, fileStatusCache) { // Filter out streaming metadata dirs or files such as "/.../_spark_metadata" (the metadata dir) // or "/.../_spark_metadata/0" (a file in the metadata dir). `rootPathsSpecified` might contain diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 6b6f6388d54e8..cc8af7b92c454 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -34,13 +34,13 @@ import org.apache.spark.sql.types.{StringType, StructType} * It provides the necessary methods to parse partition data based on a set of files. * * @param parameters as set of options to control partition discovery - * @param userPartitionSchema an optional partition schema that will be use to provide types for - * the discovered partitions + * @param userSpecifiedSchema an optional user specified schema that will be use to provide + * types for the discovered partitions */ abstract class PartitioningAwareFileIndex( sparkSession: SparkSession, parameters: Map[String, String], - userPartitionSchema: Option[StructType], + userSpecifiedSchema: Option[StructType], fileStatusCache: FileStatusCache = NoopCache) extends FileIndex with Logging { import PartitioningAwareFileIndex.BASE_PATH_PARAM @@ -126,35 +126,32 @@ abstract class PartitioningAwareFileIndex( val caseInsensitiveOptions = CaseInsensitiveMap(parameters) val timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) - - userPartitionSchema match { + val inferredPartitionSpec = PartitioningUtils.parsePartitions( + leafDirs, + typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, + basePaths = basePaths, + timeZoneId = timeZoneId) + userSpecifiedSchema match { case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => - val spec = PartitioningUtils.parsePartitions( - leafDirs, - typeInference = false, - basePaths = basePaths, - timeZoneId = timeZoneId) + val userPartitionSchema = + combineInferredAndUserSpecifiedPartitionSchema(inferredPartitionSpec) - // Without auto inference, all of value in the `row` should be null or in StringType, // we need to cast into the data type that user specified. def castPartitionValuesToUserSchema(row: InternalRow) = { InternalRow((0 until row.numFields).map { i => + val dt = inferredPartitionSpec.partitionColumns.fields(i).dataType Cast( - Literal.create(row.getUTF8String(i), StringType), - userProvidedSchema.fields(i).dataType, + Literal.create(row.get(i, dt), dt), + userPartitionSchema.fields(i).dataType, Option(timeZoneId)).eval() }: _*) } - PartitionSpec(userProvidedSchema, spec.partitions.map { part => + PartitionSpec(userPartitionSchema, inferredPartitionSpec.partitions.map { part => part.copy(values = castPartitionValuesToUserSchema(part.values)) }) case _ => - PartitioningUtils.parsePartitions( - leafDirs, - typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, - basePaths = basePaths, - timeZoneId = timeZoneId) + inferredPartitionSpec } } @@ -236,6 +233,25 @@ abstract class PartitioningAwareFileIndex( val name = path.getName !((name.startsWith("_") && !name.contains("=")) || name.startsWith(".")) } + + /** + * In the read path, only managed tables by Hive provide the partition columns properly when + * initializing this class. All other file based data sources will try to infer the partitioning, + * and then cast the inferred types to user specified dataTypes if the partition columns exist + * inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510, or + * inconsistent data types as reported in SPARK-21463. + * @param spec A partition inference result + * @return The PartitionSchema resolved from inference and cast according to `userSpecifiedSchema` + */ + private def combineInferredAndUserSpecifiedPartitionSchema(spec: PartitionSpec): StructType = { + val equality = sparkSession.sessionState.conf.resolver + val resolved = spec.partitionColumns.map { partitionField => + // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred + userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( + partitionField) + } + StructType(resolved) + } } object PartitioningAwareFileIndex { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala index 1da703cefd8ea..5cacdd070b735 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala @@ -30,14 +30,14 @@ import org.apache.spark.sql.types.StructType * A [[FileIndex]] that generates the list of files to processing by reading them from the * metadata log files generated by the [[FileStreamSink]]. * - * @param userPartitionSchema an optional partition schema that will be use to provide types for - * the discovered partitions + * @param userSpecifiedSchema an optional user specified schema that will be use to provide + * types for the discovered partitions */ class MetadataLogFileIndex( sparkSession: SparkSession, path: Path, - userPartitionSchema: Option[StructType]) - extends PartitioningAwareFileIndex(sparkSession, Map.empty, userPartitionSchema) { + userSpecifiedSchema: Option[StructType]) + extends PartitioningAwareFileIndex(sparkSession, Map.empty, userSpecifiedSchema) { private val metadataDirectory = new Path(path, FileStreamSink.metadataDir) logInfo(s"Reading streaming file log from $metadataDirectory") @@ -51,7 +51,7 @@ class MetadataLogFileIndex( } override protected val leafDirToChildrenFiles: Map[Path, Array[FileStatus]] = { - allFilesFromLog.toArray.groupBy(_.getPath.getParent) + allFilesFromLog.groupBy(_.getPath.getParent) } override def rootPaths: Seq[Path] = path :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index c1d61b843d899..8764f0c42cf9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -401,7 +401,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi sparkSession = spark, rootPathsSpecified = Seq(new Path(tempDir)), parameters = Map.empty[String, String], - partitionSchema = None) + userSpecifiedSchema = None) // This should not fail. fileCatalog.listLeafFiles(Seq(new Path(tempDir))) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala index 1a86c604d5da3..3af163af0968c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala @@ -419,7 +419,7 @@ class PartitionedTablePerfStatsSuite HiveCatalogMetrics.reset() spark.read.load(dir.getAbsolutePath) assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 1) - assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 1) + assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 0) } } } From 25892f3cc9dcb938220be8020a5b9a17c92dbdbe Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 14 Apr 2018 01:01:00 +0800 Subject: [PATCH 618/774] [SPARK-23375][SQL] Eliminate unneeded Sort in Optimizer ## What changes were proposed in this pull request? Added a new rule to remove Sort operation when its child is already sorted. For instance, this simple code: ``` spark.sparkContext.parallelize(Seq(("a", "b"))).toDF("a", "b").registerTempTable("table1") val df = sql(s"""SELECT b | FROM ( | SELECT a, b | FROM table1 | ORDER BY a | ) t | ORDER BY a""".stripMargin) df.explain(true) ``` before the PR produces this plan: ``` == Parsed Logical Plan == 'Sort ['a ASC NULLS FIRST], true +- 'Project ['b] +- 'SubqueryAlias t +- 'Sort ['a ASC NULLS FIRST], true +- 'Project ['a, 'b] +- 'UnresolvedRelation `table1` == Analyzed Logical Plan == b: string Project [b#7] +- Sort [a#6 ASC NULLS FIRST], true +- Project [b#7, a#6] +- SubqueryAlias t +- Sort [a#6 ASC NULLS FIRST], true +- Project [a#6, b#7] +- SubqueryAlias table1 +- Project [_1#3 AS a#6, _2#4 AS b#7] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false) AS _2#4] +- ExternalRDD [obj#2] == Optimized Logical Plan == Project [b#7] +- Sort [a#6 ASC NULLS FIRST], true +- Project [b#7, a#6] +- Sort [a#6 ASC NULLS FIRST], true +- Project [_1#3 AS a#6, _2#4 AS b#7] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._2, true, false) AS _2#4] +- ExternalRDD [obj#2] == Physical Plan == *(3) Project [b#7] +- *(3) Sort [a#6 ASC NULLS FIRST], true, 0 +- Exchange rangepartitioning(a#6 ASC NULLS FIRST, 200) +- *(2) Project [b#7, a#6] +- *(2) Sort [a#6 ASC NULLS FIRST], true, 0 +- Exchange rangepartitioning(a#6 ASC NULLS FIRST, 200) +- *(1) Project [_1#3 AS a#6, _2#4 AS b#7] +- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._2, true, false) AS _2#4] +- Scan ExternalRDDScan[obj#2] ``` while after the PR produces: ``` == Parsed Logical Plan == 'Sort ['a ASC NULLS FIRST], true +- 'Project ['b] +- 'SubqueryAlias t +- 'Sort ['a ASC NULLS FIRST], true +- 'Project ['a, 'b] +- 'UnresolvedRelation `table1` == Analyzed Logical Plan == b: string Project [b#7] +- Sort [a#6 ASC NULLS FIRST], true +- Project [b#7, a#6] +- SubqueryAlias t +- Sort [a#6 ASC NULLS FIRST], true +- Project [a#6, b#7] +- SubqueryAlias table1 +- Project [_1#3 AS a#6, _2#4 AS b#7] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false) AS _2#4] +- ExternalRDD [obj#2] == Optimized Logical Plan == Project [b#7] +- Sort [a#6 ASC NULLS FIRST], true +- Project [_1#3 AS a#6, _2#4 AS b#7] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._2, true, false) AS _2#4] +- ExternalRDD [obj#2] == Physical Plan == *(2) Project [b#7] +- *(2) Sort [a#6 ASC NULLS FIRST], true, 0 +- Exchange rangepartitioning(a#6 ASC NULLS FIRST, 5) +- *(1) Project [_1#3 AS a#6, _2#4 AS b#7] +- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._2, true, false) AS _2#4] +- Scan ExternalRDDScan[obj#2] ``` this means that an unnecessary sort operation is not performed after the PR. ## How was this patch tested? added UT Author: Marco Gaido Closes #20560 from mgaido91/SPARK-23375. --- .../sql/catalyst/optimizer/Optimizer.scala | 12 +++ .../catalyst/plans/logical/LogicalPlan.scala | 9 ++ .../plans/logical/basicLogicalOperators.scala | 23 ++-- .../optimizer/RemoveRedundantSortsSuite.scala | 101 ++++++++++++++++++ .../spark/sql/execution/CacheManager.scala | 4 +- .../spark/sql/execution/ExistingRDD.scala | 2 +- .../execution/columnar/InMemoryRelation.scala | 17 +-- .../spark/sql/ConfigBehaviorSuite.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 15 ++- .../columnar/InMemoryColumnarQuerySuite.scala | 14 +-- 10 files changed, 175 insertions(+), 24 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9a1bbc675e397..5fb59ef350b8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -138,6 +138,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) operatorOptimizationBatch) :+ Batch("Join Reorder", Once, CostBasedJoinReorder) :+ + Batch("Remove Redundant Sorts", Once, + RemoveRedundantSorts) :+ Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :+ Batch("Object Expressions Optimization", fixedPoint, @@ -733,6 +735,16 @@ object EliminateSorts extends Rule[LogicalPlan] { } } +/** + * Removes Sort operation if the child is already sorted + */ +object RemoveRedundantSorts extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) => + child + } +} + /** * Removes filters that can be evaluated trivially. This can be done through the following ways: * 1) by eliding the filter for cases where it will always evaluate to `true`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index c8ccd9bd03994..42034403d6d03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -219,6 +219,11 @@ abstract class LogicalPlan * Refreshes (or invalidates) any metadata/data cached in the plan recursively. */ def refresh(): Unit = children.foreach(_.refresh()) + + /** + * Returns the output ordering that this plan generates. + */ + def outputOrdering: Seq[SortOrder] = Nil } /** @@ -274,3 +279,7 @@ abstract class BinaryNode extends LogicalPlan { override final def children: Seq[LogicalPlan] = Seq(left, right) } + +abstract class OrderPreservingUnaryNode extends UnaryNode { + override final def outputOrdering: Seq[SortOrder] = child.outputOrdering +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index a4fca790dd086..10df504795430 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -43,11 +43,12 @@ case class ReturnAnswer(child: LogicalPlan) extends UnaryNode { * This node is inserted at the top of a subquery when it is optimized. This makes sure we can * recognize a subquery as such, and it allows us to write subquery aware transformations. */ -case class Subquery(child: LogicalPlan) extends UnaryNode { +case class Subquery(child: LogicalPlan) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = child.output } -case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { +case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) + extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows @@ -125,7 +126,7 @@ case class Generate( } case class Filter(condition: Expression, child: LogicalPlan) - extends UnaryNode with PredicateHelper { + extends OrderPreservingUnaryNode with PredicateHelper { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows @@ -469,6 +470,7 @@ case class Sort( child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows + override def outputOrdering: Seq[SortOrder] = order } /** Factory for constructing new `Range` nodes. */ @@ -522,6 +524,15 @@ case class Range( override def computeStats(): Statistics = { Statistics(sizeInBytes = LongType.defaultSize * numElements) } + + override def outputOrdering: Seq[SortOrder] = { + val order = if (step > 0) { + Ascending + } else { + Descending + } + output.map(a => SortOrder(a, order)) + } } case class Aggregate( @@ -728,7 +739,7 @@ object Limit { * * See [[Limit]] for more information. */ -case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { +case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = { limitExpr match { @@ -744,7 +755,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN * * See [[Limit]] for more information. */ -case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { +case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = child.output override def maxRowsPerPartition: Option[Long] = { @@ -764,7 +775,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo case class SubqueryAlias( alias: String, child: LogicalPlan) - extends UnaryNode { + extends OrderPreservingUnaryNode { override def doCanonicalize(): LogicalPlan = child.canonicalized diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala new file mode 100644 index 0000000000000..2319ab8046e56 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL} + +class RemoveRedundantSortsSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Remove Redundant Sorts", Once, + RemoveRedundantSorts) :: + Batch("Collapse Project", Once, + CollapseProject) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + test("remove redundant order by") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) + val unnecessaryReordered = orderedPlan.select('a).orderBy('a.asc, 'b.desc_nullsFirst) + val optimized = Optimize.execute(unnecessaryReordered.analyze) + val correctAnswer = orderedPlan.select('a).analyze + comparePlans(Optimize.execute(optimized), correctAnswer) + } + + test("do not remove sort if the order is different") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) + val reorderedDifferently = orderedPlan.select('a).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(reorderedDifferently.analyze) + val correctAnswer = reorderedDifferently.analyze + comparePlans(optimized, correctAnswer) + } + + test("filters don't affect order") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) + val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(filteredAndReordered.analyze) + val correctAnswer = orderedPlan.where('a > Literal(10)).analyze + comparePlans(optimized, correctAnswer) + } + + test("limits don't affect order") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) + val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(filteredAndReordered.analyze) + val correctAnswer = orderedPlan.limit(Literal(10)).analyze + comparePlans(optimized, correctAnswer) + } + + test("range is already sorted") { + val inputPlan = Range(1L, 1000L, 1, 10) + val orderedPlan = inputPlan.orderBy('id.asc) + val optimized = Optimize.execute(orderedPlan.analyze) + val correctAnswer = inputPlan.analyze + comparePlans(optimized, correctAnswer) + + val reversedPlan = inputPlan.orderBy('id.desc) + val reversedOptimized = Optimize.execute(reversedPlan.analyze) + val reversedCorrectAnswer = reversedPlan.analyze + comparePlans(reversedOptimized, reversedCorrectAnswer) + + val negativeStepInputPlan = Range(10L, 1L, -1, 10) + val negativeStepOrderedPlan = negativeStepInputPlan.orderBy('id.desc) + val negativeStepOptimized = Optimize.execute(negativeStepOrderedPlan.analyze) + val negativeStepCorrectAnswer = negativeStepInputPlan.analyze + comparePlans(negativeStepOptimized, negativeStepCorrectAnswer) + } + + test("sort should not be removed when there is a node which doesn't guarantee any order") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc) + val groupedAndResorted = orderedPlan.groupBy('a)(sum('a)).orderBy('a.asc) + val optimized = Optimize.execute(groupedAndResorted.analyze) + val correctAnswer = groupedAndResorted.analyze + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index d68aeb275afda..a8794be7280c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -99,7 +99,7 @@ class CacheManager extends Logging { sparkSession.sessionState.conf.columnBatchSize, storageLevel, sparkSession.sessionState.executePlan(planToCache).executedPlan, tableName, - planToCache.stats) + planToCache) cachedData.add(CachedData(planToCache, inMemoryRelation)) } } @@ -148,7 +148,7 @@ class CacheManager extends Logging { storageLevel = cd.cachedRepresentation.storageLevel, child = spark.sessionState.executePlan(cd.plan).executedPlan, tableName = cd.cachedRepresentation.tableName, - statsOfPlanToCache = cd.plan.stats) + logicalPlan = cd.plan) needToRecache += cd.copy(cachedRepresentation = newCache) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index f3555508185fe..be50a1571a2ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -125,7 +125,7 @@ case class LogicalRDD( output: Seq[Attribute], rdd: RDD[InternalRow], outputPartitioning: Partitioning = UnknownPartitioning(0), - outputOrdering: Seq[SortOrder] = Nil, + override val outputOrdering: Seq[SortOrder] = Nil, override val isStreaming: Boolean = false)(session: SparkSession) extends LeafNode with MultiInstanceRelation { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 2579046e30708..a7ba9b86a176f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, Statistics} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator @@ -39,9 +39,9 @@ object InMemoryRelation { storageLevel: StorageLevel, child: SparkPlan, tableName: Option[String], - statsOfPlanToCache: Statistics): InMemoryRelation = + logicalPlan: LogicalPlan): InMemoryRelation = new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)( - statsOfPlanToCache = statsOfPlanToCache) + statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) } @@ -64,7 +64,8 @@ case class InMemoryRelation( tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics) + statsOfPlanToCache: Statistics, + override val outputOrdering: Seq[SortOrder]) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[SparkPlan] = Seq(child) @@ -76,7 +77,8 @@ case class InMemoryRelation( tableName = None)( _cachedColumnBuffers, sizeInBytesStats, - statsOfPlanToCache) + statsOfPlanToCache, + outputOrdering) override def producedAttributes: AttributeSet = outputSet @@ -159,7 +161,7 @@ case class InMemoryRelation( def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { InMemoryRelation( newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache) + _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache, outputOrdering) } override def newInstance(): this.type = { @@ -172,7 +174,8 @@ case class InMemoryRelation( tableName)( _cachedColumnBuffers, sizeInBytesStats, - statsOfPlanToCache).asInstanceOf[this.type] + statsOfPlanToCache, + outputOrdering).asInstanceOf[this.type] } def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala index cee85ec8af04d..949505e449fd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala @@ -39,7 +39,7 @@ class ConfigBehaviorSuite extends QueryTest with SharedSQLContext { def computeChiSquareTest(): Double = { val n = 10000 // Trigger a sort - val data = spark.range(0, n, 1, 1).sort('id) + val data = spark.range(0, n, 1, 1).sort('id.desc) .selectExpr("SPARK_PARTITION_ID() pid", "id").as[(Int, Long)].collect() // Compute histogram for the number of records per partition post sort diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index f8b26f5b28cc7..40915a102bab0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition, Sort} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} @@ -197,6 +197,19 @@ class PlannerSuite extends SharedSQLContext { assert(planned.child.isInstanceOf[CollectLimitExec]) } + test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") { + val query = testData.select('key, 'value).sort('key.desc).cache() + assert(query.queryExecution.optimizedPlan.isInstanceOf[InMemoryRelation]) + val resorted = query.sort('key.desc) + assert(resorted.queryExecution.optimizedPlan.collect { case s: Sort => s}.isEmpty) + assert(resorted.select('key).collect().map(_.getInt(0)).toSeq == + (1 to 100).reverse) + // with a different order, the sort is needed + val sortedAsc = query.sort('key) + assert(sortedAsc.queryExecution.optimizedPlan.collect { case s: Sort => s}.size == 1) + assert(sortedAsc.select('key).collect().map(_.getInt(0)).toSeq == (1 to 100)) + } + test("PartitioningCollection") { withTempView("normal", "small", "tiny") { testData.createOrReplaceTempView("normal") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 26b63e8e8490f..9b7b316211d30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.{FilterExec, LocalTableScanExec, WholeStageCodegenExec} import org.apache.spark.sql.functions._ @@ -42,7 +43,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val storageLevel = MEMORY_ONLY val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None, - data.logicalPlan.stats) + data.logicalPlan) assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel) inMemoryRelation.cachedColumnBuffers.collect().head match { @@ -119,7 +120,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("simple columnar query") { val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, - testData.logicalPlan.stats) + testData.logicalPlan) checkAnswer(scan, testData.collect().toSeq) } @@ -138,7 +139,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val logicalPlan = testData.select('value, 'key).logicalPlan val plan = spark.sessionState.executePlan(logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, - logicalPlan.stats) + logicalPlan) checkAnswer(scan, testData.collect().map { case Row(key: Int, value: String) => value -> key @@ -155,7 +156,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, - testData.logicalPlan.stats) + testData.logicalPlan) checkAnswer(scan, testData.collect().toSeq) checkAnswer(scan, testData.collect().toSeq) @@ -329,7 +330,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-17549: cached table size should be correctly calculated") { val data = spark.sparkContext.parallelize(1 to 10, 5).toDF() val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan - val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None, data.logicalPlan.stats) + val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None, data.logicalPlan) // Materialize the data. val expectedAnswer = data.collect() @@ -455,7 +456,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-22249: buildFilter should not throw exception when In contains an empty list") { val attribute = AttributeReference("a", IntegerType)() val localTableScanExec = LocalTableScanExec(Seq(attribute), Nil) - val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None, null) + val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None, + LocalRelation(Seq(attribute), Nil)) val tableScanExec = InMemoryTableScanExec(Seq(attribute), Seq(In(attribute, Nil)), testRelation) assert(tableScanExec.partitionFilters.isEmpty) From 558f31b31c73b7e9f26f56498b54cf53997b59b8 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 13 Apr 2018 14:05:04 -0700 Subject: [PATCH 619/774] [SPARK-23963][SQL] Properly handle large number of columns in query on text-based Hive table ## What changes were proposed in this pull request? TableReader would get disproportionately slower as the number of columns in the query increased. I fixed the way TableReader was looking up metadata for each column in the row. Previously, it had been looking up this data in linked lists, accessing each linked list by an index (column number). Now it looks up this data in arrays, where indexing by column number works better. ## How was this patch tested? Manual testing All sbt unit tests python sql tests Author: Bruce Robbins Closes #21043 from bersprockets/tabreadfix. --- .../src/main/scala/org/apache/spark/sql/hive/TableReader.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index cc8907a0bbc93..b5444a4217924 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -381,7 +381,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) => soi.getStructFieldRef(attr.name) -> ordinal - }.unzip + }.toArray.unzip /** * Builds specific unwrappers ahead of time according to object inspector From cbb41a0c5b01579c85f06ef42cc0585fbef216c5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 13 Apr 2018 16:31:39 -0700 Subject: [PATCH 620/774] [SPARK-23966][SS] Refactoring all checkpoint file writing logic in a common CheckpointFileManager interface ## What changes were proposed in this pull request? Checkpoint files (offset log files, state store files) in Structured Streaming must be written atomically such that no partial files are generated (would break fault-tolerance guarantees). Currently, there are 3 locations which try to do this individually, and in some cases, incorrectly. 1. HDFSOffsetMetadataLog - This uses a FileManager interface to use any implementation of `FileSystem` or `FileContext` APIs. It preferably loads `FileContext` implementation as FileContext of HDFS has atomic renames. 1. HDFSBackedStateStore (aka in-memory state store) - Writing a version.delta file - This uses FileSystem APIs only to perform a rename. This is incorrect as rename is not atomic in HDFS FileSystem implementation. - Writing a snapshot file - Same as above. #### Current problems: 1. State Store behavior is incorrect - HDFS FileSystem implementation does not have atomic rename. 1. Inflexible - Some file systems provide mechanisms other than write-to-temp-file-and-rename for writing atomically and more efficiently. For example, with S3 you can write directly to the final file and it will be made visible only when the entire file is written and closed correctly. Any failure can be made to terminate the writing without making any partial files visible in S3. The current code does not abstract out this mechanism enough that it can be customized. #### Solution: 1. Introduce a common interface that all 3 cases above can use to write checkpoint files atomically. 2. This interface must provide the necessary interfaces that allow customization of the write-and-rename mechanism. This PR does that by introducing the interface `CheckpointFileManager` and modifying `HDFSMetadataLog` and `HDFSBackedStateStore` to use the interface. Similar to earlier `FileManager`, there are implementations based on `FileSystem` and `FileContext` APIs, and the latter implementation is preferred to make it work correctly with HDFS. The key method this interface has is `createAtomic(path, overwrite)` which returns a `CancellableFSDataOutputStream` that has the method `cancel()`. All users of this method need to either call `close()` to successfully write the file, or `cancel()` in case of an error. ## How was this patch tested? New tests in `CheckpointFileManagerSuite` and slightly modified existing tests. Author: Tathagata Das Closes #21048 from tdas/SPARK-23966. --- .../apache/spark/sql/internal/SQLConf.scala | 7 + .../streaming/CheckpointFileManager.scala | 349 ++++++++++++++++++ .../execution/streaming/HDFSMetadataLog.scala | 229 +----------- .../state/HDFSBackedStateStoreProvider.scala | 120 +++--- .../streaming/state/StateStore.scala | 4 +- .../CheckpointFileManagerSuite.scala | 192 ++++++++++ .../CompactibleFileStreamLogSuite.scala | 5 - .../streaming/HDFSMetadataLogSuite.scala | 116 +----- .../streaming/state/StateStoreSuite.scala | 58 ++- 9 files changed, 678 insertions(+), 402 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1c8ab9c62623e..0dc47bfe075d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -930,6 +930,13 @@ object SQLConf { .intConf .createWithDefault(100) + val STREAMING_CHECKPOINT_FILE_MANAGER_CLASS = + buildConf("spark.sql.streaming.checkpointFileManagerClass") + .doc("The class used to write checkpoint files atomically. This class must be a subclass " + + "of the interface CheckpointFileManager.") + .internal() + .stringConf + val NDV_MAX_ERROR = buildConf("spark.sql.statistics.ndv.maxError") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala new file mode 100644 index 0000000000000..606ba250ad9d2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.io.{FileNotFoundException, IOException, OutputStream} +import java.util.{EnumSet, UUID} + +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ +import org.apache.hadoop.fs.local.{LocalFs, RawLocalFs} +import org.apache.hadoop.fs.permission.FsPermission + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.CheckpointFileManager.RenameHelperMethods +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + +/** + * An interface to abstract out all operation related to streaming checkpoints. Most importantly, + * the key operation this interface provides is `createAtomic(path, overwrite)` which returns a + * `CancellableFSDataOutputStream`. This method is used by [[HDFSMetadataLog]] and + * [[org.apache.spark.sql.execution.streaming.state.StateStore StateStore]] implementations + * to write a complete checkpoint file atomically (i.e. no partial file will be visible), with or + * without overwrite. + * + * This higher-level interface above the Hadoop FileSystem is necessary because + * different implementation of FileSystem/FileContext may have different combination of operations + * to provide the desired atomic guarantees (e.g. write-to-temp-file-and-rename, + * direct-write-and-cancel-on-failure) and this abstraction allow different implementations while + * keeping the usage simple (`createAtomic` -> `close` or `cancel`). + */ +trait CheckpointFileManager { + + import org.apache.spark.sql.execution.streaming.CheckpointFileManager._ + + /** + * Create a file and make its contents available atomically after the output stream is closed. + * + * @param path Path to create + * @param overwriteIfPossible If true, then the implementations must do a best-effort attempt to + * overwrite the file if it already exists. It should not throw + * any exception if the file exists. However, if false, then the + * implementation must not overwrite if the file alraedy exists and + * must throw `FileAlreadyExistsException` in that case. + */ + def createAtomic(path: Path, overwriteIfPossible: Boolean): CancellableFSDataOutputStream + + /** Open a file for reading, or throw exception if it does not exist. */ + def open(path: Path): FSDataInputStream + + /** List the files in a path that match a filter. */ + def list(path: Path, filter: PathFilter): Array[FileStatus] + + /** List all the files in a path. */ + def list(path: Path): Array[FileStatus] = { + list(path, new PathFilter { override def accept(path: Path): Boolean = true }) + } + + /** Make directory at the give path and all its parent directories as needed. */ + def mkdirs(path: Path): Unit + + /** Whether path exists */ + def exists(path: Path): Boolean + + /** Recursively delete a path if it exists. Should not throw exception if file doesn't exist. */ + def delete(path: Path): Unit + + /** Is the default file system this implementation is operating on the local file system. */ + def isLocal: Boolean +} + +object CheckpointFileManager extends Logging { + + /** + * Additional methods in CheckpointFileManager implementations that allows + * [[RenameBasedFSDataOutputStream]] get atomicity by write-to-temp-file-and-rename + */ + sealed trait RenameHelperMethods { self => CheckpointFileManager + /** Create a file with overwrite. */ + def createTempFile(path: Path): FSDataOutputStream + + /** + * Rename a file. + * + * @param srcPath Source path to rename + * @param dstPath Destination path to rename to + * @param overwriteIfPossible If true, then the implementations must do a best-effort attempt to + * overwrite the file if it already exists. It should not throw + * any exception if the file exists. However, if false, then the + * implementation must not overwrite if the file alraedy exists and + * must throw `FileAlreadyExistsException` in that case. + */ + def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit + } + + /** + * An interface to add the cancel() operation to [[FSDataOutputStream]]. This is used + * mainly by `CheckpointFileManager.createAtomic` to write a file atomically. + * + * @see [[CheckpointFileManager]]. + */ + abstract class CancellableFSDataOutputStream(protected val underlyingStream: OutputStream) + extends FSDataOutputStream(underlyingStream, null) { + /** Cancel the `underlyingStream` and ensure that the output file is not generated. */ + def cancel(): Unit + } + + /** + * An implementation of [[CancellableFSDataOutputStream]] that writes a file atomically by writing + * to a temporary file and then renames. + */ + sealed class RenameBasedFSDataOutputStream( + fm: CheckpointFileManager with RenameHelperMethods, + finalPath: Path, + tempPath: Path, + overwriteIfPossible: Boolean) + extends CancellableFSDataOutputStream(fm.createTempFile(tempPath)) { + + def this(fm: CheckpointFileManager with RenameHelperMethods, path: Path, overwrite: Boolean) = { + this(fm, path, generateTempPath(path), overwrite) + } + + logInfo(s"Writing atomically to $finalPath using temp file $tempPath") + @volatile private var terminated = false + + override def close(): Unit = synchronized { + try { + if (terminated) return + underlyingStream.close() + try { + fm.renameTempFile(tempPath, finalPath, overwriteIfPossible) + } catch { + case fe: FileAlreadyExistsException => + logWarning( + s"Failed to rename temp file $tempPath to $finalPath because file exists", fe) + if (!overwriteIfPossible) throw fe + } + logInfo(s"Renamed temp file $tempPath to $finalPath") + } finally { + terminated = true + } + } + + override def cancel(): Unit = synchronized { + try { + if (terminated) return + underlyingStream.close() + fm.delete(tempPath) + } catch { + case NonFatal(e) => + logWarning(s"Error cancelling write to $finalPath", e) + } finally { + terminated = true + } + } + } + + + /** Create an instance of [[CheckpointFileManager]] based on the path and configuration. */ + def create(path: Path, hadoopConf: Configuration): CheckpointFileManager = { + val fileManagerClass = hadoopConf.get( + SQLConf.STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key) + if (fileManagerClass != null) { + return Utils.classForName(fileManagerClass) + .getConstructor(classOf[Path], classOf[Configuration]) + .newInstance(path, hadoopConf) + .asInstanceOf[CheckpointFileManager] + } + try { + // Try to create a manager based on `FileContext` because HDFS's `FileContext.rename() + // gives atomic renames, which is what we rely on for the default implementation + // `CheckpointFileManager.createAtomic`. + new FileContextBasedCheckpointFileManager(path, hadoopConf) + } catch { + case e: UnsupportedFileSystemException => + logWarning( + "Could not use FileContext API for managing Structured Streaming checkpoint files at " + + s"$path. Using FileSystem API instead for managing log files. If the implementation " + + s"of FileSystem.rename() is not atomic, then the correctness and fault-tolerance of" + + s"your Structured Streaming is not guaranteed.") + new FileSystemBasedCheckpointFileManager(path, hadoopConf) + } + } + + private def generateTempPath(path: Path): Path = { + val tc = org.apache.spark.TaskContext.get + val tid = if (tc != null) ".TID" + tc.taskAttemptId else "" + new Path(path.getParent, s".${path.getName}.${UUID.randomUUID}${tid}.tmp") + } +} + + +/** An implementation of [[CheckpointFileManager]] using Hadoop's [[FileSystem]] API. */ +class FileSystemBasedCheckpointFileManager(path: Path, hadoopConf: Configuration) + extends CheckpointFileManager with RenameHelperMethods with Logging { + + import CheckpointFileManager._ + + protected val fs = path.getFileSystem(hadoopConf) + + override def list(path: Path, filter: PathFilter): Array[FileStatus] = { + fs.listStatus(path, filter) + } + + override def mkdirs(path: Path): Unit = { + fs.mkdirs(path, FsPermission.getDirDefault) + } + + override def createTempFile(path: Path): FSDataOutputStream = { + fs.create(path, true) + } + + override def createAtomic( + path: Path, + overwriteIfPossible: Boolean): CancellableFSDataOutputStream = { + new RenameBasedFSDataOutputStream(this, path, overwriteIfPossible) + } + + override def open(path: Path): FSDataInputStream = { + fs.open(path) + } + + override def exists(path: Path): Boolean = { + try + return fs.getFileStatus(path) != null + catch { + case e: FileNotFoundException => + return false + } + } + + override def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit = { + if (!overwriteIfPossible && fs.exists(dstPath)) { + throw new FileAlreadyExistsException( + s"Failed to rename $srcPath to $dstPath as destination already exists") + } + + if (!fs.rename(srcPath, dstPath)) { + // FileSystem.rename() returning false is very ambiguous as it can be for many reasons. + // This tries to make a best effort attempt to return the most appropriate exception. + if (fs.exists(dstPath)) { + if (!overwriteIfPossible) { + throw new FileAlreadyExistsException(s"Failed to rename as $dstPath already exists") + } + } else if (!fs.exists(srcPath)) { + throw new FileNotFoundException(s"Failed to rename as $srcPath was not found") + } else { + val msg = s"Failed to rename temp file $srcPath to $dstPath as rename returned false" + logWarning(msg) + throw new IOException(msg) + } + } + } + + override def delete(path: Path): Unit = { + try { + fs.delete(path, true) + } catch { + case e: FileNotFoundException => + logInfo(s"Failed to delete $path as it does not exist") + // ignore if file has already been deleted + } + } + + override def isLocal: Boolean = fs match { + case _: LocalFileSystem | _: RawLocalFileSystem => true + case _ => false + } +} + + +/** An implementation of [[CheckpointFileManager]] using Hadoop's [[FileContext]] API. */ +class FileContextBasedCheckpointFileManager(path: Path, hadoopConf: Configuration) + extends CheckpointFileManager with RenameHelperMethods with Logging { + + import CheckpointFileManager._ + + private val fc = if (path.toUri.getScheme == null) { + FileContext.getFileContext(hadoopConf) + } else { + FileContext.getFileContext(path.toUri, hadoopConf) + } + + override def list(path: Path, filter: PathFilter): Array[FileStatus] = { + fc.util.listStatus(path, filter) + } + + override def mkdirs(path: Path): Unit = { + fc.mkdir(path, FsPermission.getDirDefault, true) + } + + override def createTempFile(path: Path): FSDataOutputStream = { + import CreateFlag._ + import Options._ + fc.create( + path, EnumSet.of(CREATE, OVERWRITE), CreateOpts.checksumParam(ChecksumOpt.createDisabled())) + } + + override def createAtomic( + path: Path, + overwriteIfPossible: Boolean): CancellableFSDataOutputStream = { + new RenameBasedFSDataOutputStream(this, path, overwriteIfPossible) + } + + override def open(path: Path): FSDataInputStream = { + fc.open(path) + } + + override def exists(path: Path): Boolean = { + fc.util.exists(path) + } + + override def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit = { + import Options.Rename._ + fc.rename(srcPath, dstPath, if (overwriteIfPossible) OVERWRITE else NONE) + } + + + override def delete(path: Path): Unit = { + try { + fc.delete(path, true) + } catch { + case e: FileNotFoundException => + // ignore if file has already been deleted + } + } + + override def isLocal: Boolean = fc.getDefaultFileSystem match { + case _: LocalFs | _: RawLocalFs => true // LocalFs = RawLocalFs + ChecksumFs + case _ => false + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 00bc215a5dc8c..bd0a46115ceb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -57,10 +57,10 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: require(implicitly[ClassTag[T]].runtimeClass != classOf[Seq[_]], "Should not create a log with type Seq, use Arrays instead - see SPARK-17372") - import HDFSMetadataLog._ - val metadataPath = new Path(path) - protected val fileManager = createFileManager() + + protected val fileManager = + CheckpointFileManager.create(metadataPath, sparkSession.sessionState.newHadoopConf) if (!fileManager.exists(metadataPath)) { fileManager.mkdirs(metadataPath) @@ -109,84 +109,31 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: require(metadata != null, "'null' metadata cannot written to a metadata log") get(batchId).map(_ => false).getOrElse { // Only write metadata when the batch has not yet been written - writeBatch(batchId, metadata) + writeBatchToFile(metadata, batchIdToPath(batchId)) true } } - private def writeTempBatch(metadata: T): Option[Path] = { - while (true) { - val tempPath = new Path(metadataPath, s".${UUID.randomUUID.toString}.tmp") - try { - val output = fileManager.create(tempPath) - try { - serialize(metadata, output) - return Some(tempPath) - } finally { - output.close() - } - } catch { - case e: FileAlreadyExistsException => - // Failed to create "tempPath". There are two cases: - // 1. Someone is creating "tempPath" too. - // 2. This is a restart. "tempPath" has already been created but not moved to the final - // batch file (not committed). - // - // For both cases, the batch has not yet been committed. So we can retry it. - // - // Note: there is a potential risk here: if HDFSMetadataLog A is running, people can use - // the same metadata path to create "HDFSMetadataLog" and fail A. However, this is not a - // big problem because it requires the attacker must have the permission to write the - // metadata path. In addition, the old Streaming also have this issue, people can create - // malicious checkpoint files to crash a Streaming application too. - } - } - None - } - - /** - * Write a batch to a temp file then rename it to the batch file. + /** Write a batch to a temp file then rename it to the batch file. * * There may be multiple [[HDFSMetadataLog]] using the same metadata path. Although it is not a * valid behavior, we still need to prevent it from destroying the files. */ - private def writeBatch(batchId: Long, metadata: T): Unit = { - val tempPath = writeTempBatch(metadata).getOrElse( - throw new IllegalStateException(s"Unable to create temp batch file $batchId")) + private def writeBatchToFile(metadata: T, path: Path): Unit = { + val output = fileManager.createAtomic(path, overwriteIfPossible = false) try { - // Try to commit the batch - // It will fail if there is an existing file (someone has committed the batch) - logDebug(s"Attempting to write log #${batchIdToPath(batchId)}") - fileManager.rename(tempPath, batchIdToPath(batchId)) - - // SPARK-17475: HDFSMetadataLog should not leak CRC files - // If the underlying filesystem didn't rename the CRC file, delete it. - val crcPath = new Path(tempPath.getParent(), s".${tempPath.getName()}.crc") - if (fileManager.exists(crcPath)) fileManager.delete(crcPath) + serialize(metadata, output) + output.close() } catch { case e: FileAlreadyExistsException => - // If "rename" fails, it means some other "HDFSMetadataLog" has committed the batch. - // So throw an exception to tell the user this is not a valid behavior. + output.cancel() + // If next batch file already exists, then another concurrently running query has + // written it. throw new ConcurrentModificationException( - s"Multiple HDFSMetadataLog are using $path", e) - } finally { - fileManager.delete(tempPath) - } - } - - /** - * @return the deserialized metadata in a batch file, or None if file not exist. - * @throws IllegalArgumentException when path does not point to a batch file. - */ - def get(batchFile: Path): Option[T] = { - if (fileManager.exists(batchFile)) { - if (isBatchFile(batchFile)) { - get(pathToBatchId(batchFile)) - } else { - throw new IllegalArgumentException(s"File ${batchFile} is not a batch file!") - } - } else { - None + s"Multiple streaming queries are concurrently using $path", e) + case e: Throwable => + output.cancel() + throw e } } @@ -219,7 +166,7 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: (endId.isEmpty || batchId <= endId.get) && (startId.isEmpty || batchId >= startId.get) }.sorted - verifyBatchIds(batchIds, startId, endId) + HDFSMetadataLog.verifyBatchIds(batchIds, startId, endId) batchIds.map(batchId => (batchId, get(batchId))).filter(_._2.isDefined).map { case (batchId, metadataOption) => @@ -280,19 +227,6 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: } } - private def createFileManager(): FileManager = { - val hadoopConf = sparkSession.sessionState.newHadoopConf() - try { - new FileContextManager(metadataPath, hadoopConf) - } catch { - case e: UnsupportedFileSystemException => - logWarning("Could not use FileContext API for managing metadata log files at path " + - s"$metadataPath. Using FileSystem API instead for managing log files. The log may be " + - s"inconsistent under failures.") - new FileSystemManager(metadataPath, hadoopConf) - } - } - /** * Parse the log version from the given `text` -- will throw exception when the parsed version * exceeds `maxSupportedVersion`, or when `text` is malformed (such as "xyz", "v", "v-1", @@ -327,135 +261,6 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: object HDFSMetadataLog { - /** A simple trait to abstract out the file management operations needed by HDFSMetadataLog. */ - trait FileManager { - - /** List the files in a path that match a filter. */ - def list(path: Path, filter: PathFilter): Array[FileStatus] - - /** Make directory at the give path and all its parent directories as needed. */ - def mkdirs(path: Path): Unit - - /** Whether path exists */ - def exists(path: Path): Boolean - - /** Open a file for reading, or throw exception if it does not exist. */ - def open(path: Path): FSDataInputStream - - /** Create path, or throw exception if it already exists */ - def create(path: Path): FSDataOutputStream - - /** - * Atomically rename path, or throw exception if it cannot be done. - * Should throw FileNotFoundException if srcPath does not exist. - * Should throw FileAlreadyExistsException if destPath already exists. - */ - def rename(srcPath: Path, destPath: Path): Unit - - /** Recursively delete a path if it exists. Should not throw exception if file doesn't exist. */ - def delete(path: Path): Unit - } - - /** - * Default implementation of FileManager using newer FileContext API. - */ - class FileContextManager(path: Path, hadoopConf: Configuration) extends FileManager { - private val fc = if (path.toUri.getScheme == null) { - FileContext.getFileContext(hadoopConf) - } else { - FileContext.getFileContext(path.toUri, hadoopConf) - } - - override def list(path: Path, filter: PathFilter): Array[FileStatus] = { - fc.util.listStatus(path, filter) - } - - override def rename(srcPath: Path, destPath: Path): Unit = { - fc.rename(srcPath, destPath) - } - - override def mkdirs(path: Path): Unit = { - fc.mkdir(path, FsPermission.getDirDefault, true) - } - - override def open(path: Path): FSDataInputStream = { - fc.open(path) - } - - override def create(path: Path): FSDataOutputStream = { - fc.create(path, EnumSet.of(CreateFlag.CREATE)) - } - - override def exists(path: Path): Boolean = { - fc.util().exists(path) - } - - override def delete(path: Path): Unit = { - try { - fc.delete(path, true) - } catch { - case e: FileNotFoundException => - // ignore if file has already been deleted - } - } - } - - /** - * Implementation of FileManager using older FileSystem API. Note that this implementation - * cannot provide atomic renaming of paths, hence can lead to consistency issues. This - * should be used only as a backup option, when FileContextManager cannot be used. - */ - class FileSystemManager(path: Path, hadoopConf: Configuration) extends FileManager { - private val fs = path.getFileSystem(hadoopConf) - - override def list(path: Path, filter: PathFilter): Array[FileStatus] = { - fs.listStatus(path, filter) - } - - /** - * Rename a path. Note that this implementation is not atomic. - * @throws FileNotFoundException if source path does not exist. - * @throws FileAlreadyExistsException if destination path already exists. - * @throws IOException if renaming fails for some unknown reason. - */ - override def rename(srcPath: Path, destPath: Path): Unit = { - if (!fs.exists(srcPath)) { - throw new FileNotFoundException(s"Source path does not exist: $srcPath") - } - if (fs.exists(destPath)) { - throw new FileAlreadyExistsException(s"Destination path already exists: $destPath") - } - if (!fs.rename(srcPath, destPath)) { - throw new IOException(s"Failed to rename $srcPath to $destPath") - } - } - - override def mkdirs(path: Path): Unit = { - fs.mkdirs(path, FsPermission.getDirDefault) - } - - override def open(path: Path): FSDataInputStream = { - fs.open(path) - } - - override def create(path: Path): FSDataOutputStream = { - fs.create(path, false) - } - - override def exists(path: Path): Boolean = { - fs.exists(path) - } - - override def delete(path: Path): Unit = { - try { - fs.delete(path, true) - } catch { - case e: FileNotFoundException => - // ignore if file has already been deleted - } - } - } - /** * Verify if batchIds are continuous and between `startId` and `endId`. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 3f5002a4e6937..df722b953228b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.streaming.state -import java.io.{DataInputStream, DataOutputStream, FileNotFoundException, IOException} +import java.io._ import java.nio.channels.ClosedChannelException import java.util.Locale @@ -27,13 +27,16 @@ import scala.util.Random import scala.util.control.NonFatal import com.google.common.io.ByteStreams +import org.apache.commons.io.IOUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs._ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.io.LZ4CompressionCodec import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.CheckpointFileManager +import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream import org.apache.spark.sql.types.StructType import org.apache.spark.util.{SizeEstimator, Utils} @@ -87,10 +90,10 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit case object ABORTED extends STATE private val newVersion = version + 1 - private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") - private lazy val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true)) @volatile private var state: STATE = UPDATING - @volatile private var finalDeltaFile: Path = null + private val finalDeltaFile: Path = deltaFile(newVersion) + private lazy val deltaFileStream = fm.createAtomic(finalDeltaFile, overwriteIfPossible = true) + private lazy val compressedStream = compressStream(deltaFileStream) override def id: StateStoreId = HDFSBackedStateStoreProvider.this.stateStoreId @@ -103,14 +106,14 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit val keyCopy = key.copy() val valueCopy = value.copy() mapToUpdate.put(keyCopy, valueCopy) - writeUpdateToDeltaFile(tempDeltaFileStream, keyCopy, valueCopy) + writeUpdateToDeltaFile(compressedStream, keyCopy, valueCopy) } override def remove(key: UnsafeRow): Unit = { verify(state == UPDATING, "Cannot remove after already committed or aborted") val prevValue = mapToUpdate.remove(key) if (prevValue != null) { - writeRemoveToDeltaFile(tempDeltaFileStream, key) + writeRemoveToDeltaFile(compressedStream, key) } } @@ -126,8 +129,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit verify(state == UPDATING, "Cannot commit after already committed or aborted") try { - finalizeDeltaFile(tempDeltaFileStream) - finalDeltaFile = commitUpdates(newVersion, mapToUpdate, tempDeltaFile) + commitUpdates(newVersion, mapToUpdate, compressedStream) state = COMMITTED logInfo(s"Committed version $newVersion for $this to file $finalDeltaFile") newVersion @@ -140,23 +142,14 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** Abort all the updates made on this store. This store will not be usable any more. */ override def abort(): Unit = { - verify(state == UPDATING || state == ABORTED, "Cannot abort after already committed") - try { + // This if statement is to ensure that files are deleted only if there are changes to the + // StateStore. We have two StateStores for each task, one which is used only for reading, and + // the other used for read+write. We don't want the read-only to delete state files. + if (state == UPDATING) { + state = ABORTED + cancelDeltaFile(compressedStream, deltaFileStream) + } else { state = ABORTED - if (tempDeltaFileStream != null) { - tempDeltaFileStream.close() - } - if (tempDeltaFile != null) { - fs.delete(tempDeltaFile, true) - } - } catch { - case c: ClosedChannelException => - // This can happen when underlying file output stream has been closed before the - // compression stream. - logDebug(s"Error aborting version $newVersion into $this", c) - - case e: Exception => - logWarning(s"Error aborting version $newVersion into $this", e) } logInfo(s"Aborted version $newVersion for $this") } @@ -212,7 +205,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit this.valueSchema = valueSchema this.storeConf = storeConf this.hadoopConf = hadoopConf - fs.mkdirs(baseDir) + fm.mkdirs(baseDir) } override def stateStoreId: StateStoreId = stateStoreId_ @@ -251,31 +244,15 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit private lazy val loadedMaps = new mutable.HashMap[Long, MapType] private lazy val baseDir = stateStoreId.storeCheckpointLocation() - private lazy val fs = baseDir.getFileSystem(hadoopConf) + private lazy val fm = CheckpointFileManager.create(baseDir, hadoopConf) private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) - /** Commit a set of updates to the store with the given new version */ - private def commitUpdates(newVersion: Long, map: MapType, tempDeltaFile: Path): Path = { + private def commitUpdates(newVersion: Long, map: MapType, output: DataOutputStream): Unit = { synchronized { - val finalDeltaFile = deltaFile(newVersion) - - // scalastyle:off - // Renaming a file atop an existing one fails on HDFS - // (http://hadoop.apache.org/docs/stable/hadoop-project-dist/hadoop-common/filesystem/filesystem.html). - // Hence we should either skip the rename step or delete the target file. Because deleting the - // target file will break speculation, skipping the rename step is the only choice. It's still - // semantically correct because Structured Streaming requires rerunning a batch should - // generate the same output. (SPARK-19677) - // scalastyle:on - if (fs.exists(finalDeltaFile)) { - fs.delete(tempDeltaFile, true) - } else if (!fs.rename(tempDeltaFile, finalDeltaFile)) { - throw new IOException(s"Failed to rename $tempDeltaFile to $finalDeltaFile") - } + finalizeDeltaFile(output) loadedMaps.put(newVersion, map) - finalDeltaFile } } @@ -365,7 +342,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit val fileToRead = deltaFile(version) var input: DataInputStream = null val sourceStream = try { - fs.open(fileToRead) + fm.open(fileToRead) } catch { case f: FileNotFoundException => throw new IllegalStateException( @@ -412,12 +389,12 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } private def writeSnapshotFile(version: Long, map: MapType): Unit = { - val fileToWrite = snapshotFile(version) - val tempFile = - new Path(fileToWrite.getParent, s"${fileToWrite.getName}.temp-${Random.nextLong}") + val targetFile = snapshotFile(version) + var rawOutput: CancellableFSDataOutputStream = null var output: DataOutputStream = null - Utils.tryWithSafeFinally { - output = compressStream(fs.create(tempFile, false)) + try { + rawOutput = fm.createAtomic(targetFile, overwriteIfPossible = true) + output = compressStream(rawOutput) val iter = map.entrySet().iterator() while(iter.hasNext) { val entry = iter.next() @@ -429,16 +406,34 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit output.write(valueBytes) } output.writeInt(-1) - } { - if (output != null) output.close() + output.close() + } catch { + case e: Throwable => + cancelDeltaFile(compressedStream = output, rawStream = rawOutput) + throw e } - if (fs.exists(fileToWrite)) { - // Skip rename if the file is alreayd created. - fs.delete(tempFile, true) - } else if (!fs.rename(tempFile, fileToWrite)) { - throw new IOException(s"Failed to rename $tempFile to $fileToWrite") + logInfo(s"Written snapshot file for version $version of $this at $targetFile") + } + + /** + * Try to cancel the underlying stream and safely close the compressed stream. + * + * @param compressedStream the compressed stream. + * @param rawStream the underlying stream which needs to be cancelled. + */ + private def cancelDeltaFile( + compressedStream: DataOutputStream, + rawStream: CancellableFSDataOutputStream): Unit = { + try { + if (rawStream != null) rawStream.cancel() + IOUtils.closeQuietly(compressedStream) + } catch { + case e: FSError if e.getCause.isInstanceOf[IOException] => + // Closing the compressedStream causes the stream to write/flush flush data into the + // rawStream. Since the rawStream is already closed, there may be errors. + // Usually its an IOException. However, Hadoop's RawLocalFileSystem wraps + // IOException into FSError. } - logInfo(s"Written snapshot file for version $version of $this at $fileToWrite") } private def readSnapshotFile(version: Long): Option[MapType] = { @@ -447,7 +442,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit var input: DataInputStream = null try { - input = decompressStream(fs.open(fileToRead)) + input = decompressStream(fm.open(fileToRead)) var eof = false while (!eof) { @@ -508,7 +503,6 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit case None => // The last map is not loaded, probably some other instance is in charge } - } } catch { case NonFatal(e) => @@ -534,7 +528,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } val filesToDelete = files.filter(_.version < earliestFileToRetain.version) filesToDelete.foreach { f => - fs.delete(f.path, true) + fm.delete(f.path) } logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this: " + filesToDelete.mkString(", ")) @@ -576,7 +570,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** Fetch all the files that back the store */ private def fetchFiles(): Seq[StoreFile] = { val files: Seq[FileStatus] = try { - fs.listStatus(baseDir) + fm.list(baseDir) } catch { case _: java.io.FileNotFoundException => Seq.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index d1d9f95cb0977..7eb68c21569ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -459,7 +459,6 @@ object StateStore extends Logging { private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { val env = SparkEnv.get if (env != null) { - logInfo("Env is not null") val isDriver = env.executorId == SparkContext.DRIVER_IDENTIFIER || env.executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER @@ -467,13 +466,12 @@ object StateStore extends Logging { // as SparkContext + SparkEnv may have been restarted. Hence, when running in driver, // always recreate the reference. if (isDriver || _coordRef == null) { - logInfo("Getting StateStoreCoordinatorRef") + logDebug("Getting StateStoreCoordinatorRef") _coordRef = StateStoreCoordinatorRef.forExecutor(env) } logInfo(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}") Some(_coordRef) } else { - logInfo("Env is null") _coordRef = null None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala new file mode 100644 index 0000000000000..fe59cb25d5005 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.io._ +import java.net.URI + +import scala.util.Random + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.Utils + +abstract class CheckpointFileManagerTests extends SparkFunSuite { + + def createManager(path: Path): CheckpointFileManager + + test("mkdirs, list, createAtomic, open, delete, exists") { + withTempPath { p => + val basePath = new Path(p.getAbsolutePath) + val fm = createManager(basePath) + // Mkdirs + val dir = new Path(s"$basePath/dir/subdir/subsubdir") + assert(!fm.exists(dir)) + fm.mkdirs(dir) + assert(fm.exists(dir)) + fm.mkdirs(dir) + + // List + val acceptAllFilter = new PathFilter { + override def accept(path: Path): Boolean = true + } + val rejectAllFilter = new PathFilter { + override def accept(path: Path): Boolean = false + } + assert(fm.list(basePath, acceptAllFilter).exists(_.getPath.getName == "dir")) + assert(fm.list(basePath, rejectAllFilter).length === 0) + + // Create atomic without overwrite + var path = new Path(s"$dir/file") + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = false).cancel() + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = false).close() + assert(fm.exists(path)) + quietly { + intercept[IOException] { + // should throw exception since file exists and overwrite is false + fm.createAtomic(path, overwriteIfPossible = false).close() + } + } + + // Create atomic with overwrite if possible + path = new Path(s"$dir/file2") + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = true).cancel() + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = true).close() + assert(fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = true).close() // should not throw exception + + // Open and delete + fm.open(path).close() + fm.delete(path) + assert(!fm.exists(path)) + intercept[IOException] { + fm.open(path) + } + fm.delete(path) // should not throw exception + } + } + + protected def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } +} + +class CheckpointFileManagerSuite extends SparkFunSuite with SharedSparkSession { + + test("CheckpointFileManager.create() should pick up user-specified class from conf") { + withSQLConf( + SQLConf.STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key -> + classOf[CreateAtomicTestManager].getName) { + val fileManager = + CheckpointFileManager.create(new Path("/"), spark.sessionState.newHadoopConf) + assert(fileManager.isInstanceOf[CreateAtomicTestManager]) + } + } + + test("CheckpointFileManager.create() should fallback from FileContext to FileSystem") { + import CheckpointFileManagerSuiteFileSystem.scheme + spark.conf.set(s"fs.$scheme.impl", classOf[CheckpointFileManagerSuiteFileSystem].getName) + quietly { + withTempDir { temp => + val metadataLog = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") + assert(metadataLog.add(0, "batch0")) + assert(metadataLog.getLatest() === Some(0 -> "batch0")) + assert(metadataLog.get(0) === Some("batch0")) + assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0")) + + val metadataLog2 = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") + assert(metadataLog2.get(0) === Some("batch0")) + assert(metadataLog2.getLatest() === Some(0 -> "batch0")) + assert(metadataLog2.get(None, Some(0)) === Array(0 -> "batch0")) + } + } + } +} + +class FileContextBasedCheckpointFileManagerSuite extends CheckpointFileManagerTests { + override def createManager(path: Path): CheckpointFileManager = { + new FileContextBasedCheckpointFileManager(path, new Configuration()) + } +} + +class FileSystemBasedCheckpointFileManagerSuite extends CheckpointFileManagerTests { + override def createManager(path: Path): CheckpointFileManager = { + new FileSystemBasedCheckpointFileManager(path, new Configuration()) + } +} + + +/** A fake implementation to test different characteristics of CheckpointFileManager interface */ +class CreateAtomicTestManager(path: Path, hadoopConf: Configuration) + extends FileSystemBasedCheckpointFileManager(path, hadoopConf) { + + import CheckpointFileManager._ + + override def createAtomic(path: Path, overwrite: Boolean): CancellableFSDataOutputStream = { + if (CreateAtomicTestManager.shouldFailInCreateAtomic) { + CreateAtomicTestManager.cancelCalledInCreateAtomic = false + } + val originalOut = super.createAtomic(path, overwrite) + + new CancellableFSDataOutputStream(originalOut) { + override def close(): Unit = { + if (CreateAtomicTestManager.shouldFailInCreateAtomic) { + throw new IOException("Copy failed intentionally") + } + super.close() + } + + override def cancel(): Unit = { + CreateAtomicTestManager.cancelCalledInCreateAtomic = true + originalOut.cancel() + } + } + } +} + +object CreateAtomicTestManager { + @volatile var shouldFailInCreateAtomic = false + @volatile var cancelCalledInCreateAtomic = false +} + + +/** + * CheckpointFileManagerSuiteFileSystem to test fallback of the CheckpointFileManager + * from FileContext to FileSystem API. + */ +private class CheckpointFileManagerSuiteFileSystem extends RawLocalFileSystem { + import CheckpointFileManagerSuiteFileSystem.scheme + + override def getUri: URI = { + URI.create(s"$scheme:///") + } +} + +private object CheckpointFileManagerSuiteFileSystem { + val scheme = s"CheckpointFileManagerSuiteFileSystem${math.abs(Random.nextInt)}" +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala index 12eaf63415081..ec961a9ecb592 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -22,15 +22,10 @@ import java.nio.charset.StandardCharsets._ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.streaming.FakeFileSystem._ import org.apache.spark.sql.test.SharedSQLContext class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext { - /** To avoid caching of FS objects */ - override protected def sparkConf = - super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") - import CompactibleFileStreamLog._ /** -- testing of `object CompactibleFileStreamLog` begins -- */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 4677769c12a35..9268306ce4275 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -17,46 +17,22 @@ package org.apache.spark.sql.execution.streaming -import java.io.{File, FileNotFoundException, IOException} -import java.net.URI +import java.io.File import java.util.ConcurrentModificationException import scala.language.implicitConversions -import scala.util.Random -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs._ import org.scalatest.concurrent.Waiters._ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.streaming.FakeFileSystem._ -import org.apache.spark.sql.execution.streaming.HDFSMetadataLog.{FileContextManager, FileManager, FileSystemManager} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.UninterruptibleThread class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { - /** To avoid caching of FS objects */ - override protected def sparkConf = - super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") - private implicit def toOption[A](a: A): Option[A] = Option(a) - test("FileManager: FileContextManager") { - withTempDir { temp => - val path = new Path(temp.getAbsolutePath) - testFileManager(path, new FileContextManager(path, new Configuration)) - } - } - - test("FileManager: FileSystemManager") { - withTempDir { temp => - val path = new Path(temp.getAbsolutePath) - testFileManager(path, new FileSystemManager(path, new Configuration)) - } - } - test("HDFSMetadataLog: basic") { withTempDir { temp => val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir @@ -82,26 +58,6 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } - testQuietly("HDFSMetadataLog: fallback from FileContext to FileSystem") { - spark.conf.set( - s"fs.$scheme.impl", - classOf[FakeFileSystem].getName) - withTempDir { temp => - val metadataLog = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") - assert(metadataLog.add(0, "batch0")) - assert(metadataLog.getLatest() === Some(0 -> "batch0")) - assert(metadataLog.get(0) === Some("batch0")) - assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0")) - - - val metadataLog2 = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") - assert(metadataLog2.get(0) === Some("batch0")) - assert(metadataLog2.getLatest() === Some(0 -> "batch0")) - assert(metadataLog2.get(None, Some(0)) === Array(0 -> "batch0")) - - } - } - test("HDFSMetadataLog: purge") { withTempDir { temp => val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) @@ -121,7 +77,8 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { // There should be exactly one file, called "2", in the metadata directory. // This check also tests for regressions of SPARK-17475 - val allFiles = new File(metadataLog.metadataPath.toString).listFiles().toSeq + val allFiles = new File(metadataLog.metadataPath.toString).listFiles() + .filter(!_.getName.startsWith(".")).toSeq assert(allFiles.size == 1) assert(allFiles(0).getName() == "2") } @@ -172,7 +129,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } - test("HDFSMetadataLog: metadata directory collision") { + testQuietly("HDFSMetadataLog: metadata directory collision") { withTempDir { temp => val waiter = new Waiter val maxBatchId = 100 @@ -206,60 +163,6 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } - /** Basic test case for [[FileManager]] implementation. */ - private def testFileManager(basePath: Path, fm: FileManager): Unit = { - // Mkdirs - val dir = new Path(s"$basePath/dir/subdir/subsubdir") - assert(!fm.exists(dir)) - fm.mkdirs(dir) - assert(fm.exists(dir)) - fm.mkdirs(dir) - - // List - val acceptAllFilter = new PathFilter { - override def accept(path: Path): Boolean = true - } - val rejectAllFilter = new PathFilter { - override def accept(path: Path): Boolean = false - } - assert(fm.list(basePath, acceptAllFilter).exists(_.getPath.getName == "dir")) - assert(fm.list(basePath, rejectAllFilter).length === 0) - - // Create - val path = new Path(s"$dir/file") - assert(!fm.exists(path)) - fm.create(path).close() - assert(fm.exists(path)) - intercept[IOException] { - fm.create(path) - } - - // Open and delete - fm.open(path).close() - fm.delete(path) - assert(!fm.exists(path)) - intercept[IOException] { - fm.open(path) - } - fm.delete(path) // should not throw exception - - // Rename - val path1 = new Path(s"$dir/file1") - val path2 = new Path(s"$dir/file2") - fm.create(path1).close() - assert(fm.exists(path1)) - fm.rename(path1, path2) - intercept[FileNotFoundException] { - fm.rename(path1, path2) - } - val path3 = new Path(s"$dir/file3") - fm.create(path3).close() - assert(fm.exists(path3)) - intercept[FileAlreadyExistsException] { - fm.rename(path2, path3) - } - } - test("verifyBatchIds") { import HDFSMetadataLog.verifyBatchIds verifyBatchIds(Seq(1L, 2L, 3L), Some(1L), Some(3L)) @@ -277,14 +180,3 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { intercept[IllegalStateException](verifyBatchIds(Seq(1, 2, 4, 5), Some(1L), Some(5L))) } } - -/** FakeFileSystem to test fallback of the HDFSMetadataLog from FileContext to FileSystem API */ -class FakeFileSystem extends RawLocalFileSystem { - override def getUri: URI = { - URI.create(s"$scheme:///") - } -} - -object FakeFileSystem { - val scheme = s"HDFSMetadataLogSuite${math.abs(Random.nextInt)}" -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index c843b65020d8c..73f8705060402 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{File, IOException} import java.net.URI import java.util.UUID -import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import scala.collection.mutable @@ -28,17 +27,17 @@ import scala.util.Random import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} +import org.apache.hadoop.fs._ import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.LocalSparkContext._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.count import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -138,7 +137,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(getData(provider, 19) === Set("a" -> 19)) } - test("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { + testQuietly("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { val conf = new Configuration() conf.set("fs.fake.impl", classOf[RenameLikeHDFSFileSystem].getName) conf.set("fs.defaultFS", "fake:///") @@ -344,7 +343,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } } - test("SPARK-18342: commit fails when rename fails") { + testQuietly("SPARK-18342: commit fails when rename fails") { import RenameReturnsFalseFileSystem._ val dir = scheme + "://" + newDir() val conf = new Configuration() @@ -366,7 +365,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] def numTempFiles: Int = { if (deltaFileDir.exists) { - deltaFileDir.listFiles.map(_.getName).count(n => n.contains("temp") && !n.startsWith(".")) + deltaFileDir.listFiles.map(_.getName).count(n => n.endsWith(".tmp")) } else 0 } @@ -471,6 +470,43 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } } + test("error writing [version].delta cancels the output stream") { + + val hadoopConf = new Configuration() + hadoopConf.set( + SQLConf.STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key, + classOf[CreateAtomicTestManager].getName) + val remoteDir = Utils.createTempDir().getAbsolutePath + + val provider = newStoreProvider( + opId = Random.nextInt, partition = 0, dir = remoteDir, hadoopConf = hadoopConf) + + // Disable failure of output stream and generate versions + CreateAtomicTestManager.shouldFailInCreateAtomic = false + for (version <- 1 to 10) { + val store = provider.getStore(version - 1) + put(store, version.toString, version) // update "1" -> 1, "2" -> 2, ... + store.commit() + } + val version10Data = (1L to 10).map(_.toString).map(x => x -> x).toSet + + CreateAtomicTestManager.cancelCalledInCreateAtomic = false + val store = provider.getStore(10) + // Fail commit for next version and verify that reloading resets the files + CreateAtomicTestManager.shouldFailInCreateAtomic = true + put(store, "11", 11) + val e = intercept[IllegalStateException] { quietly { store.commit() } } + assert(e.getCause.isInstanceOf[IOException]) + CreateAtomicTestManager.shouldFailInCreateAtomic = false + + // Abort commit for next version and verify that reloading resets the files + CreateAtomicTestManager.cancelCalledInCreateAtomic = false + val store2 = provider.getStore(10) + put(store2, "11", 11) + store2.abort() + assert(CreateAtomicTestManager.cancelCalledInCreateAtomic) + } + override def newStoreProvider(): HDFSBackedStateStoreProvider = { newStoreProvider(opId = Random.nextInt(), partition = 0) } @@ -720,6 +756,14 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] * this provider */ def getData(storeProvider: ProviderClass, version: Int): Set[(String, Int)] + + protected def testQuietly(name: String)(f: => Unit): Unit = { + test(name) { + quietly { + f + } + } + } } object StateStoreTestsHelper { From 73f28530d6f6dd8aba758ea818c456cf911a5f41 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 14 Apr 2018 08:59:04 +0800 Subject: [PATCH 621/774] [SPARK-23979][SQL] MultiAlias should not be a CodegenFallback ## What changes were proposed in this pull request? Just found `MultiAlias` is a `CodegenFallback`. It should not be as looks like `MultiAlias` won't be evaluated. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #21065 from viirya/multialias-without-codegenfallback. --- .../org/apache/spark/sql/catalyst/analysis/unresolved.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index a65f58fa61ff4..71e23175168e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.parser.ParserUtils import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode} import org.apache.spark.sql.catalyst.trees.TreeNode @@ -335,7 +335,7 @@ case class UnresolvedRegex(regexPattern: String, table: Option[String], caseSens * @param names the names to be associated with each output of computing [[child]]. */ case class MultiAlias(child: Expression, names: Seq[String]) - extends UnaryExpression with NamedExpression with CodegenFallback { + extends UnaryExpression with NamedExpression with Unevaluable { override def name: String = throw new UnresolvedException(this, "name") From c0964935d614bf345535439bce01cbd0e60c86aa Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Mon, 16 Apr 2018 12:01:42 +0800 Subject: [PATCH 622/774] [SPARK-23956][YARN] Use effective RPC port in AM registration ## What changes were proposed in this pull request? We propose not to hard-code the RPC port in the AM registration. ## How was this patch tested? Tested application reports from a pseudo-distributed cluster ``` 18/04/10 14:56:21 INFO Client: client token: N/A diagnostics: N/A ApplicationMaster host: localhost ApplicationMaster RPC port: 58338 queue: default start time: 1523397373659 final status: UNDEFINED tracking URL: http://localhost:8088/proxy/application_1523370127531_0016/ ``` Author: Gera Shegalov Closes #21047 from gerashegalov/gera/am-to-rm-nmhost. --- .../scala/org/apache/spark/deploy/yarn/YarnRMClient.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index c1ae12aabb8cc..17234b120ae13 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -29,7 +29,6 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.util.Utils /** * Handles registering and unregistering the application with the YARN ResourceManager. @@ -71,7 +70,8 @@ private[spark] class YarnRMClient extends Logging { logInfo("Registering the ApplicationMaster") synchronized { - amClient.registerApplicationMaster(Utils.localHostName(), 0, trackingUrl) + amClient.registerApplicationMaster(driverRef.address.host, driverRef.address.port, + trackingUrl) registered = true } new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr, From 69310220319163bac18c9ee69d7da6d92227253b Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 15 Apr 2018 21:45:55 -0700 Subject: [PATCH 623/774] [SPARK-23917][SQL] Add array_max function ## What changes were proposed in this pull request? The PR adds the SQL function `array_max`. It takes an array as argument and returns the maximum value in it. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21024 from mgaido91/SPARK-23917. --- python/pyspark/sql/functions.py | 15 ++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 17 +++++ .../expressions/collectionOperations.scala | 68 ++++++++++++++++++- .../CollectionExpressionsSuite.scala | 10 +++ .../org/apache/spark/sql/functions.scala | 8 +++ .../spark/sql/DataFrameFunctionsSuite.scala | 14 ++++ 8 files changed, 133 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1b192680f0795..f3492ae42639c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2080,6 +2080,21 @@ def size(col): return Column(sc._jvm.functions.size(_to_java_column(col))) +@since(2.4) +def array_max(col): + """ + Collection function: returns the maximum value of the array. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) + >>> df.select(array_max(df.data).alias('max')).collect() + [Row(max=3), Row(max=10)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_max(_to_java_column(col))) + + @since(1.5) def sort_array(col, asc=True): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 131b958239e41..05bfa2dd45340 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -409,6 +409,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), + expression[ArrayMax]("array_max"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 9212c3de1f814..942dfd4292610 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -674,11 +674,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { val evals = evalChildren.map(eval => s""" |${eval.code} - |if (!${eval.isNull} && (${ev.isNull} || - | ${ctx.genGreater(dataType, eval.value, ev.value)})) { - | ${ev.isNull} = false; - | ${ev.value} = ${eval.value}; - |} + |${ctx.reassignIfGreater(dataType, ev, eval)} """.stripMargin ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 0abfc9fa4c465..c86c5beded9d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -699,6 +699,23 @@ class CodegenContext { case _ => s"(${genComp(dataType, c1, c2)}) > 0" } + /** + * Generates code for updating `partialResult` if `item` is greater than it. + * + * @param dataType data type of the expressions + * @param partialResult `ExprCode` representing the partial result which has to be updated + * @param item `ExprCode` representing the new expression to evaluate for the result + */ + def reassignIfGreater(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = { + s""" + |if (!${item.isNull} && (${partialResult.isNull} || + | ${genGreater(dataType, item.value, partialResult.value)})) { + | ${partialResult.isNull} = false; + | ${partialResult.value} = ${item.value}; + |} + """.stripMargin + } + /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 91188da8b0bd3..e2614a179aad8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -21,7 +21,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ /** @@ -287,3 +287,69 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Returns the maximum value in the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 20, null, 3)); + 20 + """, since = "2.4.0") +case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + } else { + typeCheckResult + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val i = ctx.freshName("i") + val item = ExprCode("", + isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), + value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) + ev.copy(code = + s""" + |${childGen.code} + |boolean ${ev.isNull} = true; + |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${childGen.isNull}) { + | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) { + | ${ctx.reassignIfGreater(dataType, ev, item)} + | } + |} + """.stripMargin) + } + + override protected def nullSafeEval(input: Any): Any = { + var max: Any = null + input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => + if (item != null && (max == null || ordering.gt(item, max))) { + max = item + } + ) + max + } + + override def dataType: DataType = child.dataType match { + case ArrayType(dt, _) => dt + case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") + } + + override def prettyName: String = "array_max" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 020687e4b3a27..a2384019533b7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -105,4 +105,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + + test("Array max") { + checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), ArrayType(IntegerType))), 10) + checkEvaluation( + ArrayMax(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "abc") + checkEvaluation(ArrayMax(Literal.create(Seq(null), ArrayType(LongType))), null) + checkEvaluation(ArrayMax(Literal.create(null, ArrayType(StringType))), null) + checkEvaluation( + ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c658f25ced053..daf407926dca4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3300,6 +3300,14 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } + /** + * Returns the maximum value in the array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) } + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 50e475984f458..5d5d92c84df6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,6 +413,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("array_max function") { + val df = Seq( + Seq[Option[Int]](Some(1), Some(3), Some(2)), + Seq.empty[Option[Int]], + Seq[Option[Int]](None), + Seq[Option[Int]](None, Some(1), Some(-100)) + ).toDF("a") + + val answer = Seq(Row(3), Row(null), Row(null), Row(1)) + + checkAnswer(df.select(array_max(df("a"))), answer) + checkAnswer(df.selectExpr("array_max(a)"), answer) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 083cf223569b7896e35ff1d53a73498a4971b28d Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 16 Apr 2018 23:50:50 +0800 Subject: [PATCH 624/774] [SPARK-21033][CORE][FOLLOW-UP] Update Spillable ## What changes were proposed in this pull request? Update ```scala SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MaxValue) ``` to ```scala SparkEnv.get.conf.get(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD) ``` because of `SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD`'s default value is `Integer.MAX_VALUE`: https://github.com/apache/spark/blob/c99fc9ad9b600095baba003053dbf84304ca392b/core/src/main/scala/org/apache/spark/internal/config/package.scala#L503-L511 ## How was this patch tested? N/A Author: Yuming Wang Closes #21077 from wangyum/SPARK-21033. --- .../org/apache/spark/util/collection/Spillable.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 8183f825592c0..81457b53cd814 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -19,6 +19,7 @@ package org.apache.spark.util.collection import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager} /** @@ -41,7 +42,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) protected def forceSpill(): Boolean // Number of elements read from input since last spill - protected def elementsRead: Long = _elementsRead + protected def elementsRead: Int = _elementsRead // Called by subclasses every time a record is read // It's used for checking spilling frequency @@ -54,15 +55,15 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) // Force this collection to spill when there are this many elements in memory // For testing only - private[this] val numElementsForceSpillThreshold: Long = - SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MaxValue) + private[this] val numElementsForceSpillThreshold: Int = + SparkEnv.get.conf.get(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD) // Threshold for this collection's size in bytes before we start tracking its memory usage // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0 @volatile private[this] var myMemoryThreshold = initialMemoryThreshold // Number of elements read from input since last spill - private[this] var _elementsRead = 0L + private[this] var _elementsRead = 0 // Number of bytes spilled in total @volatile private[this] var _memoryBytesSpilled = 0L From 5003736ad60c3231bb18264c9561646c08379170 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Mon, 16 Apr 2018 11:27:30 -0500 Subject: [PATCH 625/774] [SPARK-9312][ML] Add RawPrediction, numClasses, and numFeatures for OneVsRestModel add RawPrediction as output column add numClasses and numFeatures to OneVsRestModel ## What changes were proposed in this pull request? - Add two val numClasses and numFeatures in OneVsRestModel so that we can inherit from Classifier in the future - Add rawPrediction output column in transform, the prediction label in calculated by the rawPrediciton like raw2prediction ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21044 from ludatabricks/SPARK-9312. --- .../spark/ml/classification/OneVsRest.scala | 56 +++++++++++++++---- .../ml/classification/OneVsRestSuite.scala | 7 ++- 2 files changed, 51 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index f04fde2cbbca1..5348d882cfd67 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -32,7 +32,7 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol} import org.apache.spark.ml.util._ @@ -55,7 +55,7 @@ private[ml] trait ClassifierTypeTrait { /** * Params for [[OneVsRest]]. */ -private[ml] trait OneVsRestParams extends PredictorParams +private[ml] trait OneVsRestParams extends ClassifierParams with ClassifierTypeTrait with HasWeightCol { /** @@ -138,6 +138,14 @@ final class OneVsRestModel private[ml] ( @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]]) extends Model[OneVsRestModel] with OneVsRestParams with MLWritable { + require(models.nonEmpty, "OneVsRestModel requires at least one model for one class") + + @Since("2.4.0") + val numClasses: Int = models.length + + @Since("2.4.0") + val numFeatures: Int = models.head.numFeatures + /** @group setParam */ @Since("2.1.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) @@ -146,6 +154,10 @@ final class OneVsRestModel private[ml] ( @Since("2.1.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group setParam */ + @Since("2.4.0") + def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value) + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) @@ -181,6 +193,7 @@ final class OneVsRestModel private[ml] ( val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) => predictions + ((index, prediction(1))) } + model.setFeaturesCol($(featuresCol)) val transformedDataset = model.transform(df).select(columns: _*) val updatedDataset = transformedDataset @@ -195,15 +208,34 @@ final class OneVsRestModel private[ml] ( newDataset.unpersist() } - // output the index of the classifier with highest confidence as prediction - val labelUDF = udf { (predictions: Map[Int, Double]) => - predictions.maxBy(_._2)._1.toDouble - } + if (getRawPredictionCol != "") { + val numClass = models.length - // output label and label metadata as prediction - aggregatedDataset - .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata) - .drop(accColName) + // output the RawPrediction as vector + val rawPredictionUDF = udf { (predictions: Map[Int, Double]) => + val predArray = Array.fill[Double](numClass)(0.0) + predictions.foreach { case (idx, value) => predArray(idx) = value } + Vectors.dense(predArray) + } + + // output the index of the classifier with highest confidence as prediction + val labelUDF = udf { (rawPredictions: Vector) => rawPredictions.argmax.toDouble } + + // output confidence as raw prediction, label and label metadata as prediction + aggregatedDataset + .withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName))) + .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata) + .drop(accColName) + } else { + // output the index of the classifier with highest confidence as prediction + val labelUDF = udf { (predictions: Map[Int, Double]) => + predictions.maxBy(_._2)._1.toDouble + } + // output label and label metadata as prediction + aggregatedDataset + .withColumn(getPredictionCol, labelUDF(col(accColName)), labelMetadata) + .drop(accColName) + } } @Since("1.4.1") @@ -297,6 +329,10 @@ final class OneVsRest @Since("1.4.0") ( @Since("1.5.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group setParam */ + @Since("2.4.0") + def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value) + /** * The implementation of parallel one vs. rest runs the classification for * each class in a separate threads. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 11e88367108b4..2c3417c7e4028 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -72,11 +72,12 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { .setClassifier(new LogisticRegression) assert(ova.getLabelCol === "label") assert(ova.getPredictionCol === "prediction") + assert(ova.getRawPredictionCol === "rawPrediction") val ovaModel = ova.fit(dataset) MLTestingUtils.checkCopyAndUids(ova, ovaModel) - assert(ovaModel.models.length === numClasses) + assert(ovaModel.numClasses === numClasses) val transformedDataset = ovaModel.transform(dataset) // check for label metadata in prediction col @@ -179,6 +180,7 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { val dataset2 = dataset.select(col("label").as("y"), col("features").as("fea")) ovaModel.setFeaturesCol("fea") ovaModel.setPredictionCol("pred") + ovaModel.setRawPredictionCol("") val transformedDataset = ovaModel.transform(dataset2) val outputFields = transformedDataset.schema.fieldNames.toSet assert(outputFields === Set("y", "fea", "pred")) @@ -190,7 +192,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { val ovr = new OneVsRest() .setClassifier(logReg) val output = ovr.fit(dataset).transform(dataset) - assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) + assert(output.schema.fieldNames.toSet + === Set("label", "features", "prediction", "rawPrediction")) } test("SPARK-21306: OneVsRest should support setWeightCol") { From 04614820e103feeae91299dc90dba1dd628fd485 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 16 Apr 2018 11:31:24 -0500 Subject: [PATCH 626/774] [SPARK-21088][ML] CrossValidator, TrainValidationSplit support collect all models when fitting: Python API ## What changes were proposed in this pull request? Add python API for collecting sub-models during CrossValidator/TrainValidationSplit fitting. ## How was this patch tested? UT added. Author: WeichenXu Closes #19627 from WeichenXu123/expose-model-list-py. --- .../spark/ml/tuning/CrossValidator.scala | 11 ++ .../ml/tuning/TrainValidationSplit.scala | 11 ++ .../ml/param/_shared_params_code_gen.py | 5 + python/pyspark/ml/param/shared.py | 24 ++++ python/pyspark/ml/tests.py | 78 +++++++++++++ python/pyspark/ml/tuning.py | 107 +++++++++++++----- python/pyspark/ml/util.py | 4 + 7 files changed, 211 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index a0b507d2e718c..c2826dcc08634 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -270,6 +270,17 @@ class CrossValidatorModel private[ml] ( this } + // A Python-friendly auxiliary method + private[tuning] def setSubModels(subModels: JList[JList[Model[_]]]) + : CrossValidatorModel = { + _subModels = if (subModels != null) { + Some(subModels.asScala.toArray.map(_.asScala.toArray)) + } else { + None + } + this + } + /** * @return submodels represented in two dimension array. The index of outer array is the * fold index, and the index of inner array corresponds to the ordering of diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 88ff0dfd75e96..8d1b9a8ddab59 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -262,6 +262,17 @@ class TrainValidationSplitModel private[ml] ( this } + // A Python-friendly auxiliary method + private[tuning] def setSubModels(subModels: JList[Model[_]]) + : TrainValidationSplitModel = { + _subModels = if (subModels != null) { + Some(subModels.asScala.toArray) + } else { + None + } + this + } + /** * @return submodels represented in array. The index of array corresponds to the ordering of * estimatorParamMaps diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index db951d81de1e7..6e9e0a34cdfde 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -157,6 +157,11 @@ def get$Name(self): "TypeConverters.toInt"), ("parallelism", "the number of threads to use when running parallel algorithms (>= 1).", "1", "TypeConverters.toInt"), + ("collectSubModels", "Param for whether to collect a list of sub-models trained during " + + "tuning. If set to false, then only the single best sub-model will be available after " + + "fitting. If set to true, then all sub-models will be available. Warning: For large " + + "models, collecting all sub-models can cause OOMs on the Spark driver.", + "False", "TypeConverters.toBoolean"), ("loss", "the loss function to be optimized.", None, "TypeConverters.toString")] code = [] diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 474c38764e5a1..08408ee8fbfcc 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -655,6 +655,30 @@ def getParallelism(self): return self.getOrDefault(self.parallelism) +class HasCollectSubModels(Params): + """ + Mixin for param collectSubModels: Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver. + """ + + collectSubModels = Param(Params._dummy(), "collectSubModels", "Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver.", typeConverter=TypeConverters.toBoolean) + + def __init__(self): + super(HasCollectSubModels, self).__init__() + self._setDefault(collectSubModels=False) + + def setCollectSubModels(self, value): + """ + Sets the value of :py:attr:`collectSubModels`. + """ + return self._set(collectSubModels=value) + + def getCollectSubModels(self): + """ + Gets the value of collectSubModels or its default value. + """ + return self.getOrDefault(self.collectSubModels) + + class HasLoss(Params): """ Mixin for param loss: the loss function to be optimized. diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 4ce54547eab09..2ec0be60e9fa9 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1018,6 +1018,50 @@ def test_parallel_evaluation(self): cvParallelModel = cv.fit(dataset) self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics) + def test_expose_sub_models(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + + numFolds = 3 + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + numFolds=numFolds, collectSubModels=True) + + def checkSubModels(subModels): + self.assertEqual(len(subModels), numFolds) + for i in range(numFolds): + self.assertEqual(len(subModels[i]), len(grid)) + + cvModel = cv.fit(dataset) + checkSubModels(cvModel.subModels) + + # Test the default value for option "persistSubModel" to be "true" + testSubPath = temp_path + "/testCrossValidatorSubModels" + savingPathWithSubModels = testSubPath + "cvModel3" + cvModel.save(savingPathWithSubModels) + cvModel3 = CrossValidatorModel.load(savingPathWithSubModels) + checkSubModels(cvModel3.subModels) + cvModel4 = cvModel3.copy() + checkSubModels(cvModel4.subModels) + + savingPathWithoutSubModels = testSubPath + "cvModel2" + cvModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) + cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels) + self.assertEqual(cvModel2.subModels, None) + + for i in range(numFolds): + for j in range(len(grid)): + self.assertEqual(cvModel.subModels[i][j].uid, cvModel3.subModels[i][j].uid) + def test_save_load_nested_estimator(self): temp_path = tempfile.mkdtemp() dataset = self.spark.createDataFrame( @@ -1186,6 +1230,40 @@ def test_parallel_evaluation(self): tvsParallelModel = tvs.fit(dataset) self.assertEqual(tvsSerialModel.validationMetrics, tvsParallelModel.validationMetrics) + def test_expose_sub_models(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + collectSubModels=True) + tvsModel = tvs.fit(dataset) + self.assertEqual(len(tvsModel.subModels), len(grid)) + + # Test the default value for option "persistSubModel" to be "true" + testSubPath = temp_path + "/testTrainValidationSplitSubModels" + savingPathWithSubModels = testSubPath + "cvModel3" + tvsModel.save(savingPathWithSubModels) + tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels) + self.assertEqual(len(tvsModel3.subModels), len(grid)) + tvsModel4 = tvsModel3.copy() + self.assertEqual(len(tvsModel4.subModels), len(grid)) + + savingPathWithoutSubModels = testSubPath + "cvModel2" + tvsModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) + tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels) + self.assertEqual(tvsModel2.subModels, None) + + for i in range(len(grid)): + self.assertEqual(tvsModel.subModels[i].uid, tvsModel3.subModels[i].uid) + def test_save_load_nested_estimator(self): # This tests saving and loading the trained model only. # Save/load for TrainValidationSplit will be added later: SPARK-13786 diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 545e24ca05aa5..0c8029f293cfe 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -24,7 +24,7 @@ from pyspark.ml import Estimator, Model from pyspark.ml.common import _py2java from pyspark.ml.param import Params, Param, TypeConverters -from pyspark.ml.param.shared import HasParallelism, HasSeed +from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed from pyspark.ml.util import * from pyspark.ml.wrapper import JavaParams from pyspark.sql.functions import rand @@ -33,7 +33,7 @@ 'TrainValidationSplitModel'] -def _parallelFitTasks(est, train, eva, validation, epm): +def _parallelFitTasks(est, train, eva, validation, epm, collectSubModel): """ Creates a list of callables which can be called from different threads to fit and evaluate an estimator in parallel. Each callable returns an `(index, metric)` pair. @@ -43,14 +43,15 @@ def _parallelFitTasks(est, train, eva, validation, epm): :param eva: Evaluator, used to compute `metric` :param validation: DataFrame, validation data set, used for evaluation. :param epm: Sequence of ParamMap, params maps to be used during fitting & evaluation. - :return: (int, float), an index into `epm` and the associated metric value. + :param collectSubModel: Whether to collect sub model. + :return: (int, float, subModel), an index into `epm` and the associated metric value. """ modelIter = est.fitMultiple(train, epm) def singleTask(): index, model = next(modelIter) metric = eva.evaluate(model.transform(validation, epm[index])) - return index, metric + return index, metric, model if collectSubModel else None return [singleTask] * len(epm) @@ -194,7 +195,8 @@ def _to_java_impl(self): return java_estimator, java_epms, java_evaluator -class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable): +class CrossValidator(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels, + MLReadable, MLWritable): """ K-fold cross validation performs model selection by splitting the dataset into a set of @@ -233,10 +235,10 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLW @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, - seed=None, parallelism=1): + seed=None, parallelism=1, collectSubModels=False): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ - seed=None, parallelism=1) + seed=None, parallelism=1, collectSubModels=False) """ super(CrossValidator, self).__init__() self._setDefault(numFolds=3, parallelism=1) @@ -246,10 +248,10 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF @keyword_only @since("1.4.0") def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, - seed=None, parallelism=1): + seed=None, parallelism=1, collectSubModels=False): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ - seed=None, parallelism=1): + seed=None, parallelism=1, collectSubModels=False): Sets params for cross validator. """ kwargs = self._input_kwargs @@ -282,6 +284,10 @@ def _fit(self, dataset): metrics = [0.0] * numModels pool = ThreadPool(processes=min(self.getParallelism(), numModels)) + subModels = None + collectSubModelsParam = self.getCollectSubModels() + if collectSubModelsParam: + subModels = [[None for j in range(numModels)] for i in range(nFolds)] for i in range(nFolds): validateLB = i * h @@ -290,9 +296,12 @@ def _fit(self, dataset): validation = df.filter(condition).cache() train = df.filter(~condition).cache() - tasks = _parallelFitTasks(est, train, eva, validation, epm) - for j, metric in pool.imap_unordered(lambda f: f(), tasks): + tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam) + for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks): metrics[j] += (metric / nFolds) + if collectSubModelsParam: + subModels[i][j] = subModel + validation.unpersist() train.unpersist() @@ -301,7 +310,7 @@ def _fit(self, dataset): else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) - return self._copyValues(CrossValidatorModel(bestModel, metrics)) + return self._copyValues(CrossValidatorModel(bestModel, metrics, subModels)) @since("1.4.0") def copy(self, extra=None): @@ -345,9 +354,11 @@ def _from_java(cls, java_stage): numFolds = java_stage.getNumFolds() seed = java_stage.getSeed() parallelism = java_stage.getParallelism() + collectSubModels = java_stage.getCollectSubModels() # Create a new instance of this stage. py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, - numFolds=numFolds, seed=seed, parallelism=parallelism) + numFolds=numFolds, seed=seed, parallelism=parallelism, + collectSubModels=collectSubModels) py_stage._resetUid(java_stage.uid()) return py_stage @@ -367,6 +378,7 @@ def _to_java(self): _java_obj.setSeed(self.getSeed()) _java_obj.setNumFolds(self.getNumFolds()) _java_obj.setParallelism(self.getParallelism()) + _java_obj.setCollectSubModels(self.getCollectSubModels()) return _java_obj @@ -381,13 +393,15 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): .. versionadded:: 1.4.0 """ - def __init__(self, bestModel, avgMetrics=[]): + def __init__(self, bestModel, avgMetrics=[], subModels=None): super(CrossValidatorModel, self).__init__() #: best model from cross validation self.bestModel = bestModel #: Average cross-validation metrics for each paramMap in #: CrossValidator.estimatorParamMaps, in the corresponding order. self.avgMetrics = avgMetrics + #: sub model list from cross validation + self.subModels = subModels def _transform(self, dataset): return self.bestModel.transform(dataset) @@ -399,6 +413,7 @@ def copy(self, extra=None): and some extra params. This copies the underlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. + It does not copy the extra Params into the subModels. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance @@ -407,7 +422,8 @@ def copy(self, extra=None): extra = dict() bestModel = self.bestModel.copy(extra) avgMetrics = self.avgMetrics - return CrossValidatorModel(bestModel, avgMetrics) + subModels = self.subModels + return CrossValidatorModel(bestModel, avgMetrics, subModels) @since("2.3.0") def write(self): @@ -426,13 +442,17 @@ def _from_java(cls, java_stage): Given a Java CrossValidatorModel, create and return a Python wrapper of it. Used for ML persistence. """ - bestModel = JavaParams._from_java(java_stage.bestModel()) estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) py_stage = cls(bestModel=bestModel).setEstimator(estimator) py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) + if java_stage.hasSubModels(): + py_stage.subModels = [[JavaParams._from_java(sub_model) + for sub_model in fold_sub_models] + for fold_sub_models in java_stage.subModels()] + py_stage._resetUid(java_stage.uid()) return py_stage @@ -454,10 +474,16 @@ def _to_java(self): _java_obj.set("evaluator", evaluator) _java_obj.set("estimator", estimator) _java_obj.set("estimatorParamMaps", epms) + + if self.subModels is not None: + java_sub_models = [[sub_model._to_java() for sub_model in fold_sub_models] + for fold_sub_models in self.subModels] + _java_obj.setSubModels(java_sub_models) return _java_obj -class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable): +class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels, + MLReadable, MLWritable): """ .. note:: Experimental @@ -492,10 +518,10 @@ class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadabl @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, - parallelism=1, seed=None): + parallelism=1, collectSubModels=False, seed=None): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ - parallelism=1, seed=None) + parallelism=1, collectSubModels=False, seed=None) """ super(TrainValidationSplit, self).__init__() self._setDefault(trainRatio=0.75, parallelism=1) @@ -505,10 +531,10 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trai @since("2.0.0") @keyword_only def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, - parallelism=1, seed=None): + parallelism=1, collectSubModels=False, seed=None): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ - parallelism=1, seed=None): + parallelism=1, collectSubModels=False, seed=None): Sets params for the train validation split. """ kwargs = self._input_kwargs @@ -541,11 +567,19 @@ def _fit(self, dataset): validation = df.filter(condition).cache() train = df.filter(~condition).cache() - tasks = _parallelFitTasks(est, train, eva, validation, epm) + subModels = None + collectSubModelsParam = self.getCollectSubModels() + if collectSubModelsParam: + subModels = [None for i in range(numModels)] + + tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam) pool = ThreadPool(processes=min(self.getParallelism(), numModels)) metrics = [None] * numModels - for j, metric in pool.imap_unordered(lambda f: f(), tasks): + for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks): metrics[j] = metric + if collectSubModelsParam: + subModels[j] = subModel + train.unpersist() validation.unpersist() @@ -554,7 +588,7 @@ def _fit(self, dataset): else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) - return self._copyValues(TrainValidationSplitModel(bestModel, metrics)) + return self._copyValues(TrainValidationSplitModel(bestModel, metrics, subModels)) @since("2.0.0") def copy(self, extra=None): @@ -598,9 +632,11 @@ def _from_java(cls, java_stage): trainRatio = java_stage.getTrainRatio() seed = java_stage.getSeed() parallelism = java_stage.getParallelism() + collectSubModels = java_stage.getCollectSubModels() # Create a new instance of this stage. py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, - trainRatio=trainRatio, seed=seed, parallelism=parallelism) + trainRatio=trainRatio, seed=seed, parallelism=parallelism, + collectSubModels=collectSubModels) py_stage._resetUid(java_stage.uid()) return py_stage @@ -620,7 +656,7 @@ def _to_java(self): _java_obj.setTrainRatio(self.getTrainRatio()) _java_obj.setSeed(self.getSeed()) _java_obj.setParallelism(self.getParallelism()) - + _java_obj.setCollectSubModels(self.getCollectSubModels()) return _java_obj @@ -633,12 +669,14 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): .. versionadded:: 2.0.0 """ - def __init__(self, bestModel, validationMetrics=[]): + def __init__(self, bestModel, validationMetrics=[], subModels=None): super(TrainValidationSplitModel, self).__init__() - #: best model from cross validation + #: best model from train validation split self.bestModel = bestModel #: evaluated validation metrics self.validationMetrics = validationMetrics + #: sub models from train validation split + self.subModels = subModels def _transform(self, dataset): return self.bestModel.transform(dataset) @@ -651,6 +689,7 @@ def copy(self, extra=None): creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. And, this creates a shallow copy of the validationMetrics. + It does not copy the extra Params into the subModels. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance @@ -659,7 +698,8 @@ def copy(self, extra=None): extra = dict() bestModel = self.bestModel.copy(extra) validationMetrics = list(self.validationMetrics) - return TrainValidationSplitModel(bestModel, validationMetrics) + subModels = self.subModels + return TrainValidationSplitModel(bestModel, validationMetrics, subModels) @since("2.3.0") def write(self): @@ -687,6 +727,10 @@ def _from_java(cls, java_stage): py_stage = cls(bestModel=bestModel).setEstimator(estimator) py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) + if java_stage.hasSubModels(): + py_stage.subModels = [JavaParams._from_java(sub_model) + for sub_model in java_stage.subModels()] + py_stage._resetUid(java_stage.uid()) return py_stage @@ -708,6 +752,11 @@ def _to_java(self): _java_obj.set("evaluator", evaluator) _java_obj.set("estimator", estimator) _java_obj.set("estimatorParamMaps", epms) + + if self.subModels is not None: + java_sub_models = [sub_model._to_java() for sub_model in self.subModels] + _java_obj.setSubModels(java_sub_models) + return _java_obj diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index c3c47bd79459a..a486c6a3fdeb5 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -169,6 +169,10 @@ def overwrite(self): self._jwrite.overwrite() return self + def option(self, key, value): + self._jwrite.option(key, value) + return self + def context(self, sqlContext): """ Sets the SQL context to use for saving. From fd990a908b94d1c90c4ca604604f35a13b453d44 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 16 Apr 2018 22:45:57 +0200 Subject: [PATCH 627/774] [SPARK-23873][SQL] Use accessors in interpreted LambdaVariable ## What changes were proposed in this pull request? Currently, interpreted execution of `LambdaVariable` just uses `InternalRow.get` to access element. We should use specified accessors if possible. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #20981 from viirya/SPARK-23873. --- .../spark/sql/catalyst/InternalRow.scala | 26 ++++++++++++- .../catalyst/expressions/BoundAttribute.scala | 22 ++--------- .../expressions/objects/objects.scala | 8 +++- .../expressions/ExpressionEvalHelper.scala | 4 +- .../expressions/ObjectExpressionsSuite.scala | 38 ++++++++++++++++++- 5 files changed, 75 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 29110640d64f2..274d75e680f03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} -import org.apache.spark.sql.types.{DataType, Decimal, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** @@ -119,4 +119,28 @@ object InternalRow { case v: MapData => v.copy() case _ => value } + + /** + * Returns an accessor for an `InternalRow` with given data type. The returned accessor + * actually takes a `SpecializedGetters` input because it can be generalized to other classes + * that implements `SpecializedGetters` (e.g., `ArrayData`) too. + */ + def getAccessor(dataType: DataType): (SpecializedGetters, Int) => Any = dataType match { + case BooleanType => (input, ordinal) => input.getBoolean(ordinal) + case ByteType => (input, ordinal) => input.getByte(ordinal) + case ShortType => (input, ordinal) => input.getShort(ordinal) + case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal) + case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal) + case FloatType => (input, ordinal) => input.getFloat(ordinal) + case DoubleType => (input, ordinal) => input.getDouble(ordinal) + case StringType => (input, ordinal) => input.getUTF8String(ordinal) + case BinaryType => (input, ordinal) => input.getBinary(ordinal) + case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal) + case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale) + case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size) + case _: ArrayType => (input, ordinal) => input.getArray(ordinal) + case _: MapType => (input, ordinal) => input.getMap(ordinal) + case u: UserDefinedType[_] => getAccessor(u.sqlType) + case _ => (input, ordinal) => input.get(ordinal, dataType) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 5021a567592e0..4cc84b27d9eb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -33,28 +33,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]" + private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType) + // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { - if (input.isNullAt(ordinal)) { + if (nullable && input.isNullAt(ordinal)) { null } else { - dataType match { - case BooleanType => input.getBoolean(ordinal) - case ByteType => input.getByte(ordinal) - case ShortType => input.getShort(ordinal) - case IntegerType | DateType => input.getInt(ordinal) - case LongType | TimestampType => input.getLong(ordinal) - case FloatType => input.getFloat(ordinal) - case DoubleType => input.getDouble(ordinal) - case StringType => input.getUTF8String(ordinal) - case BinaryType => input.getBinary(ordinal) - case CalendarIntervalType => input.getInterval(ordinal) - case t: DecimalType => input.getDecimal(ordinal, t.precision, t.scale) - case t: StructType => input.getStruct(ordinal, t.size) - case _: ArrayType => input.getArray(ordinal) - case _: MapType => input.getMap(ordinal) - case _ => input.get(ordinal, dataType) - } + accessor(input, ordinal) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 50e90ca550807..77802e89e942b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -560,11 +560,17 @@ case class LambdaVariable( dataType: DataType, nullable: Boolean = true) extends LeafExpression with NonSQLExpression { + private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType) + // Interpreted execution of `LambdaVariable` always get the 0-index element from input row. override def eval(input: InternalRow): Any = { assert(input.numFields == 1, "The input row of interpreted LambdaVariable should have only 1 field.") - input.get(0, dataType) + if (nullable && input.isNullAt(0)) { + null + } else { + accessor(input, 0) + } } override def genCode(ctx: CodegenContext): ExprCode = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a5ecd1b68fac4..b4bf6d7107d7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -70,7 +70,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { * Check the equality between result of expression and expected value, it will handle * Array[Byte], Spread[Double], MapData and Row. */ - protected def checkResult(result: Any, expected: Any, dataType: DataType): Boolean = { + protected def checkResult(result: Any, expected: Any, exprDataType: DataType): Boolean = { + val dataType = UserDefinedType.sqlType(exprDataType) + (result, expected) match { case (result: Array[Byte], expected: Array[Byte]) => java.util.Arrays.equals(result, expected) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index b1bc67dfac1b5..b0188b0098def 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -21,13 +21,14 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import scala.reflect.ClassTag +import scala.util.Random import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} -import org.apache.spark.sql.Row +import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util._ @@ -381,6 +382,39 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(decodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) } } + + test("LambdaVariable should support interpreted execution") { + def genSchema(dt: DataType): Seq[StructType] = { + Seq(StructType(StructField("col_1", dt, nullable = false) :: Nil), + StructType(StructField("col_1", dt, nullable = true) :: Nil)) + } + + val elementTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, + DoubleType, DecimalType.USER_DEFAULT, StringType, BinaryType, DateType, TimestampType, + CalendarIntervalType, new ExamplePointUDT()) + val arrayTypes = elementTypes.flatMap { elementType => + Seq(ArrayType(elementType, containsNull = false), ArrayType(elementType, containsNull = true)) + } + val mapTypes = elementTypes.flatMap { elementType => + Seq(MapType(elementType, elementType, false), MapType(elementType, elementType, true)) + } + val structTypes = elementTypes.flatMap { elementType => + Seq(StructType(StructField("col1", elementType, false) :: Nil), + StructType(StructField("col1", elementType, true) :: Nil)) + } + + val testTypes = elementTypes ++ arrayTypes ++ mapTypes ++ structTypes + val random = new Random(100) + testTypes.foreach { dt => + genSchema(dt).map { schema => + val row = RandomDataGenerator.randomRow(random, schema) + val rowConverter = RowEncoder(schema) + val internalRow = rowConverter.toRow(row) + val lambda = LambdaVariable("dummy", "dummuIsNull", schema(0).dataType, schema(0).nullable) + checkEvaluationWithoutCodegen(lambda, internalRow.get(0, schema(0).dataType), internalRow) + } + } + } } class TestBean extends Serializable { From 14844a62c025e7299029d7452b8c4003bc221ac8 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 17 Apr 2018 17:55:35 +0900 Subject: [PATCH 628/774] [SPARK-23918][SQL] Add array_min function ## What changes were proposed in this pull request? The PR adds the SQL function `array_min`. It takes an array as argument and returns the minimum value in it. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21025 from mgaido91/SPARK-23918. --- python/pyspark/sql/functions.py | 17 ++++- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 17 +++++ .../expressions/collectionOperations.scala | 64 +++++++++++++++++++ .../CollectionExpressionsSuite.scala | 10 +++ .../org/apache/spark/sql/functions.scala | 8 +++ .../spark/sql/DataFrameFunctionsSuite.scala | 14 ++++ 8 files changed, 131 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f3492ae42639c..6ca22b610843d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2080,6 +2080,21 @@ def size(col): return Column(sc._jvm.functions.size(_to_java_column(col))) +@since(2.4) +def array_min(col): + """ + Collection function: returns the minimum value of the array. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) + >>> df.select(array_min(df.data).alias('min')).collect() + [Row(min=1), Row(min=-1)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_min(_to_java_column(col))) + + @since(2.4) def array_max(col): """ @@ -2108,7 +2123,7 @@ def sort_array(col, asc=True): [Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])] >>> df.select(sort_array(df.data, asc=False).alias('r')).collect() [Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])] - """ + """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 05bfa2dd45340..4dd1ca509bf2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -409,6 +409,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), + expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), CreateStruct.registryEntry, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 942dfd4292610..d4e322d23b95b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -595,11 +595,7 @@ case class Least(children: Seq[Expression]) extends Expression { val evals = evalChildren.map(eval => s""" |${eval.code} - |if (!${eval.isNull} && (${ev.isNull} || - | ${ctx.genGreater(dataType, ev.value, eval.value)})) { - | ${ev.isNull} = false; - | ${ev.value} = ${eval.value}; - |} + |${ctx.reassignIfSmaller(dataType, ev, eval)} """.stripMargin ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c86c5beded9d0..d97611c98ac91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -699,6 +699,23 @@ class CodegenContext { case _ => s"(${genComp(dataType, c1, c2)}) > 0" } + /** + * Generates code for updating `partialResult` if `item` is smaller than it. + * + * @param dataType data type of the expressions + * @param partialResult `ExprCode` representing the partial result which has to be updated + * @param item `ExprCode` representing the new expression to evaluate for the result + */ + def reassignIfSmaller(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = { + s""" + |if (!${item.isNull} && (${partialResult.isNull} || + | ${genGreater(dataType, partialResult.value, item.value)})) { + | ${partialResult.isNull} = false; + | ${partialResult.value} = ${item.value}; + |} + """.stripMargin + } + /** * Generates code for updating `partialResult` if `item` is greater than it. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e2614a179aad8..7c87777eed47a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -288,6 +288,70 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Returns the minimum value in the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns the minimum value in the array. NULL elements are skipped.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 20, null, 3)); + 1 + """, since = "2.4.0") +case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + } else { + typeCheckResult + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val i = ctx.freshName("i") + val item = ExprCode("", + isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), + value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) + ev.copy(code = + s""" + |${childGen.code} + |boolean ${ev.isNull} = true; + |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${childGen.isNull}) { + | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) { + | ${ctx.reassignIfSmaller(dataType, ev, item)} + | } + |} + """.stripMargin) + } + + override protected def nullSafeEval(input: Any): Any = { + var min: Any = null + input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => + if (item != null && (min == null || ordering.lt(item, min))) { + min = item + } + ) + min + } + + override def dataType: DataType = child.dataType match { + case ArrayType(dt, _) => dt + case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") + } + + override def prettyName: String = "array_min" +} /** * Returns the maximum value in the array. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index a2384019533b7..5a31e3a30edd6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -106,6 +106,16 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + test("Array Min") { + checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11) + checkEvaluation( + ArrayMin(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "") + checkEvaluation(ArrayMin(Literal.create(Seq(null), ArrayType(LongType))), null) + checkEvaluation(ArrayMin(Literal.create(null, ArrayType(StringType))), null) + checkEvaluation( + ArrayMin(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 0.1234) + } + test("Array max") { checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), ArrayType(IntegerType))), 10) checkEvaluation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index daf407926dca4..642ac056bb809 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3300,6 +3300,14 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } + /** + * Returns the minimum value in the array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_min(e: Column): Column = withExpr { ArrayMin(e.expr) } + /** * Returns the maximum value in the array. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 5d5d92c84df6d..636e86baedf6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,6 +413,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("array_min function") { + val df = Seq( + Seq[Option[Int]](Some(1), Some(3), Some(2)), + Seq.empty[Option[Int]], + Seq[Option[Int]](None), + Seq[Option[Int]](None, Some(1), Some(-100)) + ).toDF("a") + + val answer = Seq(Row(1), Row(null), Row(null), Row(-100)) + + checkAnswer(df.select(array_min(df("a"))), answer) + checkAnswer(df.selectExpr("array_min(a)"), answer) + } + test("array_max function") { val df = Seq( Seq[Option[Int]](Some(1), Some(3), Some(2)), From 1cc66a072b7fd3bf140fa41596f6b18f8d1bd7b9 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 17 Apr 2018 01:59:38 -0700 Subject: [PATCH 629/774] [SPARK-23687][SS] Add a memory source for continuous processing. ## What changes were proposed in this pull request? Add a memory source for continuous processing. Note that only one of the ContinuousSuite tests is migrated to minimize the diff here. I'll submit a second PR for SPARK-23688 to change the rest and get rid of waitForRateSourceTriggers. ## How was this patch tested? unit test Author: Jose Torres Closes #20828 from jose-torres/continuousMemory. --- .../continuous/ContinuousExecution.scala | 5 +- .../sql/execution/streaming/memory.scala | 59 +++-- .../sources/ContinuousMemoryStream.scala | 211 ++++++++++++++++++ .../spark/sql/streaming/StreamTest.scala | 4 +- .../continuous/ContinuousSuite.scala | 31 +-- 5 files changed, 266 insertions(+), 44 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 1758b3844bd62..951d694355ec5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} +import org.apache.spark.sql.sources.v2 import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} @@ -317,8 +318,10 @@ class ContinuousExecution( synchronized { if (queryExecutionThread.isAlive) { commitLog.add(epoch) - val offset = offsetLog.get(epoch).get.offsets(0).get + val offset = + continuousSources(0).deserializeOffset(offsetLog.get(epoch).get.offsets(0).get.json) committedOffsets ++= Seq(continuousSources(0) -> offset) + continuousSources(0).commit(offset.asInstanceOf[v2.reader.streaming.Offset]) } else { return } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 352d4ce9fbcaa..628923d367ce7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -24,17 +24,19 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, ListBuffer} +import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.streaming.{OutputMode, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -47,16 +49,43 @@ object MemoryStream { new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) } +/** + * A base class for memory stream implementations. Supports adding data and resetting. + */ +abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends BaseStreamingSource { + protected val encoder = encoderFor[A] + protected val attributes = encoder.schema.toAttributes + + def toDS(): Dataset[A] = { + Dataset[A](sqlContext.sparkSession, logicalPlan) + } + + def toDF(): DataFrame = { + Dataset.ofRows(sqlContext.sparkSession, logicalPlan) + } + + def addData(data: A*): Offset = { + addData(data.toTraversable) + } + + def readSchema(): StructType = encoder.schema + + protected def logicalPlan: LogicalPlan + + def addData(data: TraversableOnce[A]): Offset +} + /** * A [[Source]] that produces value stored in memory as they are added by the user. This [[Source]] * is intended for use in unit tests as it can only replay data when the object is still * available. */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) - extends MicroBatchReader with SupportsScanUnsafeRow with Logging { - protected val encoder = encoderFor[A] - private val attributes = encoder.schema.toAttributes - protected val logicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) + extends MemoryStreamBase[A](sqlContext) + with MicroBatchReader with SupportsScanUnsafeRow with Logging { + + protected val logicalPlan: LogicalPlan = + StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) protected val output = logicalPlan.output /** @@ -70,7 +99,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) protected var currentOffset: LongOffset = new LongOffset(-1) @GuardedBy("this") - private var startOffset = new LongOffset(-1) + protected var startOffset = new LongOffset(-1) @GuardedBy("this") private var endOffset = new LongOffset(-1) @@ -82,18 +111,6 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) @GuardedBy("this") protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) - def toDS(): Dataset[A] = { - Dataset(sqlContext.sparkSession, logicalPlan) - } - - def toDF(): DataFrame = { - Dataset.ofRows(sqlContext.sparkSession, logicalPlan) - } - - def addData(data: A*): Offset = { - addData(data.toTraversable) - } - def addData(data: TraversableOnce[A]): Offset = { val objects = data.toSeq val rows = objects.iterator.map(d => encoder.toRow(d).copy().asInstanceOf[UnsafeRow]).toArray @@ -114,8 +131,6 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } } - override def readSchema(): StructType = encoder.schema - override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) override def getStartOffset: OffsetV2 = synchronized { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala new file mode 100644 index 0000000000000..c28919b8b729b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.{util => ju} +import java.util.Optional +import java.util.concurrent.atomic.AtomicInteger +import javax.annotation.concurrent.GuardedBy + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ListBuffer + +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization + +import org.apache.spark.SparkEnv +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.sql.{Encoder, Row, SQLContext} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.RpcUtils + +/** + * The overall strategy here is: + * * ContinuousMemoryStream maintains a list of records for each partition. addData() will + * distribute records evenly-ish across partitions. + * * RecordEndpoint is set up as an endpoint for executor-side + * ContinuousMemoryStreamDataReader instances to poll. It returns the record at the specified + * offset within the list, or null if that offset doesn't yet have a record. + */ +class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) + extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { + private implicit val formats = Serialization.formats(NoTypeHints) + private val NUM_PARTITIONS = 2 + + protected val logicalPlan = + StreamingRelationV2(this, "memory", Map(), attributes, None)(sqlContext.sparkSession) + + // ContinuousReader implementation + + @GuardedBy("this") + private val records = Seq.fill(NUM_PARTITIONS)(new ListBuffer[A]) + + @GuardedBy("this") + private var startOffset: ContinuousMemoryStreamOffset = _ + + private val recordEndpoint = new RecordEndpoint() + @volatile private var endpointRef: RpcEndpointRef = _ + + def addData(data: TraversableOnce[A]): Offset = synchronized { + // Distribute data evenly among partition lists. + data.toSeq.zipWithIndex.map { + case (item, index) => records(index % NUM_PARTITIONS) += item + } + + // The new target offset is the offset where all records in all partitions have been processed. + ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, records(i).size)).toMap) + } + + override def setStartOffset(start: Optional[Offset]): Unit = synchronized { + // Inferred initial offset is position 0 in each partition. + startOffset = start.orElse { + ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, 0)).toMap) + }.asInstanceOf[ContinuousMemoryStreamOffset] + } + + override def getStartOffset: Offset = synchronized { + startOffset + } + + override def deserializeOffset(json: String): ContinuousMemoryStreamOffset = { + ContinuousMemoryStreamOffset(Serialization.read[Map[Int, Int]](json)) + } + + override def mergeOffsets(offsets: Array[PartitionOffset]): ContinuousMemoryStreamOffset = { + ContinuousMemoryStreamOffset( + offsets.map { + case ContinuousMemoryStreamPartitionOffset(part, num) => (part, num) + }.toMap + ) + } + + override def createDataReaderFactories(): ju.List[DataReaderFactory[Row]] = { + synchronized { + val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id" + endpointRef = + recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) + + startOffset.partitionNums.map { + case (part, index) => + new ContinuousMemoryStreamDataReaderFactory( + endpointName, part, index): DataReaderFactory[Row] + }.toList.asJava + } + } + + override def stop(): Unit = { + if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef) + } + + override def commit(end: Offset): Unit = {} + + // ContinuousReadSupport implementation + // This is necessary because of how StreamTest finds the source for AddDataMemory steps. + def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): ContinuousReader = { + this + } + + /** + * Endpoint for executors to poll for records. + */ + private class RecordEndpoint extends ThreadSafeRpcEndpoint { + override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case GetRecord(ContinuousMemoryStreamPartitionOffset(part, index)) => + ContinuousMemoryStream.this.synchronized { + val buf = records(part) + val record = if (buf.size <= index) None else Some(buf(index)) + + context.reply(record.map(Row(_))) + } + } + } +} + +object ContinuousMemoryStream { + case class GetRecord(offset: ContinuousMemoryStreamPartitionOffset) + protected val memoryStreamId = new AtomicInteger(0) + + def apply[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = + new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) +} + +/** + * Data reader factory for continuous memory stream. + */ +class ContinuousMemoryStreamDataReaderFactory( + driverEndpointName: String, + partition: Int, + startOffset: Int) extends DataReaderFactory[Row] { + override def createDataReader: ContinuousMemoryStreamDataReader = + new ContinuousMemoryStreamDataReader(driverEndpointName, partition, startOffset) +} + +/** + * Data reader for continuous memory stream. + * + * Polls the driver endpoint for new records. + */ +class ContinuousMemoryStreamDataReader( + driverEndpointName: String, + partition: Int, + startOffset: Int) extends ContinuousDataReader[Row] { + private val endpoint = RpcUtils.makeDriverRef( + driverEndpointName, + SparkEnv.get.conf, + SparkEnv.get.rpcEnv) + + private var currentOffset = startOffset + private var current: Option[Row] = None + + override def next(): Boolean = { + current = None + while (current.isEmpty) { + Thread.sleep(10) + current = endpoint.askSync[Option[Row]]( + GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset))) + } + currentOffset += 1 + true + } + + override def get(): Row = current.get + + override def close(): Unit = {} + + override def getOffset: ContinuousMemoryStreamPartitionOffset = + ContinuousMemoryStreamPartitionOffset(partition, currentOffset) +} + +case class ContinuousMemoryStreamOffset(partitionNums: Map[Int, Int]) + extends Offset { + private implicit val formats = Serialization.formats(NoTypeHints) + override def json(): String = Serialization.write(partitionNums) +} + +case class ContinuousMemoryStreamPartitionOffset(partition: Int, numProcessed: Int) + extends PartitionOffset diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 00741d660dd2d..af0268fa47871 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -99,7 +99,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be * been processed. */ object AddData { - def apply[A](source: MemoryStream[A], data: A*): AddDataMemory[A] = + def apply[A](source: MemoryStreamBase[A], data: A*): AddDataMemory[A] = AddDataMemory(source, data) } @@ -131,7 +131,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def runAction(): Unit } - case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData { + case class AddDataMemory[A](source: MemoryStreamBase[A], data: Seq[A]) extends AddData { override def toString: String = s"AddData to $source: ${data.mkString(",")}" override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index ef74efef156d5..c318b951ff992 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.test.TestSparkSession @@ -53,32 +54,24 @@ class ContinuousSuiteBase extends StreamTest { // A continuous trigger that will only fire the initial time for the duration of a test. // This allows clean testing with manual epoch advancement. protected val longContinuousTrigger = Trigger.Continuous("1 hour") + + override protected val defaultTrigger = Trigger.Continuous(100) + override protected val defaultUseV2Sink = true } class ContinuousSuite extends ContinuousSuiteBase { import testImplicits._ - test("basic rate source") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) + test("basic") { + val input = ContinuousMemoryStream[Int] - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 10).map(Row(_))), + testStream(input.toDF())( + AddData(input, 0, 1, 2), + CheckAnswer(0, 1, 2), StopStream, - StartStream(longContinuousTrigger), - AwaitEpoch(2), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 20).map(Row(_))), - StopStream) + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(0, 1, 2, 3, 4, 5)) } test("map") { From 05ae74778a10fbdd7f2cbf7742de7855966b7d35 Mon Sep 17 00:00:00 2001 From: Efim Poberezkin Date: Tue, 17 Apr 2018 04:13:17 -0700 Subject: [PATCH 630/774] [SPARK-23747][STRUCTURED STREAMING] Add EpochCoordinator unit tests ## What changes were proposed in this pull request? Unit tests for EpochCoordinator that test correct sequencing of committed epochs. Several tests are ignored since they test functionality implemented in SPARK-23503 which is not yet merged, otherwise they fail. Author: Efim Poberezkin Closes #20983 from efimpoberezkin/pr/EpochCoordinator-tests. --- .../continuous/EpochCoordinatorSuite.scala | 224 ++++++++++++++++++ 1 file changed, 224 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala new file mode 100644 index 0000000000000..99e30561f81d5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous + +import org.mockito.InOrder +import org.mockito.Matchers.{any, eq => eqTo} +import org.mockito.Mockito._ +import org.scalatest.BeforeAndAfterEach +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark._ +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.LocalSparkSession +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.test.TestSparkSession + +class EpochCoordinatorSuite + extends SparkFunSuite + with LocalSparkSession + with MockitoSugar + with BeforeAndAfterEach { + + private var epochCoordinator: RpcEndpointRef = _ + + private var writer: StreamWriter = _ + private var query: ContinuousExecution = _ + private var orderVerifier: InOrder = _ + + override def beforeEach(): Unit = { + val reader = mock[ContinuousReader] + writer = mock[StreamWriter] + query = mock[ContinuousExecution] + orderVerifier = inOrder(writer, query) + + spark = new TestSparkSession() + + epochCoordinator + = EpochCoordinatorRef.create(writer, reader, query, "test", 1, spark, SparkEnv.get) + } + + test("single epoch") { + setWriterPartitions(3) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + commitPartitionEpoch(2, 1) + reportPartitionOffset(0, 1) + reportPartitionOffset(1, 1) + + // Here and in subsequent tests this is called to make a synchronous call to EpochCoordinator + // so that mocks would have been acted upon by the time verification happens + makeSynchronousCall() + + verifyCommit(1) + } + + test("single epoch, all but one writer partition has committed") { + setWriterPartitions(3) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + reportPartitionOffset(0, 1) + reportPartitionOffset(1, 1) + + makeSynchronousCall() + + verifyNoCommitFor(1) + } + + test("single epoch, all but one reader partition has reported an offset") { + setWriterPartitions(3) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + commitPartitionEpoch(2, 1) + reportPartitionOffset(0, 1) + + makeSynchronousCall() + + verifyNoCommitFor(1) + } + + test("consequent epochs, messages for epoch (k + 1) arrive after messages for epoch k") { + setWriterPartitions(2) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + reportPartitionOffset(0, 1) + reportPartitionOffset(1, 1) + + commitPartitionEpoch(0, 2) + commitPartitionEpoch(1, 2) + reportPartitionOffset(0, 2) + reportPartitionOffset(1, 2) + + makeSynchronousCall() + + verifyCommitsInOrderOf(List(1, 2)) + } + + ignore("consequent epochs, a message for epoch k arrives after messages for epoch (k + 1)") { + setWriterPartitions(2) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + reportPartitionOffset(0, 1) + + commitPartitionEpoch(0, 2) + commitPartitionEpoch(1, 2) + reportPartitionOffset(0, 2) + reportPartitionOffset(1, 2) + + // Message that arrives late + reportPartitionOffset(1, 1) + + makeSynchronousCall() + + verifyCommitsInOrderOf(List(1, 2)) + } + + ignore("several epochs, messages arrive in order 1 -> 3 -> 4 -> 2") { + setWriterPartitions(1) + setReaderPartitions(1) + + commitPartitionEpoch(0, 1) + reportPartitionOffset(0, 1) + + commitPartitionEpoch(0, 3) + reportPartitionOffset(0, 3) + + commitPartitionEpoch(0, 4) + reportPartitionOffset(0, 4) + + commitPartitionEpoch(0, 2) + reportPartitionOffset(0, 2) + + makeSynchronousCall() + + verifyCommitsInOrderOf(List(1, 2, 3, 4)) + } + + ignore("several epochs, messages arrive in order 1 -> 3 -> 5 -> 4 -> 2") { + setWriterPartitions(1) + setReaderPartitions(1) + + commitPartitionEpoch(0, 1) + reportPartitionOffset(0, 1) + + commitPartitionEpoch(0, 3) + reportPartitionOffset(0, 3) + + commitPartitionEpoch(0, 5) + reportPartitionOffset(0, 5) + + commitPartitionEpoch(0, 4) + reportPartitionOffset(0, 4) + + commitPartitionEpoch(0, 2) + reportPartitionOffset(0, 2) + + makeSynchronousCall() + + verifyCommitsInOrderOf(List(1, 2, 3, 4, 5)) + } + + private def setWriterPartitions(numPartitions: Int): Unit = { + epochCoordinator.askSync[Unit](SetWriterPartitions(numPartitions)) + } + + private def setReaderPartitions(numPartitions: Int): Unit = { + epochCoordinator.askSync[Unit](SetReaderPartitions(numPartitions)) + } + + private def commitPartitionEpoch(partitionId: Int, epoch: Long): Unit = { + val dummyMessage: WriterCommitMessage = mock[WriterCommitMessage] + epochCoordinator.send(CommitPartitionEpoch(partitionId, epoch, dummyMessage)) + } + + private def reportPartitionOffset(partitionId: Int, epoch: Long): Unit = { + val dummyOffset: PartitionOffset = mock[PartitionOffset] + epochCoordinator.send(ReportPartitionOffset(partitionId, epoch, dummyOffset)) + } + + private def makeSynchronousCall(): Unit = { + epochCoordinator.askSync[Long](GetCurrentEpoch) + } + + private def verifyCommit(epoch: Long): Unit = { + orderVerifier.verify(writer).commit(eqTo(epoch), any()) + orderVerifier.verify(query).commit(epoch) + } + + private def verifyNoCommitFor(epoch: Long): Unit = { + verify(writer, never()).commit(eqTo(epoch), any()) + verify(query, never()).commit(epoch) + } + + private def verifyCommitsInOrderOf(epochs: Seq[Long]): Unit = { + epochs.foreach(verifyCommit) + } +} From 30ffb53cad84283b4f7694bfd60bdd7e1101b04e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 17 Apr 2018 15:09:36 +0200 Subject: [PATCH 631/774] [SPARK-23875][SQL] Add IndexedSeq wrapper for ArrayData ## What changes were proposed in this pull request? We don't have a good way to sequentially access `UnsafeArrayData` with a common interface such as `Seq`. An example is `MapObject` where we need to access several sequence collection types together. But `UnsafeArrayData` doesn't implement `ArrayData.array`. Calling `toArray` will copy the entire array. We can provide an `IndexedSeq` wrapper for `ArrayData`, so we can avoid copying the entire array. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #20984 from viirya/SPARK-23875. --- .../expressions/objects/objects.scala | 2 +- .../spark/sql/catalyst/util/ArrayData.scala | 30 +++++- .../util/ArrayDataIndexedSeqSuite.scala | 100 ++++++++++++++++++ 3 files changed, 130 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 77802e89e942b..72b202b3a5020 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -708,7 +708,7 @@ case class MapObjects private( } } case ArrayType(et, _) => - _.asInstanceOf[ArrayData].array + _.asInstanceOf[ArrayData].toSeq[Any](et) } private lazy val mapElements: Seq[_] => Any = customCollectionCls match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala index 9beef41d639f3..2cf59d567c08c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.util import scala.reflect.ClassTag +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, UnsafeArrayData} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types._ object ArrayData { def toArrayData(input: Any): ArrayData = input match { @@ -42,6 +43,9 @@ abstract class ArrayData extends SpecializedGetters with Serializable { def array: Array[Any] + def toSeq[T](dataType: DataType): IndexedSeq[T] = + new ArrayDataIndexedSeq[T](this, dataType) + def setNullAt(i: Int): Unit def update(i: Int, value: Any): Unit @@ -164,3 +168,27 @@ abstract class ArrayData extends SpecializedGetters with Serializable { } } } + +/** + * Implements an `IndexedSeq` interface for `ArrayData`. Notice that if the original `ArrayData` + * is a primitive array and contains null elements, it is better to ask for `IndexedSeq[Any]`, + * instead of `IndexedSeq[Int]`, in order to keep the null elements. + */ +class ArrayDataIndexedSeq[T](arrayData: ArrayData, dataType: DataType) extends IndexedSeq[T] { + + private val accessor: (SpecializedGetters, Int) => Any = InternalRow.getAccessor(dataType) + + override def apply(idx: Int): T = + if (0 <= idx && idx < arrayData.numElements()) { + if (arrayData.isNullAt(idx)) { + null.asInstanceOf[T] + } else { + accessor(arrayData, idx).asInstanceOf[T] + } + } else { + throw new IndexOutOfBoundsException( + s"Index $idx must be between 0 and the length of the ArrayData.") + } + + override def length: Int = arrayData.numElements() +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala new file mode 100644 index 0000000000000..6400898343ae7 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.{FromUnsafeProjection, UnsafeArrayData, UnsafeProjection} +import org.apache.spark.sql.types._ + +class ArrayDataIndexedSeqSuite extends SparkFunSuite { + private def compArray(arrayData: ArrayData, elementDt: DataType, array: Array[Any]): Unit = { + assert(arrayData.numElements == array.length) + array.zipWithIndex.map { case (e, i) => + if (e != null) { + elementDt match { + // For NaN, etc. + case FloatType | DoubleType => assert(arrayData.get(i, elementDt).equals(e)) + case _ => assert(arrayData.get(i, elementDt) === e) + } + } else { + assert(arrayData.isNullAt(i)) + } + } + + val seq = arrayData.toSeq[Any](elementDt) + array.zipWithIndex.map { case (e, i) => + if (e != null) { + elementDt match { + // For Nan, etc. + case FloatType | DoubleType => assert(seq(i).equals(e)) + case _ => assert(seq(i) === e) + } + } else { + assert(seq(i) == null) + } + } + + intercept[IndexOutOfBoundsException] { + seq(-1) + }.getMessage().contains("must be between 0 and the length of the ArrayData.") + + intercept[IndexOutOfBoundsException] { + seq(seq.length) + }.getMessage().contains("must be between 0 and the length of the ArrayData.") + } + + private def testArrayData(): Unit = { + val elementTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, + DoubleType, DecimalType.USER_DEFAULT, StringType, BinaryType, DateType, TimestampType, + CalendarIntervalType, new ExamplePointUDT()) + val arrayTypes = elementTypes.flatMap { elementType => + Seq(ArrayType(elementType, containsNull = false), ArrayType(elementType, containsNull = true)) + } + val random = new Random(100) + arrayTypes.foreach { dt => + val schema = StructType(StructField("col_1", dt, nullable = false) :: Nil) + val row = RandomDataGenerator.randomRow(random, schema) + val rowConverter = RowEncoder(schema) + val internalRow = rowConverter.toRow(row) + + val unsafeRowConverter = UnsafeProjection.create(schema) + val safeRowConverter = FromUnsafeProjection(schema) + + val unsafeRow = unsafeRowConverter(internalRow) + val safeRow = safeRowConverter(unsafeRow) + + val genericArrayData = safeRow.getArray(0).asInstanceOf[GenericArrayData] + val unsafeArrayData = unsafeRow.getArray(0).asInstanceOf[UnsafeArrayData] + + val elementType = dt.elementType + test("ArrayDataIndexedSeq - UnsafeArrayData - " + dt.toString) { + compArray(unsafeArrayData, elementType, unsafeArrayData.toArray[Any](elementType)) + } + + test("ArrayDataIndexedSeq - GenericArrayData - " + dt.toString) { + compArray(genericArrayData, elementType, genericArrayData.toArray[Any](elementType)) + } + } + } + + testArrayData() +} From 0a9172a05e604a4a94adbb9208c8c02362afca00 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 17 Apr 2018 21:45:20 +0800 Subject: [PATCH 632/774] [SPARK-23835][SQL] Add not-null check to Tuples' arguments deserialization ## What changes were proposed in this pull request? There was no check on nullability for arguments of `Tuple`s. This could lead to have weird behavior when a null value had to be deserialized into a non-nullable Scala object: in those cases, the `null` got silently transformed in a valid value (like `-1` for `Int`), corresponding to the default value we are using in the SQL codebase. This situation was very likely to happen when deserializing to a Tuple of primitive Scala types (like Double, Int, ...). The PR adds the `AssertNotNull` to arguments of tuples which have been asked to be converted to non-nullable types. ## How was this patch tested? added UT Author: Marco Gaido Closes #20976 from mgaido91/SPARK-23835. --- .../sql/kafka010/KafkaContinuousSinkSuite.scala | 6 +++--- .../apache/spark/sql/kafka010/KafkaSinkSuite.scala | 2 +- .../spark/sql/catalyst/ScalaReflection.scala | 14 +++++++------- .../spark/sql/catalyst/ScalaReflectionSuite.scala | 12 +++++++++++- .../scala/org/apache/spark/sql/DatasetSuite.scala | 5 +++++ 5 files changed, 27 insertions(+), 12 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala index fc890a0cfdac3..ddfc0c1a4be2d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -79,7 +79,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { val reader = createKafkaReader(topic) .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] + .as[(Option[Int], Int)] .map(_._2) try { @@ -119,7 +119,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { val reader = createKafkaReader(topic) .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] + .as[(Option[Int], Int)] .map(_._2) try { @@ -167,7 +167,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { val reader = createKafkaReader(topic) .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .selectExpr("CAST(key AS INT)", "CAST(value AS INT)") - .as[(Int, Int)] + .as[(Option[Int], Int)] .map(_._2) try { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 42f8b4c7657e2..7079ac6453ffc 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -138,7 +138,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { val reader = createKafkaReader(topic) .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] + .as[(Option[Int], Int)] .map(_._2) try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 1aae3aea3a31a..e4274aaa9727e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -382,22 +382,22 @@ object ScalaReflection extends ScalaReflection { val clsName = getClassNameFromType(fieldType) val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath // For tuples, we based grab the inner fields by ordinal instead of name. - if (cls.getName startsWith "scala.Tuple") { + val constructor = if (cls.getName startsWith "scala.Tuple") { deserializerFor( fieldType, Some(addToPathOrdinal(i, dataType, newTypePath)), newTypePath) } else { - val constructor = deserializerFor( + deserializerFor( fieldType, Some(addToPath(fieldName, dataType, newTypePath)), newTypePath) + } - if (!nullable) { - AssertNotNull(constructor, newTypePath) - } else { - constructor - } + if (!nullable) { + AssertNotNull(constructor, newTypePath) + } else { + constructor } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 8c3db48a01f12..353b8344658f2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow, UpCast} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Literal, SpecificInternalRow, UpCast} import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -365,4 +365,14 @@ class ScalaReflectionSuite extends SparkFunSuite { StructField("_2", NullType, nullable = true))), nullable = true)) } + + test("SPARK-23835: add null check to non-nullable types in Tuples") { + def numberOfCheckedArguments(deserializer: Expression): Int = { + assert(deserializer.isInstanceOf[NewInstance]) + deserializer.asInstanceOf[NewInstance].arguments.count(_.isInstanceOf[AssertNotNull]) + } + assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]) == 2) + assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1) + assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 9b745befcb611..e0f4d2ba685e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1453,6 +1453,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val group2 = cached.groupBy("x").agg(min(col("z")) as "value") checkAnswer(group1.union(group2), Row(4, 5) :: Row(1, 2) :: Row(4, 6) :: Row(1, 3) :: Nil) } + + test("SPARK-23835: null primitive data type should throw NullPointerException") { + val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS() + intercept[NullPointerException](ds.as[(Int, Int)].collect()) + } } case class TestDataUnion(x: Int, y: Int, z: Int) From ed4101d29f50d54fd7846421e4c00e9ecd3599d0 Mon Sep 17 00:00:00 2001 From: jinxing Date: Tue, 17 Apr 2018 21:52:33 +0800 Subject: [PATCH 633/774] [SPARK-22676] Avoid iterating all partition paths when spark.sql.hive.verifyPartitionPath=true MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In current code, it will scanning all partition paths when spark.sql.hive.verifyPartitionPath=true. e.g. table like below: ``` CREATE TABLE `test`( `id` int, `age` int, `name` string) PARTITIONED BY ( `A` string, `B` string) load data local inpath '/tmp/data0' into table test partition(A='00', B='00') load data local inpath '/tmp/data1' into table test partition(A='01', B='01') load data local inpath '/tmp/data2' into table test partition(A='10', B='10') load data local inpath '/tmp/data3' into table test partition(A='11', B='11') ``` If I query with SQL – "select * from test where A='00' and B='01' ", current code will scan all partition paths including '/data/A=00/B=00', '/data/A=00/B=00', '/data/A=01/B=01', '/data/A=10/B=10', '/data/A=11/B=11'. It costs much time and memory cost. This pr proposes to avoid iterating all partition paths. Add a config `spark.files.ignoreMissingFiles` and ignore the `file not found` when `getPartitions/compute`(for hive table scan). This is much like the logic brought by `spark.sql.files.ignoreMissingFiles`(which is for datasource scan). ## How was this patch tested? UT Author: jinxing Closes #19868 from jinxing64/SPARK-22676. --- .../spark/internal/config/package.scala | 6 ++ .../org/apache/spark/rdd/HadoopRDD.scala | 43 +++++++++--- .../org/apache/spark/rdd/NewHadoopRDD.scala | 45 ++++++++---- .../scala/org/apache/spark/FileSuite.scala | 69 ++++++++++++++++++- .../apache/spark/sql/internal/SQLConf.scala | 3 +- .../spark/sql/hive/QueryPartitionSuite.scala | 40 +++++++++++ 6 files changed, 181 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 407545aa4a47a..99d779fb600e8 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -301,6 +301,12 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val IGNORE_MISSING_FILES = ConfigBuilder("spark.files.ignoreMissingFiles") + .doc("Whether to ignore missing files. If true, the Spark jobs will continue to run when " + + "encountering missing files and the contents that have been read will still be returned.") + .booleanConf + .createWithDefault(false) + private[spark] val APP_CALLER_CONTEXT = ConfigBuilder("spark.log.callerContext") .stringConf .createOptional diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 2480559a41b7a..44895abc7bd4d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.io.IOException +import java.io.{FileNotFoundException, IOException} import java.text.SimpleDateFormat import java.util.{Date, Locale} @@ -28,6 +28,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.mapred._ import org.apache.hadoop.mapred.lib.CombineFileSplit import org.apache.hadoop.mapreduce.TaskType +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.util.ReflectionUtils import org.apache.spark._ @@ -134,6 +135,8 @@ class HadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) + private val ignoreMissingFiles = sparkContext.conf.get(IGNORE_MISSING_FILES) + private val ignoreEmptySplits = sparkContext.conf.get(HADOOP_RDD_IGNORE_EMPTY_SPLITS) // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. @@ -197,17 +200,24 @@ class HadoopRDD[K, V]( val jobConf = getJobConf() // add the credentials here as this can be called before SparkContext initialized SparkHadoopUtil.get.addCredentials(jobConf) - val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions) - val inputSplits = if (ignoreEmptySplits) { - allInputSplits.filter(_.getLength > 0) - } else { - allInputSplits - } - val array = new Array[Partition](inputSplits.size) - for (i <- 0 until inputSplits.size) { - array(i) = new HadoopPartition(id, i, inputSplits(i)) + try { + val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions) + val inputSplits = if (ignoreEmptySplits) { + allInputSplits.filter(_.getLength > 0) + } else { + allInputSplits + } + val array = new Array[Partition](inputSplits.size) + for (i <- 0 until inputSplits.size) { + array(i) = new HadoopPartition(id, i, inputSplits(i)) + } + array + } catch { + case e: InvalidInputException if ignoreMissingFiles => + logWarning(s"${jobConf.get(FileInputFormat.INPUT_DIR)} doesn't exist and no" + + s" partitions returned from this path.", e) + Array.empty[Partition] } - array } override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { @@ -256,6 +266,12 @@ class HadoopRDD[K, V]( try { inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: ${split.inputSplit}", e) + finished = true + null + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e case e: IOException if ignoreCorruptFiles => logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e) finished = true @@ -276,6 +292,11 @@ class HadoopRDD[K, V]( try { finished = !reader.next(key, value) } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: ${split.inputSplit}", e) + finished = true + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e case e: IOException if ignoreCorruptFiles => logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e) finished = true diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index e4dd1b6a82498..ff66a04859d10 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.io.IOException +import java.io.{FileNotFoundException, IOException} import java.text.SimpleDateFormat import java.util.{Date, Locale} @@ -28,7 +28,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} +import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileInputFormat, FileSplit, InvalidInputException} import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl} import org.apache.spark._ @@ -90,6 +90,8 @@ class NewHadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) + private val ignoreMissingFiles = sparkContext.conf.get(IGNORE_MISSING_FILES) + private val ignoreEmptySplits = sparkContext.conf.get(HADOOP_RDD_IGNORE_EMPTY_SPLITS) def getConf: Configuration = { @@ -124,17 +126,25 @@ class NewHadoopRDD[K, V]( configurable.setConf(_conf) case _ => } - val allRowSplits = inputFormat.getSplits(new JobContextImpl(_conf, jobId)).asScala - val rawSplits = if (ignoreEmptySplits) { - allRowSplits.filter(_.getLength > 0) - } else { - allRowSplits - } - val result = new Array[Partition](rawSplits.size) - for (i <- 0 until rawSplits.size) { - result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + try { + val allRowSplits = inputFormat.getSplits(new JobContextImpl(_conf, jobId)).asScala + val rawSplits = if (ignoreEmptySplits) { + allRowSplits.filter(_.getLength > 0) + } else { + allRowSplits + } + val result = new Array[Partition](rawSplits.size) + for (i <- 0 until rawSplits.size) { + result(i) = + new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + result + } catch { + case e: InvalidInputException if ignoreMissingFiles => + logWarning(s"${_conf.get(FileInputFormat.INPUT_DIR)} doesn't exist and no" + + s" partitions returned from this path.", e) + Array.empty[Partition] } - result } override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { @@ -189,6 +199,12 @@ class NewHadoopRDD[K, V]( _reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) _reader } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: ${split.serializableHadoopSplit}", e) + finished = true + null + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e case e: IOException if ignoreCorruptFiles => logWarning( s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}", @@ -213,6 +229,11 @@ class NewHadoopRDD[K, V]( try { finished = !reader.nextKeyValue } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: ${split.serializableHadoopSplit}", e) + finished = true + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e case e: IOException if ignoreCorruptFiles => logWarning( s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}", diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 55a9122cf9026..a441b9c8ab97a 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -23,6 +23,7 @@ import java.util.zip.GZIPOutputStream import scala.io.Source +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io._ import org.apache.hadoop.io.compress.DefaultCodec @@ -32,7 +33,7 @@ import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInp import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} import org.apache.spark.internal.config._ -import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD} +import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -596,4 +597,70 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { actualPartitionNum = 5, expectedPartitionNum = 2) } + + test("spark.files.ignoreMissingFiles should work both HadoopRDD and NewHadoopRDD") { + // "file not found" can happen both when getPartitions or compute in HadoopRDD/NewHadoopRDD, + // We test both cases here. + + val deletedPath = new Path(tempDir.getAbsolutePath, "test-data-1") + val fs = deletedPath.getFileSystem(new Configuration()) + fs.delete(deletedPath, true) + intercept[FileNotFoundException](fs.open(deletedPath)) + + def collectRDDAndDeleteFileBeforeCompute(newApi: Boolean): Array[_] = { + val dataPath = new Path(tempDir.getAbsolutePath, "test-data-2") + val writer = new OutputStreamWriter(new FileOutputStream(new File(dataPath.toString))) + writer.write("hello\n") + writer.write("world\n") + writer.close() + val rdd = if (newApi) { + sc.newAPIHadoopFile(dataPath.toString, classOf[NewTextInputFormat], + classOf[LongWritable], classOf[Text]) + } else { + sc.textFile(dataPath.toString) + } + rdd.partitions + fs.delete(dataPath, true) + // Exception happens when initialize record reader in HadoopRDD/NewHadoopRDD.compute + // because partitions' info already cached. + rdd.collect() + } + + // collect HadoopRDD and NewHadoopRDD when spark.files.ignoreMissingFiles=false by default. + sc = new SparkContext("local", "test") + intercept[org.apache.hadoop.mapred.InvalidInputException] { + // Exception happens when HadoopRDD.getPartitions + sc.textFile(deletedPath.toString).collect() + } + + var e = intercept[SparkException] { + collectRDDAndDeleteFileBeforeCompute(false) + } + assert(e.getCause.isInstanceOf[java.io.FileNotFoundException]) + + intercept[org.apache.hadoop.mapreduce.lib.input.InvalidInputException] { + // Exception happens when NewHadoopRDD.getPartitions + sc.newAPIHadoopFile(deletedPath.toString, classOf[NewTextInputFormat], + classOf[LongWritable], classOf[Text]).collect + } + + e = intercept[SparkException] { + collectRDDAndDeleteFileBeforeCompute(true) + } + assert(e.getCause.isInstanceOf[java.io.FileNotFoundException]) + + sc.stop() + + // collect HadoopRDD and NewHadoopRDD when spark.files.ignoreMissingFiles=true. + val conf = new SparkConf().set(IGNORE_MISSING_FILES, true) + sc = new SparkContext("local", "test", conf) + assert(sc.textFile(deletedPath.toString).collect().isEmpty) + + assert(collectRDDAndDeleteFileBeforeCompute(false).isEmpty) + + assert(sc.newAPIHadoopFile(deletedPath.toString, classOf[NewTextInputFormat], + classOf[LongWritable], classOf[Text]).collect().isEmpty) + + assert(collectRDDAndDeleteFileBeforeCompute(true).isEmpty) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0dc47bfe075d0..3729bd5293eca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -437,7 +437,8 @@ object SQLConf { val HIVE_VERIFY_PARTITION_PATH = buildConf("spark.sql.hive.verifyPartitionPath") .doc("When true, check all the partition paths under the table\'s root directory " + - "when reading data stored in HDFS.") + "when reading data stored in HDFS. This configuration will be deprecated in the future " + + "releases and replaced by spark.files.ignoreMissingFiles.") .booleanConf .createWithDefault(false) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index b2dc401ce1efc..78156b17fb43b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -23,6 +23,7 @@ import java.sql.Timestamp import com.google.common.io.Files import org.apache.hadoop.fs.FileSystem +import org.apache.spark.internal.config._ import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf @@ -70,6 +71,45 @@ class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingl } } + test("Replace spark.sql.hive.verifyPartitionPath by spark.files.ignoreMissingFiles") { + withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "false")) { + sparkContext.conf.set(IGNORE_MISSING_FILES.key, "true") + val testData = sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.createOrReplaceTempView("testData") + + val tmpDir = Files.createTempDir() + // create the table for test + sql(s"CREATE TABLE table_with_partition(key int,value string) " + + s"PARTITIONED by (ds string) location '${tmpDir.toURI}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + + "SELECT key,value FROM testData") + + // test for the exist path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toDF.collect ++ testData.toDF.collect + ++ testData.toDF.collect ++ testData.toDF.collect) + + // delete the path of one partition + tmpDir.listFiles + .find { f => f.isDirectory && f.getName().startsWith("ds=") } + .foreach { f => Utils.deleteRecursively(f) } + + // test for after delete the path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) + + sql("DROP TABLE IF EXISTS table_with_partition") + sql("DROP TABLE IF EXISTS createAndInsertTest") + } + } + test("SPARK-21739: Cast expression should initialize timezoneId") { withTable("table_with_timestamp_partition") { sql("CREATE TABLE table_with_timestamp_partition(value int) PARTITIONED BY (ts TIMESTAMP)") From 3990daaf3b6ca2c5a9f7790030096262efb12cb2 Mon Sep 17 00:00:00 2001 From: jinxing Date: Tue, 17 Apr 2018 08:55:01 -0500 Subject: [PATCH 634/774] [SPARK-23948] Trigger mapstage's job listener in submitMissingTasks ## What changes were proposed in this pull request? SparkContext submitted a map stage from `submitMapStage` to `DAGScheduler`, `markMapStageJobAsFinished` is called only in (https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L933 and https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L1314); But think about below scenario: 1. stage0 and stage1 are all `ShuffleMapStage` and stage1 depends on stage0; 2. We submit stage1 by `submitMapStage`; 3. When stage 1 running, `FetchFailed` happened, stage0 and stage1 got resubmitted as stage0_1 and stage1_1; 4. When stage0_1 running, speculated tasks in old stage1 come as succeeded, but stage1 is not inside `runningStages`. So even though all splits(including the speculated tasks) in stage1 succeeded, job listener in stage1 will not be called; 5. stage0_1 finished, stage1_1 starts running. When `submitMissingTasks`, there is no missing tasks. But in current code, job listener is not triggered. We should call the job listener for map stage in `5`. ## How was this patch tested? Not added yet. Author: jinxing Closes #21019 from jinxing64/SPARK-23948. --- .../apache/spark/scheduler/DAGScheduler.scala | 33 ++++++------ .../spark/scheduler/DAGSchedulerSuite.scala | 52 +++++++++++++++++++ 2 files changed, 70 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8c46a84323392..78b6b34b5d2bb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1092,17 +1092,16 @@ class DAGScheduler( // the stage as completed here in case there are no tasks to run markStageAsFinished(stage, None) - val debugString = stage match { + stage match { case stage: ShuffleMapStage => - s"Stage ${stage} is actually done; " + - s"(available: ${stage.isAvailable}," + - s"available outputs: ${stage.numAvailableOutputs}," + - s"partitions: ${stage.numPartitions})" + logDebug(s"Stage ${stage} is actually done; " + + s"(available: ${stage.isAvailable}," + + s"available outputs: ${stage.numAvailableOutputs}," + + s"partitions: ${stage.numPartitions})") + markMapStageJobsAsFinished(stage) case stage : ResultStage => - s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})" + logDebug(s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})") } - logDebug(debugString) - submitWaitingChildStages(stage) } } @@ -1307,13 +1306,7 @@ class DAGScheduler( shuffleStage.findMissingPartitions().mkString(", ")) submitStage(shuffleStage) } else { - // Mark any map-stage jobs waiting on this stage as finished - if (shuffleStage.mapStageJobs.nonEmpty) { - val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep) - for (job <- shuffleStage.mapStageJobs) { - markMapStageJobAsFinished(job, stats) - } - } + markMapStageJobsAsFinished(shuffleStage) submitWaitingChildStages(shuffleStage) } } @@ -1433,6 +1426,16 @@ class DAGScheduler( } } + private[scheduler] def markMapStageJobsAsFinished(shuffleStage: ShuffleMapStage): Unit = { + // Mark any map-stage jobs waiting on this stage as finished + if (shuffleStage.isAvailable && shuffleStage.mapStageJobs.nonEmpty) { + val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep) + for (job <- shuffleStage.mapStageJobs) { + markMapStageJobAsFinished(job, stats) + } + } + } + /** * Responds to an executor being lost. This is called inside the event loop, so it assumes it can * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index d812b5bd92c1b..8b6ec37625eec 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2146,6 +2146,58 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assertDataStructuresEmpty() } + test("Trigger mapstage's job listener in submitMissingTasks") { + val rdd1 = new MyRDD(sc, 2, Nil) + val dep1 = new ShuffleDependency(rdd1, new HashPartitioner(2)) + val rdd2 = new MyRDD(sc, 2, List(dep1), tracker = mapOutputTracker) + val dep2 = new ShuffleDependency(rdd2, new HashPartitioner(2)) + + val listener1 = new SimpleListener + val listener2 = new SimpleListener + + submitMapStage(dep1, listener1) + submitMapStage(dep2, listener2) + + // Complete the stage0. + assert(taskSets(0).stageId === 0) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", rdd1.partitions.length)), + (Success, makeMapStatus("hostB", rdd1.partitions.length)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(listener1.results.size === 1) + + // When attempting stage1, trigger a fetch failure. + assert(taskSets(1).stageId === 1) + complete(taskSets(1), Seq( + (Success, makeMapStatus("hostC", rdd2.partitions.length)), + (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null))) + scheduler.resubmitFailedStages() + // Stage1 listener should not have a result yet + assert(listener2.results.size === 0) + + // Speculative task succeeded in stage1. + runEvent(makeCompletionEvent( + taskSets(1).tasks(1), + Success, + makeMapStatus("hostD", rdd2.partitions.length))) + // stage1 listener still should not have a result, though there's no missing partitions + // in it. Because stage1 has been failed and is not inside `runningStages` at this moment. + assert(listener2.results.size === 0) + + // Stage0 should now be running as task set 2; make its task succeed + assert(taskSets(2).stageId === 0) + complete(taskSets(2), Seq( + (Success, makeMapStatus("hostC", rdd2.partitions.length)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === + Set(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + + // After stage0 is finished, stage1 will be submitted and found there is no missing + // partitions in it. Then listener got triggered. + assert(listener2.results.size === 1) + assertDataStructuresEmpty() + } + /** * In this test, we run a map stage where one of the executors fails but we still receive a * "zombie" complete message from that executor. We want to make sure the stage is not reported From f39e82ce150b6a7ea038e6858ba7adbaba3cad88 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 18 Apr 2018 00:35:44 +0800 Subject: [PATCH 635/774] [SPARK-23986][SQL] freshName can generate non-unique names ## What changes were proposed in this pull request? We are using `CodegenContext.freshName` to get a unique name for any new variable we are adding. Unfortunately, this method currently fails to create a unique name when we request more than one instance of variables with starting name `name1` and an instance with starting name `name11`. The PR changes the way a new name is generated by `CodegenContext.freshName` so that we generate unique names in this scenario too. ## How was this patch tested? added UT Author: Marco Gaido Closes #21080 from mgaido91/SPARK-23986. --- .../catalyst/expressions/codegen/CodeGenerator.scala | 11 +++-------- .../catalyst/expressions/CodeGenerationSuite.scala | 10 ++++++++++ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d97611c98ac91..f6b6775923ac6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -572,14 +572,9 @@ class CodegenContext { } else { s"${freshNamePrefix}_$name" } - if (freshNameIds.contains(fullName)) { - val id = freshNameIds(fullName) - freshNameIds(fullName) = id + 1 - s"$fullName$id" - } else { - freshNameIds += fullName -> 1 - fullName - } + val id = freshNameIds.getOrElse(fullName, 0) + freshNameIds(fullName) = id + 1 + s"${fullName}_$id" } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index f7c023111ff59..5b71becee2de0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -489,4 +489,14 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(!ctx.subExprEliminationExprs.contains(ref)) } } + + test("SPARK-23986: freshName can generate duplicated names") { + val ctx = new CodegenContext + val names1 = ctx.freshName("myName1") :: ctx.freshName("myName1") :: + ctx.freshName("myName11") :: Nil + assert(names1.distinct.length == 3) + val names2 = ctx.freshName("a") :: ctx.freshName("a") :: + ctx.freshName("a_1") :: ctx.freshName("a_0") :: Nil + assert(names2.distinct.length == 4) + } } From 1ca3c50fefb34532c78427fa74872db3ecbf7ba2 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 17 Apr 2018 10:11:08 -0700 Subject: [PATCH 636/774] [SPARK-21741][ML][PYSPARK] Python API for DataFrame-based multivariate summarizer ## What changes were proposed in this pull request? Python API for DataFrame-based multivariate summarizer. ## How was this patch tested? doctest added. Author: WeichenXu Closes #20695 from WeichenXu123/py_summarizer. --- python/pyspark/ml/stat.py | 193 +++++++++++++++++++++++++++++++++++++- 1 file changed, 192 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index 93d0f4fd9148f..a06ab31a7a56a 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -19,7 +19,9 @@ from pyspark import since, SparkContext from pyspark.ml.common import _java2py, _py2java -from pyspark.ml.wrapper import _jvm +from pyspark.ml.wrapper import JavaWrapper, _jvm +from pyspark.sql.column import Column, _to_seq +from pyspark.sql.functions import lit class ChiSquareTest(object): @@ -195,6 +197,195 @@ def test(dataset, sampleCol, distName, *params): _jvm().PythonUtils.toSeq(params))) +class Summarizer(object): + """ + .. note:: Experimental + + Tools for vectorized statistics on MLlib Vectors. + The methods in this package provide various statistics for Vectors contained inside DataFrames. + This class lets users pick the statistics they would like to extract for a given column. + + >>> from pyspark.ml.stat import Summarizer + >>> from pyspark.sql import Row + >>> from pyspark.ml.linalg import Vectors + >>> summarizer = Summarizer.metrics("mean", "count") + >>> df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 1.0, 1.0)), + ... Row(weight=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF() + >>> df.select(summarizer.summary(df.features, df.weight)).show(truncate=False) + +-----------------------------------+ + |aggregate_metrics(features, weight)| + +-----------------------------------+ + |[[1.0,1.0,1.0], 1] | + +-----------------------------------+ + + >>> df.select(summarizer.summary(df.features)).show(truncate=False) + +--------------------------------+ + |aggregate_metrics(features, 1.0)| + +--------------------------------+ + |[[1.0,1.5,2.0], 2] | + +--------------------------------+ + + >>> df.select(Summarizer.mean(df.features, df.weight)).show(truncate=False) + +--------------+ + |mean(features)| + +--------------+ + |[1.0,1.0,1.0] | + +--------------+ + + >>> df.select(Summarizer.mean(df.features)).show(truncate=False) + +--------------+ + |mean(features)| + +--------------+ + |[1.0,1.5,2.0] | + +--------------+ + + + .. versionadded:: 2.4.0 + + """ + @staticmethod + @since("2.4.0") + def mean(col, weightCol=None): + """ + return a column of mean summary + """ + return Summarizer._get_single_metric(col, weightCol, "mean") + + @staticmethod + @since("2.4.0") + def variance(col, weightCol=None): + """ + return a column of variance summary + """ + return Summarizer._get_single_metric(col, weightCol, "variance") + + @staticmethod + @since("2.4.0") + def count(col, weightCol=None): + """ + return a column of count summary + """ + return Summarizer._get_single_metric(col, weightCol, "count") + + @staticmethod + @since("2.4.0") + def numNonZeros(col, weightCol=None): + """ + return a column of numNonZero summary + """ + return Summarizer._get_single_metric(col, weightCol, "numNonZeros") + + @staticmethod + @since("2.4.0") + def max(col, weightCol=None): + """ + return a column of max summary + """ + return Summarizer._get_single_metric(col, weightCol, "max") + + @staticmethod + @since("2.4.0") + def min(col, weightCol=None): + """ + return a column of min summary + """ + return Summarizer._get_single_metric(col, weightCol, "min") + + @staticmethod + @since("2.4.0") + def normL1(col, weightCol=None): + """ + return a column of normL1 summary + """ + return Summarizer._get_single_metric(col, weightCol, "normL1") + + @staticmethod + @since("2.4.0") + def normL2(col, weightCol=None): + """ + return a column of normL2 summary + """ + return Summarizer._get_single_metric(col, weightCol, "normL2") + + @staticmethod + def _check_param(featuresCol, weightCol): + if weightCol is None: + weightCol = lit(1.0) + if not isinstance(featuresCol, Column) or not isinstance(weightCol, Column): + raise TypeError("featureCol and weightCol should be a Column") + return featuresCol, weightCol + + @staticmethod + def _get_single_metric(col, weightCol, metric): + col, weightCol = Summarizer._check_param(col, weightCol) + return Column(JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer." + metric, + col._jc, weightCol._jc)) + + @staticmethod + @since("2.4.0") + def metrics(*metrics): + """ + Given a list of metrics, provides a builder that it turns computes metrics from a column. + + See the documentation of [[Summarizer]] for an example. + + The following metrics are accepted (case sensitive): + - mean: a vector that contains the coefficient-wise mean. + - variance: a vector tha contains the coefficient-wise variance. + - count: the count of all vectors seen. + - numNonzeros: a vector with the number of non-zeros for each coefficients + - max: the maximum for each coefficient. + - min: the minimum for each coefficient. + - normL2: the Euclidian norm for each coefficient. + - normL1: the L1 norm of each coefficient (sum of the absolute values). + + :param metrics: + metrics that can be provided. + :return: + an object of :py:class:`pyspark.ml.stat.SummaryBuilder` + + Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD + interface. + """ + sc = SparkContext._active_spark_context + js = JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer.metrics", + _to_seq(sc, metrics)) + return SummaryBuilder(js) + + +class SummaryBuilder(JavaWrapper): + """ + .. note:: Experimental + + A builder object that provides summary statistics about a given column. + + Users should not directly create such builders, but instead use one of the methods in + :py:class:`pyspark.ml.stat.Summarizer` + + .. versionadded:: 2.4.0 + + """ + def __init__(self, jSummaryBuilder): + super(SummaryBuilder, self).__init__(jSummaryBuilder) + + @since("2.4.0") + def summary(self, featuresCol, weightCol=None): + """ + Returns an aggregate object that contains the summary of the column with the requested + metrics. + + :param featuresCol: + a column that contains features Vector object. + :param weightCol: + a column that contains weight value. Default weight is 1.0. + :return: + an aggregate column that contains the statistics. The exact content of this + structure is determined during the creation of the builder. + """ + featuresCol, weightCol = Summarizer._check_param(featuresCol, weightCol) + return Column(self._java_obj.summary(featuresCol._jc, weightCol._jc)) + + if __name__ == "__main__": import doctest import pyspark.ml.stat From 5fccdae18911793967b315c02c058eb737e46174 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 17 Apr 2018 21:08:42 -0500 Subject: [PATCH 637/774] [SPARK-22968][DSTREAM] Throw an exception on partition revoking issue ## What changes were proposed in this pull request? Kafka partitions can be revoked when new consumers joined in the consumer group to rebalance the partitions. But current Spark Kafka connector code makes sure there's no partition revoking scenarios, so trying to get latest offset from revoked partitions will throw exceptions as JIRA mentioned. Partition revoking happens when new consumer joined the consumer group, which means different streaming apps are trying to use same group id. This is fundamentally not correct, different apps should use different consumer group. So instead of throwing an confused exception from Kafka, improve the exception message by identifying revoked partition and directly throw an meaningful exception when partition is revoked. Besides, this PR also fixes bugs in `DirectKafkaWordCount`, this example simply cannot be worked without the fix. ``` 8/01/05 09:48:27 INFO internals.ConsumerCoordinator: Revoking previously assigned partitions [kssh-7, kssh-4, kssh-3, kssh-6, kssh-5, kssh-0, kssh-2, kssh-1] for group use_a_separate_group_id_for_each_stream 18/01/05 09:48:27 INFO internals.AbstractCoordinator: (Re-)joining group use_a_separate_group_id_for_each_stream 18/01/05 09:48:27 INFO internals.AbstractCoordinator: Successfully joined group use_a_separate_group_id_for_each_stream with generation 4 18/01/05 09:48:27 INFO internals.ConsumerCoordinator: Setting newly assigned partitions [kssh-7, kssh-4, kssh-6, kssh-5] for group use_a_separate_group_id_for_each_stream ``` ## How was this patch tested? This is manually verified in local cluster, unfortunately I'm not sure how to simulate it in UT, so propose the PR without UT added. Author: jerryshao Closes #21038 from jerryshao/SPARK-22968. --- .../streaming/DirectKafkaWordCount.scala | 17 +++++++++++++---- .../kafka010/DirectKafkaInputDStream.scala | 12 ++++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala index def06026bde96..2082fb71afdf1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -18,6 +18,9 @@ // scalastyle:off println package org.apache.spark.examples.streaming +import org.apache.kafka.clients.consumer.ConsumerConfig +import org.apache.kafka.common.serialization.StringDeserializer + import org.apache.spark.SparkConf import org.apache.spark.streaming._ import org.apache.spark.streaming.kafka010._ @@ -26,18 +29,20 @@ import org.apache.spark.streaming.kafka010._ * Consumes messages from one or more topics in Kafka and does wordcount. * Usage: DirectKafkaWordCount * is a list of one or more Kafka brokers + * is a consumer group name to consume from topics * is a list of one or more kafka topics to consume from * * Example: * $ bin/run-example streaming.DirectKafkaWordCount broker1-host:port,broker2-host:port \ - * topic1,topic2 + * consumer-group topic1,topic2 */ object DirectKafkaWordCount { def main(args: Array[String]) { - if (args.length < 2) { + if (args.length < 3) { System.err.println(s""" |Usage: DirectKafkaWordCount | is a list of one or more Kafka brokers + | is a consumer group name to consume from topics | is a list of one or more kafka topics to consume from | """.stripMargin) @@ -46,7 +51,7 @@ object DirectKafkaWordCount { StreamingExamples.setStreamingLogLevels() - val Array(brokers, topics) = args + val Array(brokers, groupId, topics) = args // Create context with 2 second batch interval val sparkConf = new SparkConf().setAppName("DirectKafkaWordCount") @@ -54,7 +59,11 @@ object DirectKafkaWordCount { // Create direct kafka stream with brokers and topics val topicsSet = topics.split(",").toSet - val kafkaParams = Map[String, String]("metadata.broker.list" -> brokers) + val kafkaParams = Map[String, Object]( + ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG -> brokers, + ConsumerConfig.GROUP_ID_CONFIG -> groupId, + ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG -> classOf[StringDeserializer], + ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[StringDeserializer]) val messages = KafkaUtils.createDirectStream[String, String]( ssc, LocationStrategies.PreferConsistent, diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 215b7cab703fb..c3221481556f5 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -190,8 +190,20 @@ private[spark] class DirectKafkaInputDStream[K, V]( // make sure new partitions are reflected in currentOffsets val newPartitions = parts.diff(currentOffsets.keySet) + + // Check if there's any partition been revoked because of consumer rebalance. + val revokedPartitions = currentOffsets.keySet.diff(parts) + if (revokedPartitions.nonEmpty) { + throw new IllegalStateException(s"Previously tracked partitions " + + s"${revokedPartitions.mkString("[", ",", "]")} been revoked by Kafka because of consumer " + + s"rebalance. This is mostly due to another stream with same group id joined, " + + s"please check if there're different streaming application misconfigure to use same " + + s"group id. Fundamentally different stream should use different group id") + } + // position for new partitions determined by auto.offset.reset if no commit currentOffsets = currentOffsets ++ newPartitions.map(tp => tp -> c.position(tp)).toMap + // don't want to consume messages, so pause c.pause(newPartitions.asJava) // find latest available offsets From 1e3b8762a854a07c317f69fba7fa1a7bcdc58ff3 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Wed, 18 Apr 2018 10:36:41 +0800 Subject: [PATCH 638/774] [SPARK-21479][SQL] Outer join filter pushdown in null supplying table when condition is on one of the joined columns ## What changes were proposed in this pull request? Added `TransitPredicateInOuterJoin` optimization rule that transits constraints from the preserved side of an outer join to the null-supplying side. The constraints of the join operator will remain unchanged. ## How was this patch tested? Added 3 tests in `InferFiltersFromConstraintsSuite`. Author: maryannxue Closes #20816 from maryannxue/spark-21479. --- .../sql/catalyst/optimizer/Optimizer.scala | 42 +++++++++++++++++-- .../plans/logical/QueryPlanConstraints.scala | 25 +++++++++-- .../InferFiltersFromConstraintsSuite.scala | 36 ++++++++++++++++ 3 files changed, 96 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5fb59ef350b8b..913354e4df0e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -637,8 +637,11 @@ object CollapseWindow extends Rule[LogicalPlan] { * constraints. These filters are currently inserted to the existing conditions in the Filter * operators and on either side of Join operators. * - * Note: While this optimization is applicable to all types of join, it primarily benefits Inner and - * LeftSemi joins. + * In addition, for left/right outer joins, infer predicate from the preserved side of the Join + * operator and push the inferred filter over to the null-supplying side. For example, if the + * preserved side has constraints of the form 'a > 5' and the join condition is 'a = b', in + * which 'b' is an attribute from the null-supplying side, a [[Filter]] operator of 'b > 5' will + * be applied to the null-supplying side. */ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper { @@ -671,11 +674,42 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe val newConditionOpt = conditionOpt match { case Some(condition) => val newFilters = additionalConstraints -- splitConjunctivePredicates(condition) - if (newFilters.nonEmpty) Option(And(newFilters.reduce(And), condition)) else None + if (newFilters.nonEmpty) Option(And(newFilters.reduce(And), condition)) else conditionOpt case None => additionalConstraints.reduceOption(And) } - if (newConditionOpt.isDefined) Join(left, right, joinType, newConditionOpt) else join + // Infer filter for left/right outer joins + val newLeftOpt = joinType match { + case RightOuter if newConditionOpt.isDefined => + val inferredConstraints = left.getRelevantConstraints( + left.constraints + .union(right.constraints) + .union(splitConjunctivePredicates(newConditionOpt.get).toSet)) + val newFilters = inferredConstraints + .filterNot(left.constraints.contains) + .reduceLeftOption(And) + newFilters.map(Filter(_, left)) + case _ => None + } + val newRightOpt = joinType match { + case LeftOuter if newConditionOpt.isDefined => + val inferredConstraints = right.getRelevantConstraints( + right.constraints + .union(left.constraints) + .union(splitConjunctivePredicates(newConditionOpt.get).toSet)) + val newFilters = inferredConstraints + .filterNot(right.constraints.contains) + .reduceLeftOption(And) + newFilters.map(Filter(_, right)) + case _ => None + } + + if ((newConditionOpt.isDefined && (newConditionOpt ne conditionOpt)) + || newLeftOpt.isDefined || newRightOpt.isDefined) { + Join(newLeftOpt.getOrElse(left), newRightOpt.getOrElse(right), joinType, newConditionOpt) + } else { + join + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 046848875548b..a29f3d29236c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -41,9 +41,7 @@ trait QueryPlanConstraints { self: LogicalPlan => * example, if this set contains the expression `a = 2` then that expression is guaranteed to * evaluate to `true` for all rows produced. */ - lazy val constraints: ExpressionSet = ExpressionSet(allConstraints.filter { c => - c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic - }) + lazy val constraints: ExpressionSet = ExpressionSet(allConstraints.filter(selfReferenceOnly)) /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints @@ -55,6 +53,23 @@ trait QueryPlanConstraints { self: LogicalPlan => */ protected def validConstraints: Set[Expression] = Set.empty + /** + * Returns an [[ExpressionSet]] that contains an additional set of constraints, such as + * equality constraints and `isNotNull` constraints, etc., and that only contains references + * to this [[LogicalPlan]] node. + */ + def getRelevantConstraints(constraints: Set[Expression]): ExpressionSet = { + val allRelevantConstraints = + if (conf.constraintPropagationEnabled) { + constraints + .union(inferAdditionalConstraints(constraints)) + .union(constructIsNotNullConstraints(constraints)) + } else { + constraints + } + ExpressionSet(allRelevantConstraints.filter(selfReferenceOnly)) + } + /** * Infers a set of `isNotNull` constraints from null intolerant expressions as well as * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this @@ -120,4 +135,8 @@ trait QueryPlanConstraints { self: LogicalPlan => destination: Attribute): Set[Expression] = constraints.map(_ transform { case e: Expression if e.semanticEquals(source) => destination }) + + private def selfReferenceOnly(e: Expression): Boolean = { + e.references.nonEmpty && e.references.subsetOf(outputSet) && e.deterministic + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index f78c2356e35a5..e068f51044589 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -204,4 +204,40 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } + + test("SPARK-21479: Outer join after-join filters push down to null-supplying side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, LeftOuter, condition).where("x.a".attr === 2).analyze + val left = x.where(IsNotNull('a) && 'a === 2) + val right = y.where(IsNotNull('a) && 'a === 2) + val correctAnswer = left.join(right, LeftOuter, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-21479: Outer join pre-existing filters push down to null-supplying side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y.where("y.a".attr > 5), RightOuter, condition).analyze + val left = x.where(IsNotNull('a) && 'a > 5) + val right = y.where(IsNotNull('a) && 'a > 5) + val correctAnswer = left.join(right, RightOuter, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-21479: Outer join no filter push down to preserved side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y.where("y.a".attr === 1), LeftOuter, condition).analyze + val left = x + val right = y.where(IsNotNull('a) && 'a === 1) + val correctAnswer = left.join(right, LeftOuter, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } } From 310a8cd06299e434d94a1e391a6eb62944112446 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 Apr 2018 11:51:10 +0800 Subject: [PATCH 639/774] [SPARK-23341][SQL] define some standard options for data source v2 ## What changes were proposed in this pull request? Each data source implementation can define its own options and teach its users how to set them. Spark doesn't have any restrictions about what options a data source should or should not have. It's possible that some options are very common and many data sources use them. However different data sources may define the common options(key and meaning) differently, which is quite confusing to end users. This PR defines some standard options that data sources can optionally adopt: path, table and database. ## How was this patch tested? a new test case. Author: Wenchen Fan Closes #20535 from cloud-fan/options. --- .../sql/sources/v2/DataSourceOptions.java | 100 ++++++++++++++++++ .../apache/spark/sql/DataFrameReader.scala | 14 ++- .../sources/v2/DataSourceOptionsSuite.scala | 25 +++++ 3 files changed, 135 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java index c32053580f016..83df3be747085 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java @@ -17,16 +17,61 @@ package org.apache.spark.sql.sources.v2; +import java.io.IOException; import java.util.HashMap; import java.util.Locale; import java.util.Map; import java.util.Optional; +import java.util.stream.Stream; + +import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.spark.annotation.InterfaceStability; /** * An immutable string-to-string map in which keys are case-insensitive. This is used to represent * data source options. + * + * Each data source implementation can define its own options and teach its users how to set them. + * Spark doesn't have any restrictions about what options a data source should or should not have. + * Instead Spark defines some standard options that data sources can optionally adopt. It's possible + * that some options are very common and many data sources use them. However different data + * sources may define the common options(key and meaning) differently, which is quite confusing to + * end users. + * + * The standard options defined by Spark: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
    Option keyOption value
    pathA path string of the data files/directories, like + * path1, /absolute/file2, path3/*. The path can + * either be relative or absolute, points to either file or directory, and can contain + * wildcards. This option is commonly used by file-based data sources.
    pathsA JSON array style paths string of the data files/directories, like + * ["path1", "/absolute/file2"]. The format of each path is same as the + * path option, plus it should follow JSON string literal format, e.g. quotes + * should be escaped, pa\"th means pa"th. + *
    tableA table name string representing the table name directly without any interpretation. + * For example, db.tbl means a table called db.tbl, not a table called tbl + * inside database db. `t*b.l` means a table called `t*b.l`, not t*b.l.
    databaseA database name string representing the database name directly without any + * interpretation, which is very similar to the table name option.
    */ @InterfaceStability.Evolving public class DataSourceOptions { @@ -97,4 +142,59 @@ public double getDouble(String key, double defaultValue) { return keyLowerCasedMap.containsKey(lcaseKey) ? Double.parseDouble(keyLowerCasedMap.get(lcaseKey)) : defaultValue; } + + /** + * The option key for singular path. + */ + public static final String PATH_KEY = "path"; + + /** + * The option key for multiple paths. + */ + public static final String PATHS_KEY = "paths"; + + /** + * The option key for table name. + */ + public static final String TABLE_KEY = "table"; + + /** + * The option key for database name. + */ + public static final String DATABASE_KEY = "database"; + + /** + * Returns all the paths specified by both the singular path option and the multiple + * paths option. + */ + public String[] paths() { + String[] singularPath = + get(PATH_KEY).map(s -> new String[]{s}).orElseGet(() -> new String[0]); + Optional pathsStr = get(PATHS_KEY); + if (pathsStr.isPresent()) { + ObjectMapper objectMapper = new ObjectMapper(); + try { + String[] paths = objectMapper.readValue(pathsStr.get(), String[].class); + return Stream.of(singularPath, paths).flatMap(Stream::of).toArray(String[]::new); + } catch (IOException e) { + return singularPath; + } + } else { + return singularPath; + } + } + + /** + * Returns the value of the table name option. + */ + public Optional tableName() { + return get(TABLE_KEY); + } + + /** + * Returns the value of the database name option. + */ + public Optional databaseName() { + return get(DATABASE_KEY); + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index ae3ba1690f696..d640fdc530ce2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -21,6 +21,8 @@ import java.util.{Locale, Properties} import scala.collection.JavaConverters._ +import com.fasterxml.jackson.databind.ObjectMapper + import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability import org.apache.spark.api.java.JavaRDD @@ -34,7 +36,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.sources.v2.{DataSourceV2, ReadSupport, ReadSupportWithSchema} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -171,7 +173,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 1.4.0 */ def load(path: String): DataFrame = { - option("path", path).load(Seq.empty: _*) // force invocation of `load(...varargs...)` + // force invocation of `load(...varargs...)` + option(DataSourceOptions.PATH_KEY, path).load(Seq.empty: _*) } /** @@ -193,10 +196,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { if (ds.isInstanceOf[ReadSupport] || ds.isInstanceOf[ReadSupportWithSchema]) { val sessionOptions = DataSourceV2Utils.extractSessionConfigs( ds = ds, conf = sparkSession.sessionState.conf) + val pathsOption = { + val objectMapper = new ObjectMapper() + DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray) + } Dataset.ofRows(sparkSession, DataSourceV2Relation.create( - ds, extraOptions.toMap ++ sessionOptions, + ds, extraOptions.toMap ++ sessionOptions + pathsOption, userSpecifiedSchema = userSpecifiedSchema)) - } else { loadV1Source(paths: _*) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala index 31dfc55b23361..cfa69a86de1a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala @@ -79,4 +79,29 @@ class DataSourceOptionsSuite extends SparkFunSuite { options.getDouble("foo", 0.1d) } } + + test("standard options") { + val options = new DataSourceOptions(Map( + DataSourceOptions.PATH_KEY -> "abc", + DataSourceOptions.TABLE_KEY -> "tbl").asJava) + + assert(options.paths().toSeq == Seq("abc")) + assert(options.tableName().get() == "tbl") + assert(!options.databaseName().isPresent) + } + + test("standard options with both singular path and multi-paths") { + val options = new DataSourceOptions(Map( + DataSourceOptions.PATH_KEY -> "abc", + DataSourceOptions.PATHS_KEY -> """["c", "d"]""").asJava) + + assert(options.paths().toSeq == Seq("abc", "c", "d")) + } + + test("standard options with only multi-paths") { + val options = new DataSourceOptions(Map( + DataSourceOptions.PATHS_KEY -> """["c", "d\"e"]""").asJava) + + assert(options.paths().toSeq == Seq("c", "d\"e")) + } } From cce469435d61bda5893d9aa6cfdf7ea46fa717df Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 17 Apr 2018 21:03:57 -0700 Subject: [PATCH 640/774] [SPARK-24002][SQL] Task not serializable caused by org.apache.parquet.io.api.Binary$ByteBufferBackedBinary.getBytes ## What changes were proposed in this pull request? ``` Py4JJavaError: An error occurred while calling o153.sql. : org.apache.spark.SparkException: Job aborted. at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:223) at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:189) at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult$lzycompute(commands.scala:70) at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult(commands.scala:68) at org.apache.spark.sql.execution.command.ExecutedCommandExec.executeCollect(commands.scala:79) at org.apache.spark.sql.Dataset$$anonfun$6.apply(Dataset.scala:190) at org.apache.spark.sql.Dataset$$anonfun$6.apply(Dataset.scala:190) at org.apache.spark.sql.Dataset$$anonfun$59.apply(Dataset.scala:3021) at org.apache.spark.sql.execution.SQLExecution$.withCustomExecutionEnv(SQLExecution.scala:89) at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:127) at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3020) at org.apache.spark.sql.Dataset.(Dataset.scala:190) at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:74) at org.apache.spark.sql.SparkSession.sql(SparkSession.scala:646) at sun.reflect.GeneratedMethodAccessor153.invoke(Unknown Source) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380) at py4j.Gateway.invoke(Gateway.java:293) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:226) at java.lang.Thread.run(Thread.java:748) Caused by: org.apache.spark.SparkException: Exception thrown in Future.get: at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.doExecuteBroadcast(BroadcastExchangeExec.scala:190) at org.apache.spark.sql.execution.InputAdapter.doExecuteBroadcast(WholeStageCodegenExec.scala:267) at org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec.doConsume(BroadcastNestedLoopJoinExec.scala:530) at org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:155) at org.apache.spark.sql.execution.ProjectExec.consume(basicPhysicalOperators.scala:37) at org.apache.spark.sql.execution.ProjectExec.doConsume(basicPhysicalOperators.scala:69) at org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:155) at org.apache.spark.sql.execution.FilterExec.consume(basicPhysicalOperators.scala:144) ... at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:190) ... 23 more Caused by: java.util.concurrent.ExecutionException: org.apache.spark.SparkException: Task not serializable at java.util.concurrent.FutureTask.report(FutureTask.java:122) at java.util.concurrent.FutureTask.get(FutureTask.java:206) at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.doExecuteBroadcast(BroadcastExchangeExec.scala:179) ... 276 more Caused by: org.apache.spark.SparkException: Task not serializable at org.apache.spark.util.ClosureCleaner$.ensureSerializable(ClosureCleaner.scala:340) at org.apache.spark.util.ClosureCleaner$.org$apache$spark$util$ClosureCleaner$$clean(ClosureCleaner.scala:330) at org.apache.spark.util.ClosureCleaner$.clean(ClosureCleaner.scala:156) at org.apache.spark.SparkContext.clean(SparkContext.scala:2380) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndex$1.apply(RDD.scala:850) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndex$1.apply(RDD.scala:849) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112) at org.apache.spark.rdd.RDD.withScope(RDD.scala:371) at org.apache.spark.rdd.RDD.mapPartitionsWithIndex(RDD.scala:849) at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:417) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:123) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:118) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$3.apply(SparkPlan.scala:152) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:149) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:118) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.prepareShuffleDependency(ShuffleExchangeExec.scala:89) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:125) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:116) at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.doExecute(ShuffleExchangeExec.scala:116) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:123) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:118) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$3.apply(SparkPlan.scala:152) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:149) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:118) at org.apache.spark.sql.execution.InputAdapter.inputRDDs(WholeStageCodegenExec.scala:271) at org.apache.spark.sql.execution.aggregate.HashAggregateExec.inputRDDs(HashAggregateExec.scala:181) at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:414) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:123) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:118) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$3.apply(SparkPlan.scala:152) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:149) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:118) at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:61) at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:70) at org.apache.spark.sql.execution.SparkPlan.executeCollectResult(SparkPlan.scala:264) at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec$$anon$1$$anonfun$call$1.apply(BroadcastExchangeExec.scala:93) at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec$$anon$1$$anonfun$call$1.apply(BroadcastExchangeExec.scala:81) at org.apache.spark.sql.execution.SQLExecution$.withExecutionId(SQLExecution.scala:150) at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec$$anon$1.call(BroadcastExchangeExec.scala:80) at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec$$anon$1.call(BroadcastExchangeExec.scala:76) at java.util.concurrent.FutureTask.run(FutureTask.java:266) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) ... 1 more Caused by: java.nio.BufferUnderflowException at java.nio.HeapByteBuffer.get(HeapByteBuffer.java:151) at java.nio.ByteBuffer.get(ByteBuffer.java:715) at org.apache.parquet.io.api.Binary$ByteBufferBackedBinary.getBytes(Binary.java:405) at org.apache.parquet.io.api.Binary$ByteBufferBackedBinary.getBytesUnsafe(Binary.java:414) at org.apache.parquet.io.api.Binary$ByteBufferBackedBinary.writeObject(Binary.java:484) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at java.io.ObjectStreamClass.invokeWriteObject(ObjectStreamClass.java:1128) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1496) ``` The Parquet filters are serializable but not thread safe. SparkPlan.prepare() could be called in different threads (BroadcastExchange will call it in a thread pool). Thus, we could serialize the same Parquet filter at the same time. This is not easily reproduced. The fix is to avoid serializing these Parquet filters in the driver. This PR is to avoid serializing these Parquet filters by moving the parquet filter generation from the driver to executors. ## How was this patch tested? Having two queries one is a 1000-line SQL query and a 3000-line SQL query. Need to run at least one hour with a heavy write workload to reproduce once. Author: gatorsmile Closes #21086 from gatorsmile/taskNotSerializable. --- .../parquet/ParquetFileFormat.scala | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 476bd02374364..d8f47eec952de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -321,19 +321,6 @@ class ParquetFileFormat SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, sparkSession.sessionState.conf.isParquetINT96AsTimestamp) - // Try to push down filters when filter push-down is enabled. - val pushed = - if (sparkSession.sessionState.conf.parquetFilterPushDown) { - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(ParquetFilters.createFilter(requiredSchema, _)) - .reduceOption(FilterApi.and) - } else { - None - } - val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) @@ -351,12 +338,26 @@ class ParquetFileFormat val timestampConversion: Boolean = sparkSession.sessionState.conf.isParquetINT96TimestampConversion val capacity = sqlConf.parquetVectorizedReaderBatchSize + val enableParquetFilterPushDown: Boolean = + sparkSession.sessionState.conf.parquetFilterPushDown // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) + // Try to push down filters when filter push-down is enabled. + val pushed = if (enableParquetFilterPushDown) { + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(ParquetFilters.createFilter(requiredSchema, _)) + .reduceOption(FilterApi.and) + } else { + None + } + val fileSplit = new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) From f81fa478ff990146e2a8e463ac252271448d96f5 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Wed, 18 Apr 2018 18:41:55 +0900 Subject: [PATCH 641/774] [SPARK-23926][SQL] Extending reverse function to support ArrayType arguments ## What changes were proposed in this pull request? This PR extends `reverse` functions to be able to operate over array columns and covers: - Introduction of `Reverse` expression that represents logic for reversing arrays and also strings - Removal of `StringReverse` expression - A wrapper for PySpark ## How was this patch tested? New tests added into: - CollectionExpressionsSuite - DataFrameFunctionsSuite ## Codegen examples ### Primitive type ``` val df = Seq( Seq(1, 3, 4, 2), null ).toDF("i") df.filter($"i".isNotNull || $"i".isNull).select(reverse($"i")).debugCodegen ``` Result: ``` /* 032 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 033 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 034 */ null : (inputadapter_row.getArray(0)); /* 035 */ /* 036 */ boolean filter_value = true; /* 037 */ /* 038 */ if (!(!inputadapter_isNull)) { /* 039 */ filter_value = inputadapter_isNull; /* 040 */ } /* 041 */ if (!filter_value) continue; /* 042 */ /* 043 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 044 */ /* 045 */ boolean project_isNull = inputadapter_isNull; /* 046 */ ArrayData project_value = null; /* 047 */ /* 048 */ if (!inputadapter_isNull) { /* 049 */ final int project_length = inputadapter_value.numElements(); /* 050 */ project_value = inputadapter_value.copy(); /* 051 */ for(int k = 0; k < project_length / 2; k++) { /* 052 */ int l = project_length - k - 1; /* 053 */ boolean isNullAtK = project_value.isNullAt(k); /* 054 */ boolean isNullAtL = project_value.isNullAt(l); /* 055 */ if(!isNullAtK) { /* 056 */ int el = project_value.getInt(k); /* 057 */ if(!isNullAtL) { /* 058 */ project_value.setInt(k, project_value.getInt(l)); /* 059 */ } else { /* 060 */ project_value.setNullAt(k); /* 061 */ } /* 062 */ project_value.setInt(l, el); /* 063 */ } else if (!isNullAtL) { /* 064 */ project_value.setInt(k, project_value.getInt(l)); /* 065 */ project_value.setNullAt(l); /* 066 */ } /* 067 */ } /* 068 */ /* 069 */ } ``` ### Non-primitive type ``` val df = Seq( Seq("a", "c", "d", "b"), null ).toDF("s") df.filter($"s".isNotNull || $"s".isNull).select(reverse($"s")).debugCodegen ``` Result: ``` /* 032 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 033 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 034 */ null : (inputadapter_row.getArray(0)); /* 035 */ /* 036 */ boolean filter_value = true; /* 037 */ /* 038 */ if (!(!inputadapter_isNull)) { /* 039 */ filter_value = inputadapter_isNull; /* 040 */ } /* 041 */ if (!filter_value) continue; /* 042 */ /* 043 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 044 */ /* 045 */ boolean project_isNull = inputadapter_isNull; /* 046 */ ArrayData project_value = null; /* 047 */ /* 048 */ if (!inputadapter_isNull) { /* 049 */ final int project_length = inputadapter_value.numElements(); /* 050 */ project_value = new org.apache.spark.sql.catalyst.util.GenericArrayData(new Object[project_length]); /* 051 */ for(int k = 0; k < project_length; k++) { /* 052 */ int l = project_length - k - 1; /* 053 */ project_value.update(k, inputadapter_value.getUTF8String(l)); /* 054 */ } /* 055 */ /* 056 */ } ``` Author: mn-mikke Closes #21034 from mn-mikke/feature/array-api-reverse-to-master. --- python/pyspark/sql/functions.py | 20 +++- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../expressions/collectionOperations.scala | 88 +++++++++++++++++ .../expressions/stringExpressions.scala | 20 ---- .../CollectionExpressionsSuite.scala | 44 +++++++++ .../expressions/StringExpressionsSuite.scala | 6 +- .../org/apache/spark/sql/functions.scala | 15 ++- .../spark/sql/DataFrameFunctionsSuite.scala | 94 +++++++++++++++++++ 8 files changed, 256 insertions(+), 33 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6ca22b610843d..d3bb0a5d6b36a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1414,7 +1414,6 @@ def hash(*cols): 'uppercase. Words are delimited by whitespace.', 'lower': 'Converts a string column to lower case.', 'upper': 'Converts a string column to upper case.', - 'reverse': 'Reverses the string column and returns it as a new string column.', 'ltrim': 'Trim the spaces from left end for the specified string value.', 'rtrim': 'Trim the spaces from right end for the specified string value.', 'trim': 'Trim the spaces from both ends for the specified string column.', @@ -2128,6 +2127,25 @@ def sort_array(col, asc=True): return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) +@since(1.5) +@ignore_unicode_prefix +def reverse(col): + """ + Collection function: returns a reversed string or an array with reverse order of elements. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([('Spark SQL',)], ['data']) + >>> df.select(reverse(df.data).alias('s')).collect() + [Row(s=u'LQS krapS')] + >>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data']) + >>> df.select(reverse(df.data).alias('r')).collect() + [Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.reverse(_to_java_column(col))) + + @since(2.3) def map_keys(col): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 4dd1ca509bf2c..38c874ad948e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -336,7 +336,6 @@ object FunctionRegistry { expression[RegExpReplace]("regexp_replace"), expression[StringRepeat]("repeat"), expression[StringReplace]("replace"), - expression[StringReverse]("reverse"), expression[RLike]("rlike"), expression[StringRPad]("rpad"), expression[StringTrimRight]("rtrim"), @@ -411,6 +410,7 @@ object FunctionRegistry { expression[SortArray]("sort_array"), expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), + expression[Reverse]("reverse"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 7c87777eed47a..76b71f5b86074 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Given an array or map, returns its size. Returns -1 if null. @@ -212,6 +213,93 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def prettyName: String = "sort_array" } +/** + * Returns a reversed string or an array with reverse order of elements. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns a reversed string or an array with reverse order of elements.", + examples = """ + Examples: + > SELECT _FUNC_('Spark SQL'); + LQS krapS + > SELECT _FUNC_(array(2, 1, 4, 3)); + [3, 4, 1, 2] + """, + since = "1.5.0", + note = "Reverse logic for arrays is available since 2.4.0." +) +case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + // Input types are utilized by type coercion in ImplicitTypeCasts. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType)) + + override def dataType: DataType = child.dataType + + lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + override def nullSafeEval(input: Any): Any = input match { + case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse) + case s: UTF8String => s.reverse() + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => dataType match { + case _: StringType => stringCodeGen(ev, c) + case _: ArrayType => arrayCodeGen(ctx, ev, c) + }) + } + + private def stringCodeGen(ev: ExprCode, childName: String): String = { + s"${ev.value} = ($childName).reverse();" + } + + private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = { + val length = ctx.freshName("length") + val javaElementType = CodeGenerator.javaType(elementType) + val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) + + val initialization = if (isPrimitiveType) { + s"$childName.copy()" + } else { + s"new ${classOf[GenericArrayData].getName()}(new Object[$length])" + } + + val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length + + val swapAssigments = if (isPrimitiveType) { + val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType) + val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index) + s"""|boolean isNullAtK = ${ev.value}.isNullAt(k); + |boolean isNullAtL = ${ev.value}.isNullAt(l); + |if(!isNullAtK) { + | $javaElementType el = ${getCall("k")}; + | if(!isNullAtL) { + | ${ev.value}.$setFunc(k, ${getCall("l")}); + | } else { + | ${ev.value}.setNullAt(k); + | } + | ${ev.value}.$setFunc(l, el); + |} else if (!isNullAtL) { + | ${ev.value}.$setFunc(k, ${getCall("l")}); + | ${ev.value}.setNullAt(l); + |}""".stripMargin + } else { + s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});" + } + + s""" + |final int $length = $childName.numElements(); + |${ev.value} = $initialization; + |for(int k = 0; k < $numberOfIterations; k++) { + | int l = $length - k - 1; + | $swapAssigments + |} + """.stripMargin + } + + override def prettyName: String = "reverse" +} + /** * Checks if the array (left) has the element (right) */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 22fbb8998ed89..5a02ca0d6862c 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1504,26 +1504,6 @@ case class StringRepeat(str: Expression, times: Expression) } } -/** - * Returns the reversed given string. - */ -@ExpressionDescription( - usage = "_FUNC_(str) - Returns the reversed given string.", - examples = """ - Examples: - > SELECT _FUNC_('Spark SQL'); - LQS krapS - """) -case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression { - override def convert(v: UTF8String): UTF8String = v.reverse() - - override def prettyName: String = "reverse" - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).reverse()") - } -} - /** * Returns a string consisting of n spaces. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 5a31e3a30edd6..517639dbc7232 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -125,4 +125,48 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation( ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123) } + + test("Reverse") { + // Primitive-type elements + val ai0 = Literal.create(Seq(2, 1, 4, 3), ArrayType(IntegerType)) + val ai1 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) + val ai2 = Literal.create(Seq(null, 1, null, 3), ArrayType(IntegerType)) + val ai3 = Literal.create(Seq(2, null, 4, null), ArrayType(IntegerType)) + val ai4 = Literal.create(Seq(null, null, null), ArrayType(IntegerType)) + val ai5 = Literal.create(Seq(1), ArrayType(IntegerType)) + val ai6 = Literal.create(Seq.empty, ArrayType(IntegerType)) + val ai7 = Literal.create(null, ArrayType(IntegerType)) + + checkEvaluation(Reverse(ai0), Seq(3, 4, 1, 2)) + checkEvaluation(Reverse(ai1), Seq(3, 1, 2)) + checkEvaluation(Reverse(ai2), Seq(3, null, 1, null)) + checkEvaluation(Reverse(ai3), Seq(null, 4, null, 2)) + checkEvaluation(Reverse(ai4), Seq(null, null, null)) + checkEvaluation(Reverse(ai5), Seq(1)) + checkEvaluation(Reverse(ai6), Seq.empty) + checkEvaluation(Reverse(ai7), null) + + // Non-primitive-type elements + val as0 = Literal.create(Seq("b", "a", "d", "c"), ArrayType(StringType)) + val as1 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType)) + val as2 = Literal.create(Seq(null, "a", null, "c"), ArrayType(StringType)) + val as3 = Literal.create(Seq("b", null, "d", null), ArrayType(StringType)) + val as4 = Literal.create(Seq(null, null, null), ArrayType(StringType)) + val as5 = Literal.create(Seq("a"), ArrayType(StringType)) + val as6 = Literal.create(Seq.empty, ArrayType(StringType)) + val as7 = Literal.create(null, ArrayType(StringType)) + val aa = Literal.create( + Seq(Seq("a", "b"), Seq("c", "d"), Seq("e")), + ArrayType(ArrayType(StringType))) + + checkEvaluation(Reverse(as0), Seq("c", "d", "a", "b")) + checkEvaluation(Reverse(as1), Seq("c", "a", "b")) + checkEvaluation(Reverse(as2), Seq("c", null, "a", null)) + checkEvaluation(Reverse(as3), Seq(null, "d", null, "b")) + checkEvaluation(Reverse(as4), Seq(null, null, null)) + checkEvaluation(Reverse(as5), Seq("a")) + checkEvaluation(Reverse(as6), Seq.empty) + checkEvaluation(Reverse(as7), null) + checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b"))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 9a1a4da074ce3..f1a6f9b8889fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -629,9 +629,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("REVERSE") { val s = 'a.string.at(0) val row1 = create_row("abccc") - checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1) - checkEvaluation(StringReverse(s), "cccba", row1) - checkEvaluation(StringReverse(Literal.create(null, StringType)), null, row1) + checkEvaluation(Reverse(Literal("abccc")), "cccba", row1) + checkEvaluation(Reverse(s), "cccba", row1) + checkEvaluation(Reverse(Literal.create(null, StringType)), null, row1) } test("SPACE") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 642ac056bb809..a55a800f48245 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2464,14 +2464,6 @@ object functions { StringRepeat(str.expr, lit(n).expr) } - /** - * Reverses the string column and returns it as a new string column. - * - * @group string_funcs - * @since 1.5.0 - */ - def reverse(str: Column): Column = withExpr { StringReverse(str.expr) } - /** * Trim the spaces from right end for the specified string value. * @@ -3316,6 +3308,13 @@ object functions { */ def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) } + /** + * Returns a reversed string or an array with reverse order of elements. + * @group collection_funcs + * @since 1.5.0 + */ + def reverse(e: Column): Column = withExpr { Reverse(e.expr) } + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 636e86baedf6f..74c42f2599dca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -441,6 +441,100 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df.selectExpr("array_max(a)"), answer) } + test("reverse function") { + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on + + // String test cases + val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i") + + checkAnswer( + oneRowDF.select(reverse('s)), + Seq(Row("krapS")) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(s)"), + Seq(Row("krapS")) + ) + checkAnswer( + oneRowDF.select(reverse('i)), + Seq(Row("5123")) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(i)"), + Seq(Row("5123")) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(null)"), + Seq(Row(null)) + ) + + // Array test cases (primitive-type elements) + val idf = Seq( + Seq(1, 9, 8, 7), + Seq(5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + checkAnswer( + idf.select(reverse('i)), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.filter(dummyFilter('i)).select(reverse('i)), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.selectExpr("reverse(i)"), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(array(1, null, 2, null))"), + Seq(Row(Seq(null, 2, null, 1))) + ) + checkAnswer( + oneRowDF.filter(dummyFilter('i)).selectExpr("reverse(array(1, null, 2, null))"), + Seq(Row(Seq(null, 2, null, 1))) + ) + + // Array test cases (non-primitive-type elements) + val sdf = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + checkAnswer( + sdf.select(reverse('s)), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + sdf.filter(dummyFilter('s)).select(reverse('s)), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + sdf.selectExpr("reverse(s)"), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(array(array(1, 2), array(3, 4)))"), + Seq(Row(Seq(Seq(3, 4), Seq(1, 2)))) + ) + checkAnswer( + oneRowDF.filter(dummyFilter('s)).selectExpr("reverse(array(array(1, 2), array(3, 4)))"), + Seq(Row(Seq(Seq(3, 4), Seq(1, 2)))) + ) + + // Error test cases + intercept[AnalysisException] { + oneRowDF.selectExpr("reverse(struct(1, 'a'))") + } + intercept[AnalysisException] { + oneRowDF.selectExpr("reverse(map(1, 'a'))") + } + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From f09a9e9418c1697d198de18f340b1288f5eb025c Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 18 Apr 2018 08:22:05 -0700 Subject: [PATCH 642/774] [SPARK-24007][SQL] EqualNullSafe for FloatType and DoubleType might generate a wrong result by codegen. ## What changes were proposed in this pull request? `EqualNullSafe` for `FloatType` and `DoubleType` might generate a wrong result by codegen. ```scala scala> val df = Seq((Some(-1.0d), None), (None, Some(-1.0d))).toDF() df: org.apache.spark.sql.DataFrame = [_1: double, _2: double] scala> df.show() +----+----+ | _1| _2| +----+----+ |-1.0|null| |null|-1.0| +----+----+ scala> df.filter("_1 <=> _2").show() +----+----+ | _1| _2| +----+----+ |-1.0|null| |null|-1.0| +----+----+ ``` The result should be empty but the result remains two rows. ## How was this patch tested? Added a test. Author: Takuya UESHIN Closes #21094 from ueshin/issues/SPARK-24007/equalnullsafe. --- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 6 ++++-- .../spark/sql/catalyst/expressions/PredicateSuite.scala | 7 +++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f6b6775923ac6..cf0a91ff00626 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -582,8 +582,10 @@ class CodegenContext { */ def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match { case BinaryType => s"java.util.Arrays.equals($c1, $c2)" - case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2" - case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2" + case FloatType => + s"((java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2)" + case DoubleType => + s"((java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2)" case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2" case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)" case array: ArrayType => genComp(array, c1, c2) + " == 0" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 8a8f8e10225fa..1bfd180ae4393 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -442,4 +442,11 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx) assert(ctx.inlinedMutableStates.isEmpty) } + + test("SPARK-24007: EqualNullSafe for FloatType and DoubleType might generate a wrong result") { + checkEvaluation(EqualNullSafe(Literal(null, FloatType), Literal(-1.0f)), false) + checkEvaluation(EqualNullSafe(Literal(-1.0f), Literal(null, FloatType)), false) + checkEvaluation(EqualNullSafe(Literal(null, DoubleType), Literal(-1.0d)), false) + checkEvaluation(EqualNullSafe(Literal(-1.0d), Literal(null, DoubleType)), false) + } } From a9066478f6d98c3ae634c3bb9b09ee20bd60e111 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 19 Apr 2018 00:05:47 +0200 Subject: [PATCH 643/774] [SPARK-23875][SQL][FOLLOWUP] Add IndexedSeq wrapper for ArrayData ## What changes were proposed in this pull request? Use specified accessor in `ArrayData.foreach` and `toArray`. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #21099 from viirya/SPARK-23875-followup. --- .../org/apache/spark/sql/catalyst/util/ArrayData.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala index 2cf59d567c08c..104b428614849 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala @@ -141,28 +141,29 @@ abstract class ArrayData extends SpecializedGetters with Serializable { def toArray[T: ClassTag](elementType: DataType): Array[T] = { val size = numElements() + val accessor = InternalRow.getAccessor(elementType) val values = new Array[T](size) var i = 0 while (i < size) { if (isNullAt(i)) { values(i) = null.asInstanceOf[T] } else { - values(i) = get(i, elementType).asInstanceOf[T] + values(i) = accessor(this, i).asInstanceOf[T] } i += 1 } values } - // todo: specialize this. def foreach(elementType: DataType, f: (Int, Any) => Unit): Unit = { val size = numElements() + val accessor = InternalRow.getAccessor(elementType) var i = 0 while (i < size) { if (isNullAt(i)) { f(i, null) } else { - f(i, get(i, elementType)) + f(i, accessor(this, i)) } i += 1 } From 0c94e48bc50717e1627c0d2acd5382d9adc73c97 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Wed, 18 Apr 2018 16:37:41 -0700 Subject: [PATCH 644/774] [SPARK-23775][TEST] Make DataFrameRangeSuite not flaky ## What changes were proposed in this pull request? DataFrameRangeSuite.test("Cancelling stage in a query with Range.") stays sometimes in an infinite loop and times out the build. There were multiple issues with the test: 1. The first valid stageId is zero when the test started alone and not in a suite and the following code waits until timeout: ``` eventually(timeout(10.seconds), interval(1.millis)) { assert(DataFrameRangeSuite.stageToKill > 0) } ``` 2. The `DataFrameRangeSuite.stageToKill` was overwritten by the task's thread after the reset which ended up in canceling the same stage 2 times. This caused the infinite wait. This PR solves this mentioned flakyness by removing the shared `DataFrameRangeSuite.stageToKill` and using `wait` and `CountDownLatch` for synhronization. ## How was this patch tested? Existing unit test. Author: Gabor Somogyi Closes #20888 from gaborgsomogyi/SPARK-23775. --- .../spark/sql/DataFrameRangeSuite.scala | 78 +++++++++++-------- 1 file changed, 45 insertions(+), 33 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index 57a930dfaf320..a0fd74088ce8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -17,14 +17,16 @@ package org.apache.spark.sql +import java.util.concurrent.{CountDownLatch, TimeUnit} + import scala.concurrent.duration._ import scala.math.abs import scala.util.Random import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkException, TaskContext} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -152,39 +154,53 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } test("Cancelling stage in a query with Range.") { - val listener = new SparkListener { - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - eventually(timeout(10.seconds), interval(1.millis)) { - assert(DataFrameRangeSuite.stageToKill > 0) + // Save and restore the value because SparkContext is shared + val savedInterruptOnCancel = sparkContext + .getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) + + try { + sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true") + + for (codegen <- Seq(true, false)) { + // This countdown latch used to make sure with all the stages cancelStage called in listener + val latch = new CountDownLatch(2) + + val listener = new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + sparkContext.cancelStage(taskStart.stageId) + latch.countDown() + } } - sparkContext.cancelStage(DataFrameRangeSuite.stageToKill) - } - } - sparkContext.addSparkListener(listener) - for (codegen <- Seq(true, false)) { - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { - DataFrameRangeSuite.stageToKill = -1 - val ex = intercept[SparkException] { - spark.range(0, 100000000000L, 1, 1).map { x => - DataFrameRangeSuite.stageToKill = TaskContext.get().stageId() - x - }.toDF("id").agg(sum("id")).collect() + sparkContext.addSparkListener(listener) + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { + val ex = intercept[SparkException] { + sparkContext.range(0, 10000L, numSlices = 10).mapPartitions { x => + x.synchronized { + x.wait() + } + x + }.toDF("id").agg(sum("id")).collect() + } + ex.getCause() match { + case null => + assert(ex.getMessage().contains("cancelled")) + case cause: SparkException => + assert(cause.getMessage().contains("cancelled")) + case cause: Throwable => + fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") + } } - ex.getCause() match { - case null => - assert(ex.getMessage().contains("cancelled")) - case cause: SparkException => - assert(cause.getMessage().contains("cancelled")) - case cause: Throwable => - fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") + latch.await(20, TimeUnit.SECONDS) + eventually(timeout(20.seconds)) { + assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) } + sparkContext.removeSparkListener(listener) } - eventually(timeout(20.seconds)) { - assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) - } + } finally { + sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, + savedInterruptOnCancel) } - sparkContext.removeSparkListener(listener) } test("SPARK-20430 Initialize Range parameters in a driver side") { @@ -204,7 +220,3 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } } } - -object DataFrameRangeSuite { - @volatile var stageToKill = -1 -} From 8bb0df2c65355dfdcd28e362ff661c6c7ebc99c0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 19 Apr 2018 10:00:57 +0800 Subject: [PATCH 645/774] [SPARK-24014][PYSPARK] Add onStreamingStarted method to StreamingListener ## What changes were proposed in this pull request? The `StreamingListener` in PySpark side seems to be lack of `onStreamingStarted` method. This patch adds it and a test for it. This patch also includes a trivial doc improvement for `createDirectStream`. Original PR is #21057. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21098 from viirya/SPARK-24014. --- python/pyspark/streaming/kafka.py | 3 ++- python/pyspark/streaming/listener.py | 6 ++++++ python/pyspark/streaming/tests.py | 7 +++++++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index fdb9308604489..ed2e0e7d10fa2 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -104,7 +104,8 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None, :param topics: list of topic_name to consume. :param kafkaParams: Additional params for Kafka. :param fromOffsets: Per-topic/partition Kafka offsets defining the (inclusive) starting - point of the stream. + point of the stream (a dictionary mapping `TopicAndPartition` to + integers). :param keyDecoder: A function used to decode key (default is utf8_decoder). :param valueDecoder: A function used to decode value (default is utf8_decoder). :param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess diff --git a/python/pyspark/streaming/listener.py b/python/pyspark/streaming/listener.py index b830797f5c0a0..d4ecc215aea99 100644 --- a/python/pyspark/streaming/listener.py +++ b/python/pyspark/streaming/listener.py @@ -23,6 +23,12 @@ class StreamingListener(object): def __init__(self): pass + def onStreamingStarted(self, streamingStarted): + """ + Called when the streaming has been started. + """ + pass + def onReceiverStarted(self, receiverStarted): """ Called when a receiver has been started diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 7dde7c0928c08..103940923dd4d 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -507,6 +507,10 @@ def __init__(self): self.batchInfosCompleted = [] self.batchInfosStarted = [] self.batchInfosSubmitted = [] + self.streamingStartedTime = [] + + def onStreamingStarted(self, streamingStarted): + self.streamingStartedTime.append(streamingStarted.time) def onBatchSubmitted(self, batchSubmitted): self.batchInfosSubmitted.append(batchSubmitted.batchInfo()) @@ -530,9 +534,12 @@ def func(dstream): batchInfosSubmitted = batch_collector.batchInfosSubmitted batchInfosStarted = batch_collector.batchInfosStarted batchInfosCompleted = batch_collector.batchInfosCompleted + streamingStartedTime = batch_collector.streamingStartedTime self.wait_for(batchInfosCompleted, 4) + self.assertEqual(len(streamingStartedTime), 1) + self.assertGreaterEqual(len(batchInfosSubmitted), 4) for info in batchInfosSubmitted: self.assertGreaterEqual(info.batchTime().milliseconds(), 0) From d5bec48b9cb225c19b43935c07b24090c51cacce Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Apr 2018 11:59:17 +0900 Subject: [PATCH 646/774] [SPARK-23919][SQL] Add array_position function ## What changes were proposed in this pull request? The PR adds the SQL function `array_position`. The behavior of the function is based on Presto's one. The function returns the position of the first occurrence of the element in array x (or 0 if not found) using 1-based index as BigInt. ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki Closes #21037 from kiszk/SPARK-23919. --- python/pyspark/sql/functions.py | 17 ++++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 56 +++++++++++++++++++ .../CollectionExpressionsSuite.scala | 22 ++++++++ .../org/apache/spark/sql/functions.scala | 14 +++++ .../spark/sql/DataFrameFunctionsSuite.scala | 34 +++++++++++ 6 files changed, 144 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d3bb0a5d6b36a..36dcabc6766d8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1845,6 +1845,23 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(2.4) +def array_position(col, value): + """ + Collection function: Locates the position of the first occurrence of the given value + in the given array. Returns null if either of the arguments are null. + + .. note:: The position is not zero based, but 1 based index. Returns 0 if the given + value could not be found in the array. + + >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data']) + >>> df.select(array_position(df.data, "a")).collect() + [Row(array_position(data, a)=3), Row(array_position(data, a)=0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_position(_to_java_column(col), value)) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 38c874ad948e1..74095fe697b6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -402,6 +402,7 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArrayPosition]("array_position"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[MapKeys]("map_keys"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 76b71f5b86074..e6a05f535cb1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -505,3 +505,59 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override def prettyName: String = "array_max" } + + +/** + * Returns the position of the first occurrence of element in the given array as long. + * Returns 0 if the given value could not be found in the array. Returns null if either of + * the arguments are null + * + * NOTE: that this is not zero based, but 1-based index. The first element in the array has + * index 1. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, element) - Returns the (1-based) index of the first element of the array as long. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(3, 2, 1), 1); + 3 + """, + since = "2.4.0") +case class ArrayPosition(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = LongType + override def inputTypes: Seq[AbstractDataType] = + Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) + + override def nullSafeEval(arr: Any, value: Any): Any = { + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == value) { + return (i + 1).toLong + } + ) + 0L + } + + override def prettyName: String = "array_position" + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (arr, value) => { + val pos = ctx.freshName("arrayPosition") + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(arr, right.dataType, i) + s""" + |int $pos = 0; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, getValue)}) { + | $pos = $i + 1; + | break; + | } + |} + |${ev.value} = (long) $pos; + """.stripMargin + }) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 517639dbc7232..916cd3bb4cca5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -169,4 +169,26 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Reverse(as7), null) checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b"))) } + + test("Array Position") { + val a0 = Literal.create(Seq(1, null, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a2 = Literal.create(Seq(null), ArrayType(LongType)) + val a3 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayPosition(a0, Literal(3)), 4L) + checkEvaluation(ArrayPosition(a0, Literal(1)), 1L) + checkEvaluation(ArrayPosition(a0, Literal(0)), 0L) + checkEvaluation(ArrayPosition(a0, Literal.create(null, IntegerType)), null) + + checkEvaluation(ArrayPosition(a1, Literal("")), 2L) + checkEvaluation(ArrayPosition(a1, Literal("a")), 0L) + checkEvaluation(ArrayPosition(a1, Literal.create(null, StringType)), null) + + checkEvaluation(ArrayPosition(a2, Literal(1L)), 0L) + checkEvaluation(ArrayPosition(a2, Literal.create(null, LongType)), null) + + checkEvaluation(ArrayPosition(a3, Literal("")), null) + checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a55a800f48245..3a09ec4f1982e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3038,6 +3038,20 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Locates the position of the first occurrence of the value in the given array as long. + * Returns null if either of the arguments are null. + * + * @note The position is not zero based, but 1 based index. Returns 0 if value + * could not be found in array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_position(column: Column, value: Any): Column = withExpr { + ArrayPosition(column.expr, Literal(value)) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 74c42f2599dca..13161e7e24cfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -535,6 +535,40 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } + test("array position function") { + val df = Seq( + (Seq[Int](1, 2), "x"), + (Seq[Int](), "x") + ).toDF("a", "b") + + checkAnswer( + df.select(array_position(df("a"), 1)), + Seq(Row(1L), Row(0L)) + ) + checkAnswer( + df.selectExpr("array_position(a, 1)"), + Seq(Row(1L), Row(0L)) + ) + + checkAnswer( + df.select(array_position(df("a"), null)), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.selectExpr("array_position(a, null)"), + Seq(Row(null), Row(null)) + ) + + checkAnswer( + df.selectExpr("array_position(array(array(1), null)[0], 1)"), + Seq(Row(1L), Row(1L)) + ) + checkAnswer( + df.selectExpr("array_position(array(1, null), array(1, null)[0])"), + Seq(Row(1L), Row(1L)) + ) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 46bb2b5129833cc5829089bf1174a76cb7b81741 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Apr 2018 21:00:10 +0900 Subject: [PATCH 647/774] [SPARK-23924][SQL] Add element_at function ## What changes were proposed in this pull request? The PR adds the SQL function `element_at`. The behavior of the function is based on Presto's one. This function returns element of array at given index in value if column is array, or returns value for the given key in value if column is map. ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki Closes #21053 from kiszk/SPARK-23924. --- python/pyspark/sql/functions.py | 24 ++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 104 ++++++++++++++++++ .../expressions/complexTypeExtractors.scala | 64 +++++++---- .../CollectionExpressionsSuite.scala | 48 ++++++++ .../org/apache/spark/sql/functions.scala | 11 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 48 ++++++++ 7 files changed, 276 insertions(+), 24 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 36dcabc6766d8..1be68f2a4a448 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1862,6 +1862,30 @@ def array_position(col, value): return Column(sc._jvm.functions.array_position(_to_java_column(col), value)) +@ignore_unicode_prefix +@since(2.4) +def element_at(col, extraction): + """ + Collection function: Returns element of array at given index in extraction if col is array. + Returns value for the given key in extraction if col is map. + + :param col: name of column containing array or map + :param extraction: index to check for in array or key to check for in map + + .. note:: The position is not zero based, but 1 based index. + + >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) + >>> df.select(element_at(df.data, 1)).collect() + [Row(element_at(data, 1)=u'a'), Row(element_at(data, 1)=None)] + + >>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},), ({},)], ['data']) + >>> df.select(element_at(df.data, "a")).collect() + [Row(element_at(data, a)=1.0), Row(element_at(data, a)=None)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.element_at(_to_java_column(col), extraction)) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 74095fe697b6a..a44f2d5272b8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -405,6 +405,7 @@ object FunctionRegistry { expression[ArrayPosition]("array_position"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), + expression[ElementAt]("element_at"), expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[Size]("size"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e6a05f535cb1c..dba426e999dda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -561,3 +561,107 @@ case class ArrayPosition(left: Expression, right: Expression) }) } } + +/** + * Returns the value of index `right` in Array `left` or the value for key `right` in Map `left`. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0, + accesses elements from the last to the first. Returns NULL if the index exceeds the length + of the array. + + _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), 2); + 2 + > SELECT _FUNC_(map(1, 'a', 2, 'b'), 2); + "b" + """, + since = "2.4.0") +case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { + + override def dataType: DataType = left.dataType match { + case ArrayType(elementType, _) => elementType + case MapType(_, valueType, _) => valueType + } + + override def inputTypes: Seq[AbstractDataType] = { + Seq(TypeCollection(ArrayType, MapType), + left.dataType match { + case _: ArrayType => IntegerType + case _: MapType => left.dataType.asInstanceOf[MapType].keyType + } + ) + } + + override def nullable: Boolean = true + + override def nullSafeEval(value: Any, ordinal: Any): Any = { + left.dataType match { + case _: ArrayType => + val array = value.asInstanceOf[ArrayData] + val index = ordinal.asInstanceOf[Int] + if (array.numElements() < math.abs(index)) { + null + } else { + val idx = if (index == 0) { + throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") + } else if (index > 0) { + index - 1 + } else { + array.numElements() + index + } + if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) { + null + } else { + array.get(idx, dataType) + } + } + case _: MapType => + getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + left.dataType match { + case _: ArrayType => + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val index = ctx.freshName("elementAtIndex") + val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($eval1.isNullAt($index)) { + | ${ev.isNull} = true; + |} else + """.stripMargin + } else { + "" + } + s""" + |int $index = (int) $eval2; + |if ($eval1.numElements() < Math.abs($index)) { + | ${ev.isNull} = true; + |} else { + | if ($index == 0) { + | throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1"); + | } else if ($index > 0) { + | $index--; + | } else { + | $index += $eval1.numElements(); + | } + | $nullCheck + | { + | ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; + | } + |} + """.stripMargin + }) + case _: MapType => + doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType]) + } + } + + override def prettyName: String = "element_at" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 6cdad19168dce..3fba52d745453 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -268,31 +268,12 @@ case class GetArrayItem(child: Expression, ordinal: Expression) } /** - * Returns the value of key `key` in Map `child`. - * - * We need to do type checking here as `key` expression maybe unresolved. + * Common base class for [[GetMapValue]] and [[ElementAt]]. */ -case class GetMapValue(child: Expression, key: Expression) - extends BinaryExpression with ImplicitCastInputTypes with ExtractValue with NullIntolerant { - - private def keyType = child.dataType.asInstanceOf[MapType].keyType - - // We have done type checking for child in `ExtractValue`, so only need to check the `key`. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) - - override def toString: String = s"$child[$key]" - override def sql: String = s"${child.sql}[${key.sql}]" - - override def left: Expression = child - override def right: Expression = key - - /** `Null` is returned for invalid ordinals. */ - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType +abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { // todo: current search is O(n), improve it. - protected override def nullSafeEval(value: Any, ordinal: Any): Any = { + def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = { val map = value.asInstanceOf[MapData] val length = map.numElements() val keys = map.keyArray() @@ -315,14 +296,15 @@ case class GetMapValue(child: Expression, key: Expression) } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + def doGetValueGenCode(ctx: CodegenContext, ev: ExprCode, mapType: MapType): ExprCode = { val index = ctx.freshName("index") val length = ctx.freshName("length") val keys = ctx.freshName("keys") val found = ctx.freshName("found") val key = ctx.freshName("key") val values = ctx.freshName("values") - val nullCheck = if (child.dataType.asInstanceOf[MapType].valueContainsNull) { + val keyType = mapType.keyType + val nullCheck = if (mapType.valueContainsNull) { s" || $values.isNullAt($index)" } else { "" @@ -354,3 +336,37 @@ case class GetMapValue(child: Expression, key: Expression) }) } } + +/** + * Returns the value of key `key` in Map `child`. + * + * We need to do type checking here as `key` expression maybe unresolved. + */ +case class GetMapValue(child: Expression, key: Expression) + extends GetMapValueUtil with ExtractValue with NullIntolerant { + + private def keyType = child.dataType.asInstanceOf[MapType].keyType + + // We have done type checking for child in `ExtractValue`, so only need to check the `key`. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) + + override def toString: String = s"$child[$key]" + override def sql: String = s"${child.sql}[${key.sql}]" + + override def left: Expression = child + override def right: Expression = key + + /** `Null` is returned for invalid ordinals. */ + override def nullable: Boolean = true + + override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType + + // todo: current search is O(n), improve it. + override def nullSafeEval(value: Any, ordinal: Any): Any = { + getValueEval(value, ordinal, keyType) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + doGetValueGenCode(ctx, ev, child.dataType.asInstanceOf[MapType]) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 916cd3bb4cca5..7d8fe211858b2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -191,4 +191,52 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayPosition(a3, Literal("")), null) checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null) } + + test("elementAt") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a2 = Literal.create(Seq(null), ArrayType(LongType)) + val a3 = Literal.create(null, ArrayType(StringType)) + + intercept[Exception] { + checkEvaluation(ElementAt(a0, Literal(0)), null) + }.getMessage.contains("SQL array indices start at 1") + intercept[Exception] { checkEvaluation(ElementAt(a0, Literal(1.1)), null) } + checkEvaluation(ElementAt(a0, Literal(4)), null) + checkEvaluation(ElementAt(a0, Literal(-4)), null) + + checkEvaluation(ElementAt(a0, Literal(1)), 1) + checkEvaluation(ElementAt(a0, Literal(2)), 2) + checkEvaluation(ElementAt(a0, Literal(3)), 3) + checkEvaluation(ElementAt(a0, Literal(-3)), 1) + checkEvaluation(ElementAt(a0, Literal(-2)), 2) + checkEvaluation(ElementAt(a0, Literal(-1)), 3) + + checkEvaluation(ElementAt(a1, Literal(1)), null) + checkEvaluation(ElementAt(a1, Literal(2)), "") + checkEvaluation(ElementAt(a1, Literal(-2)), null) + checkEvaluation(ElementAt(a1, Literal(-1)), "") + + checkEvaluation(ElementAt(a2, Literal(1)), null) + + checkEvaluation(ElementAt(a3, Literal(1)), null) + + + val m0 = + Literal.create(Map("a" -> "1", "b" -> "2", "c" -> null), MapType(StringType, StringType)) + val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) + val m2 = Literal.create(null, MapType(StringType, StringType)) + + checkEvaluation(ElementAt(m0, Literal(1.0)), null) + + checkEvaluation(ElementAt(m0, Literal("d")), null) + + checkEvaluation(ElementAt(m1, Literal("a")), null) + + checkEvaluation(ElementAt(m0, Literal("a")), "1") + checkEvaluation(ElementAt(m0, Literal("b")), "2") + checkEvaluation(ElementAt(m0, Literal("c")), null) + + checkEvaluation(ElementAt(m2, Literal("a")), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3a09ec4f1982e..9c8580378303e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3052,6 +3052,17 @@ object functions { ArrayPosition(column.expr, Literal(value)) } + /** + * Returns element of array at given index in value if column is array. Returns value for + * the given key in value if column is map. + * + * @group collection_funcs + * @since 2.4.0 + */ + def element_at(column: Column, value: Any): Column = withExpr { + ElementAt(column.expr, Literal(value)) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 13161e7e24cfe..7c976c1b7f915 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -569,6 +569,54 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("element_at function") { + val df = Seq( + (Seq[String]("1", "2", "3")), + (Seq[String](null, "")), + (Seq[String]()) + ).toDF("a") + + intercept[Exception] { + checkAnswer( + df.select(element_at(df("a"), 0)), + Seq(Row(null), Row(null), Row(null)) + ) + }.getMessage.contains("SQL array indices start at 1") + intercept[Exception] { + checkAnswer( + df.select(element_at(df("a"), 1.1)), + Seq(Row(null), Row(null), Row(null)) + ) + } + checkAnswer( + df.select(element_at(df("a"), 4)), + Seq(Row(null), Row(null), Row(null)) + ) + + checkAnswer( + df.select(element_at(df("a"), 1)), + Seq(Row("1"), Row(null), Row(null)) + ) + checkAnswer( + df.select(element_at(df("a"), -1)), + Seq(Row("3"), Row(""), Row(null)) + ) + + checkAnswer( + df.selectExpr("element_at(a, 4)"), + Seq(Row(null), Row(null), Row(null)) + ) + + checkAnswer( + df.selectExpr("element_at(a, 1)"), + Seq(Row("1"), Row(null), Row(null)) + ) + checkAnswer( + df.selectExpr("element_at(a, -1)"), + Seq(Row("3"), Row(""), Row(null)) + ) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 1b08c4393cf48e21fea9914d130d8d3bf544061d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 19 Apr 2018 14:38:26 +0200 Subject: [PATCH 648/774] [SPARK-23584][SQL] NewInstance should support interpreted execution ## What changes were proposed in this pull request? This pr supported interpreted mode for `NewInstance`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Takeshi Yamamuro Closes #20778 from maropu/SPARK-23584. --- .../spark/sql/catalyst/ScalaReflection.scala | 13 +++++++ .../expressions/objects/objects.scala | 28 +++++++++++++-- .../expressions/ObjectExpressionsSuite.scala | 36 +++++++++++++++++++ 3 files changed, 75 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index e4274aaa9727e..818cc2fb1e8a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.catalyst +import java.lang.reflect.Constructor + +import org.apache.commons.lang3.reflect.ConstructorUtils + import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ @@ -781,6 +785,15 @@ object ScalaReflection extends ScalaReflection { } } + /** + * Finds an accessible constructor with compatible parameters. This is a more flexible search + * than the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible + * matching constructor is returned. Otherwise, it returns `None`. + */ + def findConstructor(cls: Class[_], paramTypes: Seq[Class[_]]): Option[Constructor[_]] = { + Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*)) + } + /** * Whether the fields of the given type is defined entirely by its constructor parameters. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 72b202b3a5020..1645bd7d57b1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -449,8 +449,32 @@ case class NewInstance( childrenResolved && !needOuterPointer } - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + @transient private lazy val constructor: (Seq[AnyRef]) => Any = { + val paramTypes = ScalaReflection.expressionJavaClasses(arguments) + val getConstructor = (paramClazz: Seq[Class[_]]) => { + ScalaReflection.findConstructor(cls, paramClazz).getOrElse { + sys.error(s"Couldn't find a valid constructor on $cls") + } + } + outerPointer.map { p => + val outerObj = p() + val d = outerObj.getClass +: paramTypes + val c = getConstructor(outerObj.getClass +: paramTypes) + (args: Seq[AnyRef]) => { + c.newInstance(outerObj +: args: _*) + } + }.getOrElse { + val c = getConstructor(paramTypes) + (args: Seq[AnyRef]) => { + c.newInstance(args: _*) + } + } + } + + override def eval(input: InternalRow): Any = { + val argValues = arguments.map(_.eval(input)) + constructor(argValues.map(_.asInstanceOf[AnyRef])) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = CodeGenerator.javaType(dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index b0188b0098def..bf805f4f29ac5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -47,6 +47,20 @@ class InvokeTargetSubClass extends InvokeTargetClass { override def binOp(e1: Int, e2: Double): Double = e1 - e2 } +// Tests for NewInstance +class Outer extends Serializable { + class Inner(val value: Int) { + override def hashCode(): Int = super.hashCode() + override def equals(other: Any): Boolean = { + if (other.isInstanceOf[Inner]) { + value == other.asInstanceOf[Inner].value + } else { + false + } + } + } +} + class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-16622: The returned value of the called method in Invoke can be null") { @@ -383,6 +397,27 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-23584 NewInstance should support interpreted execution") { + // Normal case test + val newInst1 = NewInstance( + cls = classOf[GenericArrayData], + arguments = Literal.fromObject(List(1, 2, 3)) :: Nil, + propagateNull = false, + dataType = ArrayType(IntegerType), + outerPointer = None) + checkObjectExprEvaluation(newInst1, new GenericArrayData(List(1, 2, 3))) + + // Inner class case test + val outerObj = new Outer() + val newInst2 = NewInstance( + cls = classOf[outerObj.Inner], + arguments = Literal(1) :: Nil, + propagateNull = false, + dataType = ObjectType(classOf[outerObj.Inner]), + outerPointer = Some(() => outerObj)) + checkObjectExprEvaluation(newInst2, new outerObj.Inner(1)) + } + test("LambdaVariable should support interpreted execution") { def genSchema(dt: DataType): Seq[StructType] = { Seq(StructType(StructField("col_1", dt, nullable = false) :: Nil), @@ -421,6 +456,7 @@ class TestBean extends Serializable { private var x: Int = 0 def setX(i: Int): Unit = x = i + def setNonPrimitive(i: AnyRef): Unit = assert(i != null, "this setter should not be called with null.") } From e13416502f814b04d59bb650953a0114332d163a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 19 Apr 2018 14:42:50 +0200 Subject: [PATCH 649/774] [SPARK-23588][SQL] CatalystToExternalMap should support interpreted execution ## What changes were proposed in this pull request? This pr supported interpreted mode for `CatalystToExternalMap`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Takeshi Yamamuro Closes #20979 from maropu/SPARK-23588. --- .../expressions/objects/objects.scala | 39 +++++++++++++++++-- .../expressions/ObjectExpressionsSuite.scala | 34 +++++++++++++--- 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 1645bd7d57b1d..bc17d1229420a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -28,12 +28,12 @@ import scala.util.Try import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -1033,8 +1033,39 @@ case class CatalystToExternalMap private( override def children: Seq[Expression] = keyLambdaFunction :: valueLambdaFunction :: inputData :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + private lazy val inputMapType = inputData.dataType.asInstanceOf[MapType] + + private lazy val keyConverter = + CatalystTypeConverters.createToScalaConverter(inputMapType.keyType) + private lazy val valueConverter = + CatalystTypeConverters.createToScalaConverter(inputMapType.valueType) + + private def newMapBuilder(): Builder[AnyRef, AnyRef] = { + val clazz = Utils.classForName(collClass.getCanonicalName + "$") + val module = clazz.getField("MODULE$").get(null) + val method = clazz.getMethod("newBuilder") + method.invoke(module).asInstanceOf[Builder[AnyRef, AnyRef]] + } + + override def eval(input: InternalRow): Any = { + val result = inputData.eval(input).asInstanceOf[MapData] + if (result != null) { + val builder = newMapBuilder() + builder.sizeHint(result.numElements()) + val keyArray = result.keyArray() + val valueArray = result.valueArray() + var i = 0 + while (i < result.numElements()) { + val key = keyConverter(keyArray.get(i, inputMapType.keyType)) + val value = valueConverter(valueArray.get(i, inputMapType.valueType)) + builder += Tuple2(key, value) + i += 1 + } + builder.result() + } else { + null + } + } override def dataType: DataType = ObjectType(collClass) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index bf805f4f29ac5..bcd035c1eba0b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -27,12 +27,14 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone -import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer, UnresolvedDeserializer} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{SQLDate, SQLTimestamp} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -162,9 +164,10 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { "fromPrimitiveArray", ObjectType(classOf[Array[Int]]), Array[Int](1, 2, 3), UnsafeArrayData.fromPrimitiveArray(Array[Int](1, 2, 3))), (DateTimeUtils.getClass, ObjectType(classOf[Date]), - "toJavaDate", ObjectType(classOf[SQLDate]), 77777, DateTimeUtils.toJavaDate(77777)), + "toJavaDate", ObjectType(classOf[DateTimeUtils.SQLDate]), 77777, + DateTimeUtils.toJavaDate(77777)), (DateTimeUtils.getClass, ObjectType(classOf[Timestamp]), - "toJavaTimestamp", ObjectType(classOf[SQLTimestamp]), + "toJavaTimestamp", ObjectType(classOf[DateTimeUtils.SQLTimestamp]), 88888888.toLong, DateTimeUtils.toJavaTimestamp(88888888)) ).foreach { case (cls, dataType, methodName, argType, arg, expected) => checkObjectExprEvaluation(StaticInvoke(cls, dataType, methodName, @@ -450,6 +453,25 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + + implicit private def mapIntStrEncoder = ExpressionEncoder[Map[Int, String]]() + + test("SPARK-23588 CatalystToExternalMap should support interpreted execution") { + // To get a resolved `CatalystToExternalMap` expression, we build a deserializer plan + // with dummy input, resolve the plan by the analyzer, and replace the dummy input + // with a literal for tests. + val unresolvedDeser = UnresolvedDeserializer(encoderFor[Map[Int, String]].deserializer) + val dummyInputPlan = LocalRelation('value.map(MapType(IntegerType, StringType))) + val plan = Project(Alias(unresolvedDeser, "none")() :: Nil, dummyInputPlan) + + val analyzedPlan = SimpleAnalyzer.execute(plan) + val Alias(toMapExpr: CatalystToExternalMap, _) = analyzedPlan.expressions.head + + // Replaces the dummy input with a literal for tests here + val data = Map[Int, String](0 -> "v0", 1 -> "v1", 2 -> null, 3 -> "v3") + val deserializer = toMapExpr.copy(inputData = Literal.create(data)) + checkObjectExprEvaluation(deserializer, expected = data) + } } class TestBean extends Serializable { From 9e10f69df52abde2de5d93435bab54e97dd59d9c Mon Sep 17 00:00:00 2001 From: jinxing Date: Thu, 19 Apr 2018 21:07:21 +0800 Subject: [PATCH 650/774] [SPARK-22676][FOLLOW-UP] fix code style for test. ## What changes were proposed in this pull request? This pr address comments in https://github.com/apache/spark/pull/19868 ; Fix the code style for `org.apache.spark.sql.hive.QueryPartitionSuite` by using: `withTempView`, `withTempDir`, `withTable`... Author: jinxing Closes #21091 from jinxing64/SPARK-22676-FOLLOW-UP. --- .../spark/sql/hive/QueryPartitionSuite.scala | 109 +++++++----------- 1 file changed, 41 insertions(+), 68 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 78156b17fb43b..1e396553c9c52 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -33,80 +33,53 @@ import org.apache.spark.util.Utils class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import spark.implicits._ - test("SPARK-5068: query data when path doesn't exist") { - withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "true")) { - val testData = sparkContext.parallelize( - (1 to 10).map(i => TestData(i, i.toString))).toDF() - testData.createOrReplaceTempView("testData") - - val tmpDir = Files.createTempDir() - // create the table for test - sql(s"CREATE TABLE table_with_partition(key int,value string) " + - s"PARTITIONED by (ds string) location '${tmpDir.toURI}' ") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + - "SELECT key,value FROM testData") - - // test for the exist path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect - ++ testData.toDF.collect ++ testData.toDF.collect) - - // delete the path of one partition - tmpDir.listFiles - .find { f => f.isDirectory && f.getName().startsWith("ds=") } - .foreach { f => Utils.deleteRecursively(f) } - - // test for after delete the path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) + private def queryWhenPathNotExist(): Unit = { + withTempView("testData") { + withTable("table_with_partition", "createAndInsertTest") { + withTempDir { tmpDir => + val testData = sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.createOrReplaceTempView("testData") + + // create the table for test + sql(s"CREATE TABLE table_with_partition(key int,value string) " + + s"PARTITIONED by (ds string) location '${tmpDir.toURI}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + + "SELECT key,value FROM testData") + + // test for the exist path + checkAnswer(sql("select key,value from table_with_partition"), + testData.union(testData).union(testData).union(testData)) + + // delete the path of one partition + tmpDir.listFiles + .find { f => f.isDirectory && f.getName().startsWith("ds=") } + .foreach { f => Utils.deleteRecursively(f) } + + // test for after delete the path + checkAnswer(sql("select key,value from table_with_partition"), + testData.union(testData).union(testData)) + } + } + } + } - sql("DROP TABLE IF EXISTS table_with_partition") - sql("DROP TABLE IF EXISTS createAndInsertTest") + test("SPARK-5068: query data when path doesn't exist") { + withSQLConf(SQLConf.HIVE_VERIFY_PARTITION_PATH.key -> "true") { + queryWhenPathNotExist() } } test("Replace spark.sql.hive.verifyPartitionPath by spark.files.ignoreMissingFiles") { - withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "false")) { + withSQLConf(SQLConf.HIVE_VERIFY_PARTITION_PATH.key -> "false") { sparkContext.conf.set(IGNORE_MISSING_FILES.key, "true") - val testData = sparkContext.parallelize( - (1 to 10).map(i => TestData(i, i.toString))).toDF() - testData.createOrReplaceTempView("testData") - - val tmpDir = Files.createTempDir() - // create the table for test - sql(s"CREATE TABLE table_with_partition(key int,value string) " + - s"PARTITIONED by (ds string) location '${tmpDir.toURI}' ") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + - "SELECT key,value FROM testData") - - // test for the exist path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect - ++ testData.toDF.collect ++ testData.toDF.collect) - - // delete the path of one partition - tmpDir.listFiles - .find { f => f.isDirectory && f.getName().startsWith("ds=") } - .foreach { f => Utils.deleteRecursively(f) } - - // test for after delete the path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) - - sql("DROP TABLE IF EXISTS table_with_partition") - sql("DROP TABLE IF EXISTS createAndInsertTest") + queryWhenPathNotExist() } } From d96c3e33cc2a95de8e15e1a2ddf50a8d0cc66dd2 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 19 Apr 2018 21:21:22 +0800 Subject: [PATCH 651/774] [SPARK-21811][SQL] Fix the inconsistency behavior when finding the widest common type ## What changes were proposed in this pull request? Currently we find the wider common type by comparing the two types from left to right, this can be a problem when you have two data types which don't have a common type but each can be promoted to StringType. For instance, if you have a table with the schema: [c1: date, c2: string, c3: int] The following succeeds: SELECT coalesce(c1, c2, c3) FROM table While the following produces an exception: SELECT coalesce(c1, c3, c2) FROM table This is only a issue when the seq of dataTypes contains `StringType` and all the types can do string promotion. close #19033 ## How was this patch tested? Add test in `TypeCoercionSuite` Author: Xingbo Jiang Closes #21074 from jiangxb1987/typeCoercion. --- docs/sql-programming-guide.md | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 24 +++++++++++++++---- .../catalyst/analysis/TypeCoercionSuite.scala | 13 ++++++++++ 3 files changed, 34 insertions(+), 5 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 55d35b9dd31db..e8ff1470970f7 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1810,7 +1810,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. - + - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index ec7e7761dc4c2..281f206e8d59e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -175,11 +175,27 @@ object TypeCoercion { }) } + /** + * Whether the data type contains StringType. + */ + def hasStringType(dt: DataType): Boolean = dt match { + case StringType => true + case ArrayType(et, _) => hasStringType(et) + // Add StructType if we support string promotion for struct fields in the future. + case _ => false + } + private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { - types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findWiderTypeForTwo(d, c) - case None => None - }) + // findWiderTypeForTwo doesn't satisfy the associative law, i.e. (a op b) op c may not equal + // to a op (b op c). This is only a problem for StringType or nested StringType in ArrayType. + // Excluding these types, findWiderTypeForTwo satisfies the associative law. For instance, + // (TimestampType, IntegerType, StringType) should have StringType as the wider common type. + val (stringTypes, nonStringTypes) = types.partition(hasStringType(_)) + (stringTypes.distinct ++ nonStringTypes).foldLeft[Option[DataType]](Some(NullType))((r, c) => + r match { + case Some(d) => findWiderTypeForTwo(d, c) + case _ => None + }) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 8ac49dc05e3cf..fd6a3121663ed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -539,6 +539,9 @@ class TypeCoercionSuite extends AnalysisTest { val floatLit = Literal.create(1.0f, FloatType) val timestampLit = Literal.create("2017-04-12", TimestampType) val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000")) + val tsArrayLit = Literal(Array(new Timestamp(System.currentTimeMillis()))) + val strArrayLit = Literal(Array("c")) + val intArrayLit = Literal(Array(1)) ruleTest(rule, Coalesce(Seq(doubleLit, intLit, floatLit)), @@ -572,6 +575,16 @@ class TypeCoercionSuite extends AnalysisTest { Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit)), Coalesce(Seq(Cast(nullLit, StringType), Cast(floatNullLit, StringType), Cast(doubleLit, StringType), Cast(stringLit, StringType)))) + + ruleTest(rule, + Coalesce(Seq(timestampLit, intLit, stringLit)), + Coalesce(Seq(Cast(timestampLit, StringType), Cast(intLit, StringType), + Cast(stringLit, StringType)))) + + ruleTest(rule, + Coalesce(Seq(tsArrayLit, intArrayLit, strArrayLit)), + Coalesce(Seq(Cast(tsArrayLit, ArrayType(StringType)), + Cast(intArrayLit, ArrayType(StringType)), Cast(strArrayLit, ArrayType(StringType))))) } test("CreateArray casts") { From 0deaa5251326a32a3d2d2b8851193ca926303972 Mon Sep 17 00:00:00 2001 From: wuyi Date: Thu, 19 Apr 2018 09:00:33 -0500 Subject: [PATCH 652/774] [SPARK-24021][CORE] fix bug in BlacklistTracker's updateBlacklistForFetchFailure MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? There‘s a miswrite in BlacklistTracker's updateBlacklistForFetchFailure: ``` val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(exec, HashSet[String]()) blacklistedExecsOnNode += exec ``` where first **exec** should be **host**. ## How was this patch tested? adjust existed test. Author: wuyi Closes #21104 from Ngone51/SPARK-24021. --- .../scala/org/apache/spark/scheduler/BlacklistTracker.scala | 2 +- .../org/apache/spark/scheduler/BlacklistTrackerSuite.scala | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index 952598f6de19d..30cf75d43ee09 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -210,7 +210,7 @@ private[scheduler] class BlacklistTracker ( updateNextExpiryTime() killBlacklistedExecutor(exec) - val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(exec, HashSet[String]()) + val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(host, HashSet[String]()) blacklistedExecsOnNode += exec } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index 06d7afaaff55c..96c8404327e24 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -574,6 +574,9 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M verify(allocationClientMock, never).killExecutors(any(), any(), any(), any()) verify(allocationClientMock, never).killExecutorsOnHost(any()) + assert(blacklist.nodeToBlacklistedExecs.contains("hostA")) + assert(blacklist.nodeToBlacklistedExecs("hostA").contains("1")) + // Enable auto-kill. Blacklist an executor and make sure killExecutors is called. conf.set(config.BLACKLIST_KILL_ENABLED, true) blacklist = new BlacklistTracker(listenerBusMock, conf, Some(allocationClientMock), clock) @@ -589,6 +592,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) assert(blacklist.nextExpiryTime === 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) assert(blacklist.nodeIdToBlacklistExpiryTime.isEmpty) + assert(blacklist.nodeToBlacklistedExecs.contains("hostA")) + assert(blacklist.nodeToBlacklistedExecs("hostA").contains("1")) // Enable external shuffle service to see if all the executors on this node will be killed. conf.set(config.SHUFFLE_SERVICE_ENABLED, true) From 6e19f7683fc73fabe7cdaac4eb1982d2e3e607b7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 19 Apr 2018 17:54:53 +0200 Subject: [PATCH 653/774] [SPARK-23989][SQL] exchange should copy data before non-serialized shuffle ## What changes were proposed in this pull request? In Spark SQL, we usually reuse the `UnsafeRow` instance and need to copy the data when a place buffers non-serialized objects. Shuffle may buffer objects if we don't make it to the bypass merge shuffle or unsafe shuffle. `ShuffleExchangeExec.needToCopyObjectsBeforeShuffle` misses the case that, if `spark.sql.shuffle.partitions` is large enough, we could fail to run unsafe shuffle and go with the non-serialized shuffle. This bug is very hard to hit since users wouldn't set such a large number of partitions(16 million) for Spark SQL exchange. TODO: test ## How was this patch tested? todo. Author: Wenchen Fan Closes #21101 from cloud-fan/shuffle. --- .../exchange/ShuffleExchangeExec.scala | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 4d95ee34f30de..b89203719541b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -153,12 +153,9 @@ object ShuffleExchangeExec { * See SPARK-2967, SPARK-4479, and SPARK-7375 for more discussion of this issue. * * @param partitioner the partitioner for the shuffle - * @param serializer the serializer that will be used to write rows * @return true if rows should be copied before being shuffled, false otherwise */ - private def needToCopyObjectsBeforeShuffle( - partitioner: Partitioner, - serializer: Serializer): Boolean = { + private def needToCopyObjectsBeforeShuffle(partitioner: Partitioner): Boolean = { // Note: even though we only use the partitioner's `numPartitions` field, we require it to be // passed instead of directly passing the number of partitions in order to guard against // corner-cases where a partitioner constructed with `numPartitions` partitions may output @@ -167,22 +164,24 @@ object ShuffleExchangeExec { val shuffleManager = SparkEnv.get.shuffleManager val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + val numParts = partitioner.numPartitions if (sortBasedShuffleOn) { - val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] - if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) { + if (numParts <= bypassMergeThreshold) { // If we're using the original SortShuffleManager and the number of output partitions is // sufficiently small, then Spark will fall back to the hash-based shuffle write path, which // doesn't buffer deserialized records. // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass. false - } else if (serializer.supportsRelocationOfSerializedObjects) { + } else if (numParts <= SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) { // SPARK-4550 and SPARK-7081 extended sort-based shuffle to serialize individual records // prior to sorting them. This optimization is only applied in cases where shuffle // dependency does not specify an aggregator or ordering and the record serializer has - // certain properties. If this optimization is enabled, we can safely avoid the copy. + // certain properties and the number of partitions doesn't exceed the limitation. If this + // optimization is enabled, we can safely avoid the copy. // - // Exchange never configures its ShuffledRDDs with aggregators or key orderings, so we only - // need to check whether the optimization is enabled and supported by our serializer. + // Exchange never configures its ShuffledRDDs with aggregators or key orderings, and the + // serializer in Spark SQL always satisfy the properties, so we only need to check whether + // the number of partitions exceeds the limitation. false } else { // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory, so we must @@ -298,7 +297,7 @@ object ShuffleExchangeExec { rdd } - if (needToCopyObjectsBeforeShuffle(part, serializer)) { + if (needToCopyObjectsBeforeShuffle(part)) { newRdd.mapPartitionsInternal { iter => val getPartitionKey = getPartitionKeyExtractor() iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } From a471880afbeafd4ef54c15a97e72ea7ff784a88d Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Thu, 19 Apr 2018 09:40:20 -0700 Subject: [PATCH 654/774] [SPARK-24026][ML] Add Power Iteration Clustering to spark.ml ## What changes were proposed in this pull request? This PR adds PowerIterationClustering as a Transformer to spark.ml. In the transform method, it calls spark.mllib's PowerIterationClustering.run() method and transforms the return value assignments (the Kmeans output of the pseudo-eigenvector) as a DataFrame (id: LongType, cluster: IntegerType). This PR is copied and modified from https://github.com/apache/spark/pull/15770 The primary author is wangmiao1981 ## How was this patch tested? This PR has 2 types of tests: * Copies of tests from spark.mllib's PIC tests * New tests specific to the spark.ml APIs Author: wm624@hotmail.com Author: wangmiao1981 Author: Joseph K. Bradley Closes #21090 from jkbradley/wangmiao1981-pic. --- .../clustering/PowerIterationClustering.scala | 256 ++++++++++++++++++ .../PowerIterationClusteringSuite.scala | 238 ++++++++++++++++ 2 files changed, 494 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala new file mode 100644 index 0000000000000..2c30a1d9aa947 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.clustering.{PowerIterationClustering => MLlibPowerIterationClustering} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ + +/** + * Common params for PowerIterationClustering + */ +private[clustering] trait PowerIterationClusteringParams extends Params with HasMaxIter + with HasPredictionCol { + + /** + * The number of clusters to create (k). Must be > 1. Default: 2. + * @group param + */ + @Since("2.4.0") + final val k = new IntParam(this, "k", "The number of clusters to create. " + + "Must be > 1.", ParamValidators.gt(1)) + + /** @group getParam */ + @Since("2.4.0") + def getK: Int = $(k) + + /** + * Param for the initialization algorithm. This can be either "random" to use a random vector + * as vertex properties, or "degree" to use a normalized sum of similarities with other vertices. + * Default: random. + * @group expertParam + */ + @Since("2.4.0") + final val initMode = { + val allowedParams = ParamValidators.inArray(Array("random", "degree")) + new Param[String](this, "initMode", "The initialization algorithm. This can be either " + + "'random' to use a random vector as vertex properties, or 'degree' to use a normalized sum " + + "of similarities with other vertices. Supported options: 'random' and 'degree'.", + allowedParams) + } + + /** @group expertGetParam */ + @Since("2.4.0") + def getInitMode: String = $(initMode) + + /** + * Param for the name of the input column for vertex IDs. + * Default: "id" + * @group param + */ + @Since("2.4.0") + val idCol = new Param[String](this, "idCol", "Name of the input column for vertex IDs.", + (value: String) => value.nonEmpty) + + setDefault(idCol, "id") + + /** @group getParam */ + @Since("2.4.0") + def getIdCol: String = getOrDefault(idCol) + + /** + * Param for the name of the input column for neighbors in the adjacency list representation. + * Default: "neighbors" + * @group param + */ + @Since("2.4.0") + val neighborsCol = new Param[String](this, "neighborsCol", + "Name of the input column for neighbors in the adjacency list representation.", + (value: String) => value.nonEmpty) + + setDefault(neighborsCol, "neighbors") + + /** @group getParam */ + @Since("2.4.0") + def getNeighborsCol: String = $(neighborsCol) + + /** + * Param for the name of the input column for neighbors in the adjacency list representation. + * Default: "similarities" + * @group param + */ + @Since("2.4.0") + val similaritiesCol = new Param[String](this, "similaritiesCol", + "Name of the input column for neighbors in the adjacency list representation.", + (value: String) => value.nonEmpty) + + setDefault(similaritiesCol, "similarities") + + /** @group getParam */ + @Since("2.4.0") + def getSimilaritiesCol: String = $(similaritiesCol) + + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnTypes(schema, $(idCol), Seq(IntegerType, LongType)) + SchemaUtils.checkColumnTypes(schema, $(neighborsCol), + Seq(ArrayType(IntegerType, containsNull = false), + ArrayType(LongType, containsNull = false))) + SchemaUtils.checkColumnTypes(schema, $(similaritiesCol), + Seq(ArrayType(FloatType, containsNull = false), + ArrayType(DoubleType, containsNull = false))) + SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) + } +} + +/** + * :: Experimental :: + * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by + * Lin and Cohen. From the abstract: + * PIC finds a very low-dimensional embedding of a dataset using truncated power + * iteration on a normalized pair-wise similarity matrix of the data. + * + * PIC takes an affinity matrix between items (or vertices) as input. An affinity matrix + * is a symmetric matrix whose entries are non-negative similarities between items. + * PIC takes this matrix (or graph) as an adjacency matrix. Specifically, each input row includes: + * - `idCol`: vertex ID + * - `neighborsCol`: neighbors of vertex in `idCol` + * - `similaritiesCol`: non-negative weights (similarities) of edges between the vertex + * in `idCol` and each neighbor in `neighborsCol` + * PIC returns a cluster assignment for each input vertex. It appends a new column `predictionCol` + * containing the cluster assignment in `[0,k)` for each row (vertex). + * + * Notes: + * - [[PowerIterationClustering]] is a transformer with an expensive [[transform]] operation. + * Transform runs the iterative PIC algorithm to cluster the whole input dataset. + * - Input validation: This validates that similarities are non-negative but does NOT validate + * that the input matrix is symmetric. + * + * @see + * Spectral clustering (Wikipedia) + */ +@Since("2.4.0") +@Experimental +class PowerIterationClustering private[clustering] ( + @Since("2.4.0") override val uid: String) + extends Transformer with PowerIterationClusteringParams with DefaultParamsWritable { + + setDefault( + k -> 2, + maxIter -> 20, + initMode -> "random") + + @Since("2.4.0") + def this() = this(Identifiable.randomUID("PowerIterationClustering")) + + /** @group setParam */ + @Since("2.4.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("2.4.0") + def setK(value: Int): this.type = set(k, value) + + /** @group expertSetParam */ + @Since("2.4.0") + def setInitMode(value: String): this.type = set(initMode, value) + + /** @group setParam */ + @Since("2.4.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + @Since("2.4.0") + def setIdCol(value: String): this.type = set(idCol, value) + + /** @group setParam */ + @Since("2.4.0") + def setNeighborsCol(value: String): this.type = set(neighborsCol, value) + + /** @group setParam */ + @Since("2.4.0") + def setSimilaritiesCol(value: String): this.type = set(similaritiesCol, value) + + @Since("2.4.0") + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + + val sparkSession = dataset.sparkSession + val idColValue = $(idCol) + val rdd: RDD[(Long, Long, Double)] = + dataset.select( + col($(idCol)).cast(LongType), + col($(neighborsCol)).cast(ArrayType(LongType, containsNull = false)), + col($(similaritiesCol)).cast(ArrayType(DoubleType, containsNull = false)) + ).rdd.flatMap { + case Row(id: Long, nbrs: Seq[_], sims: Seq[_]) => + require(nbrs.size == sims.size, s"The length of the neighbor ID list must be " + + s"equal to the the length of the neighbor similarity list. Row for ID " + + s"$idColValue=$id has neighbor ID list of length ${nbrs.length} but similarity list " + + s"of length ${sims.length}.") + nbrs.asInstanceOf[Seq[Long]].zip(sims.asInstanceOf[Seq[Double]]).map { + case (nbr, similarity) => (id, nbr, similarity) + } + } + val algorithm = new MLlibPowerIterationClustering() + .setK($(k)) + .setInitializationMode($(initMode)) + .setMaxIterations($(maxIter)) + val model = algorithm.run(rdd) + + val predictionsRDD: RDD[Row] = model.assignments.map { assignment => + Row(assignment.id, assignment.cluster) + } + + val predictionsSchema = StructType(Seq( + StructField($(idCol), LongType, nullable = false), + StructField($(predictionCol), IntegerType, nullable = false))) + val predictions = { + val uncastPredictions = sparkSession.createDataFrame(predictionsRDD, predictionsSchema) + dataset.schema($(idCol)).dataType match { + case _: LongType => + uncastPredictions + case otherType => + uncastPredictions.select(col($(idCol)).cast(otherType).alias($(idCol))) + } + } + + dataset.join(predictions, $(idCol)) + } + + @Since("2.4.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + @Since("2.4.0") + override def copy(extra: ParamMap): PowerIterationClustering = defaultCopy(extra) +} + +@Since("2.4.0") +object PowerIterationClustering extends DefaultParamsReadable[PowerIterationClustering] { + + @Since("2.4.0") + override def load(path: String): PowerIterationClustering = super.load(path) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala new file mode 100644 index 0000000000000..65328df17baff --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import scala.collection.mutable + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ + + +class PowerIterationClusteringSuite extends SparkFunSuite + with MLlibTestSparkContext with DefaultReadWriteTest { + + @transient var data: Dataset[_] = _ + final val r1 = 1.0 + final val n1 = 10 + final val r2 = 4.0 + final val n2 = 40 + + override def beforeAll(): Unit = { + super.beforeAll() + + data = PowerIterationClusteringSuite.generatePICData(spark, r1, r2, n1, n2) + } + + test("default parameters") { + val pic = new PowerIterationClustering() + + assert(pic.getK === 2) + assert(pic.getMaxIter === 20) + assert(pic.getInitMode === "random") + assert(pic.getPredictionCol === "prediction") + assert(pic.getIdCol === "id") + assert(pic.getNeighborsCol === "neighbors") + assert(pic.getSimilaritiesCol === "similarities") + } + + test("parameter validation") { + intercept[IllegalArgumentException] { + new PowerIterationClustering().setK(1) + } + intercept[IllegalArgumentException] { + new PowerIterationClustering().setInitMode("no_such_a_mode") + } + intercept[IllegalArgumentException] { + new PowerIterationClustering().setIdCol("") + } + intercept[IllegalArgumentException] { + new PowerIterationClustering().setNeighborsCol("") + } + intercept[IllegalArgumentException] { + new PowerIterationClustering().setSimilaritiesCol("") + } + } + + test("power iteration clustering") { + val n = n1 + n2 + + val model = new PowerIterationClustering() + .setK(2) + .setMaxIter(40) + val result = model.transform(data) + + val predictions = Array.fill(2)(mutable.Set.empty[Long]) + result.select("id", "prediction").collect().foreach { + case Row(id: Long, cluster: Integer) => predictions(cluster) += id + } + assert(predictions.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) + + val result2 = new PowerIterationClustering() + .setK(2) + .setMaxIter(10) + .setInitMode("degree") + .transform(data) + val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) + result2.select("id", "prediction").collect().foreach { + case Row(id: Long, cluster: Integer) => predictions2(cluster) += id + } + assert(predictions2.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) + } + + test("supported input types") { + val model = new PowerIterationClustering() + .setK(2) + .setMaxIter(1) + + def runTest(idType: DataType, neighborType: DataType, similarityType: DataType): Unit = { + val typedData = data.select( + col("id").cast(idType).alias("id"), + col("neighbors").cast(ArrayType(neighborType, containsNull = false)).alias("neighbors"), + col("similarities").cast(ArrayType(similarityType, containsNull = false)) + .alias("similarities") + ) + model.transform(typedData).collect() + } + + for (idType <- Seq(IntegerType, LongType)) { + runTest(idType, LongType, DoubleType) + } + for (neighborType <- Seq(IntegerType, LongType)) { + runTest(LongType, neighborType, DoubleType) + } + for (similarityType <- Seq(FloatType, DoubleType)) { + runTest(LongType, LongType, similarityType) + } + } + + test("invalid input: wrong types") { + val model = new PowerIterationClustering() + .setK(2) + .setMaxIter(1) + intercept[IllegalArgumentException] { + val typedData = data.select( + col("id").cast(DoubleType).alias("id"), + col("neighbors"), + col("similarities") + ) + model.transform(typedData) + } + intercept[IllegalArgumentException] { + val typedData = data.select( + col("id"), + col("neighbors").cast(ArrayType(DoubleType, containsNull = false)).alias("neighbors"), + col("similarities") + ) + model.transform(typedData) + } + intercept[IllegalArgumentException] { + val typedData = data.select( + col("id"), + col("neighbors"), + col("neighbors").alias("similarities") + ) + model.transform(typedData) + } + } + + test("invalid input: negative similarity") { + val model = new PowerIterationClustering() + .setMaxIter(1) + val badData = spark.createDataFrame(Seq( + (0, Array(1), Array(-1.0)), + (1, Array(0), Array(-1.0)) + )).toDF("id", "neighbors", "similarities") + val msg = intercept[SparkException] { + model.transform(badData) + }.getCause.getMessage + assert(msg.contains("Similarity must be nonnegative")) + } + + test("invalid input: mismatched lengths for neighbor and similarity arrays") { + val model = new PowerIterationClustering() + .setMaxIter(1) + val badData = spark.createDataFrame(Seq( + (0, Array(1), Array(0.5)), + (1, Array(0, 2), Array(0.5)), + (2, Array(1), Array(0.5)) + )).toDF("id", "neighbors", "similarities") + val msg = intercept[SparkException] { + model.transform(badData) + }.getCause.getMessage + assert(msg.contains("The length of the neighbor ID list must be equal to the the length of " + + "the neighbor similarity list.")) + assert(msg.contains(s"Row for ID ${model.getIdCol}=1")) + } + + test("read/write") { + val t = new PowerIterationClustering() + .setK(4) + .setMaxIter(100) + .setInitMode("degree") + .setIdCol("test_id") + .setNeighborsCol("myNeighborsCol") + .setSimilaritiesCol("mySimilaritiesCol") + .setPredictionCol("test_prediction") + testDefaultReadWrite(t) + } +} + +object PowerIterationClusteringSuite { + + /** Generates a circle of points. */ + private def genCircle(r: Double, n: Int): Array[(Double, Double)] = { + Array.tabulate(n) { i => + val theta = 2.0 * math.Pi * i / n + (r * math.cos(theta), r * math.sin(theta)) + } + } + + /** Computes Gaussian similarity. */ + private def sim(x: (Double, Double), y: (Double, Double)): Double = { + val dist2 = (x._1 - y._1) * (x._1 - y._1) + (x._2 - y._2) * (x._2 - y._2) + math.exp(-dist2 / 2.0) + } + + def generatePICData( + spark: SparkSession, + r1: Double, + r2: Double, + n1: Int, + n2: Int): DataFrame = { + // Generate two circles following the example in the PIC paper. + val n = n1 + n2 + val points = genCircle(r1, n1) ++ genCircle(r2, n2) + + val rows = for (i <- 1 until n) yield { + val neighbors = for (j <- 0 until i) yield { + j.toLong + } + val similarities = for (j <- 0 until i) yield { + sim(points(i), points(j)) + } + (i.toLong, neighbors.toArray, similarities.toArray) + } + + spark.createDataFrame(rows).toDF("id", "neighbors", "similarities") + } + +} From 9ea8d3d31b75246bf61118ac7934bc92c18b5f19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Thu, 19 Apr 2018 18:55:59 +0200 Subject: [PATCH 655/774] [SPARK-22362][SQL] Add unit test for Window Aggregate Functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Improving the test coverage of window functions focusing on missing test for window aggregate functions. No new UDAF test is added as it has been tested already. ## How was this patch tested? Only new tests were added, automated tests were executed. Author: “attilapiros” Author: Attila Zsolt Piros <2017933+attilapiros@users.noreply.github.com> Closes #20046 from attilapiros/SPARK-22362. --- .../resources/sql-tests/inputs/window.sql | 10 +- .../sql-tests/results/window.sql.out | 30 +- .../sql/DataFrameWindowFunctionsSuite.scala | 266 ++++++++++++++++++ 3 files changed, 294 insertions(+), 12 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql b/sql/core/src/test/resources/sql-tests/inputs/window.sql index c4bea34ec4cf3..cda4db4b449fe 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/window.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql @@ -76,7 +76,15 @@ ntile(2) OVER w AS ntile, row_number() OVER w AS row_number, var_pop(val) OVER w AS var_pop, var_samp(val) OVER w AS var_samp, -approx_count_distinct(val) OVER w AS approx_count_distinct +approx_count_distinct(val) OVER w AS approx_count_distinct, +covar_pop(val, val_long) OVER w AS covar_pop, +corr(val, val_long) OVER w AS corr, +stddev_samp(val) OVER w AS stddev_samp, +stddev_pop(val) OVER w AS stddev_pop, +collect_list(val) OVER w AS collect_list, +collect_set(val) OVER w AS collect_set, +skewness(val_double) OVER w AS skewness, +kurtosis(val_double) OVER w AS kurtosis FROM testData WINDOW w AS (PARTITION BY cate ORDER BY val) ORDER BY cate, val; diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out index 133458ae9303b..4afbcd62853dc 100644 --- a/sql/core/src/test/resources/sql-tests/results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -273,22 +273,30 @@ ntile(2) OVER w AS ntile, row_number() OVER w AS row_number, var_pop(val) OVER w AS var_pop, var_samp(val) OVER w AS var_samp, -approx_count_distinct(val) OVER w AS approx_count_distinct +approx_count_distinct(val) OVER w AS approx_count_distinct, +covar_pop(val, val_long) OVER w AS covar_pop, +corr(val, val_long) OVER w AS corr, +stddev_samp(val) OVER w AS stddev_samp, +stddev_pop(val) OVER w AS stddev_pop, +collect_list(val) OVER w AS collect_list, +collect_set(val) OVER w AS collect_set, +skewness(val_double) OVER w AS skewness, +kurtosis(val_double) OVER w AS kurtosis FROM testData WINDOW w AS (PARTITION BY cate ORDER BY val) ORDER BY cate, val -- !query 17 schema -struct +struct,collect_set:array,skewness:double,kurtosis:double> -- !query 17 output -NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0 -3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1 -NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0 -1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 1 2 0.0 0.0 1 -1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 2 3 0.0 0.0 1 -2 a 2 1 1 3 4 1.3333333333333333 0.5773502691896258 NULL 1 NULL 2 2 2 4 3 1.0 1.0 2 4 0.22222222222222224 0.33333333333333337 2 -1 b 1 1 1 1 1 1.0 NaN 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NaN 1 -2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2 -3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3 +NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0 NULL NULL NULL NULL [] [] NULL NULL +3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1 0.0 NaN NaN 0.0 [3] [3] NaN NaN +NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0 NULL NULL NULL NULL [] [] NaN NaN +1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 1 2 0.0 0.0 1 0.0 NULL 0.0 0.0 [1,1] [1] 0.7071067811865476 -1.5 +1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 2 3 0.0 0.0 1 0.0 NULL 0.0 0.0 [1,1] [1] 0.7071067811865476 -1.5 +2 a 2 1 1 3 4 1.3333333333333333 0.5773502691896258 NULL 1 NULL 2 2 2 4 3 1.0 1.0 2 4 0.22222222222222224 0.33333333333333337 2 4.772185885555555E8 1.0 0.5773502691896258 0.4714045207910317 [1,1,2] [1,2] 1.1539890888012805 -0.6672217220327235 +1 b 1 1 1 1 1 1.0 NaN 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NaN 1 NULL NULL NaN 0.0 [1] [1] NaN NaN +2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2 0.0 NaN 0.7071067811865476 0.5 [1,2] [1,2] 0.0 -2.0000000000000013 +3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3 5.3687091175E8 1.0 1.0 0.816496580927726 [1,2,3] [1,2,3] 0.7057890433107311 -1.4999999999999984 -- !query 18 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 281147835abde..3ea398aad7375 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import scala.collection.mutable + import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ @@ -86,6 +88,236 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { assert(e.message.contains("requires window to be ordered")) } + test("corr, covar_pop, stddev_pop functions in specific window") { + val df = Seq( + ("a", "p1", 10.0, 20.0), + ("b", "p1", 20.0, 10.0), + ("c", "p2", 20.0, 20.0), + ("d", "p2", 20.0, 20.0), + ("e", "p3", 0.0, 0.0), + ("f", "p3", 6.0, 12.0), + ("g", "p3", 6.0, 12.0), + ("h", "p3", 8.0, 16.0), + ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2") + checkAnswer( + df.select( + $"key", + corr("value1", "value2").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + covar_pop("value1", "value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_pop("value1") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_pop("value1") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_pop("value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_pop("value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + + // As stddev_pop(expr) = sqrt(var_pop(expr)) + // the "stddev_pop" column can be calculated from the "var_pop" column. + // + // As corr(expr1, expr2) = covar_pop(expr1, expr2) / (stddev_pop(expr1) * stddev_pop(expr2)) + // the "corr" column can be calculated from the "covar_pop" and the two "stddev_pop" columns. + Seq( + Row("a", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0), + Row("b", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0), + Row("c", null, 0.0, 0.0, 0.0, 0.0, 0.0), + Row("d", null, 0.0, 0.0, 0.0, 0.0, 0.0), + Row("e", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("f", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("g", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("h", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("i", Double.NaN, 0.0, 0.0, 0.0, 0.0, 0.0))) + } + + test("covar_samp, var_samp (variance), stddev_samp (stddev) functions in specific window") { + val df = Seq( + ("a", "p1", 10.0, 20.0), + ("b", "p1", 20.0, 10.0), + ("c", "p2", 20.0, 20.0), + ("d", "p2", 20.0, 20.0), + ("e", "p3", 0.0, 0.0), + ("f", "p3", 6.0, 12.0), + ("g", "p3", 6.0, 12.0), + ("h", "p3", 8.0, 16.0), + ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2") + checkAnswer( + df.select( + $"key", + covar_samp("value1", "value2").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_samp("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + variance("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_samp("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)) + ), + Seq( + Row("a", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), + Row("b", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), + Row("c", 0.0, 0.0, 0.0, 0.0, 0.0 ), + Row("d", 0.0, 0.0, 0.0, 0.0, 0.0 ), + Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("i", Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN))) + } + + test("collect_list in ascending ordered window") { + val df = Seq( + ("a", "p1", "1"), + ("b", "p1", "2"), + ("c", "p1", "2"), + ("d", "p1", null), + ("e", "p1", "3"), + ("f", "p2", "10"), + ("g", "p2", "11"), + ("h", "p3", "20"), + ("i", "p4", null)).toDF("key", "partition", "value") + checkAnswer( + df.select( + $"key", + sort_array( + collect_list("value").over(Window.partitionBy($"partition").orderBy($"value") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))), + Seq( + Row("a", Array("1", "2", "2", "3")), + Row("b", Array("1", "2", "2", "3")), + Row("c", Array("1", "2", "2", "3")), + Row("d", Array("1", "2", "2", "3")), + Row("e", Array("1", "2", "2", "3")), + Row("f", Array("10", "11")), + Row("g", Array("10", "11")), + Row("h", Array("20")), + Row("i", Array()))) + } + + test("collect_list in descending ordered window") { + val df = Seq( + ("a", "p1", "1"), + ("b", "p1", "2"), + ("c", "p1", "2"), + ("d", "p1", null), + ("e", "p1", "3"), + ("f", "p2", "10"), + ("g", "p2", "11"), + ("h", "p3", "20"), + ("i", "p4", null)).toDF("key", "partition", "value") + checkAnswer( + df.select( + $"key", + sort_array( + collect_list("value").over(Window.partitionBy($"partition").orderBy($"value".desc) + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))), + Seq( + Row("a", Array("1", "2", "2", "3")), + Row("b", Array("1", "2", "2", "3")), + Row("c", Array("1", "2", "2", "3")), + Row("d", Array("1", "2", "2", "3")), + Row("e", Array("1", "2", "2", "3")), + Row("f", Array("10", "11")), + Row("g", Array("10", "11")), + Row("h", Array("20")), + Row("i", Array()))) + } + + test("collect_set in window") { + val df = Seq( + ("a", "p1", "1"), + ("b", "p1", "2"), + ("c", "p1", "2"), + ("d", "p1", "3"), + ("e", "p1", "3"), + ("f", "p2", "10"), + ("g", "p2", "11"), + ("h", "p3", "20")).toDF("key", "partition", "value") + checkAnswer( + df.select( + $"key", + sort_array( + collect_set("value").over(Window.partitionBy($"partition").orderBy($"value") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))), + Seq( + Row("a", Array("1", "2", "3")), + Row("b", Array("1", "2", "3")), + Row("c", Array("1", "2", "3")), + Row("d", Array("1", "2", "3")), + Row("e", Array("1", "2", "3")), + Row("f", Array("10", "11")), + Row("g", Array("10", "11")), + Row("h", Array("20")))) + } + + test("skewness and kurtosis functions in window") { + val df = Seq( + ("a", "p1", 1.0), + ("b", "p1", 1.0), + ("c", "p1", 2.0), + ("d", "p1", 2.0), + ("e", "p1", 3.0), + ("f", "p1", 3.0), + ("g", "p1", 3.0), + ("h", "p2", 1.0), + ("i", "p2", 2.0), + ("j", "p2", 5.0)).toDF("key", "partition", "value") + checkAnswer( + df.select( + $"key", + skewness("value").over(Window.partitionBy("partition").orderBy($"key") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + kurtosis("value").over(Window.partitionBy("partition").orderBy($"key") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + // results are checked by scipy.stats.skew() and scipy.stats.kurtosis() + Seq( + Row("a", -0.27238010581457267, -1.506920415224914), + Row("b", -0.27238010581457267, -1.506920415224914), + Row("c", -0.27238010581457267, -1.506920415224914), + Row("d", -0.27238010581457267, -1.506920415224914), + Row("e", -0.27238010581457267, -1.506920415224914), + Row("f", -0.27238010581457267, -1.506920415224914), + Row("g", -0.27238010581457267, -1.506920415224914), + Row("h", 0.5280049792181881, -1.5000000000000013), + Row("i", 0.5280049792181881, -1.5000000000000013), + Row("j", 0.5280049792181881, -1.5000000000000013))) + } + + test("aggregation function on invalid column") { + val df = Seq((1, "1")).toDF("key", "value") + val e = intercept[AnalysisException]( + df.select($"key", count("invalid").over())) + assert(e.message.contains("cannot resolve '`invalid`' given input columns: [key, value]")) + } + + test("numerical aggregate functions on string column") { + val df = Seq((1, "a", "b")).toDF("key", "value1", "value2") + checkAnswer( + df.select($"key", + var_pop("value1").over(), + variance("value1").over(), + stddev_pop("value1").over(), + stddev("value1").over(), + sum("value1").over(), + mean("value1").over(), + avg("value1").over(), + corr("value1", "value2").over(), + covar_pop("value1", "value2").over(), + covar_samp("value1", "value2").over(), + skewness("value1").over(), + kurtosis("value1").over()), + Seq(Row(1, null, null, null, null, null, null, null, null, null, null, null, null))) + } + test("statistical functions") { val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). toDF("key", "value") @@ -232,6 +464,40 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row("b", 2, null, null, null, null, null, null))) } + test("last/first on descending ordered window") { + val nullStr: String = null + val df = Seq( + ("a", 0, nullStr), + ("a", 1, "x"), + ("a", 2, "y"), + ("a", 3, "z"), + ("a", 4, "v"), + ("b", 1, "k"), + ("b", 2, "l"), + ("b", 3, nullStr)). + toDF("key", "order", "value") + val window = Window.partitionBy($"key").orderBy($"order".desc) + checkAnswer( + df.select( + $"key", + $"order", + first($"value").over(window), + first($"value", ignoreNulls = false).over(window), + first($"value", ignoreNulls = true).over(window), + last($"value").over(window), + last($"value", ignoreNulls = false).over(window), + last($"value", ignoreNulls = true).over(window)), + Seq( + Row("a", 0, "v", "v", "v", null, null, "x"), + Row("a", 1, "v", "v", "v", "x", "x", "x"), + Row("a", 2, "v", "v", "v", "y", "y", "y"), + Row("a", 3, "v", "v", "v", "z", "z", "z"), + Row("a", 4, "v", "v", "v", "v", "v", "v"), + Row("b", 1, null, null, "l", "k", "k", "k"), + Row("b", 2, null, null, "l", "l", "l", "l"), + Row("b", 3, null, null, null, null, null, null))) + } + test("SPARK-12989 ExtractWindowExpressions treats alias as regular attribute") { val src = Seq((0, 3, 5)).toDF("a", "b", "c") .withColumn("Data", struct("a", "b")) From e55953b0bf2a80b34127ba123417ee54955a6064 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Thu, 19 Apr 2018 15:06:27 -0700 Subject: [PATCH 656/774] [SPARK-24022][TEST] Make SparkContextSuite not flaky ## What changes were proposed in this pull request? SparkContextSuite.test("Cancelling stages/jobs with custom reasons.") could stay in an infinite loop because of the problem found and fixed in [SPARK-23775](https://issues.apache.org/jira/browse/SPARK-23775). This PR solves this mentioned flakyness by removing shared variable usages when cancel happens in a loop and using wait and CountDownLatch for synhronization. ## How was this patch tested? Existing unit test. Author: Gabor Somogyi Closes #21105 from gaborgsomogyi/SPARK-24022. --- .../org/apache/spark/SparkContextSuite.scala | 61 ++++++++----------- 1 file changed, 26 insertions(+), 35 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index b30bd74812b36..ce9f2be1c02dd 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark import java.io.File import java.net.{MalformedURLException, URI} import java.nio.charset.StandardCharsets -import java.util.concurrent.{Semaphore, TimeUnit} +import java.util.concurrent.{CountDownLatch, Semaphore, TimeUnit} import scala.concurrent.duration._ @@ -498,45 +498,36 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu test("Cancelling stages/jobs with custom reasons.") { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true") val REASON = "You shall not pass" - val slices = 10 - val listener = new SparkListener { - override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { - if (SparkContextSuite.cancelStage) { - eventually(timeout(10.seconds)) { - assert(SparkContextSuite.isTaskStarted) + for (cancelWhat <- Seq("stage", "job")) { + // This countdown latch used to make sure stage or job canceled in listener + val latch = new CountDownLatch(1) + + val listener = cancelWhat match { + case "stage" => + new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + sc.cancelStage(taskStart.stageId, REASON) + latch.countDown() + } } - sc.cancelStage(taskStart.stageId, REASON) - SparkContextSuite.cancelStage = false - SparkContextSuite.semaphore.release(slices) - } - } - - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - if (SparkContextSuite.cancelJob) { - eventually(timeout(10.seconds)) { - assert(SparkContextSuite.isTaskStarted) + case "job" => + new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + sc.cancelJob(jobStart.jobId, REASON) + latch.countDown() + } } - sc.cancelJob(jobStart.jobId, REASON) - SparkContextSuite.cancelJob = false - SparkContextSuite.semaphore.release(slices) - } } - } - sc.addSparkListener(listener) - - for (cancelWhat <- Seq("stage", "job")) { - SparkContextSuite.semaphore.drainPermits() - SparkContextSuite.isTaskStarted = false - SparkContextSuite.cancelStage = (cancelWhat == "stage") - SparkContextSuite.cancelJob = (cancelWhat == "job") + sc.addSparkListener(listener) val ex = intercept[SparkException] { - sc.range(0, 10000L, numSlices = slices).mapPartitions { x => - SparkContextSuite.isTaskStarted = true - // Block waiting for the listener to cancel the stage or job. - SparkContextSuite.semaphore.acquire() + sc.range(0, 10000L, numSlices = 10).mapPartitions { x => + x.synchronized { + x.wait() + } x }.count() } @@ -550,9 +541,11 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") } + latch.await(20, TimeUnit.SECONDS) eventually(timeout(20.seconds)) { assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) } + sc.removeSparkListener(listener) } } @@ -637,8 +630,6 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } object SparkContextSuite { - @volatile var cancelJob = false - @volatile var cancelStage = false @volatile var isTaskStarted = false @volatile var taskKilled = false @volatile var taskSucceeded = false From b3fde5a41ee625141b9d21ce32ea68c082449430 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 20 Apr 2018 12:06:41 +0800 Subject: [PATCH 657/774] [SPARK-23877][SQL] Use filter predicates to prune partitions in metadata-only queries ## What changes were proposed in this pull request? This updates the OptimizeMetadataOnlyQuery rule to use filter expressions when listing partitions, if there are filter nodes in the logical plan. This avoids listing all partitions for large tables on the driver. This also fixes a minor bug where the partitions returned from fsRelation cannot be serialized without hitting a stack level too deep error. This is caused by serializing a stream to executors, where the stream is a recursive structure. If the stream is too long, the serialization stack reaches the maximum level of depth. The fix is to create a LocalRelation using an Array instead of the incoming Seq. ## How was this patch tested? Existing tests for metadata-only queries. Author: Ryan Blue Closes #20988 from rdblue/SPARK-23877-metadata-only-push-filters. --- .../execution/OptimizeMetadataOnlyQuery.scala | 94 +++++++++++++------ .../OptimizeHiveMetadataOnlyQuerySuite.scala | 68 ++++++++++++++ 2 files changed, 132 insertions(+), 30 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index dc4aff9f12580..acbd4becb8549 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -49,9 +49,9 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic } plan.transform { - case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(partAttrs, relation)) => + case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(_, attrs, filters, rel)) => // We only apply this optimization when only partitioned attributes are scanned. - if (a.references.subsetOf(partAttrs)) { + if (a.references.subsetOf(attrs)) { val aggFunctions = aggExprs.flatMap(_.collect { case agg: AggregateExpression => agg }) @@ -67,7 +67,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic }) } if (isAllDistinctAgg) { - a.withNewChildren(Seq(replaceTableScanWithPartitionMetadata(child, relation))) + a.withNewChildren(Seq(replaceTableScanWithPartitionMetadata(child, rel, filters))) } else { a } @@ -98,14 +98,27 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic */ private def replaceTableScanWithPartitionMetadata( child: LogicalPlan, - relation: LogicalPlan): LogicalPlan = { + relation: LogicalPlan, + partFilters: Seq[Expression]): LogicalPlan = { + // this logic comes from PruneFileSourcePartitions. it ensures that the filter names match the + // relation's schema. PartitionedRelation ensures that the filters only reference partition cols + val relFilters = partFilters.map { e => + e transform { + case a: AttributeReference => + a.withName(relation.output.find(_.semanticEquals(a)).get.name) + } + } + child transform { case plan if plan eq relation => relation match { case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, isStreaming) => val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) - val partitionData = fsRelation.location.listFiles(Nil, Nil) - LocalRelation(partAttrs, partitionData.map(_.values), isStreaming) + val partitionData = fsRelation.location.listFiles(relFilters, Nil) + // partition data may be a stream, which can cause serialization to hit stack level too + // deep exceptions because it is a recursive structure in memory. converting to array + // avoids the problem. + LocalRelation(partAttrs, partitionData.map(_.values).toArray, isStreaming) case relation: HiveTableRelation => val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) @@ -113,12 +126,21 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic CaseInsensitiveMap(relation.tableMeta.storage.properties) val timeZoneId = caseInsensitiveProperties.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(SQLConf.get.sessionLocalTimeZone) - val partitionData = catalog.listPartitions(relation.tableMeta.identifier).map { p => + val partitions = if (partFilters.nonEmpty) { + catalog.listPartitionsByFilter(relation.tableMeta.identifier, relFilters) + } else { + catalog.listPartitions(relation.tableMeta.identifier) + } + + val partitionData = partitions.map { p => InternalRow.fromSeq(partAttrs.map { attr => Cast(Literal(p.spec(attr.name)), attr.dataType, Option(timeZoneId)).eval() }) } - LocalRelation(partAttrs, partitionData) + // partition data may be a stream, which can cause serialization to hit stack level too + // deep exceptions because it is a recursive structure in memory. converting to array + // avoids the problem. + LocalRelation(partAttrs, partitionData.toArray) case _ => throw new IllegalStateException(s"unrecognized table scan node: $relation, " + @@ -129,35 +151,47 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic /** * A pattern that finds the partitioned table relation node inside the given plan, and returns a - * pair of the partition attributes and the table relation node. + * pair of the partition attributes, partition filters, and the table relation node. * * It keeps traversing down the given plan tree if there is a [[Project]] or [[Filter]] with * deterministic expressions, and returns result after reaching the partitioned table relation * node. */ - object PartitionedRelation { - - def unapply(plan: LogicalPlan): Option[(AttributeSet, LogicalPlan)] = plan match { - case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) - if fsRelation.partitionSchema.nonEmpty => - val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) - Some((AttributeSet(partAttrs), l)) - - case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty => - val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) - Some((AttributeSet(partAttrs), relation)) - - case p @ Project(projectList, child) if projectList.forall(_.deterministic) => - unapply(child).flatMap { case (partAttrs, relation) => - if (p.references.subsetOf(partAttrs)) Some((p.outputSet, relation)) else None - } + object PartitionedRelation extends PredicateHelper { + + def unapply( + plan: LogicalPlan): Option[(AttributeSet, AttributeSet, Seq[Expression], LogicalPlan)] = { + plan match { + case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) + if fsRelation.partitionSchema.nonEmpty => + val partAttrs = AttributeSet(getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l)) + Some((partAttrs, partAttrs, Nil, l)) + + case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty => + val partAttrs = AttributeSet( + getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation)) + Some((partAttrs, partAttrs, Nil, relation)) + + case p @ Project(projectList, child) if projectList.forall(_.deterministic) => + unapply(child).flatMap { case (partAttrs, attrs, filters, relation) => + if (p.references.subsetOf(attrs)) { + Some((partAttrs, p.outputSet, filters, relation)) + } else { + None + } + } - case f @ Filter(condition, child) if condition.deterministic => - unapply(child).flatMap { case (partAttrs, relation) => - if (f.references.subsetOf(partAttrs)) Some((partAttrs, relation)) else None - } + case f @ Filter(condition, child) if condition.deterministic => + unapply(child).flatMap { case (partAttrs, attrs, filters, relation) => + if (f.references.subsetOf(partAttrs)) { + Some((partAttrs, attrs, splitConjunctivePredicates(condition) ++ filters, relation)) + } else { + None + } + } - case _ => None + case _ => None + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala new file mode 100644 index 0000000000000..95f192f0e40e2 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.metrics.source.HiveCatalogMetrics +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Filter, Project, SubqueryAlias} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleton + with BeforeAndAfter with SQLTestUtils { + + import spark.implicits._ + + before { + sql("CREATE TABLE metadata_only (id bigint, data string) PARTITIONED BY (part int)") + (0 to 10).foreach(p => sql(s"ALTER TABLE metadata_only ADD PARTITION (part=$p)")) + } + + test("SPARK-23877: validate metadata-only query pushes filters to metastore") { + withTable("metadata_only") { + val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount + + // verify the number of matching partitions + assert(sql("SELECT DISTINCT part FROM metadata_only WHERE part < 5").collect().length === 5) + + // verify that the partition predicate was pushed down to the metastore + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount === 5) + } + } + + test("SPARK-23877: filter on projected expression") { + withTable("metadata_only") { + val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount + + // verify the matching partitions + val partitions = spark.internalCreateDataFrame(Distinct(Filter(($"x" < 5).expr, + Project(Seq(($"part" + 1).as("x").expr.asInstanceOf[NamedExpression]), + spark.table("metadata_only").logicalPlan.asInstanceOf[SubqueryAlias].child))) + .queryExecution.toRdd, StructType(Seq(StructField("x", IntegerType)))) + + checkAnswer(partitions, Seq(1, 2, 3, 4).toDF("x")) + + // verify that the partition predicate was not pushed down to the metastore + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount == 11) + } + } +} From e6b466084c26fbb9b9e50dd5cc8b25da7533ac72 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Fri, 20 Apr 2018 14:58:11 +0900 Subject: [PATCH 658/774] [SPARK-23736][SQL] Extending the concat function to support array columns ## What changes were proposed in this pull request? The PR adds a logic for easy concatenation of multiple array columns and covers: - Concat expression has been extended to support array columns - A Python wrapper ## How was this patch tested? New tests added into: - CollectionExpressionsSuite - DataFrameFunctionsSuite - typeCoercion/native/concat.sql ## Codegen examples ### Primitive-type elements ``` val df = Seq( (Seq(1 ,2), Seq(3, 4)), (Seq(1, 2, 3), null) ).toDF("a", "b") df.filter('a.isNotNull).select(concat('a, 'b)).debugCodegen() ``` Result: ``` /* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 034 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 035 */ null : (inputadapter_row.getArray(0)); /* 036 */ /* 037 */ if (!(!inputadapter_isNull)) continue; /* 038 */ /* 039 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 040 */ /* 041 */ ArrayData[] project_args = new ArrayData[2]; /* 042 */ /* 043 */ if (!false) { /* 044 */ project_args[0] = inputadapter_value; /* 045 */ } /* 046 */ /* 047 */ boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1); /* 048 */ ArrayData inputadapter_value1 = inputadapter_isNull1 ? /* 049 */ null : (inputadapter_row.getArray(1)); /* 050 */ if (!inputadapter_isNull1) { /* 051 */ project_args[1] = inputadapter_value1; /* 052 */ } /* 053 */ /* 054 */ ArrayData project_value = new Object() { /* 055 */ public ArrayData concat(ArrayData[] args) { /* 056 */ for (int z = 0; z < 2; z++) { /* 057 */ if (args[z] == null) return null; /* 058 */ } /* 059 */ /* 060 */ long project_numElements = 0L; /* 061 */ for (int z = 0; z < 2; z++) { /* 062 */ project_numElements += args[z].numElements(); /* 063 */ } /* 064 */ if (project_numElements > 2147483632) { /* 065 */ throw new RuntimeException("Unsuccessful try to concat arrays with " + project_numElements + /* 066 */ " elements due to exceeding the array size limit 2147483632."); /* 067 */ } /* 068 */ /* 069 */ long project_size = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( /* 070 */ project_numElements, /* 071 */ 4); /* 072 */ if (project_size > 2147483632) { /* 073 */ throw new RuntimeException("Unsuccessful try to concat arrays with " + project_size + /* 074 */ " bytes of data due to exceeding the limit 2147483632 bytes" + /* 075 */ " for UnsafeArrayData."); /* 076 */ } /* 077 */ /* 078 */ byte[] project_array = new byte[(int)project_size]; /* 079 */ UnsafeArrayData project_arrayData = new UnsafeArrayData(); /* 080 */ Platform.putLong(project_array, 16, project_numElements); /* 081 */ project_arrayData.pointTo(project_array, 16, (int)project_size); /* 082 */ int project_counter = 0; /* 083 */ for (int y = 0; y < 2; y++) { /* 084 */ for (int z = 0; z < args[y].numElements(); z++) { /* 085 */ if (args[y].isNullAt(z)) { /* 086 */ project_arrayData.setNullAt(project_counter); /* 087 */ } else { /* 088 */ project_arrayData.setInt( /* 089 */ project_counter, /* 090 */ args[y].getInt(z) /* 091 */ ); /* 092 */ } /* 093 */ project_counter++; /* 094 */ } /* 095 */ } /* 096 */ return project_arrayData; /* 097 */ } /* 098 */ }.concat(project_args); /* 099 */ boolean project_isNull = project_value == null; ``` ### Non-primitive-type elements ``` val df = Seq( (Seq("aa" ,"bb"), Seq("ccc", "ddd")), (Seq("x", "y"), null) ).toDF("a", "b") df.filter('a.isNotNull).select(concat('a, 'b)).debugCodegen() ``` Result: ``` /* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 034 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 035 */ null : (inputadapter_row.getArray(0)); /* 036 */ /* 037 */ if (!(!inputadapter_isNull)) continue; /* 038 */ /* 039 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 040 */ /* 041 */ ArrayData[] project_args = new ArrayData[2]; /* 042 */ /* 043 */ if (!false) { /* 044 */ project_args[0] = inputadapter_value; /* 045 */ } /* 046 */ /* 047 */ boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1); /* 048 */ ArrayData inputadapter_value1 = inputadapter_isNull1 ? /* 049 */ null : (inputadapter_row.getArray(1)); /* 050 */ if (!inputadapter_isNull1) { /* 051 */ project_args[1] = inputadapter_value1; /* 052 */ } /* 053 */ /* 054 */ ArrayData project_value = new Object() { /* 055 */ public ArrayData concat(ArrayData[] args) { /* 056 */ for (int z = 0; z < 2; z++) { /* 057 */ if (args[z] == null) return null; /* 058 */ } /* 059 */ /* 060 */ long project_numElements = 0L; /* 061 */ for (int z = 0; z < 2; z++) { /* 062 */ project_numElements += args[z].numElements(); /* 063 */ } /* 064 */ if (project_numElements > 2147483632) { /* 065 */ throw new RuntimeException("Unsuccessful try to concat arrays with " + project_numElements + /* 066 */ " elements due to exceeding the array size limit 2147483632."); /* 067 */ } /* 068 */ /* 069 */ Object[] project_arrayObjects = new Object[(int)project_numElements]; /* 070 */ int project_counter = 0; /* 071 */ for (int y = 0; y < 2; y++) { /* 072 */ for (int z = 0; z < args[y].numElements(); z++) { /* 073 */ project_arrayObjects[project_counter] = args[y].getUTF8String(z); /* 074 */ project_counter++; /* 075 */ } /* 076 */ } /* 077 */ return new org.apache.spark.sql.catalyst.util.GenericArrayData(project_arrayObjects); /* 078 */ } /* 079 */ }.concat(project_args); /* 080 */ boolean project_isNull = project_value == null; ``` Author: mn-mikke Closes #20858 from mn-mikke/feature/array-api-concat_arrays-to-master. --- .../spark/unsafe/array/ByteArrayMethods.java | 6 +- python/pyspark/sql/functions.py | 34 +-- .../catalyst/expressions/UnsafeArrayData.java | 10 + .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 8 + .../expressions/collectionOperations.scala | 220 +++++++++++++++++- .../expressions/stringExpressions.scala | 81 ------- .../CollectionExpressionsSuite.scala | 41 ++++ .../org/apache/spark/sql/functions.scala | 20 +- .../inputs/typeCoercion/native/concat.sql | 62 +++++ .../typeCoercion/native/concat.sql.out | 78 +++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 74 ++++++ .../sql/execution/command/DDLSuite.scala | 4 +- 13 files changed, 529 insertions(+), 111 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index 4bc9955090fd7..ef0f78d95d1ee 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -33,7 +33,11 @@ public static long nextPowerOf2(long num) { } public static int roundNumberOfBytesToNearestWord(int numBytes) { - int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` + return (int)roundNumberOfBytesToNearestWord((long)numBytes); + } + + public static long roundNumberOfBytesToNearestWord(long numBytes) { + long remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` if (remainder == 0) { return numBytes; } else { diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1be68f2a4a448..da32ab25cad0c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1425,21 +1425,6 @@ def hash(*cols): del _name, _doc -@since(1.5) -@ignore_unicode_prefix -def concat(*cols): - """ - Concatenates multiple input columns together into a single column. - If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. - - >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) - >>> df.select(concat(df.s, df.d).alias('s')).collect() - [Row(s=u'abcd123')] - """ - sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column))) - - @since(1.5) @ignore_unicode_prefix def concat_ws(sep, *cols): @@ -1845,6 +1830,25 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(1.5) +@ignore_unicode_prefix +def concat(*cols): + """ + Concatenates multiple input columns together into a single column. + The function works with strings, binary and compatible array columns. + + >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) + >>> df.select(concat(df.s, df.d).alias('s')).collect() + [Row(s=u'abcd123')] + + >>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c']) + >>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect() + [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column))) + + @since(2.4) def array_position(col, value): """ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 8546c28335536..d5d934bc91cab 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -56,9 +56,19 @@ public final class UnsafeArrayData extends ArrayData { public static int calculateHeaderPortionInBytes(int numFields) { + return (int)calculateHeaderPortionInBytes((long)numFields); + } + + public static long calculateHeaderPortionInBytes(long numFields) { return 8 + ((numFields + 63)/ 64) * 8; } + public static long calculateSizeOfUnderlyingByteArray(long numFields, int elementSize) { + long size = UnsafeArrayData.calculateHeaderPortionInBytes(numFields) + + ByteArrayMethods.roundNumberOfBytesToNearestWord(numFields * elementSize); + return size; + } + private Object baseObject; private long baseOffset; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a44f2d5272b8e..c41f16c61d7a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -308,7 +308,6 @@ object FunctionRegistry { expression[BitLength]("bit_length"), expression[Length]("char_length"), expression[Length]("character_length"), - expression[Concat]("concat"), expression[ConcatWs]("concat_ws"), expression[Decode]("decode"), expression[Elt]("elt"), @@ -413,6 +412,7 @@ object FunctionRegistry { expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), expression[Reverse]("reverse"), + expression[Concat]("concat"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 281f206e8d59e..cfcbd8db559a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -520,6 +520,14 @@ object TypeCoercion { case None => a } + case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && + !haveSameType(children) => + val types = children.map(_.dataType) + findWiderCommonType(types) match { + case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType))) + case None => c + } + case m @ CreateMap(children) if m.keys.length == m.values.length && (!haveSameType(m.keys) || !haveSameType(m.values)) => val newKeys = if (haveSameType(m.keys)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index dba426e999dda..c16793bda028e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -23,7 +23,9 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.types.{ByteArray, UTF8String} /** * Given an array or map, returns its size. Returns -1 if null. @@ -665,3 +667,219 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti override def prettyName: String = "element_at" } + +/** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + examples = """ + Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + | [1,2,3,4,5,6] + """) +case class Concat(children: Seq[Expression]) extends Expression { + + private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + val allowedTypes = Seq(StringType, BinaryType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess + } else { + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) { + return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + + s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") + } + } + + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + + lazy val javaType: String = CodeGenerator.javaType(dataType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = dataType match { + case BinaryType => + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) + case StringType => + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) + case ArrayType(elementType, _) => + val inputs = children.toStream.map(_.eval(input)) + if (inputs.contains(null)) { + null + } else { + val arrayData = inputs.map(_.asInstanceOf[ArrayData]) + val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) + if (numberOfElements > MAX_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + + s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + } + val finalData = new Array[AnyRef](numberOfElements.toInt) + var position = 0 + for(ad <- arrayData) { + val arr = ad.toObjectArray(elementType) + Array.copy(arr, 0, finalData, position, arr.length) + position += arr.length + } + new GenericArrayData(finalData) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val evals = children.map(_.genCode(ctx)) + val args = ctx.freshName("args") + + val inputs = evals.zipWithIndex.map { case (eval, index) => + s""" + ${eval.code} + if (!${eval.isNull}) { + $args[$index] = ${eval.value}; + } + """ + } + + val (concatenator, initCode) = dataType match { + case BinaryType => + (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") + case StringType => + ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") + case ArrayType(elementType, _) => + val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForPrimitiveArrays(ctx, elementType) + } else { + genCodeForNonPrimitiveArrays(ctx, elementType) + } + (arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") + } + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = inputs, + funcName = "valueConcat", + extraArguments = (s"$javaType[]", args) :: Nil) + ev.copy(s""" + $initCode + $codes + $javaType ${ev.value} = $concatenator.concat($args); + boolean ${ev.isNull} = ${ev.value} == null; + """) + } + + private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { + val numElements = ctx.freshName("numElements") + val code = s""" + |long $numElements = 0L; + |for (int z = 0; z < ${children.length}; z++) { + | $numElements += args[z].numElements(); + |} + |if ($numElements > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + + | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |} + """.stripMargin + + (code, numElements) + } + + private def nullArgumentProtection() : String = { + if (nullable) { + s""" + |for (int z = 0; z < ${children.length}; z++) { + | if (args[z] == null) return null; + |} + """.stripMargin + } else { + "" + } + } + + private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + val arrayName = ctx.freshName("array") + val arraySizeName = ctx.freshName("size") + val counter = ctx.freshName("counter") + val arrayData = ctx.freshName("arrayData") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) + + val unsafeArraySizeInBytes = s""" + |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElemName, + | ${elementType.defaultSize}); + |if ($arraySizeName > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName + + | " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" + + | " for UnsafeArrayData."); + |} + """.stripMargin + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + + s""" + |new Object() { + | public ArrayData concat($javaType[] args) { + | ${nullArgumentProtection()} + | $numElemCode + | $unsafeArraySizeInBytes + | byte[] $arrayName = new byte[(int)$arraySizeName]; + | UnsafeArrayData $arrayData = new UnsafeArrayData(); + | Platform.putLong($arrayName, $baseOffset, $numElemName); + | $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName); + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < args[y].numElements(); z++) { + | if (args[y].isNullAt(z)) { + | $arrayData.setNullAt($counter); + | } else { + | $arrayData.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue(s"args[y]", elementType, "z")} + | ); + | } + | $counter++; + | } + | } + | return $arrayData; + | } + |}""".stripMargin.stripPrefix("\n") + } + + private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayData = ctx.freshName("arrayObjects") + val counter = ctx.freshName("counter") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) + + s""" + |new Object() { + | public ArrayData concat($javaType[] args) { + | ${nullArgumentProtection()} + | $numElemCode + | Object[] $arrayData = new Object[(int)$numElemName]; + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < args[y].numElements(); z++) { + | $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")}; + | $counter++; + | } + | } + | return new $genericArrayClass($arrayData); + | } + |}""".stripMargin.stripPrefix("\n") + } + + override def toString: String = s"concat(${children.mkString(", ")})" + + override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 5a02ca0d6862c..ea005a26a4c8b 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -36,87 +36,6 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} //////////////////////////////////////////////////////////////////////////////////////////////////// -/** - * An expression that concatenates multiple inputs into a single output. - * If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. - * If any input is null, concat returns null. - */ -@ExpressionDescription( - usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of str1, str2, ..., strN.", - examples = """ - Examples: - > SELECT _FUNC_('Spark', 'SQL'); - SparkSQL - """) -case class Concat(children: Seq[Expression]) extends Expression { - - private lazy val isBinaryMode: Boolean = dataType == BinaryType - - override def checkInputDataTypes(): TypeCheckResult = { - if (children.isEmpty) { - TypeCheckResult.TypeCheckSuccess - } else { - val childTypes = children.map(_.dataType) - if (childTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { - return TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should have StringType or BinaryType, but it's " + - childTypes.map(_.simpleString).mkString("[", ", ", "]")) - } - TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") - } - } - - override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) - - override def nullable: Boolean = children.exists(_.nullable) - override def foldable: Boolean = children.forall(_.foldable) - - override def eval(input: InternalRow): Any = { - if (isBinaryMode) { - val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) - ByteArray.concat(inputs: _*) - } else { - val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) - UTF8String.concat(inputs : _*) - } - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val evals = children.map(_.genCode(ctx)) - val args = ctx.freshName("args") - - val inputs = evals.zipWithIndex.map { case (eval, index) => - s""" - ${eval.code} - if (!${eval.isNull}) { - $args[$index] = ${eval.value}; - } - """ - } - - val (concatenator, initCode) = if (isBinaryMode) { - (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") - } else { - ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") - } - val codes = ctx.splitExpressionsWithCurrentInputs( - expressions = inputs, - funcName = "valueConcat", - extraArguments = (s"${CodeGenerator.javaType(dataType)}[]", args) :: Nil) - ev.copy(s""" - $initCode - $codes - ${CodeGenerator.javaType(dataType)} ${ev.value} = $concatenator.concat($args); - boolean ${ev.isNull} = ${ev.value} == null; - """) - } - - override def toString: String = s"concat(${children.mkString(", ")})" - - override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" -} - - /** * An expression that concatenates multiple input strings or array of strings into a single string, * using a given separator (the first child). diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 7d8fe211858b2..43c5dda2e4a48 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -239,4 +239,45 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ElementAt(m2, Literal("a")), null) } + + test("Concat") { + // Primitive-type elements + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType)) + val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType)) + val ai4 = Literal.create(null, ArrayType(IntegerType)) + + checkEvaluation(Concat(Seq(ai0)), Seq(1, 2, 3)) + checkEvaluation(Concat(Seq(ai0, ai1)), Seq(1, 2, 3)) + checkEvaluation(Concat(Seq(ai1, ai0)), Seq(1, 2, 3)) + checkEvaluation(Concat(Seq(ai0, ai0)), Seq(1, 2, 3, 1, 2, 3)) + checkEvaluation(Concat(Seq(ai0, ai2)), Seq(1, 2, 3, 4, null, 5)) + checkEvaluation(Concat(Seq(ai0, ai3, ai2)), Seq(1, 2, 3, null, null, 4, null, 5)) + checkEvaluation(Concat(Seq(ai4)), null) + checkEvaluation(Concat(Seq(ai0, ai4)), null) + checkEvaluation(Concat(Seq(ai4, ai0)), null) + + // Non-primitive-type elements + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) + val as1 = Literal.create(Seq.empty[String], ArrayType(StringType)) + val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType)) + val as3 = Literal.create(Seq(null, null), ArrayType(StringType)) + val as4 = Literal.create(null, ArrayType(StringType)) + + val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")), ArrayType(ArrayType(StringType))) + val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")), ArrayType(ArrayType(StringType))) + + checkEvaluation(Concat(Seq(as0)), Seq("a", "b", "c")) + checkEvaluation(Concat(Seq(as0, as1)), Seq("a", "b", "c")) + checkEvaluation(Concat(Seq(as1, as0)), Seq("a", "b", "c")) + checkEvaluation(Concat(Seq(as0, as0)), Seq("a", "b", "c", "a", "b", "c")) + checkEvaluation(Concat(Seq(as0, as2)), Seq("a", "b", "c", "d", null, "e")) + checkEvaluation(Concat(Seq(as0, as3, as2)), Seq("a", "b", "c", null, null, "d", null, "e")) + checkEvaluation(Concat(Seq(as4)), null) + checkEvaluation(Concat(Seq(as0, as4)), null) + checkEvaluation(Concat(Seq(as4, as0)), null) + + checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f"))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9c8580378303e..bea8c0e445002 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2228,16 +2228,6 @@ object functions { */ def base64(e: Column): Column = withExpr { Base64(e.expr) } - /** - * Concatenates multiple input columns together into a single column. - * If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. - * - * @group string_funcs - * @since 1.5.0 - */ - @scala.annotation.varargs - def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) } - /** * Concatenates multiple input string columns together into a single string column, * using the given separator. @@ -3038,6 +3028,16 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + * + * @group collection_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) } + /** * Locates the position of the first occurrence of the value in the given array as long. * Returns null if either of the arguments are null. diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql index 0beebec5702fd..db00a18f2e7e9 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql @@ -91,3 +91,65 @@ FROM ( encode(string(id + 3), 'utf-8') col4 FROM range(10) ); + +CREATE TEMPORARY VIEW various_arrays AS SELECT * FROM VALUES ( + array(true, false), array(true), + array(2Y, 1Y), array(3Y, 4Y), + array(2S, 1S), array(3S, 4S), + array(2, 1), array(3, 4), + array(2L, 1L), array(3L, 4L), + array(9223372036854775809, 9223372036854775808), array(9223372036854775808, 9223372036854775809), + array(2.0D, 1.0D), array(3.0D, 4.0D), + array(float(2.0), float(1.0)), array(float(3.0), float(4.0)), + array(date '2016-03-14', date '2016-03-13'), array(date '2016-03-12', date '2016-03-11'), + array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + array(timestamp '2016-11-11 20:54:00.000'), + array('a', 'b'), array('c', 'd'), + array(array('a', 'b'), array('c', 'd')), array(array('e'), array('f')), + array(struct('a', 1), struct('b', 2)), array(struct('c', 3), struct('d', 4)), + array(map('a', 1), map('b', 2)), array(map('c', 3), map('d', 4)) +) AS various_arrays( + boolean_array1, boolean_array2, + tinyint_array1, tinyint_array2, + smallint_array1, smallint_array2, + int_array1, int_array2, + bigint_array1, bigint_array2, + decimal_array1, decimal_array2, + double_array1, double_array2, + float_array1, float_array2, + date_array1, data_array2, + timestamp_array1, timestamp_array2, + string_array1, string_array2, + array_array1, array_array2, + struct_array1, struct_array2, + map_array1, map_array2 +); + +-- Concatenate arrays of the same type +SELECT + (boolean_array1 || boolean_array2) boolean_array, + (tinyint_array1 || tinyint_array2) tinyint_array, + (smallint_array1 || smallint_array2) smallint_array, + (int_array1 || int_array2) int_array, + (bigint_array1 || bigint_array2) bigint_array, + (decimal_array1 || decimal_array2) decimal_array, + (double_array1 || double_array2) double_array, + (float_array1 || float_array2) float_array, + (date_array1 || data_array2) data_array, + (timestamp_array1 || timestamp_array2) timestamp_array, + (string_array1 || string_array2) string_array, + (array_array1 || array_array2) array_array, + (struct_array1 || struct_array2) struct_array, + (map_array1 || map_array2) map_array +FROM various_arrays; + +-- Concatenate arrays of different types +SELECT + (tinyint_array1 || smallint_array2) ts_array, + (smallint_array1 || int_array2) si_array, + (int_array1 || bigint_array2) ib_array, + (double_array1 || float_array2) df_array, + (string_array1 || data_array2) std_array, + (timestamp_array1 || string_array2) tst_array, + (string_array1 || int_array2) sti_array +FROM various_arrays; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out index 09729fdc2ec32..62befc5ca0f15 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out @@ -237,3 +237,81 @@ struct 78910 891011 9101112 + + +-- !query 11 +CREATE TEMPORARY VIEW various_arrays AS SELECT * FROM VALUES ( + array(true, false), array(true), + array(2Y, 1Y), array(3Y, 4Y), + array(2S, 1S), array(3S, 4S), + array(2, 1), array(3, 4), + array(2L, 1L), array(3L, 4L), + array(9223372036854775809, 9223372036854775808), array(9223372036854775808, 9223372036854775809), + array(2.0D, 1.0D), array(3.0D, 4.0D), + array(float(2.0), float(1.0)), array(float(3.0), float(4.0)), + array(date '2016-03-14', date '2016-03-13'), array(date '2016-03-12', date '2016-03-11'), + array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + array(timestamp '2016-11-11 20:54:00.000'), + array('a', 'b'), array('c', 'd'), + array(array('a', 'b'), array('c', 'd')), array(array('e'), array('f')), + array(struct('a', 1), struct('b', 2)), array(struct('c', 3), struct('d', 4)), + array(map('a', 1), map('b', 2)), array(map('c', 3), map('d', 4)) +) AS various_arrays( + boolean_array1, boolean_array2, + tinyint_array1, tinyint_array2, + smallint_array1, smallint_array2, + int_array1, int_array2, + bigint_array1, bigint_array2, + decimal_array1, decimal_array2, + double_array1, double_array2, + float_array1, float_array2, + date_array1, data_array2, + timestamp_array1, timestamp_array2, + string_array1, string_array2, + array_array1, array_array2, + struct_array1, struct_array2, + map_array1, map_array2 +) +-- !query 11 schema +struct<> +-- !query 11 output + + + +-- !query 12 +SELECT + (boolean_array1 || boolean_array2) boolean_array, + (tinyint_array1 || tinyint_array2) tinyint_array, + (smallint_array1 || smallint_array2) smallint_array, + (int_array1 || int_array2) int_array, + (bigint_array1 || bigint_array2) bigint_array, + (decimal_array1 || decimal_array2) decimal_array, + (double_array1 || double_array2) double_array, + (float_array1 || float_array2) float_array, + (date_array1 || data_array2) data_array, + (timestamp_array1 || timestamp_array2) timestamp_array, + (string_array1 || string_array2) string_array, + (array_array1 || array_array2) array_array, + (struct_array1 || struct_array2) struct_array, + (map_array1 || map_array2) map_array +FROM various_arrays +-- !query 12 schema +struct,tinyint_array:array,smallint_array:array,int_array:array,bigint_array:array,decimal_array:array,double_array:array,float_array:array,data_array:array,timestamp_array:array,string_array:array,array_array:array>,struct_array:array>,map_array:array>> +-- !query 12 output +[true,false,true] [2,1,3,4] [2,1,3,4] [2,1,3,4] [2,1,3,4] [9223372036854775809,9223372036854775808,9223372036854775808,9223372036854775809] [2.0,1.0,3.0,4.0] [2.0,1.0,3.0,4.0] [2016-03-14,2016-03-13,2016-03-12,2016-03-11] [2016-11-15 20:54:00.0,2016-11-12 20:54:00.0,2016-11-11 20:54:00.0] ["a","b","c","d"] [["a","b"],["c","d"],["e"],["f"]] [{"col1":"a","col2":1},{"col1":"b","col2":2},{"col1":"c","col2":3},{"col1":"d","col2":4}] [{"a":1},{"b":2},{"c":3},{"d":4}] + + +-- !query 13 +SELECT + (tinyint_array1 || smallint_array2) ts_array, + (smallint_array1 || int_array2) si_array, + (int_array1 || bigint_array2) ib_array, + (double_array1 || float_array2) df_array, + (string_array1 || data_array2) std_array, + (timestamp_array1 || string_array2) tst_array, + (string_array1 || int_array2) sti_array +FROM various_arrays +-- !query 13 schema +struct,si_array:array,ib_array:array,df_array:array,std_array:array,tst_array:array,sti_array:array> +-- !query 13 output +[2,1,3,4] [2,1,3,4] [2,1,3,4] [2.0,1.0,3.0,4.0] ["a","b","2016-03-12","2016-03-11"] ["2016-11-15 20:54:00","2016-11-12 20:54:00","c","d"] ["a","b","3","4"] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7c976c1b7f915..25e5cd60dd236 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -617,6 +617,80 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("concat function - arrays") { + val nseqi : Seq[Int] = null + val nseqs : Seq[String] = null + val df = Seq( + + (Seq(1), Seq(2, 3), Seq(5L, 6L), nseqi, Seq("a", "b", "c"), Seq("d", "e"), Seq("f"), nseqs), + (Seq(1, 0), Seq.empty[Int], Seq(2L), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs) + ).toDF("i1", "i2", "i3", "in", "s1", "s2", "s3", "sn") + + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codeGen on + + // Simple test cases + checkAnswer( + df.selectExpr("array(1, 2, 3L)"), + Seq(Row(Seq(1L, 2L, 3L)), Row(Seq(1L, 2L, 3L))) + ) + + checkAnswer ( + df.select(concat($"i1", $"s1")), + Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a"))) + ) + checkAnswer( + df.select(concat($"i1", $"i2", $"i3")), + Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) + ) + checkAnswer( + df.filter(dummyFilter($"i1")).select(concat($"i1", $"i2", $"i3")), + Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) + ) + checkAnswer( + df.selectExpr("concat(array(1, null), i2, i3)"), + Seq(Row(Seq(1, null, 2, 3, 5, 6)), Row(Seq(1, null, 2))) + ) + checkAnswer( + df.select(concat($"s1", $"s2", $"s3")), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + checkAnswer( + df.selectExpr("concat(s1, s2, s3)"), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + checkAnswer( + df.filter(dummyFilter($"s1"))select(concat($"s1", $"s2", $"s3")), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + + // Null test cases + checkAnswer( + df.select(concat($"i1", $"in")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"in", $"i1")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"s1", $"sn")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"sn", $"s1")), + Seq(Row(null), Row(null)) + ) + + // Type error test cases + intercept[AnalysisException] { + df.selectExpr("concat(i1, i2, null)") + } + + intercept[AnalysisException] { + df.selectExpr("concat(i1, array(i1, i2))") + } + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index cbd7f9d6f67be..3998ceca38b30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1742,8 +1742,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql("DESCRIBE FUNCTION 'concat'"), Row("Class: org.apache.spark.sql.catalyst.expressions.Concat") :: Row("Function: concat") :: - Row("Usage: concat(str1, str2, ..., strN) - " + - "Returns the concatenation of str1, str2, ..., strN.") :: Nil + Row("Usage: concat(col1, col2, ..., colN) - " + + "Returns the concatenation of col1, col2, ..., colN.") :: Nil ) // extended mode checkAnswer( From 074a7f90536493b607e8e74bcebf3a27ea49a49d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 20 Apr 2018 14:43:47 +0200 Subject: [PATCH 659/774] [SPARK-23588][SQL][FOLLOW-UP] Resolve a map builder method per execution in CatalystToExternalMap ## What changes were proposed in this pull request? This pr is a follow-up pr of #20979 and fixes code to resolve a map builder method per execution instead of per row in `CatalystToExternalMap`. ## How was this patch tested? Existing tests. Author: Takeshi Yamamuro Closes #21112 from maropu/SPARK-23588-FOLLOWUP. --- .../sql/catalyst/expressions/objects/objects.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index bc17d1229420a..32c1f34ef97a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1040,11 +1040,13 @@ case class CatalystToExternalMap private( private lazy val valueConverter = CatalystTypeConverters.createToScalaConverter(inputMapType.valueType) - private def newMapBuilder(): Builder[AnyRef, AnyRef] = { + private lazy val (newMapBuilderMethod, moduleField) = { val clazz = Utils.classForName(collClass.getCanonicalName + "$") - val module = clazz.getField("MODULE$").get(null) - val method = clazz.getMethod("newBuilder") - method.invoke(module).asInstanceOf[Builder[AnyRef, AnyRef]] + (clazz.getMethod("newBuilder"), clazz.getField("MODULE$").get(null)) + } + + private def newMapBuilder(): Builder[AnyRef, AnyRef] = { + newMapBuilderMethod.invoke(moduleField).asInstanceOf[Builder[AnyRef, AnyRef]] } override def eval(input: InternalRow): Any = { From 0dd97f6ea4affde1531dec1bec004b7ab18c6965 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 20 Apr 2018 15:02:27 +0200 Subject: [PATCH 660/774] [SPARK-23595][SQL] ValidateExternalType should support interpreted execution ## What changes were proposed in this pull request? This pr supported interpreted mode for `ValidateExternalType`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Takeshi Yamamuro Closes #20757 from maropu/SPARK-23595. --- .../spark/sql/catalyst/ScalaReflection.scala | 13 +++++++ .../sql/catalyst/encoders/RowEncoder.scala | 2 +- .../expressions/objects/objects.scala | 34 ++++++++++++++++--- .../expressions/ObjectExpressionsSuite.scala | 33 ++++++++++++++++-- 4 files changed, 74 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 818cc2fb1e8a8..f9acc208b715e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -846,6 +846,19 @@ object ScalaReflection extends ScalaReflection { } } + def javaBoxedType(dt: DataType): Class[_] = dt match { + case _: DecimalType => classOf[Decimal] + case BinaryType => classOf[Array[Byte]] + case StringType => classOf[UTF8String] + case CalendarIntervalType => classOf[CalendarInterval] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayType] + case _: MapType => classOf[MapType] + case udt: UserDefinedType[_] => javaBoxedType(udt.sqlType) + case ObjectType(cls) => cls + case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt, classOf[java.lang.Object]) + } + def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = { if (arguments != Nil) { arguments.map(e => dataTypeJavaClass(e.dataType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 789750fd408f2..3340789398f9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 32c1f34ef97a5..f1ffcaec8a484 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.util.Utils /** @@ -1672,13 +1673,36 @@ case class ValidateExternalType(child: Expression, expected: DataType) override def nullable: Boolean = child.nullable - override def dataType: DataType = RowEncoder.externalDataTypeForInput(expected) - - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override val dataType: DataType = RowEncoder.externalDataTypeForInput(expected) private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" + private lazy val checkType: (Any) => Boolean = expected match { + case _: DecimalType => + (value: Any) => { + value.isInstanceOf[java.math.BigDecimal] || value.isInstanceOf[scala.math.BigDecimal] || + value.isInstanceOf[Decimal] + } + case _: ArrayType => + (value: Any) => { + value.getClass.isArray || value.isInstanceOf[Seq[_]] + } + case _ => + val dataTypeClazz = ScalaReflection.javaBoxedType(dataType) + (value: Any) => { + dataTypeClazz.isInstance(value) + } + } + + override def eval(input: InternalRow): Any = { + val result = child.eval(input) + if (checkType(result)) { + result + } else { + throw new RuntimeException(s"${result.getClass.getName}$errMsg") + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the type doesn't match. @@ -1691,7 +1715,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal]) .map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ") case _: ArrayType => - s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()" + s"$obj.getClass().isArray() || $obj instanceof ${classOf[Seq[_]].getName}" case _ => s"$obj instanceof ${CodeGenerator.boxedType(dataType)}" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index bcd035c1eba0b..7136af8934486 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, Generic import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class InvokeTargetClass extends Serializable { def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0 @@ -296,7 +296,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0") Seq((Row(1), 1), (Row(3), 3)).foreach { case (input, expected) => - checkEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input))) + checkObjectExprEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input))) } // If an input row or a field are null, a runtime exception will be thrown @@ -472,6 +472,35 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val deserializer = toMapExpr.copy(inputData = Literal.create(data)) checkObjectExprEvaluation(deserializer, expected = data) } + + test("SPARK-23595 ValidateExternalType should support interpreted execution") { + val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) + Seq( + (true, BooleanType), + (2.toByte, ByteType), + (5.toShort, ShortType), + (23, IntegerType), + (61L, LongType), + (1.0f, FloatType), + (10.0, DoubleType), + ("abcd".getBytes, BinaryType), + ("abcd", StringType), + (BigDecimal.valueOf(10), DecimalType.IntDecimal), + (CalendarInterval.fromString("interval 3 day"), CalendarIntervalType), + (java.math.BigDecimal.valueOf(10), DecimalType.BigIntDecimal), + (Array(3, 2, 1), ArrayType(IntegerType)) + ).foreach { case (input, dt) => + val validateType = ValidateExternalType( + GetExternalRowField(inputObject, index = 0, fieldName = "c0"), dt) + checkObjectExprEvaluation(validateType, input, InternalRow.fromSeq(Seq(Row(input)))) + } + + checkExceptionInExpression[RuntimeException]( + ValidateExternalType( + GetExternalRowField(inputObject, index = 0, fieldName = "c0"), DoubleType), + InternalRow.fromSeq(Seq(Row(1))), + "java.lang.Integer is not a valid external type for schema of double") + } } class TestBean extends Serializable { From 1d758dc73b54e802fdc92be204185fe7414e6553 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 20 Apr 2018 10:23:01 -0700 Subject: [PATCH 661/774] Revert "[SPARK-23775][TEST] Make DataFrameRangeSuite not flaky" This reverts commit 0c94e48bc50717e1627c0d2acd5382d9adc73c97. --- .../spark/sql/DataFrameRangeSuite.scala | 78 ++++++++----------- 1 file changed, 33 insertions(+), 45 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index a0fd74088ce8b..57a930dfaf320 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.sql -import java.util.concurrent.{CountDownLatch, TimeUnit} - import scala.concurrent.duration._ import scala.math.abs import scala.util.Random import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -154,53 +152,39 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } test("Cancelling stage in a query with Range.") { - // Save and restore the value because SparkContext is shared - val savedInterruptOnCancel = sparkContext - .getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) - - try { - sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true") - - for (codegen <- Seq(true, false)) { - // This countdown latch used to make sure with all the stages cancelStage called in listener - val latch = new CountDownLatch(2) - - val listener = new SparkListener { - override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { - sparkContext.cancelStage(taskStart.stageId) - latch.countDown() - } + val listener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + eventually(timeout(10.seconds), interval(1.millis)) { + assert(DataFrameRangeSuite.stageToKill > 0) } + sparkContext.cancelStage(DataFrameRangeSuite.stageToKill) + } + } - sparkContext.addSparkListener(listener) - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { - val ex = intercept[SparkException] { - sparkContext.range(0, 10000L, numSlices = 10).mapPartitions { x => - x.synchronized { - x.wait() - } - x - }.toDF("id").agg(sum("id")).collect() - } - ex.getCause() match { - case null => - assert(ex.getMessage().contains("cancelled")) - case cause: SparkException => - assert(cause.getMessage().contains("cancelled")) - case cause: Throwable => - fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") - } + sparkContext.addSparkListener(listener) + for (codegen <- Seq(true, false)) { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { + DataFrameRangeSuite.stageToKill = -1 + val ex = intercept[SparkException] { + spark.range(0, 100000000000L, 1, 1).map { x => + DataFrameRangeSuite.stageToKill = TaskContext.get().stageId() + x + }.toDF("id").agg(sum("id")).collect() } - latch.await(20, TimeUnit.SECONDS) - eventually(timeout(20.seconds)) { - assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + ex.getCause() match { + case null => + assert(ex.getMessage().contains("cancelled")) + case cause: SparkException => + assert(cause.getMessage().contains("cancelled")) + case cause: Throwable => + fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") } - sparkContext.removeSparkListener(listener) } - } finally { - sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, - savedInterruptOnCancel) + eventually(timeout(20.seconds)) { + assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } } + sparkContext.removeSparkListener(listener) } test("SPARK-20430 Initialize Range parameters in a driver side") { @@ -220,3 +204,7 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } } } + +object DataFrameRangeSuite { + @volatile var stageToKill = -1 +} From 32b4bcd6d31b92b179a15f9886779fc5f96404b5 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sat, 21 Apr 2018 23:14:58 +0800 Subject: [PATCH 662/774] [SPARK-24029][CORE] Set SO_REUSEADDR on listen sockets. This allows sockets to be bound even if there are sockets from a previous application that are still pending closure. It avoids bind issues when, for example, re-starting the SHS. Don't enable the option on Windows though. The following page explains some odd behavior that this option can have there: https://msdn.microsoft.com/en-us/library/windows/desktop/ms740621%28v=vs.85%29.aspx I intentionally ignored server sockets that always bind to ephemeral ports, since those don't benefit from this option. Author: Marcelo Vanzin Closes #21110 from vanzin/SPARK-24029. --- .../java/org/apache/spark/network/server/TransportServer.java | 4 +++- .../org/apache/spark/deploy/rest/RestSubmissionServer.scala | 1 + core/src/main/scala/org/apache/spark/ui/JettyUtils.scala | 1 + 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java index 0719fa7647bcc..612750972c4bb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -32,6 +32,7 @@ import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.SocketChannel; +import org.apache.commons.lang3.SystemUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -98,7 +99,8 @@ private void init(String hostToBind, int portToBind) { .group(bossGroup, workerGroup) .channel(NettyUtils.getServerChannelClass(ioMode)) .option(ChannelOption.ALLOCATOR, allocator) - .childOption(ChannelOption.ALLOCATOR, allocator); + .childOption(ChannelOption.ALLOCATOR, allocator) + .childOption(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS); this.metrics = new NettyMemoryMetrics( allocator, conf.getModuleName() + "-server", conf); diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index e88195d95f270..3d99d085408c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -94,6 +94,7 @@ private[spark] abstract class RestSubmissionServer( new HttpConnectionFactory()) connector.setHost(host) connector.setPort(startPort) + connector.setReuseAddress(!Utils.isWindows) server.addConnector(connector) val mainHandler = new ServletContextHandler diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 0e8a6307de6a8..d6a025a6f12da 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -344,6 +344,7 @@ private[spark] object JettyUtils extends Logging { connectionFactories: _*) connector.setPort(port) connector.setHost(hostName) + connector.setReuseAddress(!Utils.isWindows) // Currently we only use "SelectChannelConnector" // Limit the max acceptor number to 8 so that we don't waste a lot of threads From 7bc853d08973a6bd839ad2222911eb0a0f413677 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 21 Apr 2018 10:45:12 -0700 Subject: [PATCH 663/774] [SPARK-24033][SQL] Fix Mismatched of Window Frame specifiedwindowframe(RowFrame, -1, -1) ## What changes were proposed in this pull request? When the OffsetWindowFunction's frame is `UnaryMinus(Literal(1))` but the specified window frame has been simplified to `Literal(-1)` by some optimizer rules e.g., `ConstantFolding`. Thus, they do not match and cause the following error: ``` org.apache.spark.sql.AnalysisException: Window Frame specifiedwindowframe(RowFrame, -1, -1) must match the required frame specifiedwindowframe(RowFrame, -1, -1); at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.failAnalysis(CheckAnalysis.scala:41) at org.apache.spark.sql.catalyst.analysis.Analyzer.failAnalysis(Analyzer.scala:91) at ``` ## How was this patch tested? Added a test Author: gatorsmile Closes #21115 from gatorsmile/fixLag. --- .../catalyst/expressions/windowExpressions.scala | 5 ++++- .../spark/sql/DataFrameWindowFramesSuite.scala | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 78895f1c2f6f5..9fe2fb2b95e4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -342,7 +342,10 @@ abstract class OffsetWindowFunction override lazy val frame: WindowFrame = { val boundary = direction match { case Ascending => offset - case Descending => UnaryMinus(offset) + case Descending => UnaryMinus(offset) match { + case e: Expression if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType) + case o => o + } } SpecifiedWindowFrame(RowFrame, boundary, boundary) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala index 0ee9b0edc02b2..2a0b2b85e10a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -402,4 +402,18 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSQLContext { Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: Row(10, 6000) :: Nil) } + + test("SPARK-24033: Analysis Failure of OffsetWindowFunction") { + val ds = Seq((1, 1), (1, 2), (1, 3), (2, 1), (2, 2)).toDF("n", "i") + val res = + Row(1, 1, null) :: Row (1, 2, 1) :: Row(1, 3, 2) :: Row(2, 1, null) :: Row(2, 2, 1) :: Nil + checkAnswer( + ds.withColumn("m", + lead("i", -1).over(Window.partitionBy("n").orderBy("i").rowsBetween(-1, -1))), + res) + checkAnswer( + ds.withColumn("m", + lag("i", 1).over(Window.partitionBy("n").orderBy("i").rowsBetween(-1, -1))), + res) + } } From c48085aa91c60615a4de3b391f019f46f3fcdbe3 Mon Sep 17 00:00:00 2001 From: Mykhailo Shtelma Date: Sat, 21 Apr 2018 23:33:57 -0700 Subject: [PATCH 664/774] [SPARK-23799][SQL] FilterEstimation.evaluateInSet produces devision by zero in a case of empty table with analyzed statistics >What changes were proposed in this pull request? During evaluation of IN conditions, if the source data frame, is represented by a plan, that uses hive table with columns, which were previously analysed, and the plan has conditions for these fields, that cannot be satisfied (which leads us to an empty data frame), FilterEstimation.evaluateInSet method produces NumberFormatException and ClassCastException. In order to fix this bug, method FilterEstimation.evaluateInSet at first checks, if distinct count is not zero, and also checks if colStat.min and colStat.max are defined, and only in this case proceeds with the calculation. If at least one of the conditions is not satisfied, zero is returned. >How was this patch tested? In order to test the PR two tests were implemented: one in FilterEstimationSuite, that tests the plan with the statistics that violates the conditions mentioned above, and another one in StatisticsCollectionSuite, that test the whole process of analysis/optimisation of the query, that leads to the problems, mentioned in the first section. Author: Mykhailo Shtelma Author: smikesh Closes #21052 from mshtelma/filter_estimation_evaluateInSet_Bugs. --- .../statsEstimation/FilterEstimation.scala | 4 +++ .../FilterEstimationSuite.scala | 11 ++++++++ .../spark/sql/StatisticsCollectionSuite.scala | 28 +++++++++++++++++++ 3 files changed, 43 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 0538c9d88584b..263c9ba60d145 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -392,6 +392,10 @@ case class FilterEstimation(plan: Filter) extends Logging { val dataType = attr.dataType var newNdv = ndv + if (ndv.toDouble == 0 || colStat.min.isEmpty || colStat.max.isEmpty) { + return Some(0.0) + } + // use [min, max] to filter the original hSet dataType match { case _: NumericType | BooleanType | DateType | TimestampType => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 43440d51dede6..16cb5d032cf57 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -357,6 +357,17 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 3) } + test("evaluateInSet with all zeros") { + validateEstimatedStats( + Filter(InSet(attrString, Set(3, 4, 5)), + StatsTestPlan(Seq(attrString), 0, + AttributeMap(Seq(attrString -> + ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(0), avgLen = Some(0), maxLen = Some(0)))))), + Seq(attrString -> ColumnStat(distinctCount = Some(0))), + expectedRowCount = 0) + } + test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 14a565863d66c..b91712f4cc25d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -382,4 +382,32 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } } + + test("Simple queries must be working, if CBO is turned on") { + withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { + withTable("TBL1", "TBL") { + import org.apache.spark.sql.functions._ + val df = spark.range(1000L).select('id, + 'id * 2 as "FLD1", + 'id * 12 as "FLD2", + lit("aaa") + 'id as "fld3") + df.write + .mode(SaveMode.Overwrite) + .bucketBy(10, "id", "FLD1", "FLD2") + .sortBy("id", "FLD1", "FLD2") + .saveAsTable("TBL") + sql("ANALYZE TABLE TBL COMPUTE STATISTICS ") + sql("ANALYZE TABLE TBL COMPUTE STATISTICS FOR COLUMNS ID, FLD1, FLD2, FLD3") + val df2 = spark.sql( + """ + |SELECT t1.id, t1.fld1, t1.fld2, t1.fld3 + |FROM tbl t1 + |JOIN tbl t2 on t1.id=t2.id + |WHERE t1.fld3 IN (-123.23,321.23) + """.stripMargin) + df2.createTempView("TBL2") + sql("SELECT * FROM tbl2 WHERE fld3 IN ('qqq', 'qwe') ").queryExecution.executedPlan + } + } + } } From c3a86faa53c9e49efd595802adc38a6d412ce681 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 23 Apr 2018 10:45:25 +0800 Subject: [PATCH 665/774] [SPARK-10399][SPARK-23879][FOLLOWUP][CORE] Free unused off-heap memory in MemoryBlockSuite ## What changes were proposed in this pull request? As viirya pointed out [here](https://github.com/apache/spark/pull/19222#discussion_r179910484), this PR explicitly frees unused off-heap memory in `MemoryBlockSuite` ## How was this patch tested? Existing UTs Author: Kazuaki Ishizaki Closes #21117 from kiszk/SPARK-10399-free-offheap. --- .../java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java index 5d5fdc1c55a75..ef5ff8ee70ec0 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java @@ -120,6 +120,8 @@ private void check(MemoryBlock memory, Object obj, long offset, int length) { } catch (Exception expected) { Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); } + + memory.setPageNumber(MemoryBlock.NO_PAGE_NUMBER); } @Test @@ -165,11 +167,13 @@ public void testOffHeapArrayMemoryBlock() { int length = 56; check(memory, obj, offset, length); + memoryAllocator.free(memory); long address = Platform.allocateMemory(112); memory = new OffHeapMemoryBlock(address, length); obj = memory.getBaseObject(); offset = memory.getBaseOffset(); check(memory, obj, offset, length); + Platform.freeMemory(address); } } From f70f46d1e5bc503e9071707d837df618b7696d32 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Apr 2018 20:18:50 +0800 Subject: [PATCH 666/774] [SPARK-23877][SQL][FOLLOWUP] use PhysicalOperation to simplify the handling of Project and Filter over partitioned relation ## What changes were proposed in this pull request? A followup of https://github.com/apache/spark/pull/20988 `PhysicalOperation` can collect Project and Filters over a certain plan and substitute the alias with the original attributes in the bottom plan. We can use it in `OptimizeMetadataOnlyQuery` rule to handle the Project and Filter over partitioned relation. ## How was this patch tested? existing test Author: Wenchen Fan Closes #21111 from cloud-fan/refactor. --- .../plans/logical/LocalRelation.scala | 6 ++ .../sql/execution/LocalTableScanExec.scala | 3 + .../execution/OptimizeMetadataOnlyQuery.scala | 58 ++++++------------- .../OptimizeHiveMetadataOnlyQuerySuite.scala | 16 ++++- 4 files changed, 39 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index b05508db786ad..720d42ab409a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -43,6 +43,12 @@ object LocalRelation { } } +/** + * Logical plan node for scanning data from a local collection. + * + * @param data The local collection holding the data. It doesn't need to be sent to executors + * and then doesn't need to be serializable. + */ case class LocalRelation( output: Seq[Attribute], data: Seq[InternalRow] = Nil, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index 514ad7018d8c7..448eb703eacde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -25,6 +25,9 @@ import org.apache.spark.sql.execution.metric.SQLMetrics /** * Physical plan node for scanning data from a local collection. + * + * `Seq` may not be serializable and ideally we should not send `rows` and `unsafeRows` + * to the executors. Thus marking them as transient. */ case class LocalTableScanExec( output: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index acbd4becb8549..3ca03ab2939aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.{HiveTableRelation, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} @@ -49,9 +50,13 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic } plan.transform { - case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(_, attrs, filters, rel)) => + case a @ Aggregate(_, aggExprs, child @ PhysicalOperation( + projectList, filters, PartitionedRelation(partAttrs, rel))) => // We only apply this optimization when only partitioned attributes are scanned. - if (a.references.subsetOf(attrs)) { + if (AttributeSet((projectList ++ filters).flatMap(_.references)).subsetOf(partAttrs)) { + // The project list and filters all only refer to partition attributes, which means the + // the Aggregator operator can also only refer to partition attributes, and filters are + // all partition filters. This is a metadata only query we can optimize. val aggFunctions = aggExprs.flatMap(_.collect { case agg: AggregateExpression => agg }) @@ -102,7 +107,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic partFilters: Seq[Expression]): LogicalPlan = { // this logic comes from PruneFileSourcePartitions. it ensures that the filter names match the // relation's schema. PartitionedRelation ensures that the filters only reference partition cols - val relFilters = partFilters.map { e => + val normalizedFilters = partFilters.map { e => e transform { case a: AttributeReference => a.withName(relation.output.find(_.semanticEquals(a)).get.name) @@ -114,11 +119,8 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic relation match { case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, isStreaming) => val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) - val partitionData = fsRelation.location.listFiles(relFilters, Nil) - // partition data may be a stream, which can cause serialization to hit stack level too - // deep exceptions because it is a recursive structure in memory. converting to array - // avoids the problem. - LocalRelation(partAttrs, partitionData.map(_.values).toArray, isStreaming) + val partitionData = fsRelation.location.listFiles(normalizedFilters, Nil) + LocalRelation(partAttrs, partitionData.map(_.values), isStreaming) case relation: HiveTableRelation => val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) @@ -127,7 +129,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic val timeZoneId = caseInsensitiveProperties.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(SQLConf.get.sessionLocalTimeZone) val partitions = if (partFilters.nonEmpty) { - catalog.listPartitionsByFilter(relation.tableMeta.identifier, relFilters) + catalog.listPartitionsByFilter(relation.tableMeta.identifier, normalizedFilters) } else { catalog.listPartitions(relation.tableMeta.identifier) } @@ -137,10 +139,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic Cast(Literal(p.spec(attr.name)), attr.dataType, Option(timeZoneId)).eval() }) } - // partition data may be a stream, which can cause serialization to hit stack level too - // deep exceptions because it is a recursive structure in memory. converting to array - // avoids the problem. - LocalRelation(partAttrs, partitionData.toArray) + LocalRelation(partAttrs, partitionData) case _ => throw new IllegalStateException(s"unrecognized table scan node: $relation, " + @@ -151,44 +150,21 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic /** * A pattern that finds the partitioned table relation node inside the given plan, and returns a - * pair of the partition attributes, partition filters, and the table relation node. - * - * It keeps traversing down the given plan tree if there is a [[Project]] or [[Filter]] with - * deterministic expressions, and returns result after reaching the partitioned table relation - * node. + * pair of the partition attributes and the table relation node. */ object PartitionedRelation extends PredicateHelper { - def unapply( - plan: LogicalPlan): Option[(AttributeSet, AttributeSet, Seq[Expression], LogicalPlan)] = { + def unapply(plan: LogicalPlan): Option[(AttributeSet, LogicalPlan)] = { plan match { case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) - if fsRelation.partitionSchema.nonEmpty => + if fsRelation.partitionSchema.nonEmpty => val partAttrs = AttributeSet(getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l)) - Some((partAttrs, partAttrs, Nil, l)) + Some((partAttrs, l)) case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty => val partAttrs = AttributeSet( getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation)) - Some((partAttrs, partAttrs, Nil, relation)) - - case p @ Project(projectList, child) if projectList.forall(_.deterministic) => - unapply(child).flatMap { case (partAttrs, attrs, filters, relation) => - if (p.references.subsetOf(attrs)) { - Some((partAttrs, p.outputSet, filters, relation)) - } else { - None - } - } - - case f @ Filter(condition, child) if condition.deterministic => - unapply(child).flatMap { case (partAttrs, attrs, filters, relation) => - if (f.references.subsetOf(partAttrs)) { - Some((partAttrs, attrs, splitConjunctivePredicates(condition) ++ filters, relation)) - } else { - None - } - } + Some((partAttrs, relation)) case _ => None } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala index 95f192f0e40e2..1e525c46a9cfb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Filter, Project, SubqueryAlias} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_METADATA_ONLY import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -32,13 +33,22 @@ class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleto import spark.implicits._ - before { + override def beforeAll(): Unit = { + super.beforeAll() sql("CREATE TABLE metadata_only (id bigint, data string) PARTITIONED BY (part int)") (0 to 10).foreach(p => sql(s"ALTER TABLE metadata_only ADD PARTITION (part=$p)")) } + override protected def afterAll(): Unit = { + try { + sql("DROP TABLE IF EXISTS metadata_only") + } finally { + super.afterAll() + } + } + test("SPARK-23877: validate metadata-only query pushes filters to metastore") { - withTable("metadata_only") { + withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") { val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount // verify the number of matching partitions @@ -50,7 +60,7 @@ class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleto } test("SPARK-23877: filter on projected expression") { - withTable("metadata_only") { + withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") { val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount // verify the matching partitions From d87d30e4fe9c9e91c462351e9f744a830db8d6fc Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Apr 2018 20:21:01 +0800 Subject: [PATCH 667/774] [SPARK-23564][SQL] infer additional filters from constraints for join's children ## What changes were proposed in this pull request? The existing query constraints framework has 2 steps: 1. propagate constraints bottom up. 2. use constraints to infer additional filters for better data pruning. For step 2, it mostly helps with Join, because we can connect the constraints from children to the join condition and infer powerful filters to prune the data of the join sides. e.g., the left side has constraints `a = 1`, the join condition is `left.a = right.a`, then we can infer `right.a = 1` to the right side and prune the right side a lot. However, the current logic of inferring filters from constraints for Join is pretty weak. It infers the filters from Join's constraints. Some joins like left semi/anti exclude output from right side and the right side constraints will be lost here. This PR propose to check the left and right constraints individually, expand the constraints with join condition and add filters to children of join directly, instead of adding to the join condition. This reverts https://github.com/apache/spark/pull/20670 , covers https://github.com/apache/spark/pull/20717 and https://github.com/apache/spark/pull/20816 This is inspired by the original PRs and the tests are all from these PRs. Thanks to the authors mgaido91 maryannxue KaiXinXiaoLei ! ## How was this patch tested? new tests Author: Wenchen Fan Closes #21083 from cloud-fan/join. --- .../sql/catalyst/optimizer/Optimizer.scala | 97 +++++++++---------- .../plans/logical/QueryPlanConstraints.scala | 95 ++++++++---------- .../InferFiltersFromConstraintsSuite.scala | 53 +++++++--- 3 files changed, 124 insertions(+), 121 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 913354e4df0e6..f00d40d11f23f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -637,13 +637,11 @@ object CollapseWindow extends Rule[LogicalPlan] { * constraints. These filters are currently inserted to the existing conditions in the Filter * operators and on either side of Join operators. * - * In addition, for left/right outer joins, infer predicate from the preserved side of the Join - * operator and push the inferred filter over to the null-supplying side. For example, if the - * preserved side has constraints of the form 'a > 5' and the join condition is 'a = b', in - * which 'b' is an attribute from the null-supplying side, a [[Filter]] operator of 'b > 5' will - * be applied to the null-supplying side. + * Note: While this optimization is applicable to a lot of types of join, it primarily benefits + * Inner and LeftSemi joins. */ -object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper { +object InferFiltersFromConstraints extends Rule[LogicalPlan] + with PredicateHelper with ConstraintHelper { def apply(plan: LogicalPlan): LogicalPlan = { if (SQLConf.get.constraintPropagationEnabled) { @@ -664,53 +662,52 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe } case join @ Join(left, right, joinType, conditionOpt) => - // Only consider constraints that can be pushed down completely to either the left or the - // right child - val constraints = join.allConstraints.filter { c => - c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet) - } - // Remove those constraints that are already enforced by either the left or the right child - val additionalConstraints = constraints -- (left.constraints ++ right.constraints) - val newConditionOpt = conditionOpt match { - case Some(condition) => - val newFilters = additionalConstraints -- splitConjunctivePredicates(condition) - if (newFilters.nonEmpty) Option(And(newFilters.reduce(And), condition)) else conditionOpt - case None => - additionalConstraints.reduceOption(And) - } - // Infer filter for left/right outer joins - val newLeftOpt = joinType match { - case RightOuter if newConditionOpt.isDefined => - val inferredConstraints = left.getRelevantConstraints( - left.constraints - .union(right.constraints) - .union(splitConjunctivePredicates(newConditionOpt.get).toSet)) - val newFilters = inferredConstraints - .filterNot(left.constraints.contains) - .reduceLeftOption(And) - newFilters.map(Filter(_, left)) - case _ => None - } - val newRightOpt = joinType match { - case LeftOuter if newConditionOpt.isDefined => - val inferredConstraints = right.getRelevantConstraints( - right.constraints - .union(left.constraints) - .union(splitConjunctivePredicates(newConditionOpt.get).toSet)) - val newFilters = inferredConstraints - .filterNot(right.constraints.contains) - .reduceLeftOption(And) - newFilters.map(Filter(_, right)) - case _ => None - } + joinType match { + // For inner join, we can infer additional filters for both sides. LeftSemi is kind of an + // inner join, it just drops the right side in the final output. + case _: InnerLike | LeftSemi => + val allConstraints = getAllConstraints(left, right, conditionOpt) + val newLeft = inferNewFilter(left, allConstraints) + val newRight = inferNewFilter(right, allConstraints) + join.copy(left = newLeft, right = newRight) - if ((newConditionOpt.isDefined && (newConditionOpt ne conditionOpt)) - || newLeftOpt.isDefined || newRightOpt.isDefined) { - Join(newLeftOpt.getOrElse(left), newRightOpt.getOrElse(right), joinType, newConditionOpt) - } else { - join + // For right outer join, we can only infer additional filters for left side. + case RightOuter => + val allConstraints = getAllConstraints(left, right, conditionOpt) + val newLeft = inferNewFilter(left, allConstraints) + join.copy(left = newLeft) + + // For left join, we can only infer additional filters for right side. + case LeftOuter | LeftAnti => + val allConstraints = getAllConstraints(left, right, conditionOpt) + val newRight = inferNewFilter(right, allConstraints) + join.copy(right = newRight) + + case _ => join } } + + private def getAllConstraints( + left: LogicalPlan, + right: LogicalPlan, + conditionOpt: Option[Expression]): Set[Expression] = { + val baseConstraints = left.constraints.union(right.constraints) + .union(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil).toSet) + baseConstraints.union(inferAdditionalConstraints(baseConstraints)) + } + + private def inferNewFilter(plan: LogicalPlan, constraints: Set[Expression]): LogicalPlan = { + val newPredicates = constraints + .union(constructIsNotNullConstraints(constraints, plan.output)) + .filter { c => + c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic + } -- plan.constraints + if (newPredicates.isEmpty) { + plan + } else { + Filter(newPredicates.reduce(And), plan) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index a29f3d29236c7..cc352c59dff80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -20,29 +20,28 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ -trait QueryPlanConstraints { self: LogicalPlan => +trait QueryPlanConstraints extends ConstraintHelper { self: LogicalPlan => /** - * An [[ExpressionSet]] that contains an additional set of constraints, such as equality - * constraints and `isNotNull` constraints, etc. + * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For + * example, if this set contains the expression `a = 2` then that expression is guaranteed to + * evaluate to `true` for all rows produced. */ - lazy val allConstraints: ExpressionSet = { + lazy val constraints: ExpressionSet = { if (conf.constraintPropagationEnabled) { - ExpressionSet(validConstraints - .union(inferAdditionalConstraints(validConstraints)) - .union(constructIsNotNullConstraints(validConstraints))) + ExpressionSet( + validConstraints + .union(inferAdditionalConstraints(validConstraints)) + .union(constructIsNotNullConstraints(validConstraints, output)) + .filter { c => + c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic + } + ) } else { ExpressionSet(Set.empty) } } - /** - * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For - * example, if this set contains the expression `a = 2` then that expression is guaranteed to - * evaluate to `true` for all rows produced. - */ - lazy val constraints: ExpressionSet = ExpressionSet(allConstraints.filter(selfReferenceOnly)) - /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints * based on the given operator's constraint propagation logic. These constraints are then @@ -52,30 +51,42 @@ trait QueryPlanConstraints { self: LogicalPlan => * See [[Canonicalize]] for more details. */ protected def validConstraints: Set[Expression] = Set.empty +} + +trait ConstraintHelper { /** - * Returns an [[ExpressionSet]] that contains an additional set of constraints, such as - * equality constraints and `isNotNull` constraints, etc., and that only contains references - * to this [[LogicalPlan]] node. + * Infers an additional set of constraints from a given set of equality constraints. + * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an + * additional constraint of the form `b = 5`. */ - def getRelevantConstraints(constraints: Set[Expression]): ExpressionSet = { - val allRelevantConstraints = - if (conf.constraintPropagationEnabled) { - constraints - .union(inferAdditionalConstraints(constraints)) - .union(constructIsNotNullConstraints(constraints)) - } else { - constraints - } - ExpressionSet(allRelevantConstraints.filter(selfReferenceOnly)) + def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { + var inferredConstraints = Set.empty[Expression] + constraints.foreach { + case eq @ EqualTo(l: Attribute, r: Attribute) => + val candidateConstraints = constraints - eq + inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) + inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) + case _ => // No inference + } + inferredConstraints -- constraints } + private def replaceConstraints( + constraints: Set[Expression], + source: Expression, + destination: Attribute): Set[Expression] = constraints.map(_ transform { + case e: Expression if e.semanticEquals(source) => destination + }) + /** * Infers a set of `isNotNull` constraints from null intolerant expressions as well as * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this * returns a constraint of the form `isNotNull(a)` */ - private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { + def constructIsNotNullConstraints( + constraints: Set[Expression], + output: Seq[Attribute]): Set[Expression] = { // First, we propagate constraints from the null intolerant expressions. var isNotNullConstraints: Set[Expression] = constraints.flatMap(inferIsNotNullConstraints) @@ -111,32 +122,4 @@ trait QueryPlanConstraints { self: LogicalPlan => case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) case _ => Seq.empty[Attribute] } - - /** - * Infers an additional set of constraints from a given set of equality constraints. - * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an - * additional constraint of the form `b = 5`. - */ - private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { - var inferredConstraints = Set.empty[Expression] - constraints.foreach { - case eq @ EqualTo(l: Attribute, r: Attribute) => - val candidateConstraints = constraints - eq - inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) - inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) - case _ => // No inference - } - inferredConstraints -- constraints - } - - private def replaceConstraints( - constraints: Set[Expression], - source: Expression, - destination: Attribute): Set[Expression] = constraints.map(_ transform { - case e: Expression if e.semanticEquals(source) => destination - }) - - private def selfReferenceOnly(e: Expression): Boolean = { - e.references.nonEmpty && e.references.subsetOf(outputSet) && e.deterministic - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index e068f51044589..e4671f0d1cce6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -35,11 +35,25 @@ class InferFiltersFromConstraintsSuite extends PlanTest { InferFiltersFromConstraints, CombineFilters, SimplifyBinaryComparison, - BooleanSimplification) :: Nil + BooleanSimplification, + PruneFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + private def testConstraintsAfterJoin( + x: LogicalPlan, + y: LogicalPlan, + expectedLeft: LogicalPlan, + expectedRight: LogicalPlan, + joinType: JoinType) = { + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, joinType, condition).analyze + val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + test("filter: filter out constraints in condition") { val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze val correctAnswer = testRelation @@ -196,13 +210,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val condition = Some("x.a".attr === "y.a".attr) - val originalQuery = x.join(y, LeftSemi, condition).analyze - val left = x.where(IsNotNull('a)) - val right = y.where(IsNotNull('a)) - val correctAnswer = left.join(right, LeftSemi, condition).analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftSemi) } test("SPARK-21479: Outer join after-join filters push down to null-supplying side") { @@ -232,12 +240,27 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("SPARK-21479: Outer join no filter push down to preserved side") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val condition = Some("x.a".attr === "y.a".attr) - val originalQuery = x.join(y.where("y.a".attr === 1), LeftOuter, condition).analyze - val left = x - val right = y.where(IsNotNull('a) && 'a === 1) - val correctAnswer = left.join(right, LeftOuter, condition).analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + testConstraintsAfterJoin( + x, y.where("a".attr === 1), + x, y.where(IsNotNull('a) && 'a === 1), + LeftOuter) + } + + test("SPARK-23564: left anti join should filter out null join keys on right side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftAnti) + } + + test("SPARK-23564: left outer join should filter out null join keys on right side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftOuter) + } + + test("SPARK-23564: right outer join should filter out null join keys on left side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter) } } From afbdf427302aba858f95205ecef7667f412b2a6a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 23 Apr 2018 14:28:28 +0200 Subject: [PATCH 668/774] [SPARK-23589][SQL] ExternalMapToCatalyst should support interpreted execution ## What changes were proposed in this pull request? This pr supported interpreted mode for `ExternalMapToCatalyst`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Takeshi Yamamuro Closes #20980 from maropu/SPARK-23589. --- .../expressions/objects/objects.scala | 60 +++++++++- .../expressions/ObjectExpressionsSuite.scala | 108 +++++++++++++++++- 2 files changed, 165 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index f1ffcaec8a484..9c7e76467d153 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1255,8 +1255,64 @@ case class ExternalMapToCatalyst private( override def dataType: MapType = MapType( keyConverter.dataType, valueConverter.dataType, valueContainsNull = valueConverter.nullable) - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + private lazy val mapCatalystConverter: Any => (Array[Any], Array[Any]) = child.dataType match { + case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[java.util.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + val iter = data.entrySet().iterator() + var i = 0 + while (iter.hasNext) { + val entry = iter.next() + val (key, value) = (entry.getKey, entry.getValue) + keys(i) = if (key != null) { + keyConverter.eval(InternalRow.fromSeq(key :: Nil)) + } else { + throw new RuntimeException("Cannot use null as map key!") + } + values(i) = if (value != null) { + valueConverter.eval(InternalRow.fromSeq(value :: Nil)) + } else { + null + } + i += 1 + } + (keys, values) + } + + case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[scala.collection.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + var i = 0 + for ((key, value) <- data) { + keys(i) = if (key != null) { + keyConverter.eval(InternalRow.fromSeq(key :: Nil)) + } else { + throw new RuntimeException("Cannot use null as map key!") + } + values(i) = if (value != null) { + valueConverter.eval(InternalRow.fromSeq(value :: Nil)) + } else { + null + } + i += 1 + } + (keys, values) + } + } + + override def eval(input: InternalRow): Any = { + val result = child.eval(input) + if (result != null) { + val (keys, values) = mapCatalystConverter(result) + new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values)) + } else { + null + } + } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputMap = child.genCode(ctx) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 7136af8934486..730b36c32333c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -21,12 +21,13 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag import scala.util.Random import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.{RandomDataGenerator, Row} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer, UnresolvedDeserializer} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.encoders._ @@ -501,6 +502,111 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { InternalRow.fromSeq(Seq(Row(1))), "java.lang.Integer is not a valid external type for schema of double") } + + private def javaMapSerializerFor( + keyClazz: Class[_], + valueClazz: Class[_])(inputObject: Expression): Expression = { + + def kvSerializerFor(inputObject: Expression, clazz: Class[_]): Expression = clazz match { + case c if c == classOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case c if c == classOf[java.lang.String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil, + returnNullable = false) + } + + ExternalMapToCatalyst( + inputObject, + ObjectType(keyClazz), + kvSerializerFor(_, keyClazz), + keyNullable = true, + ObjectType(valueClazz), + kvSerializerFor(_, valueClazz), + valueNullable = true + ) + } + + private def scalaMapSerializerFor[T: TypeTag, U: TypeTag](inputObject: Expression): Expression = { + import org.apache.spark.sql.catalyst.ScalaReflection._ + + val curId = new java.util.concurrent.atomic.AtomicInteger() + + def kvSerializerFor[V: TypeTag](inputObject: Expression): Expression = + localTypeOf[V].dealias match { + case t if t <:< localTypeOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case t if t <:< localTypeOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil, + returnNullable = false) + case _ => + inputObject + } + + ExternalMapToCatalyst( + inputObject, + dataTypeFor[T], + kvSerializerFor[T], + keyNullable = !localTypeOf[T].typeSymbol.asClass.isPrimitive, + dataTypeFor[U], + kvSerializerFor[U], + valueNullable = !localTypeOf[U].typeSymbol.asClass.isPrimitive + ) + } + + test("SPARK-23589 ExternalMapToCatalyst should support interpreted execution") { + // Simple test + val scalaMap = scala.collection.Map[Int, String](0 -> "v0", 1 -> "v1", 2 -> null, 3 -> "v3") + val javaMap = new java.util.HashMap[java.lang.Integer, java.lang.String]() { + { + put(0, "v0") + put(1, "v1") + put(2, null) + put(3, "v3") + } + } + val expected = CatalystTypeConverters.convertToCatalyst(scalaMap) + + // Java Map + val serializer1 = javaMapSerializerFor(classOf[java.lang.Integer], classOf[java.lang.String])( + Literal.fromObject(javaMap)) + checkEvaluation(serializer1, expected) + + // Scala Map + val serializer2 = scalaMapSerializerFor[Int, String](Literal.fromObject(scalaMap)) + checkEvaluation(serializer2, expected) + + // NULL key test + val scalaMapHasNullKey = scala.collection.Map[java.lang.Integer, String]( + null.asInstanceOf[java.lang.Integer] -> "v0", new java.lang.Integer(1) -> "v1") + val javaMapHasNullKey = new java.util.HashMap[java.lang.Integer, java.lang.String]() { + { + put(null, "v0") + put(1, "v1") + } + } + + // Java Map + val serializer3 = + javaMapSerializerFor(classOf[java.lang.Integer], classOf[java.lang.String])( + Literal.fromObject(javaMapHasNullKey)) + checkExceptionInExpression[RuntimeException]( + serializer3, EmptyRow, "Cannot use null as map key!") + + // Scala Map + val serializer4 = scalaMapSerializerFor[java.lang.Integer, String]( + Literal.fromObject(scalaMapHasNullKey)) + + checkExceptionInExpression[RuntimeException]( + serializer4, EmptyRow, "Cannot use null as map key!") + } } class TestBean extends Serializable { From 293a0f29e314dc532cec2048a7c6bc00e31de472 Mon Sep 17 00:00:00 2001 From: Teng Peng Date: Mon, 23 Apr 2018 10:29:47 -0700 Subject: [PATCH 669/774] [Spark-24024][ML] Fix poisson deviance calculations in GLM to handle y = 0 ## What changes were proposed in this pull request? It is reported by Spark users that the deviance calculation for poisson regression does not handle y = 0. Thus, the correct model summary cannot be obtained. The user has confirmed the the issue is in ``` override def deviance(y: Double, mu: Double, weight: Double): Double = { 2.0 * weight * (y * math.log(y / mu) - (y - mu)) } when y = 0. ``` The user also mentioned there are many other places he believe we should check the same thing. However, no other changes are needed, including Gamma distribution. ## How was this patch tested? Add a comparison with R deviance calculation to the existing unit test. Author: Teng Peng Closes #21125 from tengpeng/Spark24024GLM. --- .../ml/regression/GeneralizedLinearRegression.scala | 10 +++++----- .../regression/GeneralizedLinearRegressionSuite.scala | 10 ++++++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 9f1f2405c428e..4c3f1431d5077 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -471,6 +471,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine private[regression] val epsilon: Double = 1E-16 + private[regression] def ylogy(y: Double, mu: Double): Double = { + if (y == 0) 0.0 else y * math.log(y / mu) + } + /** * Wrapper of family and link combination used in the model. */ @@ -725,10 +729,6 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def variance(mu: Double): Double = mu * (1.0 - mu) - private def ylogy(y: Double, mu: Double): Double = { - if (y == 0) 0.0 else y * math.log(y / mu) - } - override def deviance(y: Double, mu: Double, weight: Double): Double = { 2.0 * weight * (ylogy(y, mu) + ylogy(1.0 - y, 1.0 - mu)) } @@ -783,7 +783,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def variance(mu: Double): Double = mu override def deviance(y: Double, mu: Double, weight: Double): Double = { - 2.0 * weight * (y * math.log(y / mu) - (y - mu)) + 2.0 * weight * (ylogy(y, mu) - (y - mu)) } override def aic( diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index d5bcbb221783e..997c50157dcda 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -493,11 +493,20 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest } [1] -0.0457441 -0.6833928 [1] 1.8121235 -0.1747493 -0.5815417 + + R code for deivance calculation: + data = cbind(y=c(0,1,0,0,0,1), x1=c(18, 12, 15, 13, 15, 16), x2=c(1,0,0,2,1,1)) + summary(glm(y~x1+x2, family=poisson, data=data.frame(data)))$deviance + [1] 3.70055 + summary(glm(y~x1+x2-1, family=poisson, data=data.frame(data)))$deviance + [1] 3.809296 */ val expected = Seq( Vectors.dense(0.0, -0.0457441, -0.6833928), Vectors.dense(1.8121235, -0.1747493, -0.5815417)) + val residualDeviancesR = Array(3.809296, 3.70055) + import GeneralizedLinearRegression._ var idx = 0 @@ -510,6 +519,7 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " + s"$link link and fitIntercept = $fitIntercept (with zero values).") + assert(model.summary.deviance ~== residualDeviancesR(idx) absTol 1E-3) idx += 1 } } From 448d248f897fa39cfc82d71a3d6b67e6470f8a02 Mon Sep 17 00:00:00 2001 From: liuzhaokun Date: Mon, 23 Apr 2018 13:56:11 -0500 Subject: [PATCH 670/774] [SPARK-21168] KafkaRDD should always set kafka clientId. [https://issues.apache.org/jira/browse/SPARK-21168](https://issues.apache.org/jira/browse/SPARK-21168) There are no a number of other places that a client ID should be set,and I think we should use consumer.clientId in the clientId method,because the fetch request will be used by the same consumer behind. Author: liuzhaokun Closes #19887 from liu-zhaokun/master1205. --- .../main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index 5ea52b6ad36a0..791cf0efaf888 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -191,6 +191,7 @@ class KafkaRDD[ private def fetchBatch: Iterator[MessageAndOffset] = { val req = new FetchRequestBuilder() + .clientId(consumer.clientId) .addFetch(part.topic, part.partition, requestOffset, kc.config.fetchMessageMaxBytes) .build() val resp = consumer.fetch(req) From 770add81c3474e754867d7105031a5eaf27159bd Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 23 Apr 2018 13:20:32 -0700 Subject: [PATCH 671/774] [SPARK-23004][SS] Ensure StateStore.commit is called only once in a streaming aggregation task MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? A structured streaming query with a streaming aggregation can throw the following error in rare cases.  ``` java.lang.IllegalStateException: Cannot commit after already committed or aborted at org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider.org$apache$spark$sql$execution$streaming$state$HDFSBackedStateStoreProvider$$verify(HDFSBackedStateStoreProvider.scala:643) at org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider$HDFSBackedStateStore.commit(HDFSBackedStateStoreProvider.scala:135) at org.apache.spark.sql.execution.streaming.StateStoreSaveExec$$anonfun$doExecute$3$$anon$2$$anonfun$hasNext$2.apply$mcV$sp(statefulOperators.scala:359) at org.apache.spark.sql.execution.streaming.StateStoreWriter$class.timeTakenMs(statefulOperators.scala:102) at org.apache.spark.sql.execution.streaming.StateStoreSaveExec.timeTakenMs(statefulOperators.scala:251) at org.apache.spark.sql.execution.streaming.StateStoreSaveExec$$anonfun$doExecute$3$$anon$2.hasNext(statefulOperators.scala:359) at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.processInputs(ObjectAggregationIterator.scala:188) at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.(ObjectAggregationIterator.scala:78) at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:114) at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:105) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndexInternal$1$$anonfun$apply$24.apply(RDD.scala:830) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndexInternal$1$$anonfun$apply$24.apply(RDD.scala:830) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:42) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:336) ``` This can happen when the following conditions are accidentally hit.  - Streaming aggregation with aggregation function that is a subset of [`TypedImperativeAggregation`](https://github.com/apache/spark/blob/76b8b840ddc951ee6203f9cccd2c2b9671c1b5e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala#L473) (for example, `collect_set`, `collect_list`, `percentile`, etc.).  - Query running in `update}` mode - After the shuffle, a partition has exactly 128 records.  This causes StateStore.commit to be called twice. See the [JIRA](https://issues.apache.org/jira/browse/SPARK-23004) for a more detailed explanation. The solution is to use `NextIterator` or `CompletionIterator`, each of which has a flag to prevent the "onCompletion" task from being called more than once. In this PR, I chose to implement using `NextIterator`. ## How was this patch tested? Added unit test that I have confirm will fail without the fix. Author: Tathagata Das Closes #21124 from tdas/SPARK-23004. --- .../streaming/statefulOperators.scala | 40 +++++++++---------- .../streaming/StreamingAggregationSuite.scala | 25 ++++++++++++ 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index b9b07a2e688f9..c9354ac0ec78a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -340,37 +340,35 @@ case class StateStoreSaveExec( // Update and output modified rows from the StateStore. case Some(Update) => - val updatesStartTimeNs = System.nanoTime - - new Iterator[InternalRow] { - + new NextIterator[InternalRow] { // Filter late date using watermark if specified private[this] val baseIterator = watermarkPredicateForData match { case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) case None => iter } + private val updatesStartTimeNs = System.nanoTime - override def hasNext: Boolean = { - if (!baseIterator.hasNext) { - allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) - - // Remove old aggregates if watermark specified - allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } - commitTimeMs += timeTakenMs { store.commit() } - setStoreMetrics(store) - false + override protected def getNext(): InternalRow = { + if (baseIterator.hasNext) { + val row = baseIterator.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key, row) + numOutputRows += 1 + numUpdatedStateRows += 1 + row } else { - true + finished = true + null } } - override def next(): InternalRow = { - val row = baseIterator.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key, row) - numOutputRows += 1 - numUpdatedStateRows += 1 - row + override protected def close(): Unit = { + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + + // Remove old aggregates if watermark specified + allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } + commitTimeMs += timeTakenMs { store.commit() } + setStoreMetrics(store) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 1cae8cb8d47f1..382da13430781 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -536,6 +536,31 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } + test("SPARK-23004: Ensure that TypedImperativeAggregate functions do not throw errors") { + // See the JIRA SPARK-23004 for more details. In short, this test reproduces the error + // by ensuring the following. + // - A streaming query with a streaming aggregation. + // - Aggregation function 'collect_list' that is a subclass of TypedImperativeAggregate. + // - Post shuffle partition has exactly 128 records (i.e. the threshold at which + // ObjectHashAggregateExec falls back to sort-based aggregation). This is done by having a + // micro-batch with 128 records that shuffle to a single partition. + // This test throws the exact error reported in SPARK-23004 without the corresponding fix. + withSQLConf("spark.sql.shuffle.partitions" -> "1") { + val input = MemoryStream[Int] + val df = input.toDF().toDF("value") + .selectExpr("value as group", "value") + .groupBy("group") + .agg(collect_list("value")) + testStream(df, outputMode = OutputMode.Update)( + AddData(input, (1 to spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*), + AssertOnQuery { q => + q.processAllAvailable() + true + } + ) + } + } + /** Add blocks of data to the `BlockRDDBackedSource`. */ case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { override def addData(query: Option[StreamExecution]): (Source, Offset) = { From e82cb68349b785c1b35bcfb85bff3a8ec2c93fee Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 23 Apr 2018 13:23:02 -0700 Subject: [PATCH 672/774] [SPARK-11237][ML] Add pmml export for k-means in Spark ML ## What changes were proposed in this pull request? Adding PMML export to Spark ML's KMeans Model. ## How was this patch tested? New unit test for Spark ML PMML export based on the old Spark MLlib unit test. Author: Holden Karau Closes #20907 from holdenk/SPARK-11237-Add-PMML-Export-for-KMeans. --- .../org.apache.spark.ml.util.MLFormatRegister | 4 +- .../apache/spark/ml/clustering/KMeans.scala | 75 ++++++++++++------- .../ml/regression/LinearRegression.scala | 2 +- .../spark/ml/clustering/KMeansSuite.scala | 32 +++++++- 4 files changed, 83 insertions(+), 30 deletions(-) diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister index 5e5484fd8784d..f14431d50feec 100644 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister @@ -1,2 +1,4 @@ org.apache.spark.ml.regression.InternalLinearRegressionModelWriter -org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter \ No newline at end of file +org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter +org.apache.spark.ml.clustering.InternalKMeansModelWriter +org.apache.spark.ml.clustering.PMMLKMeansModelWriter \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 987a4285ebad4..1ad157a695a7d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -17,11 +17,13 @@ package org.apache.spark.ml.clustering +import scala.collection.mutable + import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.{Estimator, Model, PipelineStage} import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -30,7 +32,7 @@ import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.storage.StorageLevel @@ -103,8 +105,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, - private val parentModel: MLlibKMeansModel) - extends Model[KMeansModel] with KMeansParams with MLWritable { + private[clustering] val parentModel: MLlibKMeansModel) + extends Model[KMeansModel] with KMeansParams with GeneralMLWritable { @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { @@ -152,14 +154,14 @@ class KMeansModel private[ml] ( } /** - * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance. + * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance. * * For [[KMeansModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. * */ @Since("1.6.0") - override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) + override def write: GeneralMLWriter = new GeneralMLWriter(this) private var trainingSummary: Option[KMeansSummary] = None @@ -185,6 +187,47 @@ class KMeansModel private[ml] ( } } +/** Helper class for storing model data */ +private case class ClusterData(clusterIdx: Int, clusterCenter: Vector) + + +/** A writer for KMeans that handles the "internal" (or default) format */ +private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "internal" + override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel" + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val instance = stage.asInstanceOf[KMeansModel] + val sc = sparkSession.sparkContext + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: cluster centers + val data: Array[ClusterData] = instance.clusterCenters.zipWithIndex.map { + case (center, idx) => + ClusterData(idx, center) + } + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath) + } +} + +/** A writer for KMeans that handles the "pmml" format */ +private class PMMLKMeansModelWriter extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "pmml" + override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel" + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val instance = stage.asInstanceOf[KMeansModel] + val sc = sparkSession.sparkContext + instance.parentModel.toPMML(sc, path) + } +} + + @Since("1.6.0") object KMeansModel extends MLReadable[KMeansModel] { @@ -194,30 +237,12 @@ object KMeansModel extends MLReadable[KMeansModel] { @Since("1.6.0") override def load(path: String): KMeansModel = super.load(path) - /** Helper class for storing model data */ - private case class Data(clusterIdx: Int, clusterCenter: Vector) - /** * We store all cluster centers in a single row and use this class to store model data by * Spark 1.6 and earlier. A model can be loaded from such older data for backward compatibility. */ private case class OldData(clusterCenters: Array[OldVector]) - /** [[MLWriter]] instance for [[KMeansModel]] */ - private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { - - override protected def saveImpl(path: String): Unit = { - // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) - // Save model data: cluster centers - val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) => - Data(idx, center) - } - val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath) - } - } - private class KMeansModelReader extends MLReader[KMeansModel] { /** Checked against metadata when loading model */ @@ -232,7 +257,7 @@ object KMeansModel extends MLReadable[KMeansModel] { val dataPath = new Path(path, "data").toString val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) { - val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data] + val data: Dataset[ClusterData] = sparkSession.read.parquet(dataPath).as[ClusterData] data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML) } else { // Loads KMeansModel stored with the old format used by Spark 1.6 and earlier. diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f67d9d831f327..9cdd3a051e719 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -746,7 +746,7 @@ private class InternalLinearRegressionModelWriter /** A writer for LinearRegression that handles the "pmml" format */ private class PMMLLinearRegressionModelWriter - extends MLWriterFormat with MLFormatRegister { + extends MLWriterFormat with MLFormatRegister { override def format(): String = "pmml" diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 32830b39407ad..77c9d482d95b6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -19,17 +19,22 @@ package org.apache.spark.ml.clustering import scala.util.Random +import org.dmg.pmml.{ClusteringModel, PMML} + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, + KMeansModel => MLlibKMeansModel} +import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} private[clustering] case class TestRow(features: Vector) -class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest + with PMMLReadWriteTest { final val k = 5 @transient var dataset: Dataset[_] = _ @@ -202,6 +207,27 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, KMeansSuite.allParamSettings, checkModelData) } + + test("pmml export") { + val clusterCenters = Array( + MLlibVectors.dense(1.0, 2.0, 6.0), + MLlibVectors.dense(1.0, 3.0, 0.0), + MLlibVectors.dense(1.0, 4.0, 6.0)) + val oldKmeansModel = new MLlibKMeansModel(clusterCenters) + val kmeansModel = new KMeansModel("", oldKmeansModel) + def checkModel(pmml: PMML): Unit = { + // Check the header descripiton is what we expect + assert(pmml.getHeader.getDescription === "k-means clustering") + // check that the number of fields match the single vector size + assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size) + // This verify that there is a model attached to the pmml object and the model is a clustering + // one. It also verifies that the pmml model has the same number of clusters of the spark + // model. + val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel] + assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length) + } + testPMMLWrite(sc, kmeansModel, checkModel) + } } object KMeansSuite { From c8f3ac69d176bd10b8de1c147b6903a247943d51 Mon Sep 17 00:00:00 2001 From: wuyi Date: Mon, 23 Apr 2018 15:35:45 -0500 Subject: [PATCH 673/774] [SPARK-23888][CORE] correct the comment of hasAttemptOnHost() TaskSetManager.hasAttemptOnHost had a misleading comment. The comment said that it only checked for running tasks, but really it checked for any tasks that might have run in the past as well. This updates to line up with the implementation. Author: wuyi Closes #20998 from Ngone51/SPARK-23888. --- .../main/scala/org/apache/spark/scheduler/TaskSetManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index d958658527f6d..8a96a7692f614 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -287,7 +287,7 @@ private[spark] class TaskSetManager( None } - /** Check whether a task is currently running an attempt on a given host */ + /** Check whether a task once ran an attempt on a given host */ private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = { taskAttempts(taskIndex).exists(_.host == host) } From 428b903859c3d8873045fdcfffdebe24fc6e027f Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 24 Apr 2018 09:10:29 +0800 Subject: [PATCH 674/774] [SPARK-24029][CORE] Follow up: set SO_REUSEADDR on the server socket. "childOption" is for the remote connections, not for the server socket that actually listens for incoming connections. Author: Marcelo Vanzin Closes #21132 from vanzin/SPARK-24029.2. --- .../java/org/apache/spark/network/server/TransportServer.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java index 612750972c4bb..60f51125c07fd 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -99,8 +99,8 @@ private void init(String hostToBind, int portToBind) { .group(bossGroup, workerGroup) .channel(NettyUtils.getServerChannelClass(ioMode)) .option(ChannelOption.ALLOCATOR, allocator) - .childOption(ChannelOption.ALLOCATOR, allocator) - .childOption(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS); + .option(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS) + .childOption(ChannelOption.ALLOCATOR, allocator); this.metrics = new NettyMemoryMetrics( allocator, conf.getModuleName() + "-server", conf); From 281c1ca0dc96b0441a60c32df3d16fbb1c61e99f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 24 Apr 2018 10:11:09 +0800 Subject: [PATCH 675/774] [SPARK-23973][SQL] Remove consecutive Sorts ## What changes were proposed in this pull request? In SPARK-23375 we introduced the ability of removing `Sort` operation during query optimization if the data is already sorted. In this follow-up we remove also a `Sort` which is followed by another `Sort`: in this case the first sort is not needed and can be safely removed. The PR starts from henryr's comment: https://github.com/apache/spark/pull/20560#discussion_r180601594. So credit should be given to him. ## How was this patch tested? added UT Author: Marco Gaido Closes #21072 from mgaido91/SPARK-23973. --- .../sql/catalyst/optimizer/Optimizer.scala | 21 +++++++- .../optimizer/RemoveRedundantSortsSuite.scala | 51 ++++++++++++++++--- 2 files changed, 63 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f00d40d11f23f..45f13956a0a85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -767,12 +767,29 @@ object EliminateSorts extends Rule[LogicalPlan] { } /** - * Removes Sort operation if the child is already sorted + * Removes redundant Sort operation. This can happen: + * 1) if the child is already sorted + * 2) if there is another Sort operator separated by 0...n Project/Filter operators */ object RemoveRedundantSorts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) => child + case s @ Sort(_, _, child) => s.copy(child = recursiveRemoveSort(child)) + } + + def recursiveRemoveSort(plan: LogicalPlan): LogicalPlan = plan match { + case Sort(_, _, child) => recursiveRemoveSort(child) + case other if canEliminateSort(other) => + other.withNewChildren(other.children.map(recursiveRemoveSort)) + case _ => plan + } + + def canEliminateSort(plan: LogicalPlan): Boolean = plan match { + case p: Project => p.projectList.forall(_.deterministic) + case f: Filter => f.condition.deterministic + case _: ResolvedHint => true + case _ => false } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala index 2319ab8046e56..dae5e6f3ee3dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala @@ -17,16 +17,12 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} -import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL} class RemoveRedundantSortsSuite extends PlanTest { @@ -42,15 +38,15 @@ class RemoveRedundantSortsSuite extends PlanTest { test("remove redundant order by") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) - val unnecessaryReordered = orderedPlan.select('a).orderBy('a.asc, 'b.desc_nullsFirst) + val unnecessaryReordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst) val optimized = Optimize.execute(unnecessaryReordered.analyze) - val correctAnswer = orderedPlan.select('a).analyze + val correctAnswer = orderedPlan.limit(2).select('a).analyze comparePlans(Optimize.execute(optimized), correctAnswer) } test("do not remove sort if the order is different") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) - val reorderedDifferently = orderedPlan.select('a).orderBy('a.asc, 'b.desc) + val reorderedDifferently = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc) val optimized = Optimize.execute(reorderedDifferently.analyze) val correctAnswer = reorderedDifferently.analyze comparePlans(optimized, correctAnswer) @@ -72,6 +68,14 @@ class RemoveRedundantSortsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("different sorts are not simplified if limit is in between") { + val orderedPlan = testRelation.select('a, 'b).orderBy('b.desc).limit(Literal(10)) + .orderBy('a.asc) + val optimized = Optimize.execute(orderedPlan.analyze) + val correctAnswer = orderedPlan.analyze + comparePlans(optimized, correctAnswer) + } + test("range is already sorted") { val inputPlan = Range(1L, 1000L, 1, 10) val orderedPlan = inputPlan.orderBy('id.asc) @@ -98,4 +102,37 @@ class RemoveRedundantSortsSuite extends PlanTest { val correctAnswer = groupedAndResorted.analyze comparePlans(optimized, correctAnswer) } + + test("remove two consecutive sorts") { + val orderedTwice = testRelation.orderBy('a.asc).orderBy('b.desc) + val optimized = Optimize.execute(orderedTwice.analyze) + val correctAnswer = testRelation.orderBy('b.desc).analyze + comparePlans(optimized, correctAnswer) + } + + test("remove sorts separated by Filter/Project operators") { + val orderedTwiceWithProject = testRelation.orderBy('a.asc).select('b).orderBy('b.desc) + val optimizedWithProject = Optimize.execute(orderedTwiceWithProject.analyze) + val correctAnswerWithProject = testRelation.select('b).orderBy('b.desc).analyze + comparePlans(optimizedWithProject, correctAnswerWithProject) + + val orderedTwiceWithFilter = + testRelation.orderBy('a.asc).where('b > Literal(0)).orderBy('b.desc) + val optimizedWithFilter = Optimize.execute(orderedTwiceWithFilter.analyze) + val correctAnswerWithFilter = testRelation.where('b > Literal(0)).orderBy('b.desc).analyze + comparePlans(optimizedWithFilter, correctAnswerWithFilter) + + val orderedTwiceWithBoth = + testRelation.orderBy('a.asc).select('b).where('b > Literal(0)).orderBy('b.desc) + val optimizedWithBoth = Optimize.execute(orderedTwiceWithBoth.analyze) + val correctAnswerWithBoth = + testRelation.select('b).where('b > Literal(0)).orderBy('b.desc).analyze + comparePlans(optimizedWithBoth, correctAnswerWithBoth) + + val orderedThrice = orderedTwiceWithBoth.select(('b + 1).as('c)).orderBy('c.asc) + val optimizedThrice = Optimize.execute(orderedThrice.analyze) + val correctAnswerThrice = testRelation.select('b).where('b > Literal(0)) + .select(('b + 1).as('c)).orderBy('c.asc).analyze + comparePlans(optimizedThrice, correctAnswerThrice) + } } From c303b1b6766a3dc5961713f98f62cd7d7ac7972a Mon Sep 17 00:00:00 2001 From: seancxmao Date: Tue, 24 Apr 2018 16:16:07 +0800 Subject: [PATCH 676/774] [MINOR][DOCS] Fix comments of SQLExecution#withExecutionId ## What changes were proposed in this pull request? Fix comment. Change `BroadcastHashJoin.broadcastFuture` to `BroadcastExchangeExec.relationFuture`: https://github.com/apache/spark/blob/d28d5732ae205771f1f443b15b10e64dcffb5ff0/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala#L66 ## How was this patch tested? N/A Author: seancxmao Closes #21113 from seancxmao/SPARK-13136. --- .../scala/org/apache/spark/sql/execution/SQLExecution.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index e991da7df0bde..2c5102b1e5ee7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -88,7 +88,7 @@ object SQLExecution { /** * Wrap an action with a known executionId. When running a different action in a different * thread from the original one, this method can be used to connect the Spark jobs in this action - * with the known executionId, e.g., `BroadcastHashJoin.broadcastFuture`. + * with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`. */ def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = { val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) From 87e8a572be14381da9081365d9aa2cbf3253a32c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 24 Apr 2018 16:18:20 +0800 Subject: [PATCH 677/774] [SPARK-24054][R] Add array_position function / element_at functions ## What changes were proposed in this pull request? This PR proposes to add array_position and element_at in R side too. array_position: ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_array(df$gear, df$am, df$carb)) head(select(mutated, array_position(mutated$v1, 1))) ``` ``` array_position(v1, 1.0) 1 2 2 2 3 2 4 3 5 0 6 3 ``` element_at: ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) head(select(mutated, element_at(mutated$v1, 1))) ``` ``` element_at(v1, 1.0) 1 21.0 2 21.0 3 22.8 4 21.4 5 18.7 6 18.1 ``` ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_map(df$model, df$cyl)) head(select(mutated, element_at(mutated$v1, "Valiant"))) ``` ``` element_at(v3, Valiant) 1 NA 2 NA 3 NA 4 NA 5 NA 6 6 ``` ## How was this patch tested? Unit tests were added in `R/pkg/tests/fulltests/test_sparkSQL.R` and manually tested. Documentation was manually built and verified. Author: hyukjinkwon Closes #21130 from HyukjinKwon/sparkr_array_position_element_at. --- R/pkg/NAMESPACE | 2 ++ R/pkg/R/functions.R | 42 +++++++++++++++++++++++++-- R/pkg/R/generics.R | 8 +++++ R/pkg/tests/fulltests/test_sparkSQL.R | 13 +++++++-- 4 files changed, 61 insertions(+), 4 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 190c50ea10482..55dec177ea853 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -201,6 +201,7 @@ exportMethods("%<=>%", "approxCountDistinct", "approxQuantile", "array_contains", + "array_position", "asc", "ascii", "asin", @@ -245,6 +246,7 @@ exportMethods("%<=>%", "decode", "dense_rank", "desc", + "element_at", "encode", "endsWith", "exp", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index a527426b19674..7b3aa05074563 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -189,6 +189,11 @@ NULL #' the map or array of maps. #' \item \code{from_json}: it is the column containing the JSON string. #' } +#' @param value A value to compute on. +#' \itemize{ +#' \item \code{array_contains}: a value to be checked if contained in the column. +#' \item \code{array_position}: a value to locate in the given array. +#' } #' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains #' additional named properties to control how it is converted, accepts the same #' options as the JSON data source. @@ -201,6 +206,7 @@ NULL #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) +#' head(select(tmp, array_position(tmp$v1, 21))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) @@ -208,7 +214,8 @@ NULL #' head(select(tmp, sort_array(tmp$v1, asc = FALSE))) #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl)) #' head(select(tmp3, map_keys(tmp3$v3))) -#' head(select(tmp3, map_values(tmp3$v3)))} +#' head(select(tmp3, map_values(tmp3$v3))) +#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))} NULL #' Window functions for Column operations @@ -2975,7 +2982,6 @@ setMethod("row_number", #' \code{array_contains}: Returns null if the array is null, true if the array contains #' the value, and false otherwise. #' -#' @param value a value to be checked if contained in the column #' @rdname column_collection_functions #' @aliases array_contains array_contains,Column-method #' @note array_contains since 1.6.0 @@ -2986,6 +2992,22 @@ setMethod("array_contains", column(jc) }) +#' @details +#' \code{array_position}: Locates the position of the first occurrence of the given value +#' in the given array. Returns NA if either of the arguments are NA. +#' Note: The position is not zero based, but 1 based index. Returns 0 if the given +#' value could not be found in the array. +#' +#' @rdname column_collection_functions +#' @aliases array_position array_position,Column-method +#' @note array_position since 2.4.0 +setMethod("array_position", + signature(x = "Column", value = "ANY"), + function(x, value) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_position", x@jc, value) + column(jc) + }) + #' @details #' \code{map_keys}: Returns an unordered array containing the keys of the map. #' @@ -3012,6 +3034,22 @@ setMethod("map_values", column(jc) }) +#' @details +#' \code{element_at}: Returns element of array at given index in \code{extraction} if +#' \code{x} is array. Returns value for the given key in \code{extraction} if \code{x} is map. +#' Note: The position is not zero based, but 1 based index. +#' +#' @param extraction index to check for in array or key to check for in map +#' @rdname column_collection_functions +#' @aliases element_at element_at,Column-method +#' @note element_at since 2.4.0 +setMethod("element_at", + signature(x = "Column", extraction = "ANY"), + function(x, extraction) { + jc <- callJStatic("org.apache.spark.sql.functions", "element_at", x@jc, extraction) + column(jc) + }) + #' @details #' \code{explode}: Creates a new row for each element in the given array or map column. #' diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 974beff1a3d76..f30ac9e4295e4 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -757,6 +757,10 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) + #' @rdname column_string_functions #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -886,6 +890,10 @@ setGeneric("decode", function(x, charset) { standardGeneric("decode") }) #' @name NULL setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("element_at", function(x, extraction) { standardGeneric("element_at") }) + #' @rdname column_string_functions #' @name NULL setGeneric("encode", function(x, charset) { standardGeneric("encode") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 7105469ffc242..a384997830276 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1479,17 +1479,23 @@ test_that("column functions", { df5 <- createDataFrame(list(list(a = "010101"))) expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") - # Test array_contains() and sort_array() + # Test array_contains(), array_position(), element_at() and sort_array() df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] expect_equal(result, c(TRUE, FALSE)) + result <- collect(select(df, array_position(df[[1]], 1L)))[[1]] + expect_equal(result, c(1, 0)) + + result <- collect(select(df, element_at(df[[1]], 1L)))[[1]] + expect_equal(result, c(1, 6)) + result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]] expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) result <- collect(select(df, sort_array(df[[1]])))[[1]] expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) - # Test map_keys() and map_values() + # Test map_keys(), map_values() and element_at() df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) result <- collect(select(df, map_keys(df$map)))[[1]] expect_equal(result, list(list("x", "y"))) @@ -1497,6 +1503,9 @@ test_that("column functions", { result <- collect(select(df, map_values(df$map)))[[1]] expect_equal(result, list(list(1, 2))) + result <- collect(select(df, element_at(df$map, "y")))[[1]] + expect_equal(result, 2) + # Test that stats::lag is working expect_equal(length(lag(ldeaths, 12)), 72) From 4926a7c2f0a47b562f99dbb4f1ca17adb3192061 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 24 Apr 2018 17:52:05 +0200 Subject: [PATCH 678/774] [SPARK-23589][SQL][FOLLOW-UP] Reuse InternalRow in ExternalMapToCatalyst eval ## What changes were proposed in this pull request? This pr is a follow-up of #20980 and fixes code to reuse `InternalRow` for converting input keys/values in `ExternalMapToCatalyst` eval. ## How was this patch tested? Existing tests. Author: Takeshi Yamamuro Closes #21137 from maropu/SPARK-23589-FOLLOWUP. --- .../expressions/objects/objects.scala | 92 ++++++++++--------- 1 file changed, 50 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9c7e76467d153..f974fd81fc788 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1255,53 +1255,61 @@ case class ExternalMapToCatalyst private( override def dataType: MapType = MapType( keyConverter.dataType, valueConverter.dataType, valueContainsNull = valueConverter.nullable) - private lazy val mapCatalystConverter: Any => (Array[Any], Array[Any]) = child.dataType match { - case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => - (input: Any) => { - val data = input.asInstanceOf[java.util.Map[Any, Any]] - val keys = new Array[Any](data.size) - val values = new Array[Any](data.size) - val iter = data.entrySet().iterator() - var i = 0 - while (iter.hasNext) { - val entry = iter.next() - val (key, value) = (entry.getKey, entry.getValue) - keys(i) = if (key != null) { - keyConverter.eval(InternalRow.fromSeq(key :: Nil)) - } else { - throw new RuntimeException("Cannot use null as map key!") - } - values(i) = if (value != null) { - valueConverter.eval(InternalRow.fromSeq(value :: Nil)) - } else { - null + private lazy val mapCatalystConverter: Any => (Array[Any], Array[Any]) = { + val rowBuffer = InternalRow.fromSeq(Array[Any](1)) + def rowWrapper(data: Any): InternalRow = { + rowBuffer.update(0, data) + rowBuffer + } + + child.dataType match { + case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[java.util.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + val iter = data.entrySet().iterator() + var i = 0 + while (iter.hasNext) { + val entry = iter.next() + val (key, value) = (entry.getKey, entry.getValue) + keys(i) = if (key != null) { + keyConverter.eval(rowWrapper(key)) + } else { + throw new RuntimeException("Cannot use null as map key!") + } + values(i) = if (value != null) { + valueConverter.eval(rowWrapper(value)) + } else { + null + } + i += 1 } - i += 1 + (keys, values) } - (keys, values) - } - case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => - (input: Any) => { - val data = input.asInstanceOf[scala.collection.Map[Any, Any]] - val keys = new Array[Any](data.size) - val values = new Array[Any](data.size) - var i = 0 - for ((key, value) <- data) { - keys(i) = if (key != null) { - keyConverter.eval(InternalRow.fromSeq(key :: Nil)) - } else { - throw new RuntimeException("Cannot use null as map key!") - } - values(i) = if (value != null) { - valueConverter.eval(InternalRow.fromSeq(value :: Nil)) - } else { - null + case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[scala.collection.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + var i = 0 + for ((key, value) <- data) { + keys(i) = if (key != null) { + keyConverter.eval(rowWrapper(key)) + } else { + throw new RuntimeException("Cannot use null as map key!") + } + values(i) = if (value != null) { + valueConverter.eval(rowWrapper(value)) + } else { + null + } + i += 1 } - i += 1 + (keys, values) } - (keys, values) - } + } } override def eval(input: InternalRow): Any = { From 55c4ca88a3b093ee197a8689631be8d1fac1f10f Mon Sep 17 00:00:00 2001 From: Julien Cuquemelle Date: Tue, 24 Apr 2018 10:56:55 -0500 Subject: [PATCH 679/774] [SPARK-22683][CORE] Add a executorAllocationRatio parameter to throttle the parallelism of the dynamic allocation ## What changes were proposed in this pull request? By default, the dynamic allocation will request enough executors to maximize the parallelism according to the number of tasks to process. While this minimizes the latency of the job, with small tasks this setting can waste a lot of resources due to executor allocation overhead, as some executor might not even do any work. This setting allows to set a ratio that will be used to reduce the number of target executors w.r.t. full parallelism. The number of executors computed with this setting is still fenced by `spark.dynamicAllocation.maxExecutors` and `spark.dynamicAllocation.minExecutors` ## How was this patch tested? Units tests and runs on various actual workloads on a Yarn Cluster Author: Julien Cuquemelle Closes #19881 from jcuquemelle/AddTaskPerExecutorSlot. --- .../spark/ExecutorAllocationManager.scala | 24 +++++++++++--- .../spark/internal/config/package.scala | 4 +++ .../ExecutorAllocationManagerSuite.scala | 33 +++++++++++++++++++ docs/configuration.md | 18 ++++++++++ 4 files changed, 74 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 189d91333c045..aa363eeffffb8 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -26,7 +26,7 @@ import scala.util.control.{ControlThrowable, NonFatal} import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.{DYN_ALLOCATION_MAX_EXECUTORS, DYN_ALLOCATION_MIN_EXECUTORS} +import org.apache.spark.internal.config._ import org.apache.spark.metrics.source.Source import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMaster @@ -69,6 +69,10 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} * spark.dynamicAllocation.maxExecutors - Upper bound on the number of executors * spark.dynamicAllocation.initialExecutors - Number of executors to start with * + * spark.dynamicAllocation.executorAllocationRatio - + * This is used to reduce the parallelism of the dynamic allocation that can waste + * resources when tasks are small + * * spark.dynamicAllocation.schedulerBacklogTimeout (M) - * If there are backlogged tasks for this duration, add new executors * @@ -116,9 +120,12 @@ private[spark] class ExecutorAllocationManager( // TODO: The default value of 1 for spark.executor.cores works right now because dynamic // allocation is only supported for YARN and the default number of cores per executor in YARN is // 1, but it might need to be attained differently for different cluster managers - private val tasksPerExecutor = + private val tasksPerExecutorForFullParallelism = conf.getInt("spark.executor.cores", 1) / conf.getInt("spark.task.cpus", 1) + private val executorAllocationRatio = + conf.get(DYN_ALLOCATION_EXECUTOR_ALLOCATION_RATIO) + validateSettings() // Number of executors to add in the next round @@ -209,8 +216,13 @@ private[spark] class ExecutorAllocationManager( throw new SparkException("Dynamic allocation of executors requires the external " + "shuffle service. You may enable this through spark.shuffle.service.enabled.") } - if (tasksPerExecutor == 0) { - throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.") + if (tasksPerExecutorForFullParallelism == 0) { + throw new SparkException("spark.executor.cores must not be < spark.task.cpus.") + } + + if (executorAllocationRatio > 1.0 || executorAllocationRatio <= 0.0) { + throw new SparkException( + "spark.dynamicAllocation.executorAllocationRatio must be > 0 and <= 1.0") } } @@ -273,7 +285,9 @@ private[spark] class ExecutorAllocationManager( */ private def maxNumExecutorsNeeded(): Int = { val numRunningOrPendingTasks = listener.totalPendingTasks + listener.totalRunningTasks - (numRunningOrPendingTasks + tasksPerExecutor - 1) / tasksPerExecutor + math.ceil(numRunningOrPendingTasks * executorAllocationRatio / + tasksPerExecutorForFullParallelism) + .toInt } private def totalRunningTasks(): Int = synchronized { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 99d779fb600e8..6bb98c37b4479 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -126,6 +126,10 @@ package object config { private[spark] val DYN_ALLOCATION_MAX_EXECUTORS = ConfigBuilder("spark.dynamicAllocation.maxExecutors").intConf.createWithDefault(Int.MaxValue) + private[spark] val DYN_ALLOCATION_EXECUTOR_ALLOCATION_RATIO = + ConfigBuilder("spark.dynamicAllocation.executorAllocationRatio") + .doubleConf.createWithDefault(1.0) + private[spark] val LOCALITY_WAIT = ConfigBuilder("spark.locality.wait") .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("3s") diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 9807d1269e3d4..3cfb0a9feb32b 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -145,6 +145,39 @@ class ExecutorAllocationManagerSuite assert(numExecutorsToAdd(manager) === 1) } + def testAllocationRatio(cores: Int, divisor: Double, expected: Int): Unit = { + val conf = new SparkConf() + .setMaster("myDummyLocalExternalClusterManager") + .setAppName("test-executor-allocation-manager") + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.dynamicAllocation.testing", "true") + .set("spark.dynamicAllocation.maxExecutors", "15") + .set("spark.dynamicAllocation.minExecutors", "3") + .set("spark.dynamicAllocation.executorAllocationRatio", divisor.toString) + .set("spark.executor.cores", cores.toString) + val sc = new SparkContext(conf) + contexts += sc + var manager = sc.executorAllocationManager.get + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 20))) + for (i <- 0 to 5) { + addExecutors(manager) + } + assert(numExecutorsTarget(manager) === expected) + sc.stop() + } + + test("executionAllocationRatio is correctly handled") { + testAllocationRatio(1, 0.5, 10) + testAllocationRatio(1, 1.0/3.0, 7) + testAllocationRatio(2, 1.0/3.0, 4) + testAllocationRatio(1, 0.385, 8) + + // max/min executors capping + testAllocationRatio(1, 1.0, 15) // should be 20 but capped by max + testAllocationRatio(4, 1.0/3.0, 3) // should be 2 but elevated by min + } + + test("add executors capped by num pending tasks") { sc = createSparkContext(0, 10, 0) val manager = sc.executorAllocationManager.get diff --git a/docs/configuration.md b/docs/configuration.md index 4d4d0c58dd07d..fb02d7ea1d4ea 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1753,6 +1753,7 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.minExecutors, spark.dynamicAllocation.maxExecutors, and spark.dynamicAllocation.initialExecutors + spark.dynamicAllocation.executorAllocationRatio
    spark.dynamicAllocation.executorAllocationRatio1 + By default, the dynamic allocation will request enough executors to maximize the + parallelism according to the number of tasks to process. While this minimizes the + latency of the job, with small tasks this setting can waste a lot of resources due to + executor allocation overhead, as some executor might not even do any work. + This setting allows to set a ratio that will be used to reduce the number of + executors w.r.t. full parallelism. + Defaults to 1.0 to give maximum parallelism. + 0.5 will divide the target number of executors by 2 + The target number of executors computed by the dynamicAllocation can still be overriden + by the spark.dynamicAllocation.minExecutors and + spark.dynamicAllocation.maxExecutors settings +
    spark.dynamicAllocation.schedulerBacklogTimeout 1s